In [2]:
import torch
import lightning.pytorch as ptl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import boda

# Set up

## Pick modules
Pick modules to define:
1. The data, how it's preprocessed and train/val/test split
2. The model, the architecture setup, loss function, etc.
3. The graph, how the data is used to train the model (i.e. training loop)

In [3]:
data_module = boda.data.SeqDataModule
model_module= boda.model.BassetBranched
graph_module= boda.graph.CNNBasicTraining

## Dummy dataset generation for testing purposes

In [4]:
import random
import csv

random.seed(42)

# Function to generate random DNA sequence
def generate_dna_sequence(length):
    return ''.join(random.choice('ACGT') for _ in range(length))

# Function to generate fake numerical score
def generate_numerical_score():
    return random.uniform(-10, 10)

# Number of sequences in the dataset
num_sequences = 200

# Length of DNA sequences
sequence_length = 200

header = ["Sequence", "Random/Fake Score"]  # Define the header

## TRAIN
# Generating dummy dataset
dummy_train = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_train.append((sequence, score))

traintsv_file = "dummy_train.tsv"
with open(traintsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_train:
        writer.writerow([sequence, score])

## TEST
# Generating dummy dataset
dummy_test = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_test.append((sequence, score))

testtsv_file = "dummy_test.tsv"
with open(testtsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_test:
        writer.writerow([sequence, score])

## VALIDATE
# Generating dummy dataset
dummy_val = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_val.append((sequence, score))

valtsv_file = "dummy_val.tsv"
with open(valtsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_val:
        writer.writerow([sequence, score])

print(f"Dummy train with {num_sequences} sequences saved to '{traintsv_file}'.")
print(f"Dummy test with {num_sequences} sequences saved to '{testtsv_file}'.")
print(f"Dummy val with {num_sequences} sequences saved to '{valtsv_file}'.")

Dummy train with 200 sequences saved to 'dummy_train.tsv'.
Dummy test with 200 sequences saved to 'dummy_test.tsv'.
Dummy val with 200 sequences saved to 'dummy_val.tsv'.


## Initalize Data and Model
I added chr1 to test and chr2 to val to speed up this example. I also removed the reverse complmentat data augmentation.

In [5]:
data = data_module(
    train_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_train.tsv",
    test_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_test.tsv",
    val_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_val.tsv",
    right_flank = boda.common.constants.MPRA_DOWNSTREAM[:200],
    left_flank = boda.common.constants.MPRA_UPSTREAM[-200:],
    use_revcomp = True
)

model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', kl_scale=5.0
)

TypeError: __init__() got an unexpected keyword argument 'use_revcomp'

## Append Graph to Model
Augment the model class to append functions from the graph module. A downside to this structure is that you need to make sure all relevent Graph args are defined (even if None is an acceptable default). This is because the `__init__` block in the Graph class doesn't run.

In [None]:
graph_args = {
    'optimizer': 'Adam', 
    'optimizer_args': {
        'lr': 0.0033, 'betas':[0.9, 0.999], 
        'weight_decay': 3.43e-4, 'amsgrad': True
    },
    'scheduler': 'CosineAnnealingWarmRestarts', 
    'scheduler_monitor': None, 
    'scheduler_interval': 'step',
    'scheduler_args': {
        'T_0': 4096,
    }
}

model.__class__ = type(
    'BODA_module',
    (model_module,graph_module),
    graph_args
)

In [None]:
graph = graph_module(**graph_args)
graph.training_step

In [None]:
model.training_step

In [None]:
ptl.__version__

In [None]:
model(torch.randn(10,4,600))

In [None]:
boda.data.__file__

## Lightning trainer
Normally we train for more epochs, but reduced in this example

In [None]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, 
    monitor='prediction_mean_spearman', 
    mode='max'
)

stopping_callback = EarlyStopping(
    monitor='prediction_mean_spearman', 
    patience=5,
    mode='max'
)

trainer = ptl.Trainer(
    accelerator='gpu', devices=1, 
    min_epochs=5, max_epochs=20, 
    precision=16, callbacks= [
        checkpoint_callback,
        stopping_callback
    ]
)

## Train model

In [None]:
trainer.fit(model, data)

In [None]:
import tempfile
import re
import sys
import os

def set_best(my_model, callbacks):
    """
    Set the best model checkpoint for the provided model.

    This function sets the state of the provided model to the state of the best checkpoint,
    as determined by the `ModelCheckpoint` callback.

    Args:
        my_model (nn.Module): The model to be updated.
        callbacks (dict): Dictionary of callbacks, including 'model_checkpoint'.

    Returns:
        nn.Module: The updated model.
    """
    with tempfile.TemporaryDirectory() as tmpdirname:
        try:
            best_path = callbacks['model_checkpoint'].best_model_path
            get_epoch = re.search('epoch=(\d*)', best_path).group(1)
            if 'gs://' in best_path:
                subprocess.call(['gsutil','cp',best_path,tmpdirname])
                best_path = os.path.join( tmpdirname, os.path.basename(best_path) )
            print(f'Best model stashed at: {best_path}', file=sys.stderr)
            print(f'Exists: {os.path.isfile(best_path)}', file=sys.stderr)
            ckpt = torch.load( best_path )
            my_model.load_state_dict( ckpt['state_dict'] )
            print(f'Setting model from epoch: {get_epoch}', file=sys.stderr)
        except KeyError:
            print('Setting most recent model', file=sys.stderr)
    return my_model

model = set_best(model, {'model_checkpoint': checkpoint_callback})

## Test model

In [None]:
test_path = data.test_file

In [None]:
with open(fn_in,'r') as f:
    seq_tensor = torch.stack([ boda.common.utils.dna2tensor(line.split()[0]) for line in f ])