In [1]:
import os
import sys
import time
import yaml
import shutil
import argparse
import tarfile
import random
import tempfile
import subprocess

import torch
import pytorch_lightning as ptl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

import boda


In [2]:
print(torch.__version__)
print(ptl.__version__)

1.9.0+cu102
1.3.0


In [3]:
def main(args, callbacks=None):
    data_module = getattr(boda.data, args['Main args'].data_module)
    model_module= getattr(boda.model, args['Main args'].model_module)
    graph_module= getattr(boda.graph, args['Main args'].graph_module)

    data = data_module(**vars(data_module.process_args(args)))
    model= model_module(**vars(model_module.process_args(args)))

    model.__class__ = type(
        'BODA_module',
        (model_module,graph_module),
        vars(graph_module.process_args(args))
    )

    os.makedirs('/tmp/output/artifacts', exist_ok=True)
    trainer = Trainer.from_argparse_args(args['pl.Trainer'], callbacks=callbacks)
    
    trainer.fit(model, data)
    
    #_save_model(data_module, model_module, graph_module, 
    #            model, trainer, args)
    
    return data_module, model_module, graph_module, model, trainer, args

In [4]:
def _save_model(data_module, model_module, graph_module, 
                model, trainer, args):
    local_dir = args['pl.Trainer'].default_root_dir
    save_dict = {
        'data_module'  : data_module.__name__,
        'data_hparams' : data_module.process_args(args),
        'model_module' : model_module.__name__,
        'model_hparams': model_module.process_args(args),
        'graph_module' : graph_module.__name__,
        'graph_hparams': graph_module.process_args(args),
        'model_state_dict': model.state_dict(),
        'timestamp'    : time.strftime("%Y%m%d_%H%M%S"),
        'random_tag'   : random.randint(100000,999999)
    }
    torch.save(save_dict, os.path.join(local_dir,'torch_checkpoint.pt'))
    
    filename=f'model_artifacts__{save_dict["timestamp"]}__{save_dict["random_tag"]}.tar.gz'
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmpdirname = '/tmp/output'
        with tarfile.open(os.path.join(tmpdirname,filename), 'w:gz') as tar:
            tar.add(local_dir,arcname='artifacts')

        if 'gs://' in args['Main args'].artifact_path:
            clound_target = os.path.join(args['Main args'].artifact_path,filename)
            subprocess.check_call(
                ['gsutil', 'cp', os.path.join(tmpdirname,filename), clound_target]
            )
        else:
            os.makedirs(args['Main args'].artifact_path, exist_ok=True)
            shutil.copy(os.path.join(tmpdirname,filename), args['Main args'].artifact_path)


In [5]:
def model_fn(model_dir):
    checkpoint = torch.load(os.path.join(model_dir,'torch_checkpoint.pt'))
    model_module = getattr(boda, checkpoint['model_module'])
    model        = model_module(**checkpoint['model_hparams'])
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f'Loaded model from {checkpoint["timestamp"]}')
    return model


# Process runtime arguments

## Command line args to use for testing

In [6]:
cmd_str = '--data_module MPRA_DataModule ' +\
            '--datafile_path  gs://syrgoth/data/MPRA_ALL_v3.txt ' +\
            '--batch_size  1991 --padded_seq_len 600 --num_workers 1 ' +\
            '--synth_seed 102202 ' +\
          '--model_module BassetBranched ' +\
            '--n_linear_layers 1 --linear_channels 1000 --linear_dropout_p 3.691822111811164e-1 ' +\
            '--branched_dropout_p 3.163309575564524e-1 ' +\
            '--n_branched_layers 3 --branched_channels 100 --n_outputs  3 --loss_criterion L1KLmixed ' +\
          '--graph_module CNNTransferLearning ' +\
            '--optimizer Adam --amsgrad True ' +\
            '--lr 4.621469184528976e-4 --weight_decay 6.465866656156007e-5 ' +\
            '--beta1 9.154886174667547e-1 --beta2 9.064388107548405e-1 ' +\
            '--loss_criterion L1KLmixed --kl_scale 5.555669898051376 ' +\
            '--parent_weights gs://syrgoth/aip_ui_test/model_artifacts__20211016_050110__476387.tar.gz ' +\
            '--frozen_epochs 0 ' +\
          '--gpus 1 --min_epochs 3 --max_epochs 3 --precision 16 --default_root_dir /tmp/output/artifacts ' +\
          '--artifact_path gs://syrgoth/aip_ui_test '

'python /home/ubuntu/boda2/src/main.py ' + cmd_str

'python /home/ubuntu/boda2/src/main.py --data_module MPRA_DataModule --datafile_path  gs://syrgoth/data/MPRA_ALL_v3.txt --batch_size  1991 --padded_seq_len 600 --num_workers 1 --synth_seed 102202 --model_module BassetBranched --n_linear_layers 1 --linear_channels 1000 --linear_dropout_p 3.691822111811164e-1 --branched_dropout_p 3.163309575564524e-1 --n_branched_layers 3 --branched_channels 100 --n_outputs  3 --loss_criterion L1KLmixed --graph_module CNNTransferLearning --optimizer Adam --amsgrad True --lr 4.621469184528976e-4 --weight_decay 6.465866656156007e-5 --beta1 9.154886174667547e-1 --beta2 9.064388107548405e-1 --loss_criterion L1KLmixed --kl_scale 5.555669898051376 --parent_weights gs://syrgoth/aip_ui_test/model_artifacts__20211016_050110__476387.tar.gz --frozen_epochs 0 --gpus 1 --min_epochs 3 --max_epochs 3 --precision 16 --default_root_dir /tmp/output/artifacts --artifact_path gs://syrgoth/aip_ui_test '

## Set base args for script

Basic arguments to identify which submodules are used and where data will be saved

In [7]:
parser = argparse.ArgumentParser(description="BODA trainer", add_help=False)
group = parser.add_argument_group('Main args')
group.add_argument('--data_module', type=str, required=True, help='BODA data module to process dataset.')
group.add_argument('--model_module',type=str, required=True, help='BODA model module to fit dataset.')
group.add_argument('--graph_module',type=str, required=True, help='BODA graph module to define computations.')
group.add_argument('--artifact_path', type=str, default='/opt/ml/checkpoints/', help='Path where model artifacts are deposited.')
group.add_argument('--pretrained_weights', type=str, help='Pretrained weights.')



_StoreAction(option_strings=['--pretrained_weights'], dest='pretrained_weights', nargs=None, const=None, default=None, type=<class 'str'>, choices=None, help='Pretrained weights.', metavar=None)

In [8]:
known_args, leftover_args = parser.parse_known_args(
    cmd_str.rstrip().split()
)

In [9]:
known_args

Namespace(artifact_path='gs://syrgoth/aip_ui_test', data_module='MPRA_DataModule', graph_module='CNNTransferLearning', model_module='BassetBranched', pretrained_weights=None)

## Extract first-order submodule args

Get submodule specific arguments.

In [10]:
Data  = getattr(boda.data,  known_args.data_module)
Model = getattr(boda.model, known_args.model_module)
Graph = getattr(boda.graph, known_args.graph_module)

parser = Data.add_data_specific_args(parser)
parser = Model.add_model_specific_args(parser)
parser = Graph.add_graph_specific_args(parser)

known_args, leftover_args = parser.parse_known_args(
    cmd_str.rstrip().split()
)


In [11]:
known_args

Namespace(activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'], artifact_path='gs://syrgoth/aip_ui_test', batch_size=1991, branched_activation='ReLU', branched_channels=100, branched_dropout_p=0.3163309575564524, chr_column='chr', conv1_channels=300, conv1_kernel_size=19, conv2_channels=200, conv2_kernel_size=11, conv3_channels=200, conv3_kernel_size=7, criterion_reduction='mean', data_module='MPRA_DataModule', data_project=['BODA', 'UKBB', 'GTEX'], datafile_path='gs://syrgoth/data/MPRA_ALL_v3.txt', exclude_chr_train=[''], frozen_epochs=0, graph_module='CNNTransferLearning', kl_scale=5.555669898051376, linear_activation='ReLU', linear_channels=1000, linear_dropout_p=0.3691822111811164, loss_criterion='L1KLmixed', model_module='BassetBranched', mse_scale=1.0, n_branched_layers=3, n_linear_layers=1, n_outputs=3, normalize=False, num_workers=1, optimizer='Adam', padded_seq_len=600, parent_weights='gs://syrgoth/aip_ui_test/model_artifacts__20211016_050110__476387.tar.gz', pretrained_

## Extract second-order submodule args

Get another set of submodule specific arguments based preliminary choices. (i.e., optional arguments for optimizer of choice)

In [12]:
parser = Data.add_conditional_args(parser, known_args)
parser = Model.add_conditional_args(parser, known_args)
parser = Graph.add_conditional_args(parser, known_args)

parser = Trainer.add_argparse_args(parser)
parser.add_argument('--help', '-h', action='help')
args = parser.parse_args(
    cmd_str.rstrip().split()
)

args = boda.common.utils.organize_args(parser, args)



In [13]:
args

{'positional arguments': Namespace(),
 'optional arguments': Namespace(help=None),
 'Main args': Namespace(artifact_path='gs://syrgoth/aip_ui_test', data_module='MPRA_DataModule', graph_module='CNNTransferLearning', model_module='BassetBranched', pretrained_weights=None),
 'Data Module args': Namespace(activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'], batch_size=1991, chr_column='chr', data_project=['BODA', 'UKBB', 'GTEX'], datafile_path='gs://syrgoth/data/MPRA_ALL_v3.txt', exclude_chr_train=[''], normalize=False, num_workers=1, padded_seq_len=600, project_column='data_project', sequence_column='nt_sequence', std_multiple_cut=6.0, synth_chr='synth', synth_seed=102202, synth_test_pct=10.0, synth_val_pct=10.0, test_chrs=['7', '13'], up_cutoff_move=3.0, val_chrs=['17', '19', '21', 'X']),
 'Model Module args': Namespace(branched_activation='ReLU', branched_channels=100, branched_dropout_p=0.3163309575564524, conv1_channels=300, conv1_kernel_size=19, conv2_channels=200, conv2_kern

# Run training
use modified `main` runner.

In [14]:
checkpoint_callback = ModelCheckpoint(
                save_top_k=1,
                monitor='entropy_spearman', 
                mode='max'
            )
data_module, model_module, graph_module, model, trainer, args = main(args, callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

In [None]:
_save_model(data_module, model_module, graph_module, 
            model, trainer, args)

In [16]:
vars(trainer.logger.experiment)

{'log_dir': '/tmp/output/artifacts/lightning_logs/version_1',
 'purge_step': None,
 'max_queue': 10,
 'flush_secs': 120,
 'filename_suffix': '',
 'file_writer': None,
 'all_writers': None,
 'default_bins': [-9.920775621859783e+19,
  -9.018886928963438e+19,
  -8.198988117239489e+19,
  -7.453625561126807e+19,
  -6.776023237388005e+19,
  -6.160021124898186e+19,
  -5.600019204452896e+19,
  -5.090926549502632e+19,
  -4.628115045002392e+19,
  -4.2073773136385384e+19,
  -3.824888466944125e+19,
  -3.477171333585568e+19,
  -3.1610648487141528e+19,
  -2.873695317012866e+19,
  -2.6124502881935143e+19,
  -2.3749548074486493e+19,
  -2.1590498249533174e+19,
  -1.962772568139379e+19,
  -1.7843386983085265e+19,
  -1.6221260893713875e+19,
  -1.4746600812467157e+19,
  -1.3406000738606506e+19,
  -1.2187273398733187e+19,
  -1.1079339453393805e+19,
  -1.0072126775812549e+19,
  -9.156478887102316e+18,
  -8.324071715547559e+18,
  -7.567337923225053e+18,
  -6.879398112022775e+18,
  -6.253998283657068e+18,
  -

In [20]:
vars(checkpoint_callback)

{'monitor': 'entropy_spearman',
 'verbose': False,
 'save_last': None,
 'save_top_k': 1,
 'save_weights_only': False,
 'auto_insert_metric_name': True,
 '_last_global_step_saved': 3899,
 'current_score': tensor(0.1598, device='cuda:0'),
 'best_k_models': {'/tmp/output/artifacts/lightning_logs/version_1/checkpoints/epoch=4-step=3899.ckpt': tensor(0.1598, device='cuda:0')},
 'kth_best_model_path': '/tmp/output/artifacts/lightning_logs/version_1/checkpoints/epoch=4-step=3899.ckpt',
 'best_model_score': tensor(0.1598, device='cuda:0'),
 'best_model_path': '/tmp/output/artifacts/lightning_logs/version_1/checkpoints/epoch=4-step=3899.ckpt',
 'last_model_path': '',
 'kth_value': tensor(0.1598, device='cuda:0'),
 'mode': 'max',
 '_fs': <fsspec.implementations.local.LocalFileSystem at 0x7f614c6e2750>,
 'dirpath': '/tmp/output/artifacts/lightning_logs/version_1/checkpoints',
 'filename': None,
 '_every_n_val_epochs': 1,
 '_every_n_train_steps': 0,
 '_period': 1,
 '_save_function': <bound method 

In [24]:
print(
f"{checkpoint_callback.monitor}: {checkpoint_callback.best_model_score.item()}",
f"global step: {checkpoint_callback._last_global_step_saved + 1}"
)

entropy_spearman: 0.15984028577804565 global step: 3900


In [17]:
trainer.logger.experiment()

TypeError: 'SummaryWriter' object is not callable

In [25]:
use_callbacks = None
if use_callbacks is None:
    print('ehllo')

ehllo
