In [1]:
import sys
import os
import subprocess
import tarfile
import shutil

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits import mplot3d
from Bio import motifs

import boda
from boda.generator.parameters import StraightThroughParameters
from boda.generator import FastSeqProp
from boda.generator.plot_tools import matrix_to_dms, ppm_to_IC, ppm_to_pwm
from boda.model.mpra_basset import MPRA_Basset
from boda.common import constants, utils
from boda.data import MPRA_DataModule

boda_src = os.path.join( os.path.dirname( os.path.dirname( os.getcwd() ) ), 'src' )
sys.path.insert(0, boda_src)

from main import unpack_artifact, model_fn
from pymeme import streme, parse_streme_output

In [2]:
#----------------------- model -----------------------
model = MPRA_Basset(extra_hidden_size = 250)

In [4]:
! gsutil cp gs://syrgoth/data/MPRA_UKBB_BODA_v2.txt ./

Copying gs://syrgoth/data/MPRA_UKBB_BODA_v2.txt...
| [1 files][107.1 MiB/107.1 MiB]                                                
Operation completed over 1 objects/107.1 MiB.                                    


In [3]:
%%time
datamodule = MPRA_DataModule(datafile_path='MPRA_UKBB_BODA_v2.txt',
                             data_project=['BODA', 'UKBB'], #['BODA', 'UKBB'],
#                              project_column='data_project',
#                              sequence_column='nt_sequence',
#                              activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean'],
#                              exclude_chr_train=['synth'], #['17', '19', '21', 'X'],
#                              val_chrs=['17', '19', '21', 'X'], #['17', '19', '21', 'X']
#                              test_chrs=[''], #['7','13'],
#                              chr_column='chr',
#                              std_multiple_cut=6.0,
#                              up_cutoff_move=4.0,
#                              synth_chr='synth',
#                              synth_val_pct=10,
#                              synth_test_pct=10,
#                              synth_seed=0,
#                              batch_size=32,
                             padded_seq_len=600 
#                              num_workers=8,
#                              normalize=False
                            )

datamodule.setup()

--------------------------------------------------

K562 | top cut value: 10.95, bottom cut value: -6.0
HepG2 | top cut value: 9.99, bottom cut value: -5.26
SKNSH | top cut value: 10.14, bottom cut value: -5.51

Number of examples discarded from top: 0
Number of examples discarded from bottom: 8

Number of examples available: 358538

--------------------------------------------------

Padding sequences...
Tokenizing sequences...
Creating train/val/test datasets...
--------------------------------------------------

Number of examples in train: 295827 (82.51%)
Number of examples in val:   31453 (8.77%)
Number of examples in test:  31258 (8.72%)

Excluded from train: 0 (0.0)%
--------------------------------------------------
CPU times: user 1min 25s, sys: 4.47 s, total: 1min 30s
Wall time: 1min 29s


In [4]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger

#-------------------- Train only last layer ------------------------
model.basset_net.freeze()

model.epochs = 5
model.learning_rate = 0.05     #0.05
model.weight_decay = 1e-6      #1e-6
model.scheduler = True         #True
datamodule.batch_size = 1024   #1024

logger = TensorBoardLogger('model_logs', name='MPRAbasset_logs', log_graph=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
trainer = pl.Trainer(gpus=1, max_epochs=model.epochs, progress_bar_refresh_rate=10,
                     logger=logger, callbacks=[lr_monitor], precision=16)

trainer.fit(model, datamodule)
trainer.test()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)

  | Name            | Type       | Params | In sizes  | Out sizes
-----------------------------------------------------------------------
0 | criterion       | MSELoss    | 0      | ?         | ?        
1 | last_activation | Tanh       | 0      | [1, 250]  | [1, 250] 
2 | basset_net      | Basset     | 4.9 M  | ?         | ?        
3 | output_1        | Sequential | 250 K  | [1, 1000] | [1, 1]   
4 | output_2        | Sequential | 250 K  | [1, 1000] | [1, 1]   
5 | output_3        | Sequential | 250 K  | [1, 1000] | [1, 1]   
-----------------------------------------------------------------------
751 K     Trainable params
4.9 M     Non-trainable params
5.6 M     Total params
22.411    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

NameError: name 'Shannon_entropy' is not defined

In [34]:
datafile_path='MPRA_UKBB_BODA_v2.txt'
project_column='data_project'
data_project=['BODA', 'UKBB']
sequence_column='nt_sequence'
chr_column='chr'
activity_columns=['K562_mean', 'HepG2_mean', 'SKNSH_mean']

columns = [sequence_column, *activity_columns, chr_column, project_column]
temp_df = utils.parse_file(file_path=datafile_path, columns=columns)
temp_df = temp_df[temp_df[project_column].isin(data_project)].reset_index(drop=True)

In [35]:
temp_df

Unnamed: 0,nt_sequence,K562_mean,HepG2_mean,SKNSH_mean,chr,data_project
0,AAAAAAAAAAAAAAAAAAAAAAAAAAAAGTGAGTAACAAAAACAAC...,1.114658,0.785664,0.715749,12,UKBB
1,AAAAAAAAAAAAAAAAAAAAAAAAAAAAGTGAGTAACAAAAACAAC...,1.228857,0.677087,0.748533,12,UKBB
2,AAAAAAAAAAAAAAAAAAAAAAAAAAACAGAAAAGAAAAGAAACAT...,0.354881,0.381372,-0.435285,12,UKBB
3,AAAAAAAAAAAAAAAAAAAAAAAAAAACAGAAAAGAAAAGAAACAT...,0.360825,0.630168,-0.256892,12,UKBB
4,AAAAAAAAAAAAAAAAAAAAAAAAAAACTAGTCGGGCATGGTGGCG...,0.207887,0.416471,0.289111,20,UKBB
...,...,...,...,...,...,...
358541,TTTTTTTTTTTTGAGACGGAGTCTCACTCTGTCGCCCAAGTTGGAG...,0.962659,0.737992,0.554276,5,BODA
358542,TTTTTTTTTTTTTTACCTTATTTTGACTCAATGTCTGTTTTATCTG...,1.731245,1.093225,0.869407,1,BODA
358543,TTTTTTTTTTTTTTGACAGAGTCTTGCTCTGTCACCCAGGTTGGAG...,-0.300288,-0.470305,-0.591421,4,BODA
358544,TTTTTTTTTTTTTTTGTATTTTTAGTAGAGACAGGGTTTCACCATG...,0.925107,0.901311,0.703593,17,BODA


In [36]:
temp_df['seq_len'] = temp_df.apply(lambda x: len(x['nt_sequence']), axis=1)

In [37]:
temp_df[temp_df['seq_len'] > 200]

Unnamed: 0,nt_sequence,K562_mean,HepG2_mean,SKNSH_mean,chr,data_project,seq_len


In [105]:
from functools import partial

pad_column_name = 'padded_seq'
padded_seq_len=200 

padding_fn = partial(row_pad_sequence,
                                  in_column_name=sequence_column,
                                  padded_seq_len=padded_seq_len
                                  )

temp_df[pad_column_name] = temp_df.apply(padding_fn, axis=1)

In [106]:
temp_df['padded_seq_len'] = temp_df.apply(lambda x: len(x[pad_column_name]), axis=1)

In [108]:
temp_df[temp_df['padded_seq_len'] < 200]

Unnamed: 0,nt_sequence,K562_mean,HepG2_mean,SKNSH_mean,chr,data_project,seq_len,padded_seq,padded_seq_len


In [104]:
def row_pad_sequence(row,
                     in_column_name='nt_sequence',
                     padded_seq_len=400,
                     upStreamSeq=constants.MPRA_UPSTREAM,
                     downStreamSeq=constants.MPRA_DOWNSTREAM):
    sequence = row[in_column_name]
    origSeqLen = len(sequence)
    paddingLen = padded_seq_len - origSeqLen
    assert paddingLen <= (len(upStreamSeq) + len(downStreamSeq)), 'Not enough padding available'
    if paddingLen > 0:
        if -paddingLen//2 + paddingLen%2 < 0:
            upPad = upStreamSeq[-paddingLen//2 + paddingLen%2:]
        else:
            upPad = ''
        downPad = downStreamSeq[:paddingLen//2 + paddingLen%2]
        paddedSequence = upPad + sequence + downPad
        assert len(paddedSequence) == padded_seq_len, 'Kiubo?'
        return paddedSequence
    else:
        return sequence

In [67]:
temp_df[temp_df['seq_len'] < 200]

Unnamed: 0,nt_sequence,K562_mean,HepG2_mean,SKNSH_mean,chr,data_project,seq_len,padded_seq,padded_seq_len
44,AAAAAAAAAAAAAAAAAAAAGAAAGAAAAAAAGAAAGAAAGAAAGA...,0.062643,0.445604,2.195362,3,UKBB,199,AAAAAAAAAAAAAAAAAAAAGAAAGAAAAAAAGAAAGAAAGAAAGA...,400
51,AAAAAAAAAAAAAAAAAAAAGGATTTGAGCTAGAAAATGGGACCAT...,2.954926,3.041289,2.299703,1,UKBB,199,AAAAAAAAAAAAAAAAAAAAGGATTTGAGCTAGAAAATGGGACCAT...,400
75,AAAAAAAAAAAAAAAAAACTTCCCTCTAAATACACACATTAATAAT...,0.720907,0.459670,0.524583,13,UKBB,199,AAAAAAAAAAAAAAAAAACTTCCCTCTAAATACACACATTAATAAT...,400
103,AAAAAAAAAAAAAAAAAGCAAGAGAGATAAAATACATGGTTCTAAA...,0.716638,0.455792,0.456843,1,UKBB,194,AAAAAAAAAAAAAAAAAGCAAGAGAGATAAAATACATGGTTCTAAA...,400
111,AAAAAAAAAAAAAAAACAACAAAATAAAATTCCCAACATGCAGATA...,1.591423,1.276372,0.608198,3,UKBB,199,AAAAAAAAAAAAAAAACAACAAAATAAAATTCCCAACATGCAGATA...,400
...,...,...,...,...,...,...,...,...,...
330198,TTTTTTTTTTTTTTTGAGACGGAGTCTCGCTCTGTCGCCCAGGCTG...,1.100679,0.595327,0.829593,7,UKBB,199,TTTTTTTTTTTTTTTGAGACGGAGTCTCGCTCTGTCGCCCAGGCTG...,400
330225,TTTTTTTTTTTTTTTTGAGATGGAGTACCCATCTGTTGCTCAGGAT...,-0.247592,-0.208786,-0.188732,12,UKBB,199,TTTTTTTTTTTTTTTTGAGATGGAGTACCCATCTGTTGCTCAGGAT...,400
330272,TTTTTTTTTTTTTTTTTTGACGGAGTCTTGCTCTGTTGCCAGGCTG...,-0.096071,-0.005683,0.351095,4,UKBB,186,TTTTTTTTTTTTTTTTTTGACGGAGTCTTGCTCTGTTGCCAGGCTG...,400
330277,TTTTTTTTTTTTTTTTTTGAGACGGAGTCTCACTCTGTTGCCCAGG...,-0.815297,0.447955,0.212442,17,UKBB,198,TTTTTTTTTTTTTTTTTTGAGACGGAGTCTCACTCTGTTGCCCAGG...,400


In [96]:
padded_seq_len=200
upStreamSeq=constants.MPRA_UPSTREAM
downStreamSeq=constants.MPRA_DOWNSTREAM
    
sequence = temp_df.iloc[44]['nt_sequence']

origSeqLen = len(sequence)
paddingLen = padded_seq_len - origSeqLen

In [101]:
if paddingLen > 0:
    if -paddingLen//2 + paddingLen%2 < 0:
        upPad = upStreamSeq[-paddingLen//2 + paddingLen%2:]
    else:
        upPad = ''
    downPad = downStreamSeq[:paddingLen//2 + paddingLen%2]
    paddedSequence = upPad + sequence + downPad
    assert len(paddedSequence) == padded_seq_len, 'Kiubo?'

In [102]:
len(paddedSequence)

200

In [99]:
upPad

'ACGAAAATGTTGGATGCTCATACTCGTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTACTAGTACGTCTCTCAAGGATAAGTAAGTAATATTAAGGTACGGGAGGTATTGGACAGGCCGCAATAAAATATCTTTATTTTCATTACATCTGTGTGTTGGTTTTTTGTGTGAATCGATAGTACTAACATACGCTCTCCATCAAAACAAAACGAAACAAAACAAACTAGCAAAATAGGCTGTCCCCAGTGCAAGTGCAGGTGCCAGAACATTTCTCTGGCCTAACTGGCCGCTTGACG'

In [84]:
paddingLen

1

In [85]:
-paddingLen//2 + paddingLen%2

0

In [86]:
upStreamSeq[-paddingLen//2 + paddingLen%2:]

('ACGAAAATGTTGGATGCTCATACTCGTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTACTAGTACGTCTCTCAAGGATAAGTAAGTAATATTAAGGTACGGGAGGTATTGGACAGGCCGCAATAAAATATCTTTATTTTCATTACATCTGTGTGTTGGTTTTTTGTGTGAATCGATAGTACTAACATACGCTCTCCATCAAAACAAAACGAAACAAAACAAACTAGCAAAATAGGCTGTCCCCAGTGCAAGTGCAGGTGCCAGAACATTTCTCTGGCCTAACTGGCCGCTTGACG',)

In [93]:
upStreamSeq

('ACGAAAATGTTGGATGCTCATACTCGTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTACTAGTACGTCTCTCAAGGATAAGTAAGTAATATTAAGGTACGGGAGGTATTGGACAGGCCGCAATAAAATATCTTTATTTTCATTACATCTGTGTGTTGGTTTTTTGTGTGAATCGATAGTACTAACATACGCTCTCCATCAAAACAAAACGAAACAAAACAAACTAGCAAAATAGGCTGTCCCCAGTGCAAGTGCAGGTGCCAGAACATTTCTCTGGCCTAACTGGCCGCTTGACG',)

In [88]:
downStreamSeq

'CACTGCGGCTCCTGCGATCTAACTGGCCGGTACCTGAGCTCGCTAGCCTCGAGGATATCAAGATCTGGCCTCGGCGGCCAAGCTTAGACACTAGAGGGTATATAATGGAAGCTCGACTTCCAGCTTGGCAATCCGGTACTGTTGGTAAAGCCACCATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGGGTGGTGCCCATCCTGGTCGAGCTGGACGGCGACGTAAACGGCCACAAGTTCAGCGTGTCCGGCGAGGGCGAGGGCGATGCCACCTACGGCAAGCTGACCCTGAAGTTCATCT'

In [89]:
constants.MPRA_UPSTREAM

'ACGAAAATGTTGGATGCTCATACTCGTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTACTAGTACGTCTCTCAAGGATAAGTAAGTAATATTAAGGTACGGGAGGTATTGGACAGGCCGCAATAAAATATCTTTATTTTCATTACATCTGTGTGTTGGTTTTTTGTGTGAATCGATAGTACTAACATACGCTCTCCATCAAAACAAAACGAAACAAAACAAACTAGCAAAATAGGCTGTCCCCAGTGCAAGTGCAGGTGCCAGAACATTTCTCTGGCCTAACTGGCCGCTTGACG'

In [94]:
upStreamSeq=constants.MPRA_UPSTREAM

In [95]:
upStreamSeq

'ACGAAAATGTTGGATGCTCATACTCGTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTACTAGTACGTCTCTCAAGGATAAGTAAGTAATATTAAGGTACGGGAGGTATTGGACAGGCCGCAATAAAATATCTTTATTTTCATTACATCTGTGTGTTGGTTTTTTGTGTGAATCGATAGTACTAACATACGCTCTCCATCAAAACAAAACGAAACAAAACAAACTAGCAAAATAGGCTGTCCCCAGTGCAAGTGCAGGTGCCAGAACATTTCTCTGGCCTAACTGGCCGCTTGACG'