In [1]:
import torch
import lightning.pytorch as ptl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import boda

# Set up

## Get data

In [2]:
!gsutil cp gs://tewhey-public-data/CODA_resources/Table_S2__MPRA_dataset.txt ./

Copying gs://tewhey-public-data/CODA_resources/Table_S2__MPRA_dataset.txt...
| [1 files][267.2 MiB/267.2 MiB]                                                
Operation completed over 1 objects/267.2 MiB.                                    


## Pick modules
Pick modules to define:
1. The data, how it's preprocessed and train/val/test split
2. The model, the architecture setup, loss function, etc.
3. The graph, how the data is used to train the model (i.e. training loop)

In [3]:
data_module = boda.data.MPRA_DataModule
model_module= boda.model.BassetBranched
graph_module= boda.graph.CNNBasicTraining

## Initalize Data
I added chr1 to test and chr2 to val to speed up this example. I also removed the reverse complementation data augmentation.

In [4]:
data = data_module(
    datafile_path='Table_S2__MPRA_dataset.txt', 
    sep='\t', sequence_column='sequence',
    synth_val_pct=0.0, synth_test_pct=99.98,
    val_chrs=['2','19','21','X'], test_chrs=['1','7','13'], 
    activity_columns=['HepG2_log2FC', 'SKNSH_log2FC'],
    batch_size=1024, padded_seq_len=600, 
    use_reverse_complements=False, 
    duplication_cutoff=2.0, 
    num_workers=8
)


## Initialize Model

In [5]:
model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', loss_args={'beta':5.0}
)

## Append Graph to Model
Augment the model class to append functions from the graph module. A downside to this structure is that you need to make sure all relevent Graph args are defined (even if None is an acceptable default). This is because the `__init__` block in the Graph class doesn't run.

In [6]:
graph_args = {
    'optimizer': 'Adam', 
    'optimizer_args': {
        'lr': 0.0033, 'betas':[0.9, 0.999], 
        'weight_decay': 3.43e-4, 'amsgrad': True
    },
    'scheduler': 'CosineAnnealingWarmRestarts', 
    'scheduler_monitor': None, 
    'scheduler_interval': 'step',
    'scheduler_args': {
        'T_0': 4096,
    }
}

graph = graph_module(
    model = model,
    **graph_args
)

In [7]:
model(torch.randn(10,4,600))

tensor([[-0.0587, -0.0574],
        [-0.0590, -0.0574],
        [-0.0588, -0.0574],
        [-0.0593, -0.0574],
        [-0.0588, -0.0573],
        [-0.0589, -0.0571],
        [-0.0589, -0.0573],
        [-0.0588, -0.0573],
        [-0.0589, -0.0575],
        [-0.0588, -0.0572]], grad_fn=<PermuteBackward0>)

## Lightning trainer
Normally we train for more epochs, but reduced in this example. Update `min_epochs` and `max_epochs` accordingly.

In [8]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, 
    monitor='prediction_mean_spearman', 
    mode='max'
)

stopping_callback = EarlyStopping(
    monitor='prediction_mean_spearman', 
    patience=5,
    mode='max'
)

trainer = ptl.Trainer(
    accelerator='gpu', devices=1, 
    min_epochs=2, max_epochs=5, # <- we use min_epochs=60, max_epochs=200
    precision=16, callbacks= [
        checkpoint_callback,
        stopping_callback
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Train model

In [9]:
trainer.fit(graph, data)

--------------------------------------------------

HepG2_log2FC | top cut value: 10.87, bottom cut value: -5.92
SKNSH_log2FC | top cut value: 11.55, bottom cut value: -6.7

Number of examples discarded from top: 2
Number of examples discarded from bottom: 7

Number of examples available: 783969

--------------------------------------------------

Padding sequences... 

Creating train/val/test datasets with tokenized sequences... 

--------------------------------------------------

Number of examples in train: 576531 (73.54%)
Number of examples in val:   121889 (15.55%)
Number of examples in test:  138383 (17.65%)

Excluded from train: -52834 (-6.74)%
--------------------------------------------------


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type           | Params
---------------------------------------------
0 | model     | BassetBranched | 3.9 M 
1 | criterion | L1KLmixed      | 0     
---------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
7.855     Total estimated model params size (MB)


Found 3927422 parameters


Sanity Checking: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.14104 | harmonic_mean_loss: 1.39642 | prediction_mean_spearman: 0.01361 | entropy_spearman: 0.01739 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



  "reduction: 'mean' divides the total loss by both the batch size and the support size."
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.23725 | harmonic_mean_loss: 2.55652 | prediction_mean_spearman: 0.39923 | entropy_spearman: 0.15776 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 1.00000 | arithmetic_mean_loss: 0.12850 | harmonic_mean_loss: 1.01111 | prediction_mean_spearman: 0.60748 | entropy_spearman: 0.07900 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 2.00000 | arithmetic_mean_loss: 0.11812 | harmonic_mean_loss: 0.78516 | prediction_mean_spearman: 0.65084 | entropy_spearman: 0.17695 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 3.00000 | arithmetic_mean_loss: 0.10964 | harmonic_mean_loss: 0.66550 | prediction_mean_spearman: 0.67931 | entropy_spearman: 0.23140 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 4.00000 | arithmetic_mean_loss: 0.10127 | harmonic_mean_loss: 0.62077 | prediction_mean_spearman: 0.69835 | entropy_spearman: 0.17335 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



`Trainer.fit` stopped: `max_epochs=5` reached.


## Reload best epoch and save

In [10]:
import tempfile
import re
import sys
import os

def set_best(my_model, callbacks):
    """
    Set the best model checkpoint for the provided model.

    This function sets the state of the provided model to the state of the best checkpoint,
    as determined by the `ModelCheckpoint` callback.

    Args:
        my_model (nn.Module): The model to be updated.
        callbacks (dict): Dictionary of callbacks, including 'model_checkpoint'.

    Returns:
        nn.Module: The updated model.
    """
    with tempfile.TemporaryDirectory() as tmpdirname:
        try:
            best_path = callbacks['model_checkpoint'].best_model_path
            get_epoch = re.search('epoch=(\d*)', best_path).group(1)
            if 'gs://' in best_path:
                subprocess.call(['gsutil','cp',best_path,tmpdirname])
                best_path = os.path.join( tmpdirname, os.path.basename(best_path) )
            print(f'Best model stashed at: {best_path}', file=sys.stderr)
            print(f'Exists: {os.path.isfile(best_path)}', file=sys.stderr)
            ckpt = torch.load( best_path )
            my_model.load_state_dict( ckpt['state_dict'] )
            print(f'Setting model from epoch: {get_epoch}', file=sys.stderr)
        except KeyError:
            print('Setting most recent model', file=sys.stderr)
    return my_model

graph = set_best(graph, {'model_checkpoint': checkpoint_callback})

Best model stashed at: /home/ubuntu/boda2/tutorials/lightning_logs/version_13/checkpoints/epoch=4-step=2820.ckpt
Exists: True
Setting model from epoch: 4


In [11]:
torch.save(graph.model.state_dict(), 'example_new_model.pt')

## load the save

In [12]:
new_model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', loss_args={'beta':5.0}
)

new_model.load_state_dict(torch.load('example_new_model.pt'))
new_model.eval()
new_model.cuda()

BassetBranched(
  (pad1): ConstantPad1d(padding=(9, 9), value=0.0)
  (conv1): Conv1dNorm(
    (conv): Conv1d(4, 300, kernel_size=(19,), stride=(1,))
    (bn_layer): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad2): ConstantPad1d(padding=(5, 5), value=0.0)
  (conv2): Conv1dNorm(
    (conv): Conv1d(300, 200, kernel_size=(11,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad3): ConstantPad1d(padding=(3, 3), value=0.0)
  (conv3): Conv1dNorm(
    (conv): Conv1d(200, 200, kernel_size=(7,), stride=(1,))
    (bn_layer): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pad4): ConstantPad1d(padding=(1, 1), value=0.0)
  (maxpool_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (maxpool_4): MaxPool1d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)
  (linear1): LinearNorm(
    (linear): Linear(in

In [13]:
new_model( torch.randn(5,4,600).cuda() )

tensor([[30.3106, 58.3631],
        [32.2132, 56.4480],
        [30.0093, 56.7785],
        [26.7668, 45.2139],
        [25.1147, 43.6605]], device='cuda:0', grad_fn=<PermuteBackward0>)