In [1]:
import torch
import lightning.pytorch as ptl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import boda

# Set up

## Pick modules
Pick modules to define:
1. The data, how it's preprocessed and train/val/test split
2. The model, the architecture setup, loss function, etc.
3. The graph, how the data is used to train the model (i.e. training loop)

In [2]:
data_module = boda.data.SeqDataModule
model_module= boda.model.BassetBranched
graph_module= boda.graph.CNNBasicTraining

## Dummy dataset generation for testing purposes

In [3]:
import random
import csv

random.seed(42)

# Function to generate random DNA sequence
def generate_dna_sequence(length):
    return ''.join(random.choice('ACGT') for _ in range(length))

# Function to generate fake numerical score
def generate_numerical_score():
    return random.uniform(-10, 10)

# Number of sequences in the dataset
num_sequences = 200

# Length of DNA sequences
sequence_length = 200

header = ["Sequence", "Random/Fake Score"]  # Define the header

## TRAIN
# Generating dummy dataset
dummy_train = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_train.append((sequence, score))

traintsv_file = "dummy_train.tsv"
with open(traintsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_train:
        writer.writerow([sequence, score])

## TEST
# Generating dummy dataset
dummy_test = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_test.append((sequence, score))

testtsv_file = "dummy_test.tsv"
with open(testtsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_test:
        writer.writerow([sequence, score])

## VALIDATE
# Generating dummy dataset
dummy_val = []
for _ in range(num_sequences):
    sequence = generate_dna_sequence(sequence_length)
    score = generate_numerical_score()
    dummy_val.append((sequence, score))

valtsv_file = "dummy_val.tsv"
with open(valtsv_file, 'w', newline='') as file:
    writer = csv.writer(file, delimiter='\t')
    writer.writerow(header)  # Write the header row
    for sequence, score in dummy_val:
        writer.writerow([sequence, score])

print(f"Dummy train with {num_sequences} sequences saved to '{traintsv_file}'.")
print(f"Dummy test with {num_sequences} sequences saved to '{testtsv_file}'.")
print(f"Dummy val with {num_sequences} sequences saved to '{valtsv_file}'.")

Dummy train with 200 sequences saved to 'dummy_train.tsv'.
Dummy test with 200 sequences saved to 'dummy_test.tsv'.
Dummy val with 200 sequences saved to 'dummy_val.tsv'.


## Initalize Data and Model
I added chr1 to test and chr2 to val to speed up this example. I also removed the reverse complmentat data augmentation.

In [4]:
data = data_module(
    train_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_train.tsv",
    test_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_test.tsv",
    val_file = "/home/ubuntu/boda2/analysis/AR001__rotation/dummy_val.tsv",
    right_flank = boda.common.constants.MPRA_DOWNSTREAM[:200],
    left_flank = boda.common.constants.MPRA_UPSTREAM[-200:],
    use_revcomp = True,
    skip_header=True
)

model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', kl_scale=5.0
)

## Append Graph to Model
Augment the model class to append functions from the graph module. A downside to this structure is that you need to make sure all relevent Graph args are defined (even if None is an acceptable default). This is because the `__init__` block in the Graph class doesn't run.

In [5]:
graph_args = {
    'optimizer': 'Adam', 
    'optimizer_args': {
        'lr': 0.0033, 'betas':[0.9, 0.999], 
        'weight_decay': 3.43e-4, 'amsgrad': True
    },
    'scheduler': 'CosineAnnealingWarmRestarts', 
    'scheduler_monitor': None, 
    'scheduler_interval': 'step',
    'scheduler_args': {
        'T_0': 4096,
    }
}

model.__class__ = type(
    'BODA_module',
    (model_module,graph_module),
    graph_args
)

In [6]:
graph = graph_module(**graph_args)
graph.training_step

<bound method CNNBasicTraining.training_step of CNNBasicTraining()>

In [7]:
model.training_step

<bound method CNNBasicTraining.training_step of BODA_module(
  (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
  (conv1): Conv1dNorm(
    (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
    (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
  (conv2): Conv1dNorm(
    (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
  (conv3): Conv1dNorm(
    (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
  (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (

In [8]:
ptl.__version__

'1.9.5'

In [9]:
model(torch.randn(10,4,600))

tensor([[-0.0712,  0.0120],
        [-0.0711,  0.0116],
        [-0.0715,  0.0123],
        [-0.0714,  0.0116],
        [-0.0713,  0.0117],
        [-0.0713,  0.0120],
        [-0.0711,  0.0118],
        [-0.0710,  0.0117],
        [-0.0713,  0.0117],
        [-0.0716,  0.0119]], grad_fn=<PermuteBackward0>)

In [10]:
boda.data.__file__

'/home/ubuntu/boda2/boda/data/__init__.py'

## Lightning trainer
Normally we train for more epochs, but reduced in this example

In [11]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, 
    monitor='prediction_mean_spearman', 
    mode='max'
)

stopping_callback = EarlyStopping(
    monitor='prediction_mean_spearman', 
    patience=5,
    mode='max'
)

trainer = ptl.Trainer(
    accelerator='gpu', devices=1, 
    min_epochs=5, max_epochs=20, 
    precision=16, callbacks= [
        checkpoint_callback,
        stopping_callback
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Train model

In [12]:
trainer.fit(model, data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type           | Params
----------------------------------------------
0  | pad1      | ConstantPad1d  | 0     
1  | conv1     | Conv1dNorm     | 23.7 K
2  | pad2      | ConstantPad1d  | 0     
3  | conv2     | Conv1dNorm     | 660 K 
4  | pad3      | ConstantPad1d  | 0     
5  | conv3     | Conv1dNorm     | 280 K 
6  | pad4      | ConstantPad1d  | 0     
7  | maxpool_3 | MaxPool1d      | 0     
8  | maxpool_4 | MaxPool1d      | 0     
9  | linear1   | LinearNorm     | 2.6 M 
10 | branched  | BranchedLinear | 359 K 
11 | output    | GroupedLinear  | 282   
12 | nonlin    | ReLU           | 0     
13 | dropout   | Dropout        | 0     
14 | criterion | L1KLmixed      | 0     
----------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
7.855     Total estimated model params size (MB)


Found 3927422 parameters


Sanity Checking: 0it [00:00, ?it/s]




---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 9.63547 | harmonic_mean_loss: 45.93113 | prediction_mean_spearman: 0.29549 | entropy_spearman: 0.55940 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



  return F.l1_loss(input, target, reduction=self.reduction)
  "reduction: 'mean' divides the total loss by both the batch size and the support size."
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 8.53853 | harmonic_mean_loss: 34.11116 | prediction_mean_spearman: -0.02668 | entropy_spearman: 0.01304 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 1.00000 | arithmetic_mean_loss: 8.51866 | harmonic_mean_loss: 33.75887 | prediction_mean_spearman: 0.00712 | entropy_spearman: -0.00428 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 2.00000 | arithmetic_mean_loss: 8.52140 | harmonic_mean_loss: 33.79566 | prediction_mean_spearman: 0.01159 | entropy_spearman: 0.02561 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 3.00000 | arithmetic_mean_loss: 8.52013 | harmonic_mean_loss: 33.81696 | prediction_mean_spearman: -0.03085 | entropy_spearman: 0.08156 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


-----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 4.00000 | arithmetic_mean_loss: 8.52349 | harmonic_mean_loss: 33.87461 | prediction_mean_spearman: -0.00494 | entropy_spearman: -0.04054 |
-----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 5.00000 | arithmetic_mean_loss: 8.52471 | harmonic_mean_loss: 33.89266 | prediction_mean_spearman: -0.00451 | entropy_spearman: 0.05743 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 6.00000 | arithmetic_mean_loss: 8.52635 | harmonic_mean_loss: 33.91988 | prediction_mean_spearman: -0.00711 | entropy_spearman: 0.01276 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


----------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 7.00000 | arithmetic_mean_loss: 8.52860 | harmonic_mean_loss: 33.95908 | prediction_mean_spearman: -0.01607 | entropy_spearman: 0.20422 |
----------------------------------------------------------------------------------------------------------------------------------------------------------



In [13]:
import tempfile
import re
import sys
import os

def set_best(my_model, callbacks):
    """
    Set the best model checkpoint for the provided model.

    This function sets the state of the provided model to the state of the best checkpoint,
    as determined by the `ModelCheckpoint` callback.

    Args:
        my_model (nn.Module): The model to be updated.
        callbacks (dict): Dictionary of callbacks, including 'model_checkpoint'.

    Returns:
        nn.Module: The updated model.
    """
    with tempfile.TemporaryDirectory() as tmpdirname:
        try:
            best_path = callbacks['model_checkpoint'].best_model_path
            get_epoch = re.search('epoch=(\d*)', best_path).group(1)
            if 'gs://' in best_path:
                subprocess.call(['gsutil','cp',best_path,tmpdirname])
                best_path = os.path.join( tmpdirname, os.path.basename(best_path) )
            print(f'Best model stashed at: {best_path}', file=sys.stderr)
            print(f'Exists: {os.path.isfile(best_path)}', file=sys.stderr)
            ckpt = torch.load( best_path )
            my_model.load_state_dict( ckpt['state_dict'] )
            print(f'Setting model from epoch: {get_epoch}', file=sys.stderr)
        except KeyError:
            print('Setting most recent model', file=sys.stderr)
    return my_model

model = set_best(model, {'model_checkpoint': checkpoint_callback})

Best model stashed at: /home/ubuntu/boda2/analysis/AR001__rotation/lightning_logs/version_5/checkpoints/epoch=2-step=120.ckpt
Exists: True
Setting model from epoch: 2


## Test model

In [14]:
test_path = data.test_file

In [15]:
with open(test_path,'r') as f:
    f.readline()
    seq_tensor = torch.stack([ boda.common.utils.dna2tensor(line.split()[0]) for line in f ])

In [16]:
seq_tensor

tensor([[[0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 1., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]],

        [[0., 0., 1.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 1., 0.],
         [0., 1., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.]],

        ...,

        [[1., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 1., 0., 0.]],

        [[1., 0., 0.,  ..., 1., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 1.,  ..., 0., 0