In [1]:
from explainn import tools
from explainn import networks
from explainn import train
from explainn import test
from explainn import interpretation

from Bio import SeqIO
import torch
import os
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
import logomaker
import math
import pickle
import h5py
import copy
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
num_epochs = 10
batch_size = 100 
learning_rate = 0.01 

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Building the dataset from DanQ outputs

In [4]:
fasta_sequences = SeqIO.parse(open("../../AI-TAC/ai_tac_data/sequences.fasta"),'fasta')
with open("../new_aitac_201_sequences.fasta","w") as out_file:
    for fasta in fasta_sequences:
        name, sequence = fasta.id, str(fasta.seq)
        new_sequence = sequence[25:][:-25].upper()
        out_file.write(">" + name + "\n")
        out_file.write(new_sequence+"\n")

In [3]:
cell_type_names = ["LTHSC.34-.BM","LTHSC.34+.BM","MMP3.48+.BM","STHSC.150-.BM",
                   "MMP4.135+.BM","proB.CLP.BM","proB.FrA.BM","proB.FrBC.BM",
                   "B.FrE.BM","B1b.PC","B.T1.Sp","B.T2.Sp","B.T3.Sp","B.Sp",
                   "B.Fem.Sp","B.Fo.Sp","B.MZ.Sp","B.GC.CB.Sp","B.GC.CC.Sp",
                   "B.PB.Sp","B.PC.Sp","B.PC.BM","B.mem.Sp","preT.DN1.Th",
                   "preT.DN2a.Th","preT.DN2b.Th","preT.DN3.Th","T.DN4.Th",
                   "T.ISP.Th","T.DP.Th","T.4.Th","T.8.Th","T.4.Nve.Sp","T.4.Nve.Fem.Sp",
                   "T.4.Sp.aCD3+CD40.18hr","Treg.4.FP3+.Nrplo.Co","Treg.4.25hi.Sp",
                   "T.8.Nve.Sp","T8.TN.P14.Sp","T8.IEL.LCMV.d7.Gut","T8.TE.LCMV.d7.Sp",
                   "T8.MP.LCMV.d7.Sp","T8.Tcm.LCMV.d180.Sp","T8.Tem.LCMV.d180.Sp",
                   "NKT.Sp","NKT.Sp.LPS.3hr","NKT.Sp.LPS.18hr","NKT.Sp.LPS.3d",
                   "Tgd.g2+d17.24a+.Th","Tgd.g2+d17.LN","Tgd.g2+d1.24a+.Th",
                   "Tgd.g2+d1.LN","Tgd.g1.1+d1.24a+.Th","Tgd.g1.1+d1.LN","Tgd.Sp",
                   "NK.27+11b-.BM","NK.27+11b+.BM","NK.27-11b+.BM","NK.27+11b-.Sp",
                   "NK.27+11b+.Sp","NK.27-11b+.Sp","ILC2.SI","ILC3.NKp46-CCR6-.SI",
                   "ILC3.NKp46+.SI","ILC3.CCR6+.SI","GN.BM","GN.Sp","GN.Thio.PC",
                   "Mo.6C+II-.Bl","Mo.6C-II-.Bl","MF.PC","MF.Fem.PC",
                   "MF.226+II+480lo.PC","MF.102+480+.PC","MF.RP.Sp","MF.Alv.Lu",
                   "MF.pIC.Alv.Lu","MF.microglia.CNS","DC.4+.Sp","DC.8+.Sp","DC.pDC.Sp"]

cell_line = pd.read_excel('../../AI-TAC/TableS4_lineageCells.xlsx', header = 2,index_col=0)
lineage_names = list(cell_line.loc[cell_type_names,:]["lineageModel"].values)

with open('../../AI-TAC/ATAC_Data_Intensity_FilteredPeaksLogQuantile.txt', 'r') as tsv:
    columns=tsv.readline().split('\t')
    
columns = columns[1:]

columns = [i.strip() for i in columns]

In [None]:
# get DanQ models predictions

In [None]:
%%bash
#NEW EXPERIMENT
array=( $( ls /mnt/md1/home/oriol/ExplaiNN/results/ReMap/DanQ/mm10-dbTF ) )

mkdir -p DanQ_scan_results_350_models

for str in ${array[@]}; do
    mkdir -p DanQ_scan_results_350_models/$str
    
    #no sigmoid
    python predict-danq.py -o DanQ_scan_results_350_models/$str/test.tsv.gz /mm10-dbTF/$str/best_model.pth.tar new_aitac_201_sequences.fasta
done

In [4]:
import os

myTFs = os.listdir("../DanQ_scan_results_350_models/")

len(myTFs)

350

In [5]:
# collect DanQ outputs for every sequence in the dataset
Xs = pd.read_csv(f"../DanQ_scan_results_350_models/{myTFs[0]}/test.tsv.gz", sep="\t", index_col=0)
Xs = Xs.rename(columns={"Mean":myTFs[0]})
Xs = Xs[myTFs[0]]

for i in tqdm(range(1,350)):
    ps = pd.read_csv(f"../DanQ_scan_results_350_models/{myTFs[i]}/test.tsv.gz", sep="\t", index_col=0)
    ps = ps.rename(columns={"Mean":myTFs[i]})
    ps = ps[myTFs[i]]
    Xs = pd.concat([Xs, ps], axis=1) 

100%|██████████| 349/349 [03:45<00:00,  1.55it/s]


In [6]:
Xs.shape

(327927, 350)

In [7]:
x = np.load("../../AI-TAC/ai_tac_data/one_hot_seqs.npy")
x = x.astype(np.float32)
y = np.load("../../AI-TAC/ai_tac_data/cell_type_array.npy")
y = y.astype(np.float32)
peak_names = np.load("../../AI-TAC/ai_tac_data/peak_names.npy")

In [8]:
Xs.head()

Unnamed: 0_level_0,FEZF2,ZFAT,DDIT3,RUNX2,MITF,ZBTB16,GLI1,LYL1,DMRTB1,HNF1A,...,POU5F1,STAT4,CEBPA,MYCN,MXI1,TLX1,ASCL2,ATOH1,TGIF1,TFE3
SeqId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ImmGenATAC1219.peak_3,-0.400502,0.077805,-2.1078,0.279974,-5.009469,-1.752214,0.017303,-1.164692,-0.84927,-1.547905,...,1.27859,-3.962104,-1.276736,1.626089,-1.491597,-1.799529,-1.357136,-1.561578,-6.378598,-2.638047
ImmGenATAC1219.peak_4,-0.514117,-2.256453,-0.937865,-2.08807,-4.940352,-2.538595,-1.91698,-0.680535,-1.956003,-1.229827,...,-1.69917,-3.818219,-1.74184,-1.551899,-2.48617,-0.437989,-1.279083,-1.156343,-4.821173,-3.798195
ImmGenATAC1219.peak_6,0.004966,-0.582273,1.567085,-0.872777,2.592524,-1.416038,-0.605857,-1.377699,-1.698155,-0.585176,...,0.624094,-0.024136,0.107289,-0.442143,-1.055012,1.127333,1.369699,-0.844233,-2.878908,-1.775934
ImmGenATAC1219.peak_7,1.498347,-1.122641,-3.315898,-2.597305,0.901148,-0.157733,-0.425205,-1.432771,-1.210408,-0.854235,...,2.388976,-3.13966,-1.701328,-0.968636,-0.172606,-1.706625,-0.357936,-1.53842,-5.509221,-4.474012
ImmGenATAC1219.peak_8,-0.770809,-1.809848,-1.56681,-0.422452,-5.720912,-1.887979,-1.740428,-1.53967,-1.887203,-1.098496,...,-2.099256,-1.882061,0.918624,-0.786474,-2.211233,-0.910751,0.357631,-1.528017,-4.384786,-2.674277


In [9]:
target_labels = list(Xs)
Xs_values = Xs.values

train_data, eval_data, train_labels, eval_labels, train_names, eval_names = train_test_split(Xs_values, 
                                                                                             y, peak_names, 
                                                                                             test_size=0.1, random_state=40)

# Data loader
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_data).float(), torch.from_numpy(train_labels))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) #True

eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data).float(), torch.from_numpy(eval_labels))
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) #True

# Training models

In [11]:
#training a single layer

model = networks.SingleLayer(num_inputs=350, num_classes=81).to(device)

criterion = tools.pearson_loss
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

os.makedirs("../PaP_Linear_350_no_sigmoid/")

num_epochs = 20


model, train_error, test_error = train.train_explainn(train_loader, eval_loader, 
                                                      model, device, criterion, 
                                                      optimizer, num_epochs,
                                                      "../PaP_Linear_350_no_sigmoid/", "",
                                                      verbose=True, trim_weights=False)

In [None]:
#training a single layer

nam_model = networks.NAMLayer(num_inputs=350, num_hidden=200, num_classes=81).to(device)

criterion = tools.pearson_loss
optimizer = torch.optim.Adam(nam_model.parameters(), lr=learning_rate)

os.makedirs("../PaP_Linear_350_no_sigmoid_NAM/")

num_epochs = 20

model = model.double()
nam_model, train_error, test_error = train.train_explainn(train_loader, eval_loader, 
                                                          nam_model, device, criterion, 
                                                          optimizer, num_epochs,
                                                          "../PaP_Linear_350_no_sigmoid_NAM/", "",
                                                          verbose=True, trim_weights=False)

In [16]:
nam_model = networks.NAMLayer(num_inputs=350, num_hidden=200, num_classes=81).to(device)
nam_model.load_state_dict(torch.load("../PaP_Linear_350_no_sigmoid_NAM/model_epoch_19_.pth"))
nam_model.eval(); 

predictions = torch.zeros(0, 81).to(device)

with torch.no_grad():
    for seqs, labels in eval_loader:
        seqs = seqs.to(device)
        pred = nam_model(seqs)
        predictions = torch.cat((predictions, pred), 0)

predictions = predictions.cpu().numpy()

correlations = []
vars = []
for i in range(len(predictions)):
    var = np.var(eval_labels[i, :])
    vars.append(var)
    x = np.corrcoef(predictions[i, :], eval_labels[i, :])[0, 1]
    correlations.append(x)

weighted_cor = np.dot(correlations, vars) / np.sum(vars)

nan_cors = [value for value in correlations if math.isnan(value)]
correlations = [value for value in correlations if not math.isnan(value)]

print(np.mean(correlations))
print(np.where(np.array(correlations) >= 0.75)[0].shape)

0.3410467614680922
(1541,)


# UMAP visualization

In [17]:
model = networks.NAMLayer(num_inputs=350, num_hidden=200, num_classes=81).to(device)
model.load_state_dict(torch.load("../PaP_Linear_350_no_sigmoid_NAM/model_epoch_19_.pth"))
model.eval(); 

running_activations = []

with torch.no_grad():
    for seq, lbl in tqdm(eval_loader, total=len(eval_loader)):
        seq = seq.unsqueeze(-1)
        seq = seq.to(device)
        
        act = model.linear(seq)

        running_activations.extend(act.cpu().numpy())
        
running_activations = np.array(running_activations)

100%|██████████| 328/328 [00:02<00:00, 151.11it/s]


In [19]:
#well predicted only
batch_size = 100
dataset = torch.utils.data.TensorDataset(torch.from_numpy(Xs_values).float(), torch.from_numpy(y).float())
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) #True

running_outputs = []
running_labels = []

with torch.no_grad():
    for seq, lbl in tqdm(data_loader):
        seq = seq.to(device)
        out = model(seq)
        out = out.detach().cpu()
        running_outputs.extend(out.numpy()) 
        running_labels.extend(lbl.numpy())

running_labels = np.array(running_labels)
running_outputs = np.array(running_outputs)

correlations = []
vars = []
for i in range(len(running_outputs)):
    var = np.var(running_labels[i, :])
    vars.append(var)
    x_ = np.corrcoef(running_outputs[i, :], running_labels[i, :])[0, 1]
    correlations.append(x_)

correlations = [value for value in correlations if not math.isnan(value)]

#~17K for CAM...
idx = np.argwhere(np.asarray(correlations) > 0.75).squeeze()
idx.shape

100%|██████████| 3280/3280 [00:21<00:00, 151.43it/s]


(14887,)

In [22]:
x2 = Xs_values[idx, :]
y2 = y[idx, :]

dataset = torch.utils.data.TensorDataset(torch.from_numpy(x2).float(), torch.from_numpy(y2).float())
data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

In [23]:
running_activations = []

with torch.no_grad():
    for seq, lbl in tqdm(data_loader, total=len(data_loader)):
        seq = seq.unsqueeze(-1)
        seq = seq.to(device)
        
        act = model.linear(seq)

        running_activations.extend(act.cpu().numpy())
        
running_activations = np.array(running_activations)

100%|██████████| 149/149 [00:01<00:00, 147.66it/s]


In [24]:
import umap

umap.__version__

'0.5.2'

In [25]:
import umap

reducer = umap.UMAP()

##eval_data eval_labels eval_names
#embedding = reducer.fit_transform(eval_data)
embedding = reducer.fit_transform(running_activations)

colors = {'B':"#2a7fffff", 'abT':"#ffff00ff", 
          'gdT':"#00ffffff", 'innate.lym':"#00ff00ff", 
          'myeloid':"#ff5555ff", 'stem':"#ff00ffff"}

color_lineage = [colors[l] for l in lineage_names]

all_labels_df = pd.DataFrame(data=y, index=Xs.index, columns=columns)
all_labels_df = all_labels_df.loc[:,cell_type_names]
all_labels_df = all_labels_df.iloc[idx,:]

all_labels_df["stem"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'stem')[0]].max(axis='columns')
all_labels_df["abT"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'abT')[0]].max(axis='columns')
all_labels_df["gdT"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'gdT')[0]].max(axis='columns')
all_labels_df["innate.lym"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'innate.lym')[0]].max(axis='columns')
all_labels_df["myeloid"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'myeloid')[0]].max(axis='columns')
all_labels_df["B"] = all_labels_df.iloc[:,np.where(np.array(lineage_names) == 'B')[0]].max(axis='columns')

all_labels_df = all_labels_df.loc[:, colors.keys()]

#ems = pd.DataFrame(data=embedding, index=eval_names, columns=["UMAP1", "UMAP2"])
ems = pd.DataFrame(data=embedding, index=all_labels_df.index, columns=["UMAP1", "UMAP2"])

#ems = pd.concat([ems, eval_labels_df], axis=1)
ems = pd.concat([ems, all_labels_df], axis=1)

ems.shape

(14887, 8)

In [27]:
all_data_df = pd.DataFrame(data=Xs.iloc[idx,:].values, index=all_labels_df.index, columns=Xs.columns)

ems = pd.concat([ems, all_data_df], axis=1)

In [30]:
import plotly.graph_objects as go

fig = go.Figure()

#x = ["IRF1", 'IRF2', 'IRF3', 'IRF4', 'IRF5', 'IRF7', 'IRF8', 'IRF9']

#for i in x:
i = "SPIB"
fig.add_trace(go.Scatter(x=ems["UMAP1"], 
                         y=ems["UMAP2"],
                         mode='markers',
                         name='',
                         marker=dict(
                             size=2,
                             #cmax=2,
                             #cmin =-2,
                             color=ems[i], #set color equal to a variable
                             colorscale="Viridis",
                             #colorscale=[[0.0, "rgb(203,213,232)"],
                             #            [1.0, "red"]], # one of plotly colorscales
                             showscale=True
                         )))


fig.update_layout(title_text=f'{i} logits', 
                  xaxis_title='UMAP1',
                  yaxis_title='UMAP2',
                  plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
fig.update_xaxes(showline=True, linewidth=2, linecolor='black')
fig.update_yaxes(showline=True, linewidth=2, linecolor='black')

fig.show()