In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import multibind as mb
import numpy as np
import pandas as pd
import torch
import bindome as bd
bd.constants.ANNOTATIONS_DIRECTORY = 'annotations'
# mb.models.MultiBind
import torch.optim as topti
import torch.utils.data as tdata
import matplotlib.pyplot as plt
import logomaker
import os
import scipy
import pickle

# Use a GPU if available, as it should be faster.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(device))

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [4]:
%load_ext line_profiler

# SELEX (one dataset)

In [4]:
n_rounds = 1
data = pd.read_csv('../data/countTable.0.CTCF_r3.tsv.gz', sep='\t', header=None)
data.columns = ['seq'] + [i for i in range(n_rounds+1)]
data.index = data['seq']
del data['seq']
data = data.sample(n=1000)
labels = list(data.columns[:n_rounds + 1])

In [5]:
dataset = mb.datasets.SelexDataset(data, n_rounds=n_rounds, labels=labels)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [6]:
%lprun -f mb.tl.train_iterative model, best_loss = mb.tl.train_iterative(train, device, num_epochs=500, show_logo=False, early_stopping=50, log_each=50)

next w 15 <class 'int'>
# rounds 1
# batches 1
# enr_series True

Kernel to optimize 0

Freezing kernels
setting grad status of kernel at 0 to 1
setting grad status of kernel at 1 to 0
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.PoissonLoss'> n_epochs 500 early_stopping 50
lr= 0.01, weight_decay= 0.001, dir weight= 0
Epoch: 51, Loss: 1.074707 , best epoch: 49 secs per epoch: 0.057 s
Epoch: 101, Loss: 0.849744 , best epoch: 99 secs per epoch: 0.055 s
Epoch: 151, Loss: 0.847395 , best epoch: 139 secs per epoch: 0.054 s
Epoch: 201, Loss: 0.847329 , best epoch: 178 secs per epoch: 0.053 s
Epoch: 229, Loss: 0.8473 , best epoch: 178 secs per epoch: 0.053 s
early stop!
total time: 12.017 s
secs per epoch: 0.053 s

Kernel to optimize 1

Freezing kernels
setting grad status of kernel at 0 to 0
setting grad status of kernel at 1 to 1
setting grad status of kernel

Timer unit: 1e-06 s

Total time: 74.3266 s
File: /home/johanna/ICB/multibind/multibind/tl/prediction.py
Function: train_iterative at line 232

Line #      Hits         Time  Per Hit   % Time  Line Contents
   232                                           def train_iterative(
   233                                               train,
   234                                               device,
   235                                               n_kernels=4,
   236                                               w=15,
   237                                               # min_w=10,
   238                                               max_w=20,
   239                                               num_epochs=100,
   240                                               early_stopping=15,
   241                                               log_each=10,
   242                                               opt_kernel_shift=True,
   243                                               opt_kernel_len

In [None]:
#Since natually it takes most of the time to train the model, I'll profile the forward method on its own.

In [10]:
store_rev = train.dataset.store_rev
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
rounds = batch["rounds"].to(device) if "rounds" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
residues = batch["residues"].to(device) if "residues" in batch else None
protein_id = batch["protein_id"].to(device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
    mononuc_rev = batch["mononuc_rev"].to(device)
    inputs["mono_rev"] = mononuc_rev
if residues is not None:
    inputs["residues"] = residues
if protein_id is not None:
    inputs["protein_id"] = protein_id

In [13]:
%lprun -f model.forward model(**inputs)

Timer unit: 1e-06 s

Total time: 0.001917 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def forward(self, mono, **kwargs):
    82                                                   # mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
    83         1          1.0      1.0      0.1          mono_rev = kwargs.get("mono_rev", None)
    84         1          1.0      1.0      0.1          di = kwargs.get("di", None)
    85         1          1.0      1.0      0.1          di_rev = kwargs.get("di_rev", None)
    86         1        120.0    120.0      6.3          mono = self.padding(mono)
    87         1          1.0      1.0      0.1          if mono_rev is None:
    88         1        110.0    110.0      5.7              mono_rev = mb.tl.mono2revmono(mono)
    89                   

# PBM

In [14]:
matlab_path = os.path.join(bd.constants.ANNOTATIONS_DIRECTORY, 'pbm', 'affreg', 'PbmDataHom6_norm.mat')
mat = scipy.io.loadmat(matlab_path)
data = mat['PbmData'][0]
seqs_dna =  data[0][5]
seqs_dna = [s[0][0] for s in seqs_dna]
# load the MSA sequences, one hot encoded
df, signal = bd.datasets.PBM.pbm_homeo_affreg()
# x, y = pickle.load(open('../../data/example_homeo_PbmData.pkl', 'rb'))
# x, y = pickle.load(open('annotations/pbm/example_homeo_PbmData.pkl', 'rb'))

In [15]:
# build a small subsample
# x = x[1:6]
seqs_dna = seqs_dna[0:1000]
signal = signal[0:1, 0:1000]

# shift signal by adding a constant s.t. no negative values are included
signal -= np.min(signal)

# Set up the dataset
df = pd.DataFrame(signal.T)
df['seq'] = seqs_dna
df.index = df['seq']
del df['seq']

In [16]:
dataset = mb.datasets.PBMDataset(df)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [17]:
%lprun -f mb.tl.train_iterative model, best_loss = mb.tl.train_iterative(train, device, num_epochs=500, show_logo=False, early_stopping=50, log_each=50)

next w 15 <class 'int'>
# proteins 1

Kernel to optimize 0

Freezing kernels
setting grad status of kernel at 0 to 1
setting grad status of kernel at 1 to 0
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.MSELoss'> n_epochs 500 early_stopping 50
lr= 0.01, weight_decay= 0.001, dir weight= 0
Epoch: 51, Loss: 175706.113281 , best epoch: 49 secs per epoch: 0.075 s
Epoch: 101, Loss: 175599.359375 , best epoch: 99 secs per epoch: 0.075 s
Epoch: 151, Loss: 175561.937500 , best epoch: 149 secs per epoch: 0.075 s
Epoch: 201, Loss: 175543.156250 , best epoch: 199 secs per epoch: 0.075 s
Epoch: 251, Loss: 175532.582031 , best epoch: 248 secs per epoch: 0.075 s
Epoch: 301, Loss: 175526.265625 , best epoch: 295 secs per epoch: 0.075 s
Epoch: 351, Loss: 175521.917969 , best epoch: 347 secs per epoch: 0.076 s
Epoch: 401, Loss: 175517.859375 , best epoch: 386 secs per epoc

Timer unit: 1e-06 s

Total time: 492.297 s
File: /home/johanna/ICB/multibind/multibind/tl/prediction.py
Function: train_iterative at line 232

Line #      Hits         Time  Per Hit   % Time  Line Contents
   232                                           def train_iterative(
   233                                               train,
   234                                               device,
   235                                               n_kernels=4,
   236                                               w=15,
   237                                               # min_w=10,
   238                                               max_w=20,
   239                                               num_epochs=100,
   240                                               early_stopping=15,
   241                                               log_each=10,
   242                                               opt_kernel_shift=True,
   243                                               opt_kernel_len

In [18]:
store_rev = train.dataset.store_rev
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
rounds = batch["rounds"].to(device) if "rounds" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
residues = batch["residues"].to(device) if "residues" in batch else None
protein_id = batch["protein_id"].to(device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
    mononuc_rev = batch["mononuc_rev"].to(device)
    inputs["mono_rev"] = mononuc_rev
if residues is not None:
    inputs["residues"] = residues
if protein_id is not None:
    inputs["protein_id"] = protein_id

In [19]:
%lprun -f model.forward model(**inputs)

Timer unit: 1e-06 s

Total time: 0.024673 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def forward(self, mono, **kwargs):
    82                                                   # mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
    83         1          1.0      1.0      0.0          mono_rev = kwargs.get("mono_rev", None)
    84         1          1.0      1.0      0.0          di = kwargs.get("di", None)
    85         1          1.0      1.0      0.0          di_rev = kwargs.get("di_rev", None)
    86         1        157.0    157.0      0.6          mono = self.padding(mono)
    87         1          1.0      1.0      0.0          if mono_rev is None:
    88         1        185.0    185.0      0.7              mono_rev = mb.tl.mono2revmono(mono)
    89                   

# Genomics Dataset (ATAC) - separated kernels

In [21]:
adata = mb.bindome.datasets.scATAC.PBMCs_10x_v2(datadir='../../atac_poisson_study/data/')
peak_ids = adata.var_names
adata.var['summit'] = ((adata.var['end'] + adata.var['start']) / 2).astype(int)
adata.var['summit.start'] = adata.var['summit'] - 100
adata.var['summit.end'] = adata.var['summit'] + 100
adata.var['k.summit'] = adata.var['chr'] + ':' + adata.var['summit.start'].astype(str) + '-' + adata.var['summit.end'].astype(str)
n_seqs = 1000
seqs = mb.bindome.tl.get_sequences_from_bed(adata.var[['chr', 'summit.start', 'summit.end']].head(n_seqs), genome='hg38', uppercase=True)
keys = set([s[0] for s in seqs])
adata = adata[:,adata.var['k.summit'].isin(keys)]
# seqs = [[s[0], s[1].upper()] for s in seqs[0]]
adata.shape

/tmp/tmppwpw7qtf
genome hg38 False
options
/home/johanna/ICB/annotations/hg38/genome/hg38.fa
True /home/johanna/ICB/annotations/hg38/genome/hg38.fa
running bedtools...
bedtools getfasta -fi /home/johanna/ICB/annotations/hg38/genome/hg38.fa -bed /tmp/tmppwpw7qtf -fo /tmp/tmpxcm0i89j


(10246, 1000)

In [25]:
# remove Ns
seqs = [[s[0], s[1].replace('N', '')] for s in seqs]
counts = adata.X.T
next_data = pd.DataFrame.sparse.from_spmatrix(counts)
next_data = next_data[range(20)].copy()
next_data['seq'] = [s[1] for s in seqs]
var = []
for ri, r in next_data.iterrows():
    if ri % 10000 == 0:
        print(ri, next_data.shape)
    # print(ri, r.values[:-1], r.values[:-1].var())
    var.append(r.values[:-1].var())
    # break
next_data['var'] = var
top_var = next_data[['var']].sort_values('var', ascending=False).index[:1000]

0 (1000, 21)


In [26]:
# next_data = next_data.head(10000)
next_data_sel = next_data.reindex(top_var).reset_index(drop=True)
del next_data_sel['var']
next_data_sel.index = next_data_sel['seq']
del next_data_sel['seq']
df = next_data_sel
dataset = mb.datasets.GenomicsDataset(df)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [27]:
%lprun -f mb.tl.train_iterative model, best_loss = mb.tl.train_iterative(train, device, num_epochs=500, show_logo=False, early_stopping=50, log_each=50)

next w 15 <class 'int'>
# cells 20

Kernel to optimize 0

Freezing kernels
setting grad status of kernel at 0 to 1
setting grad status of kernel at 1 to 0
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.MSELoss'> n_epochs 500 early_stopping 50
lr= 0.01, weight_decay= 0.001, dir weight= 0
Epoch: 51, Loss: 1587260.705696 , best epoch: 46 secs per epoch: 2.104 s
Epoch: 101, Loss: 1587257.107595 , best epoch: 88 secs per epoch: 2.114 s
Epoch: 139, Loss: 1587258.2184 , best epoch: 88 secs per epoch: 2.107 s
early stop!
total time: 290.729 s
secs per epoch: 2.107 s

Kernel to optimize 1

Freezing kernels
setting grad status of kernel at 0 to 0
setting grad status of kernel at 1 to 1
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.MSEL

Timer unit: 1e-06 s

Total time: 6444.22 s
File: /home/johanna/ICB/multibind/multibind/tl/prediction.py
Function: train_iterative at line 232

Line #      Hits         Time  Per Hit   % Time  Line Contents
   232                                           def train_iterative(
   233                                               train,
   234                                               device,
   235                                               n_kernels=4,
   236                                               w=15,
   237                                               # min_w=10,
   238                                               max_w=20,
   239                                               num_epochs=100,
   240                                               early_stopping=15,
   241                                               log_each=10,
   242                                               opt_kernel_shift=True,
   243                                               opt_kernel_len

In [28]:
store_rev = train.dataset.store_rev
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
rounds = batch["rounds"].to(device) if "rounds" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
residues = batch["residues"].to(device) if "residues" in batch else None
protein_id = batch["protein_id"].to(device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
    mononuc_rev = batch["mononuc_rev"].to(device)
    inputs["mono_rev"] = mononuc_rev
if residues is not None:
    inputs["residues"] = residues
if protein_id is not None:
    inputs["protein_id"] = protein_id

In [29]:
%lprun -f model.forward model(**inputs)

Timer unit: 1e-06 s

Total time: 0.030239 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def forward(self, mono, **kwargs):
    82                                                   # mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
    83         1          1.0      1.0      0.0          mono_rev = kwargs.get("mono_rev", None)
    84         1          0.0      0.0      0.0          di = kwargs.get("di", None)
    85         1          0.0      0.0      0.0          di_rev = kwargs.get("di_rev", None)
    86         1        177.0    177.0      0.6          mono = self.padding(mono)
    87         1          1.0      1.0      0.0          if mono_rev is None:
    88         1        195.0    195.0      0.6              mono_rev = mb.tl.mono2revmono(mono)
    89                   

# Genomics Dataset (ATAC) - joint learning

In [5]:
adata = mb.bindome.datasets.scATAC.PBMCs_10x_v2(datadir='../../atac_poisson_study/data/')
peak_ids = adata.var_names
adata.var['summit'] = ((adata.var['end'] + adata.var['start']) / 2).astype(int)
adata.var['summit.start'] = adata.var['summit'] - 100
adata.var['summit.end'] = adata.var['summit'] + 100
adata.var['k.summit'] = adata.var['chr'] + ':' + adata.var['summit.start'].astype(str) + '-' + adata.var['summit.end'].astype(str)
n_seqs = 1000
seqs = mb.bindome.tl.get_sequences_from_bed(adata.var[['chr', 'summit.start', 'summit.end']].head(n_seqs), genome='hg38', uppercase=True)
keys = set([s[0] for s in seqs])
adata = adata[:,adata.var['k.summit'].isin(keys)]
# seqs = [[s[0], s[1].upper()] for s in seqs[0]]
adata.shape

/tmp/tmp3t23lcnb
genome hg38 False
options
/home/johanna/ICB/annotations/hg38/genome/hg38.fa
True /home/johanna/ICB/annotations/hg38/genome/hg38.fa
running bedtools...
bedtools getfasta -fi /home/johanna/ICB/annotations/hg38/genome/hg38.fa -bed /tmp/tmp3t23lcnb -fo /tmp/tmp7im5j941


(10246, 1000)

In [6]:
# remove Ns
seqs = [[s[0], s[1].replace('N', '')] for s in seqs]
counts = adata.X.T
next_data = pd.DataFrame.sparse.from_spmatrix(counts)
next_data = next_data[range(20)].copy()
next_data['seq'] = [s[1] for s in seqs]
var = []
for ri, r in next_data.iterrows():
    if ri % 10000 == 0:
        print(ri, next_data.shape)
    # print(ri, r.values[:-1], r.values[:-1].var())
    var.append(r.values[:-1].var())
    # break
next_data['var'] = var
top_var = next_data[['var']].sort_values('var', ascending=False).index[:1000]

0 (1000, 21)


In [7]:
# next_data = next_data.head(10000)
next_data_sel = next_data.reindex(top_var).reset_index(drop=True)
del next_data_sel['var']
next_data_sel.index = next_data_sel['seq']
del next_data_sel['seq']
df = next_data_sel
dataset = mb.datasets.GenomicsDataset(df)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [8]:
%lprun -f mb.tl.train_iterative model, best_loss = mb.tl.train_iterative(train, device, joint_learning=True, num_epochs=500, show_logo=False, early_stopping=50, log_each=50)

next w 15 <class 'int'>
# cells 20

Kernel to optimize 0

Freezing kernels
setting grad status of kernel at 0 to 1
setting grad status of kernel at 1 to 0
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.MSELoss'> n_epochs 500 early_stopping 50
lr= 0.01, weight_decay= 0.001, dir weight= 0
Epoch: 51, Loss: 1648303.976266 , best epoch: 47 secs per epoch: 4.097 s
Epoch: 101, Loss: 1648304.770570 , best epoch: 95 secs per epoch: 3.895 s
Epoch: 151, Loss: 1648305.180380 , best epoch: 138 secs per epoch: 3.866 s
Epoch: 189, Loss: 1648304.0016 , best epoch: 138 secs per epoch: 3.867 s
early stop!
total time: 727.041 s
secs per epoch: 3.867 s

Kernel to optimize 1

Freezing kernels
setting grad status of kernel at 0 to 0
setting grad status of kernel at 1 to 1
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimiz

Timer unit: 1e-06 s

Total time: 8357.71 s
File: /home/johanna/ICB/multibind/multibind/tl/prediction.py
Function: train_iterative at line 232

Line #      Hits         Time  Per Hit   % Time  Line Contents
   232                                           def train_iterative(
   233                                               train,
   234                                               device,
   235                                               n_kernels=4,
   236                                               w=15,
   237                                               # min_w=10,
   238                                               max_w=20,
   239                                               num_epochs=100,
   240                                               early_stopping=15,
   241                                               log_each=10,
   242                                               opt_kernel_shift=True,
   243                                               opt_kernel_len

In [9]:
store_rev = train.dataset.store_rev
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
rounds = batch["rounds"].to(device) if "rounds" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
residues = batch["residues"].to(device) if "residues" in batch else None
protein_id = batch["protein_id"].to(device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
    mononuc_rev = batch["mononuc_rev"].to(device)
    inputs["mono_rev"] = mononuc_rev
if residues is not None:
    inputs["residues"] = residues
if protein_id is not None:
    inputs["protein_id"] = protein_id

In [10]:
%lprun -f model.forward model(**inputs)

Timer unit: 1e-06 s

Total time: 0.012556 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def forward(self, mono, **kwargs):
    82                                                   # mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
    83         1          1.0      1.0      0.0          mono_rev = kwargs.get("mono_rev", None)
    84         1          1.0      1.0      0.0          di = kwargs.get("di", None)
    85         1          1.0      1.0      0.0          di_rev = kwargs.get("di_rev", None)
    86         1        265.0    265.0      2.1          mono = self.padding(mono)
    87         1          1.0      1.0      0.0          if mono_rev is None:
    88         1        395.0    395.0      3.1              mono_rev = mb.tl.mono2revmono(mono)
    89                   

# PBM after improvement

In [11]:
matlab_path = os.path.join(bd.constants.ANNOTATIONS_DIRECTORY, 'pbm', 'affreg', 'PbmDataHom6_norm.mat')
mat = scipy.io.loadmat(matlab_path)
data = mat['PbmData'][0]
seqs_dna =  data[0][5]
seqs_dna = [s[0][0] for s in seqs_dna]
# load the MSA sequences, one hot encoded
df, signal = bd.datasets.PBM.pbm_homeo_affreg()
# x, y = pickle.load(open('../../data/example_homeo_PbmData.pkl', 'rb'))
# x, y = pickle.load(open('annotations/pbm/example_homeo_PbmData.pkl', 'rb'))

In [12]:
# build a small subsample
# x = x[1:6]
seqs_dna = seqs_dna[0:1000]
signal = signal[0:1, 0:1000]

# shift signal by adding a constant s.t. no negative values are included
signal -= np.min(signal)

# Set up the dataset
df = pd.DataFrame(signal.T)
df['seq'] = seqs_dna
df.index = df['seq']
del df['seq']

In [13]:
dataset = mb.datasets.PBMDataset(df)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

In [14]:
%lprun -f mb.tl.train_iterative model, best_loss = mb.tl.train_iterative(train, device, num_epochs=500, show_logo=False, early_stopping=50, log_each=50)

next w 15 <class 'int'>
# proteins 1

Kernel to optimize 0

Freezing kernels
setting grad status of kernel at 0 to 1
setting grad status of kernel at 1 to 0
setting grad status of kernel at 2 to 0
setting grad status of kernel at 3 to 0


kernels mask None
optimizing using <class 'torch.optim.adam.Adam'> and <class 'multibind.tl.loss.MSELoss'> n_epochs 500 early_stopping 50
lr= 0.01, weight_decay= 0.001, dir weight= 0
Epoch: 51, Loss: 196404.035156 , best epoch: 49 secs per epoch: 0.079 s
Epoch: 101, Loss: 196289.984375 , best epoch: 99 secs per epoch: 0.077 s
Epoch: 151, Loss: 196249.402344 , best epoch: 149 secs per epoch: 0.077 s
Epoch: 201, Loss: 196230.250000 , best epoch: 199 secs per epoch: 0.077 s
Epoch: 251, Loss: 196219.617188 , best epoch: 245 secs per epoch: 0.078 s
Epoch: 301, Loss: 196212.281250 , best epoch: 299 secs per epoch: 0.078 s
Epoch: 351, Loss: 196208.464844 , best epoch: 343 secs per epoch: 0.078 s
Epoch: 401, Loss: 196205.296875 , best epoch: 396 secs per epoc

Timer unit: 1e-06 s

Total time: 273.429 s
File: /home/johanna/ICB/multibind/multibind/tl/prediction.py
Function: train_iterative at line 232

Line #      Hits         Time  Per Hit   % Time  Line Contents
   232                                           def train_iterative(
   233                                               train,
   234                                               device,
   235                                               n_kernels=4,
   236                                               w=15,
   237                                               # min_w=10,
   238                                               max_w=20,
   239                                               num_epochs=100,
   240                                               early_stopping=15,
   241                                               log_each=10,
   242                                               opt_kernel_shift=True,
   243                                               opt_kernel_len

In [15]:
store_rev = train.dataset.store_rev
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
rounds = batch["rounds"].to(device) if "rounds" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
residues = batch["residues"].to(device) if "residues" in batch else None
protein_id = batch["protein_id"].to(device) if "protein_id" in batch else None
inputs = {"mono": mononuc, "batch": b, "countsum": countsum}
if store_rev:
    mononuc_rev = batch["mononuc_rev"].to(device)
    inputs["mono_rev"] = mononuc_rev
if residues is not None:
    inputs["residues"] = residues
if protein_id is not None:
    inputs["protein_id"] = protein_id

In [16]:
%lprun -f model.forward model(**inputs)

Timer unit: 1e-06 s

Total time: 0.002318 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 81

Line #      Hits         Time  Per Hit   % Time  Line Contents
    81                                               def forward(self, mono, **kwargs):
    82                                                   # mono_rev=None, di=None, di_rev=None, batch=None, countsum=None, residues=None, protein_id=None):
    83         1          1.0      1.0      0.0          mono_rev = kwargs.get("mono_rev", None)
    84         1          0.0      0.0      0.0          di = kwargs.get("di", None)
    85         1          0.0      0.0      0.0          di_rev = kwargs.get("di_rev", None)
    86         1        165.0    165.0      7.1          mono = self.padding(mono)
    87         1          1.0      1.0      0.0          if mono_rev is None:
    88         1        175.0    175.0      7.5              mono_rev = mb.tl.mono2revmono(mono)
    89                   