In [1]:
import torch
import lightning.pytorch as ptl
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

import boda

# Set up

## Pick modules
Pick modules to define:
1. The data, how it's preprocessed and train/val/test split
2. The model, the architecture setup, loss function, etc.
3. The graph, how the data is used to train the model (i.e. training loop)

In [2]:
data_module = boda.data.SeqDataModule
model_module= boda.model.BassetBranched
graph_module= boda.graph.CNNBasicTraining

## Initalize Data and Model
I added chr1 to test and chr2 to val to speed up this example. I also removed the reverse complmentat data augmentation.

In [6]:
data = data_module(
    train_file = "home/ubuntu/boda2/analysis/AR001__rotation/dummy_train.tsv",
    test_file = "home/ubuntu/boda2/analysis/AR001__rotation/dummy_test.tsv",
    val_file = "home/ubuntu/boda2/analysis/AR001__rotation/dummy_val.tsv",
    right_flank = boda.common.constants.MPRA_DOWNSTREAM[:200],
    left_flank = boda.common.constants.MPRA_UPSTREAM[-200:]
)

model = model_module(
    n_outputs=2, 
    n_linear_layers=1, linear_channels=1000,
    linear_activation='ReLU', linear_dropout_p=0.12, 
    n_branched_layers=3, branched_channels=140, 
    branched_activation='ReLU', branched_dropout_p=0.56, 
    loss_criterion='L1KLmixed', kl_scale=5.0
)

## Append Graph to Model
Augment the model class to append functions from the graph module. A downside to this structure is that you need to make sure all relevent Graph args are defined (even if None is an acceptable default). This is because the `__init__` block in the Graph class doesn't run.

In [7]:
graph_args = {
    'optimizer': 'Adam', 
    'optimizer_args': {
        'lr': 0.0033, 'betas':[0.9, 0.999], 
        'weight_decay': 3.43e-4, 'amsgrad': True
    },
    'scheduler': 'CosineAnnealingWarmRestarts', 
    'scheduler_monitor': None, 
    'scheduler_interval': 'step',
    'scheduler_args': {
        'T_0': 4096,
    }
}

model.__class__ = type(
    'BODA_module',
    (model_module,graph_module),
    graph_args
)

In [8]:
model(torch.randn(10,4,600))

tensor([[ 0.0077, -0.0493],
        [ 0.0076, -0.0493],
        [ 0.0079, -0.0494],
        [ 0.0078, -0.0494],
        [ 0.0074, -0.0494],
        [ 0.0076, -0.0495],
        [ 0.0077, -0.0495],
        [ 0.0075, -0.0490],
        [ 0.0075, -0.0491],
        [ 0.0077, -0.0493]], grad_fn=<PermuteBackward0>)

## Lightning trainer
Normally we train for more epochs, but reduced in this example

In [9]:
checkpoint_callback = ModelCheckpoint(
    save_top_k=1, 
    monitor='prediction_mean_spearman', 
    mode='max'
)

stopping_callback = EarlyStopping(
    monitor='prediction_mean_spearman', 
    patience=5,
    mode='max'
)

trainer = ptl.Trainer(
    accelerator='gpu', devices=1, 
    min_epochs=5, max_epochs=20, 
    precision=16, callbacks= [
        checkpoint_callback,
        stopping_callback
    ]
)

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Train model

In [10]:
trainer.fit(model, data)

Missing logger folder: /home/ubuntu/boda2/analysis/AR001__rotation/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type           | Params
----------------------------------------------
0  | pad1      | ConstantPad1d  | 0     
1  | conv1     | Conv1dNorm     | 23.7 K
2  | pad2      | ConstantPad1d  | 0     
3  | conv2     | Conv1dNorm     | 660 K 
4  | pad3      | ConstantPad1d  | 0     
5  | conv3     | Conv1dNorm     | 280 K 
6  | pad4      | ConstantPad1d  | 0     
7  | maxpool_3 | MaxPool1d      | 0     
8  | maxpool_4 | MaxPool1d      | 0     
9  | linear1   | LinearNorm     | 2.6 M 
10 | branched  | BranchedLinear | 359 K 
11 | output    | GroupedLinear  | 282   
12 | nonlin    | ReLU           | 0     
13 | dropout   | Dropout        | 0     
14 | criterion | L1KLmixed      | 0     
----------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
7.855     Total estimated model params size 

Found 3927422 parameters


Sanity Checking: 0it [00:00, ?it/s]


-------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 3210.60815 | prediction_mean_spearman: 0.04286 | entropy_spearman: 0.34737 |
-------------------------------------------------------------------------------------------------------------------------------------------------------



  return F.l1_loss(input, target, reduction=self.reduction)
  "reduction: 'mean' divides the total loss by both the batch size and the support size."
  f"You called `self.log({self.meta.name!r}, ...)` in your `{self.meta.fx}` but the value needs to"


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



---------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 18327.17773 | prediction_mean_spearman: 0.12816 | entropy_spearman: -0.06330 |
---------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 1.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 877.63416 | prediction_mean_spearman: 0.04358 | entropy_spearman: 0.04048 |
------------------------------------------------------------------------------------------------------------------------------------------------------



NaN or Inf found in input tensor.


Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



-------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 2.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 1131.41187 | prediction_mean_spearman: 0.06560 | entropy_spearman: 0.13588 |
-------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



-------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 3.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 3053.36255 | prediction_mean_spearman: 0.05055 | entropy_spearman: 0.06719 |
-------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 4.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 1082.36609 | prediction_mean_spearman: -0.00865 | entropy_spearman: 0.07530 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



Validation: 0it [00:00, ?it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.



--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 5.00000 | arithmetic_mean_loss: nan | harmonic_mean_loss: 1737.79443 | prediction_mean_spearman: -0.01997 | entropy_spearman: 0.02380 |
--------------------------------------------------------------------------------------------------------------------------------------------------------



In [9]:
import tempfile
import re
import sys
import os

def set_best(my_model, callbacks):
    """
    Set the best model checkpoint for the provided model.

    This function sets the state of the provided model to the state of the best checkpoint,
    as determined by the `ModelCheckpoint` callback.

    Args:
        my_model (nn.Module): The model to be updated.
        callbacks (dict): Dictionary of callbacks, including 'model_checkpoint'.

    Returns:
        nn.Module: The updated model.
    """
    with tempfile.TemporaryDirectory() as tmpdirname:
        try:
            best_path = callbacks['model_checkpoint'].best_model_path
            get_epoch = re.search('epoch=(\d*)', best_path).group(1)
            if 'gs://' in best_path:
                subprocess.call(['gsutil','cp',best_path,tmpdirname])
                best_path = os.path.join( tmpdirname, os.path.basename(best_path) )
            print(f'Best model stashed at: {best_path}', file=sys.stderr)
            print(f'Exists: {os.path.isfile(best_path)}', file=sys.stderr)
            ckpt = torch.load( best_path )
            my_model.load_state_dict( ckpt['state_dict'] )
            print(f'Setting model from epoch: {get_epoch}', file=sys.stderr)
        except KeyError:
            print('Setting most recent model', file=sys.stderr)
    return my_model

model = set_best(model, {'model_checkpoint': checkpoint_callback})

Best model stashed at: /home/ubuntu/boda2/tutorials/lightning_logs/version_4/checkpoints/epoch=6-step=3423.ckpt
Exists: True
Setting model from epoch: 6
