## Training


**NOTE:** The steps described in this notebook are not necessary for replicating Selenobot results. Pre-trained models, as well as completed training, testing, and validation datasets are available in a Google Cloud bucket, and instructions for downloading them can be found in the `testing.ipynb` and `training.ipynb` notebooks.

If you want to run this code, be sure to modify the `DATA_DIR` variable below to specify where the data will be stored on your machine. `DATA_DIR` is the absolute path specifying the location where the data will be stored


In [None]:
DATA_DIR = '/home/prichter/Documents/data/testing/'

In [None]:
import sys
# Add the selenobot/ subirectory to the module search path, so that the modules in this directory are visible from the notebook.
sys.path.append('../selenobot/')

In [None]:
from dataset import Dataset, get_dataloader 
from embedders import AacEmbedder, LengthEmbedder
from classifiers import Classifier, SimpleClassifier
import pandas as pd
from typing import NoReturn
import numpy as np
import pickle

## Downloading training and validation data

Training and validation data used in this project is available for download from a Google Cloud bucket, and can be accessed using the URLs below. These datasets can also be generated from scratch by following the procedure in the `setup.ipynb` notebook.

In [None]:
# Download the training data from Google Cloud.
! curl https://storage.googleapis.com/selenobot-data/train.csv -o '{DATA_DIR}train.csv'
# Download the validation data from Google Cloud.
! curl https://storage.googleapis.com/selenobot-data/val.csv -o '{DATA_DIR}val.csv'

## Instantiating `Dataset`s

Datasets are Pytorch objects which store data and provide a consistent interface for accessing that data. For the Selenobot project, we define custom Dataset objects which provide extra functionality for embedding sequence data. The custom `Dataset`s take a pandas `DataFrame` and an `Embedder` object as input (either a `LengthEmbedder` or `AacEmbedder`). The embedder will be applied to the sequences in the `DataFrame`, and the resulting vectors stored in the `embeddings` attribute of the `Dataset`. If no embedder is specified, it is assumed that the embeddings are already contained in the input DataFrame, and are extracted into the embeddings attribute (this is necessary when using PLM embeddings, as they cannot be generated *ad hoc* due to computational cost). 

For training, three `Datasets` are instantiated using the training data. Each `Dataset` uses a different embedder: one to train the AAC classifier, one to train the length classifier, and one to train the Selenobot (PLM embedding-based) classifier. For each of these `Dataset`s, a corresponding `Dataset` (with the same embedder) is created to store the validation data. 


In [None]:
# Load the validation and training data into pandas DataFrames.
train_df = pd.read_csv(f'{DATA_DIR}train.csv')
val_df = pd.read_csv(f'{DATA_DIR}val.csv')

In [4]:
len_train_dataset = Dataset(train_df, embedder=LengthEmbedder())
len_val_dataset = Dataset(val_df, embedder=LengthEmbedder())

aac_train_dataset = Dataset(train_df, embedder=AacEmbedder())
aac_val_dataset = Dataset(val_df, embedder=AacEmbedder())

sel_train_dataset = Dataset(train_df, embedder=None)
sel_val_dataset = Dataset(val_df, embedder=None)

## Instantiating `DataLoader`s

Model training is mediated by Pytorch `DataLoaders`, which are objects that facilitate batch training. In order to address the imbalance in the training data (many more full-length proteins than truncated selenoproteins), we implemented a custom Pytorch `BatchSampler`, called  `BalancedBatchSampler`, which is defined in the `dataset.py` file. `BalancedBatchSampler` is used in conjunction with the `DataLoaders` to ensure that each batch has an equal number of truncated selenoproteins and non-selenoproteins. It does so by repeatedly resampling from the training data, using the algorithm described in the `dataset.py` file. 

`DataLoaders` are created using the `get_dataloader` function, which takes a `Dataset` as input, and returns a balanced-batch `DataLoader` object with the specified batch size. For all models in this investigation, we used batches of size 1024. No other batch sizes were tested. 

In [None]:
# Store the batch size.
batch_size = 1024

In [None]:
# Instantiate DataLoaders for each training Dataset.
len_dataloader = get_dataloader(aac_train_dataset, batch_size=batch_size)
aac_dataloader = get_dataloader(aac_train_dataset, batch_size=batch_size)
sel_dataloader = get_dataloader(sel_train_dataset, batch_size=batch_size)

## Training the models

 Two types of classifiers are used in this investigation: the `Classifier` (for AAC and PLM embeddings) and the `SimpleClassifier` (for length embeddings). Each of these models is defined in the `classifiers.py` file. Both classes implement a fit method, which takes a `DataLoader` and validation `Dataset` as input, as well as other parameters such as learning rate and epochs. This method trains the model using the `DataLoader`, and returns a `Reporter` object, which stores information regarding model performance during the training process.

Each model is trained for `10` epochs, which was found to be sufficient for convergence. The learning rate was also held constant at `0.001`. Upon completion, the weights of the trained models are stored as PTH files using the PyTorch `save` method. The Reporter objects are also stored using the Python `pickle` module. 


In [8]:
# Instantiate each model with the appropriate dimensions.
len_model = SimpleClassifier(latent_dim=1)
aac_model = Classifier(latent_dim=21, hidden_dim=8)
sel_model = Classifier(latent_dim=1024, hidden_dim=512)

In [None]:
# Store variables for the learning rate and number of epochs.
lr = 0.001
epochs = 10

In [None]:
len_reporter = len_model.fit(len_dataloader, val_dataset=len_val_dataset, embedder=embedder), epochs=epochs, lr=lr)
# Save the Reporter object using pickle.
with open(f'{DATA_DIR}len_reporter.pkl', 'wb') as f:
        pickle.dump(train_reporter, f)
# Save the weights of the trained model. 
torch.save(model.state_dict(), f'{DATA_DIR}len_model_weights.pth')

In [None]:
len_reporter = len_model.fit(len_dataloader, val_dataset=len_val_dataset, embedder=embedder), epochs=epochs, lr=lr)
# Save the Reporter object using pickle.
with open(f'{DATA_DIR}len_reporter.pkl', 'wb') as f:
        pickle.dump(train_reporter, f)
# Save the weights of the trained model. 
torch.save(model.state_dict(), f'{DATA_DIR}len_model_weights.pth')

In [None]:
len_reporter = len_model.fit(len_dataloader, val_dataset=len_val_dataset, embedder=embedder), epochs=epochs, lr=lr)
# Save the Reporter object using pickle.
with open(f'{DATA_DIR}len_reporter.pkl', 'wb') as f:
        pickle.dump(train_reporter, f)
# Save the weights of the trained model. 
torch.save(model.state_dict(), f'{DATA_DIR}len_model_weights.pth')

## Evaluating performance

In [None]:

def download(filename:str, directory:str, config:ConfigParser, stream:bool=False) -> ConfigParser:
    '''Download a file from the Google Cloud bucket, writing the information to the specified directory.
    Also adds the resulting filepath to the config file.
    
    args:
        - filename: The name of the file in the Google Cloud bucket.
        - directory: The directory location where the downloaded file will be stored. 
        - config: The configuration file object.
        - stream: Whether or not to stream the downloaded file to avoid excessive RAM usage. 
    '''
    path = os.path.join(directory, filename)
    if not os.path.exists(path): # Skip the download step if file is already present. 
        t1 = perf_counter()
        response = requests.get(BUCKET + filename, stream=stream)
        path = write_file(path, response, stream=stream)
        t2 = perf_counter()
        print(f'setup.main.download: Downloaded {filename} to {directory} in {np.round(t2 - t1, 2)} seconds.')
        
    key = filename.split('.')[0] + '_path' # Create a key for the path in the config file by removing the file extension.
    config['paths'][key] = path # This should be the extracted file's path, when unzip=True.
    return config
