<a href="https://colab.research.google.com/github/songlab-cal/slc22a5/blob/main/slc22a5_train_potts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install mogwai

In [1]:
# takes ~3 mins

!pip install -q git+https://github.com/nickbhat/mogwai.git@stage-lightning-exception

[K     |████████████████████████████████| 2.3 MB 6.2 MB/s 
[K     |████████████████████████████████| 584 kB 35.0 MB/s 
[K     |████████████████████████████████| 31.8 MB 28 kB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 408 kB 42.2 MB/s 
[K     |████████████████████████████████| 136 kB 42.5 MB/s 
[K     |████████████████████████████████| 596 kB 45.8 MB/s 
[K     |████████████████████████████████| 1.1 MB 41.5 MB/s 
[K     |████████████████████████████████| 94 kB 3.3 MB/s 
[K     |████████████████████████████████| 271 kB 43.4 MB/s 
[K     |████████████████████████████████| 144 kB 47.1 MB/s 
[?25h  Building wheel for mogwai-protein (setup.py) ... [?25l[?25hdone
  Building wheel for biotite (PEP 517) ... [?25l[?25hdone


In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from tqdm import tqdm


# Download alignments and wildtype

In [3]:
# wildtype
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/O76082.fasta

# alignments
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/alignments/mammals_30.a3m
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/alignments/hhblits.a3m
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/alignments/vertebrates_100.a3m
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/alignments/eve.a3m
!wget -q http://s3.amazonaws.com/songlabdata/slc22a5/alignments/deepsequence.a3m

# Train model

In [4]:
import mogwai.models as models
import mogwai.data_loading as data_loading
import os

def torch_to_numpy(state_dict, keys):
    """Convert `keys` in the `state_dict` to numpy arrays."""
    numpy_dict = dict()
    for key in keys:
        numpy_dict[key] = state_dict[key].numpy()
    return numpy_dict


def get_outfile_path(alignment_path):
    prefix, suffix = alignment_path.rsplit('.', 1)
    return os.path.join(prefix + '_potts_state_dict.npz')

In [5]:
def train_potts(alignment_path, max_steps, gpus):
    print('Reading {}'.format(alignment_path))
    # Load msa
    msa_dm = data_loading.MSADataModule(alignment_path, batch_size=4096)
    msa_dm.setup()

    # Initialize model
    num_seqs, msa_length, msa_counts = msa_dm.get_stats()
    model = models.Gremlin(num_seqs, msa_length, msa_counts)

    # Initialize Trainer
    trainer = pl.Trainer(min_steps=50, max_steps=max_steps, gpus=gpus)

    # Train model
    trainer.fit(model, msa_dm)

    # save model

    model_dict = torch_to_numpy(model.state_dict(), ['weight', 'bias'])

    outfile = get_outfile_path(alignment_path)

    np.savez(outfile, **model_dict)
    print('Saved model state dict to {}'.format(outfile))

In [7]:
# this takes about 10 mins.
alignment_path = 'deepsequence.a3m'
max_steps = 500  # 500
gpus = [0]  # CPU training is extremely slow, not recommended.
train_potts(alignment_path=alignment_path, max_steps=max_steps, gpus=gpus)

Reading deepsequence.a3m


  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
496.443   Total estimated model params size (MB)
  cpuset_checked))


Training: 0it [00:00, ?it/s]

Saved model state dict to deepsequence_potts_state_dict.npz
