In [1]:
import os
import sys
import time
import yaml
import shutil
import argparse
import tarfile
import random
import tempfile
import subprocess

import torch
import pytorch_lightning as ptl
from pytorch_lightning import Trainer

import hypertune

import boda


In [2]:
print(torch.__version__)
print(ptl.__version__)

1.9.0+cu102
1.3.0


In [3]:
def main(args):
    data_module = getattr(boda.data, args['Main args'].data_module)
    model_module= getattr(boda.model, args['Main args'].model_module)
    graph_module= getattr(boda.graph, args['Main args'].graph_module)

    data = data_module(**vars(data_module.process_args(args)))
    model= model_module(**vars(model_module.process_args(args)))

    model.__class__ = type(
        'BODA_module',
        (model_module,graph_module),
        vars(graph_module.process_args(args))
    )

    trainer = Trainer.from_argparse_args(args['pl.Trainer'])
    os.makedirs(trainer.default_root_dir, exist_ok=True)
    
    trainer.fit(model, data)
    
    #_save_model(data_module, model_module, graph_module, 
    #            model, trainer, args)
    
    return data_module, model_module, graph_module, model, trainer, args

In [4]:
def _save_model(data_module, model_module, graph_module, 
                model, trainer, args):
    local_dir = args['pl.Trainer'].default_root_dir
    save_dict = {
        'data_module'  : data_module.__name__,
        'data_hparams' : data_module.process_args(args),
        'model_module' : model_module.__name__,
        'model_hparams': model_module.process_args(args),
        'graph_module' : graph_module.__name__,
        'graph_hparams': graph_module.process_args(args),
        'model_state_dict': model.state_dict(),
        'timestamp'    : time.strftime("%Y%m%d_%H%M%S"),
        'random_tag'   : random.randint(100000,999999)
    }
    torch.save(save_dict, os.path.join(local_dir,'torch_checkpoint.pt'))
    
    filename=f'model_artifacts__{save_dict["timestamp"]}__{save_dict["random_tag"]}.tar.gz'
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmpdirname = '/tmp/output'
        with tarfile.open(os.path.join(tmpdirname,filename), 'w:gz') as tar:
            tar.add(local_dir,arcname='artifacts')

        if 'gs://' in args['Main args'].artifact_path:
            clound_target = os.path.join(args['Main args'].artifact_path,filename)
            subprocess.check_call(
                ['gsutil', 'cp', os.path.join(tmpdirname,filename), clound_target]
            )
        else:
            os.makedirs(args['Main args'].artifact_path, exist_ok=True)
            shutil.copy(os.path.join(tmpdirname,filename), args['Main args'].artifact_path)


In [5]:
def model_fn(model_dir):
    checkpoint = torch.load(os.path.join(model_dir,'torch_checkpoint.pt'))
    model_module = getattr(boda, checkpoint['model_module'])
    model        = model_module(**checkpoint['model_hparams'])
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f'Loaded model from {checkpoint["timestamp"]}')
    return model


# Process runtime arguments

## Command line args to use for testing

In [6]:
basset_weights = 'gs://syrgoth/my-model.epoch_5-step_19885.pkl'
old_data = 'gs://syrgoth/data/MPRA_UKBB_BODA.txt'
new_data = 'gs://syrgoth/data/MPRA_ALL_v3.txt'
local_stash = '/tmp/temp_model.pkl'

In [7]:
cmd_str = '--data_module MPRA_DataModule ' +\
            '--datafile_path  gs://syrgoth/data/MPRA_ALL_v3.txt ' +\
            '--batch_size  1991 --padded_seq_len 600 --num_workers 1 ' +\
            '--synth_seed 102202 ' +\
          '--model_module BassetBranched ' +\
            '--n_linear_layers 1 --linear_channels 1000 --linear_dropout_p 3.691822111811164e-1 ' +\
            '--branched_dropout_p 3.163309575564524e-1 ' +\
            '--n_branched_layers 3 --branched_channels 100 --n_outputs  3 --loss_criterion L1KLmixed ' +\
          '--graph_module CNNTransferLearning ' +\
            '--optimizer Adam --amsgrad True ' +\
            '--lr 4.621469184528976e-4 --weight_decay 6.465866656156007e-5 ' +\
            '--beta1 9.154886174667547e-1 --beta2 9.064388107548405e-1 ' +\
            '--loss_criterion L1KLmixed --kl_scale 5.555669898051376 ' +\
            '--parent_weights /tmp/temp_model.pkl ' +\
            '--frozen_epochs 0 ' +\
          '--gpus 1 --min_epochs 10 --max_epochs 10 --precision 16 --default_root_dir /tmp/output/artifacts ' +\
          '--artifact_path gs://syrgoth/aip_ui_test '

'python /home/ubuntu/boda2/src/main.py ' + cmd_str

'python /home/ubuntu/boda2/src/main.py --data_module MPRA_DataModule --datafile_path  gs://syrgoth/data/MPRA_ALL_v3.txt --batch_size  1991 --padded_seq_len 600 --num_workers 1 --synth_seed 102202 --model_module BassetBranched --n_linear_layers 1 --linear_channels 1000 --linear_dropout_p 3.691822111811164e-1 --branched_dropout_p 3.163309575564524e-1 --n_branched_layers 3 --branched_channels 100 --n_outputs  3 --loss_criterion L1KLmixed --graph_module CNNTransferLearning --optimizer Adam --amsgrad True --lr 4.621469184528976e-4 --weight_decay 6.465866656156007e-5 --beta1 9.154886174667547e-1 --beta2 9.064388107548405e-1 --loss_criterion L1KLmixed --kl_scale 5.555669898051376 --parent_weights /tmp/temp_model.pkl --frozen_epochs 0 --gpus 1 --min_epochs 10 --max_epochs 10 --precision 16 --default_root_dir /tmp/output/artifacts --artifact_path gs://syrgoth/aip_ui_test '

In [8]:
for i, value in enumerate(cmd_str.split()):
    if i % 2:
        print("="+value)
    else:
        print(value, end='')

--data_module=MPRA_DataModule
--datafile_path=gs://syrgoth/data/MPRA_ALL_v3.txt
--batch_size=1991
--padded_seq_len=600
--num_workers=1
--synth_seed=102202
--model_module=BassetBranched
--n_linear_layers=1
--linear_channels=1000
--linear_dropout_p=3.691822111811164e-1
--branched_dropout_p=3.163309575564524e-1
--n_branched_layers=3
--branched_channels=100
--n_outputs=3
--loss_criterion=L1KLmixed
--graph_module=CNNTransferLearning
--optimizer=Adam
--amsgrad=True
--lr=4.621469184528976e-4
--weight_decay=6.465866656156007e-5
--beta1=9.154886174667547e-1
--beta2=9.064388107548405e-1
--loss_criterion=L1KLmixed
--kl_scale=5.555669898051376
--parent_weights=/tmp/temp_model.pkl
--frozen_epochs=0
--gpus=1
--min_epochs=10
--max_epochs=10
--precision=16
--default_root_dir=/tmp/output/artifacts
--artifact_path=gs://syrgoth/aip_ui_test


## Set base args for script

Basic arguments to identify which submodules are used and where data will be saved

In [9]:
parser = argparse.ArgumentParser(description="BODA trainer", add_help=False)
group = parser.add_argument_group('Main args')
group.add_argument('--data_module', type=str, required=True, help='BODA data module to process dataset.')
group.add_argument('--model_module',type=str, required=True, help='BODA model module to fit dataset.')
group.add_argument('--graph_module',type=str, required=True, help='BODA graph module to define computations.')
group.add_argument('--artifact_path', type=str, default='/opt/ml/checkpoints/', help='Path where model artifacts are deposited.')
group.add_argument('--pretrained_weights', type=str, help='Pretrained weights.')



_StoreAction(option_strings=['--pretrained_weights'], dest='pretrained_weights', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, help='Pretrained weights.', metavar=None)

In [10]:
known_args, leftover_args = parser.parse_known_args(
    cmd_str.rstrip().split()
)

In [11]:
known_args

Namespace(artifact_path='gs://syrgoth/aip_ui_test', data_module='MPRA_DataModule', graph_module='CNNTransferLearning', model_module='BassetBranched', pretrained_weights=None)

## Extract first-order submodule args

Get submodule specific arguments.

In [12]:
Data  = getattr(boda.data,  known_args.data_module)
Model = getattr(boda.model, known_args.model_module)
Graph = getattr(boda.graph, known_args.graph_module)

parser = Data.add_data_specific_args(parser)
parser = Model.add_model_specific_args(parser)
parser = Graph.add_graph_specific_args(parser)

known_args, leftover_args = parser.parse_known_args(
    cmd_str.rstrip().split()
)


In [13]:
known_args

Namespace(activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'], artifact_path='gs://syrgoth/aip_ui_test', batch_size=1991, branched_activation='ReLU', branched_channels=100, branched_dropout_p=0.3163309575564524, chr_column='chr', conv1_channels=300, conv1_kernel_size=19, conv2_channels=200, conv2_kernel_size=11, conv3_channels=200, conv3_kernel_size=7, criterion_reduction='mean', data_module='MPRA_DataModule', data_project=['BODA', 'UKBB', 'GTEX'], datafile_path='gs://syrgoth/data/MPRA_ALL_v3.txt', exclude_chr_train=[''], frozen_epochs=0, graph_module='CNNTransferLearning', kl_scale=5.555669898051376, linear_activation='ReLU', linear_channels=1000, linear_dropout_p=0.3691822111811164, loss_criterion='L1KLmixed', model_module='BassetBranched', mse_scale=1.0, n_branched_layers=3, n_linear_layers=1, n_outputs=3, normalize=False, num_workers=1, optimizer='Adam', padded_seq_len=600, parent_weights='/tmp/temp_model.pkl', pretrained_weights=None, project_column='data_project', schedule

## Extract second-order submodule args

Get another set of submodule specific arguments based preliminary choices. (i.e., optional arguments for optimizer of choice)

In [14]:
parser = Data.add_conditional_args(parser, known_args)
parser = Model.add_conditional_args(parser, known_args)
parser = Graph.add_conditional_args(parser, known_args)

parser = Trainer.add_argparse_args(parser)
parser.add_argument('--help', '-h', action='help')
args = parser.parse_args(
    cmd_str.rstrip().split()
)

args = boda.common.utils.organize_args(parser, args)



In [15]:
args

{'positional arguments': Namespace(),
 'optional arguments': Namespace(help=None),
 'Main args': Namespace(artifact_path='gs://syrgoth/aip_ui_test', data_module='MPRA_DataModule', graph_module='CNNTransferLearning', model_module='BassetBranched', pretrained_weights=None),
 'Data Module args': Namespace(activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'], batch_size=1991, chr_column='chr', data_project=['BODA', 'UKBB', 'GTEX'], datafile_path='gs://syrgoth/data/MPRA_ALL_v3.txt', exclude_chr_train=[''], normalize=False, num_workers=1, padded_seq_len=600, project_column='data_project', sequence_column='nt_sequence', std_multiple_cut=6.0, synth_chr='synth', synth_seed=102202, synth_test_pct=10.0, synth_val_pct=10.0, test_chrs=['7', '13'], up_cutoff_move=3.0, val_chrs=['17', '19', '21', 'X']),
 'Model Module args': Namespace(branched_activation='ReLU', branched_channels=100, branched_dropout_p=0.3163309575564524, conv1_channels=300, conv1_kernel_size=19, conv2_channels=200, conv2_kern

# Run training
use modified `main` runner.

In [16]:
data_module, model_module, graph_module, model, trainer, args = main(args)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


--------------------------------------------------

K562 | top cut value: 10.47, bottom cut value: -6.47
HepG2 | top cut value: 9.67, bottom cut value: -5.75
SKNSH | top cut value: 10.33, bottom cut value: -6.5

Number of examples discarded from top: 0
Number of examples discarded from bottom: 6

Number of examples available: 792635

--------------------------------------------------

Padding sequences...
Tokenizing sequences...
Creating train/val/test datasets...
--------------------------------------------------

Number of examples in train: 625065 (78.86%)
Number of examples in val:   101161 (12.76%)
Number of examples in test:  66409 (8.38%)

Excluded from train: 0 (0.0)%
--------------------------------------------------


Key conv1.conv.weight successfully matched
Key conv1.conv.bias successfully matched
Key conv1.bn_layer.weight successfully matched
Key conv1.bn_layer.bias successfully matched
Key conv1.bn_layer.running_mean successfully matched
Key conv1.bn_layer.running_var successfully matched
Key conv1.bn_layer.num_batches_tracked successfully matched
Key conv2.conv.weight successfully matched
Key conv2.conv.bias successfully matched
Key conv2.bn_layer.weight successfully matched
Key conv2.bn_layer.bias successfully matched
Key conv2.bn_layer.running_mean successfully matched
Key conv2.bn_layer.running_var successfully matched
Key conv2.bn_layer.num_batches_tracked successfully matched
Key conv3.conv.weight successfully matched
Key conv3.conv.bias successfully matched
Key conv3.bn_layer.weight successfully matched
Key conv3.bn_layer.bias successfully matched
Key conv3.bn_layer.running_mean successfully matched
Key conv3.bn_layer.running_var successfully matched
Key conv3.bn_layer.num_batches_tracke

Found 3929103 parameters


Validation sanity check: 0it [00:00, ?it/s]

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)
  "reduction: 'mean' divides the total loss by both the batch size and the support size."



--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.50375 | harmonic_mean_loss: 0.43654 | prediction_mean_spearman: 0.77823 | entropy_spearman: 0.50766 |
--------------------------------------------------------------------------------------------------------------------------------------------------------





Training: 0it [00:00, ?it/s]

starting epoch 0


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 0.00000 | arithmetic_mean_loss: 0.50983 | harmonic_mean_loss: 0.42794 | prediction_mean_spearman: 0.80991 | entropy_spearman: 0.51533 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 1


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 1.00000 | arithmetic_mean_loss: 0.50825 | harmonic_mean_loss: 0.42035 | prediction_mean_spearman: 0.80919 | entropy_spearman: 0.51008 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 2


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 2.00000 | arithmetic_mean_loss: 0.51222 | harmonic_mean_loss: 0.42052 | prediction_mean_spearman: 0.80786 | entropy_spearman: 0.51097 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 3


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 3.00000 | arithmetic_mean_loss: 0.51210 | harmonic_mean_loss: 0.42193 | prediction_mean_spearman: 0.80642 | entropy_spearman: 0.51042 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 4


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 4.00000 | arithmetic_mean_loss: 0.51816 | harmonic_mean_loss: 0.42398 | prediction_mean_spearman: 0.80587 | entropy_spearman: 0.50536 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 5


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 5.00000 | arithmetic_mean_loss: 0.51694 | harmonic_mean_loss: 0.42618 | prediction_mean_spearman: 0.80385 | entropy_spearman: 0.50908 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 6


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 6.00000 | arithmetic_mean_loss: 0.51939 | harmonic_mean_loss: 0.42559 | prediction_mean_spearman: 0.80307 | entropy_spearman: 0.50531 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 7


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 7.00000 | arithmetic_mean_loss: 0.52374 | harmonic_mean_loss: 0.42695 | prediction_mean_spearman: 0.80284 | entropy_spearman: 0.50442 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 8


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 8.00000 | arithmetic_mean_loss: 0.52303 | harmonic_mean_loss: 0.43784 | prediction_mean_spearman: 0.79989 | entropy_spearman: 0.50189 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

starting epoch 9


Validating: 0it [00:00, ?it/s]


--------------------------------------------------------------------------------------------------------------------------------------------------------
| current_epoch: 9.00000 | arithmetic_mean_loss: 0.52473 | harmonic_mean_loss: 0.42819 | prediction_mean_spearman: 0.80028 | entropy_spearman: 0.50071 |
--------------------------------------------------------------------------------------------------------------------------------------------------------

