### Two datasets with a similar motif

In [None]:
%load_ext autoreload
%autoreload 2

: 

In [None]:
import mubind as mb
import numpy as np
import pandas as pd

import torch
import torch.optim as topti
import torch.utils.data as tdata
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import r2_score

# Use a GPU if available, as it should be faster.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device: " + str(device))

: 

In [None]:
import random
random.seed(500)

n_batch = 3
# batch sizes. Dataset 2 has many times more reads than Dataset 1
batch_sizes = [5, 500, 5000]
train1, test1 = mb.tl.create_simulated_data(motif='GATA', n_batch=n_batch, n_trials=1000, seqlen=10, batch_sizes=batch_sizes) # multiplier=100)
# train2, test2 = mb.tl.create_simulated_data(motif='GATA', batch=1, n_trials=5000, seqlen=25) #  multiplier=5000)

: 

### If treating y data as log, the convergence gets to the GATA motif in small sequences.

In [None]:
train1.dataset.target = np.log(train1.dataset.target)

: 

### Library sizes

In [None]:
df = pd.DataFrame()
df['y'] = train1.dataset.target
df['batch'] = train1.dataset.batch
sns.boxplot(data=df, x='batch', y='y')

: 

In [None]:
net2 = mb.models.DinucMulti(use_dinuc=False, n_datasets=n_batch, w=7).to(device)
net2.dataset.weight.data.uniform_(1, 1) # initialize weights as ones.

: 

In [None]:
mononuc = torch.Tensor(train1.dataset.mononuc).to(device)
dinuc = torch.Tensor(train1.dataset.dinuc).to(device)
b = torch.Tensor(train1.dataset.batch).to(torch.int64)
y_true = torch.Tensor(train1.dataset.target)

inputs = (mononuc, dinuc, b)
y_pred = net2(inputs).detach().numpy()

: 

In [None]:
train1.dataset.batch.shape, train1.dataset.mononuc.shape
net2.conv_mono(torch.unsqueeze(torch.Tensor(train1.dataset.mononuc), 1)).shape

: 

In [None]:
optimiser = topti.Adam(net2.parameters(), lr=0.001, weight_decay=0.0001)
criterion = mb.tl.PoissonLoss()
l2 = []
l2 += mb.tl.train_network(net2, train1, device, optimiser, criterion, num_epochs=5000, log_each=100)

: 

In [None]:
plt.plot(l2)

: 

In [None]:
## check the batch effects
net2.dataset.weight

: 

In [None]:
# print(np.exp(net2.log_weight_1.squeeze().cpu().detach().numpy()))
# print(np.exp(net.log_weight_2.squeeze().cpu().detach().numpy()))
#mb.tl.create_logo(net)
#plt.show()
#mb.tl.create_heatmap(net)
import logomaker
weights = net2.conv_mono.weight
weights = weights.squeeze().cpu().detach().numpy()
weights = pd.DataFrame(weights)
weights.index = 'A', 'C', 'G', 'T'
crp_logo = logomaker.Logo(weights.T, shade_below=.5, fade_below=.5)

: 

In [None]:
from sklearn.metrics import r2_score

: 

In [None]:
net2.dataset.weight

: 

### Check the quality of the predictions, across datasets

In [None]:
mononuc = torch.Tensor(train1.dataset.mononuc).to(device)
dinuc = torch.Tensor(train1.dataset.dinuc).to(device)
b = torch.Tensor(train1.dataset.batch).to(torch.int64)
y_true = torch.Tensor(train1.dataset.target)
inputs = (mononuc, dinuc, b)
y_pred = net2(inputs).detach().numpy()

: 

In [None]:
r2_score(y_true, y_pred)

: 

In [None]:
plt.scatter(y_true, y_pred, s=5, c=b)
plt.xlabel('observed')
plt.ylabel('predicted')

: 