<a href="https://colab.research.google.com/github/captaincapsaicin/slc22a5/blob/main/slc22a5_train_potts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install mogwai

In [1]:
# takes ~3 mins

!pip install -q git+https://github.com/nickbhat/mogwai.git@stage-lightning-exception

[K     |████████████████████████████████| 2.3 MB 26.8 MB/s 
[K     |████████████████████████████████| 584 kB 54.3 MB/s 
[K     |████████████████████████████████| 31.8 MB 1.2 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 596 kB 50.9 MB/s 
[K     |████████████████████████████████| 408 kB 53.2 MB/s 
[K     |████████████████████████████████| 136 kB 43.5 MB/s 
[K     |████████████████████████████████| 1.1 MB 61.9 MB/s 
[K     |████████████████████████████████| 271 kB 61.8 MB/s 
[K     |████████████████████████████████| 144 kB 47.8 MB/s 
[K     |████████████████████████████████| 94 kB 3.4 MB/s 
[?25h  Building wheel for mogwai-protein (setup.py) ... [?25l[?25hdone
  Building wheel for biotite (PEP 517) ... [?25l[?25hdone


In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from tqdm import tqdm


# Download alignments and wildtype

In [3]:
# wildtype
!wget http://s3.amazonaws.com/songlabdata/slc22a5/O76082.fasta

# alignments
!wget http://s3.amazonaws.com/songlabdata/slc22a5/alignments/mammals_30.a3m
!wget http://s3.amazonaws.com/songlabdata/slc22a5/alignments/hhblits.a3m
!wget http://s3.amazonaws.com/songlabdata/slc22a5/alignments/vertebrates_100.a3m
!wget http://s3.amazonaws.com/songlabdata/slc22a5/alignments/eve.a3m
!wget http://s3.amazonaws.com/songlabdata/slc22a5/alignments/deepsequence.a3m

--2022-05-04 15:59:35--  http://s3.amazonaws.com/songlabdata/slc22a5/O76082.fasta
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.65.70
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.65.70|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 669 [binary/octet-stream]
Saving to: ‘O76082.fasta’


2022-05-04 15:59:35 (67.6 MB/s) - ‘O76082.fasta’ saved [669/669]

--2022-05-04 15:59:35--  http://s3.amazonaws.com/songlabdata/slc22a5/alignments/mammals_30.a3m
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.65.70
Connecting to s3.amazonaws.com (s3.amazonaws.com)|52.217.65.70|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18796 (18K) [binary/octet-stream]
Saving to: ‘mammals_30.a3m’


2022-05-04 15:59:36 (3.89 MB/s) - ‘mammals_30.a3m’ saved [18796/18796]

--2022-05-04 15:59:36--  http://s3.amazonaws.com/songlabdata/slc22a5/alignments/hhblits.a3m
Resolving s3.amazonaws.com (s3.amazonaws.com)... 52.217.65.70
Connecting to s3.

# Train model

In [4]:
import mogwai.models as models
import mogwai.data_loading as data_loading
import os

def torch_to_numpy(state_dict, keys):
    """Convert `keys` in the `state_dict` to numpy arrays."""
    numpy_dict = dict()
    for key in keys:
        numpy_dict[key] = state_dict[key].numpy()
    return numpy_dict


def get_outfile_path(alignment_path):
    prefix, suffix = alignment_path.rsplit('.', 1)
    return os.path.join(prefix + '_potts_state_dict.npz')

In [5]:
def train_potts(alignment_path, max_steps, gpus):
    print('Reading {}'.format(alignment_path))
    # Load msa
    msa_dm = data_loading.MSADataModule(alignment_path, batch_size=4096)
    msa_dm.setup()

    # Initialize model
    num_seqs, msa_length, msa_counts = msa_dm.get_stats()
    model = models.Gremlin(num_seqs, msa_length, msa_counts)

    # Initialize Trainer
    trainer = pl.Trainer(min_steps=50, max_steps=max_steps, gpus=gpus)

    # Train model
    trainer.fit(model, msa_dm)

    # save model

    model_dict = torch_to_numpy(model.state_dict(), ['weight', 'bias'])

    outfile = get_outfile_path(alignment_path)

    np.savez(outfile, **model_dict)
    print('Saved model state dict to {}'.format(outfile))

In [None]:
# this takes about 10 mins.
alignment_path = 'deepsequence.a3m'
max_steps = 500
gpus = None  # set this to None if connected to CPU runtime or [0] if connected to GPU runtime
train_potts(alignment_path=alignment_path, max_steps=max_steps, gpus=gpus)

Reading deepsequence.a3m


Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 232, in _feed
    close()
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 232, in _feed
    close()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
  File "/usr/lib/python3.7/multiprocessing/connection.py", li

Training: 0it [00:00, ?it/s]