<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Before-you-run-this-notebook" data-toc-modified-id="Before-you-run-this-notebook-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Before you run this notebook</a></span></li><li><span><a href="#Setting-up-training-/-validation-data" data-toc-modified-id="Setting-up-training-/-validation-data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Setting up training / validation data</a></span></li></ul></div>

# Before you run this notebook

Make sure that you have all of the appropriate packages installed. 
See the DeepBLAST README on installation instructions.
Also if you are using a GPU, make sure that CUDA and CUDNN are correctly configured.
In the slurm environment that was used to run this notebook, the following modules were loaded.

```
module load gcc cudnn cuda
```

This notebook is expected to take >15 minutes to run.

# Setting up training / validation data

Here we will show an example of how to run deepblast on a small subset of structural alignments generated from TM-align.

In [3]:
import os
import pandas as pd
from deepblast.sim import hmm_alignments
import argparse

fname = '../data/tm_align_output_10k.tab'
cols = [
    'chain1_name', 'chain2_name', 'tmscore1', 'tmscore2', 'rmsd',
    'chain1', 'chain2', 'alignment'
]
align_df = pd.read_table(fname, header=None, sep='\s+')
align_df.columns = cols
n_alignments = align_df.shape[0]

ModuleNotFoundError: No module named 'pandas'

The simulated alignments will be split into training / testing and validation.

In [None]:
parts = n_alignments // 10
train_df = align_df.iloc[:parts * 8]
test_df = align_df.iloc[parts * 8:parts * 9]
valid_df = align_df.iloc[parts * 9:]

# save the files to disk.
if not os.path.exists('data'):
    os.mkdir('data')
    
train_df.to_csv('data/train.txt', sep='\t', index=None, header=None)
test_df.to_csv('data/test.txt', sep='\t', index=None, header=None)
valid_df.to_csv('data/valid.txt', sep='\t', index=None, header=None)

: 

We will prepare the environment to make sure that the appropriate output directories exist to store the results.

In [None]:
from deepblast.trainer import DeepBLAST
from pytorch_lightning import Trainer

output_dir = 'struct_results'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

: 

We will now create the arguments.  First we need to download the Protrans model and tokenizer.

In [None]:
from transformers import T5EncoderModel, T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")

: 

Now let's create the DeepBLAST model.  

In [None]:
model = DeepBLAST(
    train_pairs=f'{os.getcwd()}/data/train.txt',  # training data
    test_pairs=f'{os.getcwd()}/data/test.txt',    # test data
    valid_pairs=f'{os.getcwd()}/data/valid.txt',  # validation data
    output_directory=output_dir,                  # output directory storing model + diagnostics
    batch_size=10,                                # number of alignments per training batch
    num_workers=30,                               # number of cores for manipulating training data
    layers=2,                                     # number of CNN layers for blosum matrix estimation
    learning_rate=5e-5,                           # learning rate
    loss='cross_entropy',                         # type of loss function 
    lm=model,                                     # pretrained language model 
    tokenizer=tokenizer                           # tokenizer for residues
)

: 

We can now train the model.

In [None]:
trainer = Trainer(
    max_epochs=5,
    gpus=1,
    check_val_every_n_epoch=1,
    limit_train_batches=5,
    limit_val_batches=1
)

trainer.fit(model)

: 

The model diagnostics can be directly visualized in Tensorboard. Here, we show the losses, the accuracy and the alignment results.

In [None]:
!ls lightning_logs

: 

In [None]:
%load_ext tensorboard

: 

In [None]:
%tensorboard --logdir lightning_logs

: 

: 