In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import multibind as mb
import numpy as np
import pandas as pd
import torch
import bindome as bd
bd.constants.ANNOTATIONS_DIRECTORY = 'annotations'
# mb.models.MultiBind
import torch.optim as topti
import torch.utils.data as tdata
import matplotlib.pyplot as plt
import logomaker
import seaborn as sns
from sklearn.metrics import r2_score

# Use a GPU if available, as it should be faster.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(device))

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


In [6]:
from matplotlib import rcParams
rcParams['figure.figsize'] = 5, 1

In [6]:
df = mb.bindome.datasets.ProBound.ctcf(flank_length=0)
data = df.sample(n=2000)

In [7]:
n_rounds = 1
dataset = mb.datasets.SelexDataset(data, n_rounds=n_rounds)
train = tdata.DataLoader(dataset=dataset, batch_size=256, shuffle=True)

## Optimizing last steps of forward

In [7]:
model = mb.models.DinucSelex(use_dinuc=False, kernels=[0, 14, 12], n_rounds=1, n_batches=1).to(device)
optimiser = topti.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
# optimiser = topti.LBFGS(model.parameters())
criterion = mb.tl.PoissonLoss()

In [9]:
%load_ext line_profiler

In [16]:
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
inputs = (mononuc, b, countsum)

optimiser.zero_grad()

In [24]:
%lprun -f model.forward model(inputs)

Timer unit: 1e-06 s

Total time: 0.006034 s
File: /home/johanna/ICB/multibind/multibind/models/models.py
Function: forward at line 105

Line #      Hits         Time  Per Hit   % Time  Line Contents
   105                                               def forward(self, x, min_value=1e-15):
   106                                                   # Create the forward pass through the network.
   107         1          1.0      1.0      0.0          mono, batch, countsum = x
   108                                           
   109                                                   # convert mono to dinuc
   110                                                   # print(mono.shape)
   111                                           
   112                                                   # print(mono.shape, di.shape)
   113                                                   # assert False
   114                                           
   115                                                 

## Testing flip vs advanced indexing for _mono2revmono

In [64]:
i, batch = enumerate(train).__next__()
mononuc = batch["mononuc"].to(device)
b = batch["batch"].to(device) if "batch" in batch else None
countsum = batch["countsum"].to(device) if "countsum" in batch else None
inputs = (mononuc, b, countsum)

In [65]:
def _mono2revmono_flip(x):
    return torch.flip(x, [2])[:, [3, 2, 1, 0], :]

In [66]:
%timeit _mono2revmono_flip(mononuc)

62.5 µs ± 8.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [67]:
def _mono2revmono_flip2(x):
    return torch.flip(x, [1, 2])

In [68]:
%timeit _mono2revmono_flip2(mononuc)

18.6 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [69]:
def _mono2revmono_index(x):
    n = mononuc.shape[2]
    reverse_index = torch.arange(n - 1, -1, -1)
    # reverse_index
    compl_bases = torch.tensor([3, 2, 1, 0])
    compl_bases = compl_bases.repeat(30, 1).T
    reverse_index = reverse_index.repeat(4, 1)
    return mononuc[:, compl_bases, reverse_index]

In [70]:
%timeit _mono2revmono_index(mononuc)

47 µs ± 2.73 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [71]:
def _mono2revmono_index2(x):
    n = mononuc.shape[2]
    reverse_index = torch.arange(n - 1, -1, -1, device=device)
    # reverse_index
    compl_bases = torch.tensor([3, 2, 1, 0])
    compl_bases = compl_bases.expand(30, 4).T
    reverse_index = reverse_index.expand(4, 30)
    return mononuc[:, compl_bases, reverse_index]

In [72]:
%timeit _mono2revmono_index2(mononuc)

38.4 µs ± 2.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [73]:
%lprun -f _mono2revmono_index2 _mono2revmono_index2(mononuc)

Timer unit: 1e-06 s

Total time: 0.000598 s
File: /tmp/ipykernel_11638/2604124128.py
Function: _mono2revmono_index2 at line 1

Line #      Hits         Time  Per Hit   % Time  Line Contents
     1                                           def _mono2revmono_index2(x):
     2         1         10.0     10.0      1.7      n = mononuc.shape[2]
     3         1        354.0    354.0     59.2      reverse_index = torch.arange(n - 1, -1, -1, device=device)
     4                                               # reverse_index
     5         1         25.0     25.0      4.2      compl_bases = torch.tensor([3, 2, 1, 0])
     6         1         63.0     63.0     10.5      compl_bases = compl_bases.expand(30, 4).T
     7         1          5.0      5.0      0.8      reverse_index = reverse_index.expand(4, 30)
     8         1        141.0    141.0     23.6      return mononuc[:, compl_bases, reverse_index]

In [74]:
import time
import torch

n = 1024
batch_size = 256
ntrials = 1000
x = torch.randn(batch_size, n)

In [75]:
start = time.perf_counter()
[x.flip(-1) for _ in range(ntrials)]
end = time.perf_counter()
print('Flip time (CPU): {}s'.format(end - start))

reverse_index = torch.arange(n - 1, -1, -1)
start = time.perf_counter()
[x[..., reverse_index] for _ in range(ntrials)]
end = time.perf_counter()
print('Advanced indexing time (CPU): {}s'.format(end - start))

Flip time (CPU): 0.0904085500005749s
Advanced indexing time (CPU): 0.12989644399931422s


In [59]:
x = x.to('cuda')
reverse_index = reverse_index.to('cuda')

torch.cuda.synchronize()
start = time.perf_counter()
[x.flip(-1) for _ in range(ntrials)]
torch.cuda.synchronize()
end = time.perf_counter()
print('Flip time (CUDA): {}s'.format(end - start))

start = time.perf_counter()
[x[..., reverse_index] for _ in range(ntrials)]
torch.cuda.synchronize()
end = time.perf_counter()
print('Advanced indexing time (CUDA): {}s'.format(end - start))