In [1]:
import torch
import lightning.pytorch as ptl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import boda

# Set up

## Pick modules
Pick modules to define:
1. The data, how it's preprocessed and train/val/test split
2. The model, the architecture setup, loss function, etc.
3. The graph, how the data is used to train the model (i.e. training loop)

In [2]:
data_module = boda.data.SeqDataModule
model_module= boda.model.BassetBranched
graph_module= boda.graph.CNNBasicTraining

AttributeError: module 'boda.data' has no attribute 'SeqDataModule'

## Initalize Data and Model
I added chr1 to test and chr2 to val to speed up this example. I also removed the reverse complmentat data augmentation.

In [3]:
data = data_module(
    datafile_path='MPRA_ALL_HD_v2.txt', 
    synth_val_pct=0.0, synth_test_pct=99.98,
    val_chrs=['2','19','21','X'], test_chrs=['1','7','13'], 
    activity_columns=['HepG2_mean', 'SKNSH_mean'],
    batch_size=1024, padded_seq_len=600, 
    use_reverse_complements=False, 
    duplication_cutoff=2.0, 
    num_workers=8
)

model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', kl_scale=5.0
)

## Append Graph to Model
Augment the model class to append functions from the graph module. A downside to this structure is that you need to make sure all relevent Graph args are defined (even if None is an acceptable default). This is because the `__init__` block in the Graph class doesn't run.

In [4]:
graph_args = {
    'optimizer': 'Adam', 
    'optimizer_args': {
        'lr': 0.0033, 'betas':[0.9, 0.999], 
        'weight_decay': 3.43e-4, 'amsgrad': True
    },
    'scheduler': 'CosineAnnealingWarmRestarts', 
    'scheduler_monitor': None, 
    'scheduler_interval': 'step',
    'scheduler_args': {
        'T_0': 4096,
    }
}

model.__class__ = type(
    'BODA_module',
    (model_module,graph_module),
    graph_args
)

In [5]:
model(torch.randn(10,4,600))

tensor([[0.0464, 0.0279],
        [0.0462, 0.0282],
        [0.0464, 0.0282],
        [0.0460, 0.0285],
        [0.0468, 0.0284],
        [0.0457, 0.0284],
        [0.0461, 0.0283],
        [0.0462, 0.0284],
        [0.0463, 0.0283],
        [0.0461, 0.0279]], grad_fn=<PermuteBackward0>)

## Lightning trainer
Normally we train for more epochs, but reduced in this example

In [6]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, 
    monitor='prediction_mean_spearman', 
    mode='max'
)

stopping_callback = EarlyStopping(
    monitor='prediction_mean_spearman', 
    patience=5,
    mode='max'
)

trainer = ptl.Trainer(
    accelerator='gpu', devices=1, 
    min_epochs=5, max_epochs=20, 
    precision=16, callbacks= [
        checkpoint_callback,
        stopping_callback
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Train model

In [7]:
!gsutil cp gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt ./

Copying gs://tewhey-public-data/CODA_resources/MPRA_ALL_HD_v2.txt...
| [1 files][311.6 MiB/311.6 MiB]                                                
Operation completed over 1 objects/311.6 MiB.                                    


In [8]:
trainer.fit(model, data)

--------------------------------------------------

HepG2 | top cut value: 10.76, bottom cut value: -5.73
SKNSH | top cut value: 11.34, bottom cut value: -6.38

Number of examples discarded from top: 0
Number of examples discarded from bottom: 0

Number of examples available: 693349

--------------------------------------------------

Padding sequences... 

Creating train/val/test datasets with tokenized sequences... 

--------------------------------------------------

Number of examples in train: 500043 (72.12%)
Number of examples in val:   105586 (15.23%)
Number of examples in test:  134597 (19.41%)

Excluded from train: -46877 (-6.76)%
--------------------------------------------------


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type           | Params
----------------------------------------------
0  | pad1      | ConstantPad1d  | 0     
1  | conv1     | Conv1dNorm     | 23.7 K
2  | pad2      | ConstantPad1d  | 0     
3  | conv2     | Conv1dNorm     | 660 K 
4  | pad3      | ConstantPad1d  | 0     
5  | conv3     | Conv1dNorm     | 280 K 
6  | pad4      | ConstantPad1d  | 0     
7  | maxpool_3 | MaxPool1d      | 0     
8  | maxpool_4 | MaxPool1d      | 0     
9  | linear1   | LinearNorm     | 2.6 M 
10 | branched  | BranchedLinear | 359 K 
11 | output    | GroupedLinear  | 282   
12 | nonlin    | ReLU           | 0     
13 | dropout   | Dropout        | 0     
14 | criterion | L1KLmixed      | 0     
----------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
7.855     Total estimated model params size (MB)


Found 3927422 parameters


Sanity Checking: 0it [00:00, ?it/s]


---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.78512 | harmonic_mean_loss: 1.27266 | prediction_mean_spearman: 0.02915 | entropy_spearman: -0.02169 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



  "reduction: 'mean' divides the total loss by both the batch size and the support size."
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.90933 | harmonic_mean_loss: 1.05306 | prediction_mean_spearman: 0.58503 | entropy_spearman: 0.08964 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 1.00000 | arithmetic_mean_loss: 0.64500 | harmonic_mean_loss: 0.81748 | prediction_mean_spearman: 0.67216 | entropy_spearman: 0.17194 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 2.00000 | arithmetic_mean_loss: 0.65669 | harmonic_mean_loss: 0.68760 | prediction_mean_spearman: 0.69919 | entropy_spearman: 0.16933 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 3.00000 | arithmetic_mean_loss: 0.62214 | harmonic_mean_loss: 0.76318 | prediction_mean_spearman: 0.71376 | entropy_spearman: 0.24398 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 4.00000 | arithmetic_mean_loss: 0.60689 | harmonic_mean_loss: 0.58931 | prediction_mean_spearman: 0.71935 | entropy_spearman: 0.25877 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 5.00000 | arithmetic_mean_loss: 0.57465 | harmonic_mean_loss: 0.60783 | prediction_mean_spearman: 0.74736 | entropy_spearman: 0.25879 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 6.00000 | arithmetic_mean_loss: 0.56855 | harmonic_mean_loss: 0.57389 | prediction_mean_spearman: 0.75293 | entropy_spearman: 0.27140 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 7.00000 | arithmetic_mean_loss: 0.56753 | harmonic_mean_loss: 0.56449 | prediction_mean_spearman: 0.75242 | entropy_spearman: 0.27225 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 8.00000 | arithmetic_mean_loss: 0.60934 | harmonic_mean_loss: 0.62266 | prediction_mean_spearman: 0.73117 | entropy_spearman: 0.18729 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 9.00000 | arithmetic_mean_loss: 0.59494 | harmonic_mean_loss: 0.60001 | prediction_mean_spearman: 0.73648 | entropy_spearman: 0.19145 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 10.00000 | arithmetic_mean_loss: 0.56524 | harmonic_mean_loss: 0.57423 | prediction_mean_spearman: 0.74612 | entropy_spearman: 0.22832 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]


---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 11.00000 | arithmetic_mean_loss: 0.59007 | harmonic_mean_loss: 0.56649 | prediction_mean_spearman: 0.74463 | entropy_spearman: 0.25843 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



In [9]:
import tempfile
import re
import sys
import os

def set_best(my_model, callbacks):
    """
    Set the best model checkpoint for the provided model.

    This function sets the state of the provided model to the state of the best checkpoint,
    as determined by the `ModelCheckpoint` callback.

    Args:
        my_model (nn.Module): The model to be updated.
        callbacks (dict): Dictionary of callbacks, including 'model_checkpoint'.

    Returns:
        nn.Module: The updated model.
    """
    with tempfile.TemporaryDirectory() as tmpdirname:
        try:
            best_path = callbacks['model_checkpoint'].best_model_path
            get_epoch = re.search('epoch=(\d*)', best_path).group(1)
            if 'gs://' in best_path:
                subprocess.call(['gsutil','cp',best_path,tmpdirname])
                best_path = os.path.join( tmpdirname, os.path.basename(best_path) )
            print(f'Best model stashed at: {best_path}', file=sys.stderr)
            print(f'Exists: {os.path.isfile(best_path)}', file=sys.stderr)
            ckpt = torch.load( best_path )
            my_model.load_state_dict( ckpt['state_dict'] )
            print(f'Setting model from epoch: {get_epoch}', file=sys.stderr)
        except KeyError:
            print('Setting most recent model', file=sys.stderr)
    return my_model

model = set_best(model, {'model_checkpoint': checkpoint_callback})

Best model stashed at: /home/ubuntu/boda2/tutorials/lightning_logs/version_4/checkpoints/epoch=6-step=3423.ckpt
Exists: True
Setting model from epoch: 6
