In [None]:
from tools import *
from models import *
import plotly.graph_objects as go
import plotly.figure_factory as ff
from Bio.SeqUtils import GC
import pickle

import warnings
warnings.filterwarnings('ignore')

In [None]:
#CONSTANTS AND HYPERPARAMETERS (add to yaml)
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

batch_size = 100

In [None]:
data = h5py.File("../data/tf_peaks_50_partial.h5", 'r')

x = torch.Tensor(data['train_in'])
y = torch.Tensor(data['valid_in'])
z = torch.Tensor(data['test_in'])

x_lab = torch.Tensor(data['train_out'])
y_lab = torch.Tensor(data['valid_out'])
z_lab = torch.Tensor(data['test_out'])

res = torch.cat((x, y, z), dim=0)
res_lab = torch.cat((x_lab, y_lab, z_lab), dim=0)

all_dataset = torch.utils.data.TensorDataset(res, res_lab)
dataloader = torch.utils.data.DataLoader(all_dataset, 
                                                  batch_size=100, shuffle=False,
                                                  num_workers=0)

In [None]:
target_labels = list(data['target_labels'])

target_labels = [i.decode("utf-8") for i in target_labels]

In [None]:
model = ConvNetDeep(50).to(device)

model.load_state_dict(torch.load("../weights_multimodel_partial/model_epoch_6_.pth"))
model.eval();

#copy trained model weights to motif extraction model
motif_model = motifCNN(model, 50).to(device)
motif_model.load_state_dict(model.state_dict())
motif_model.eval();

## Extraction of PWMs

In [None]:
# run predictions with full model on all data
running_outputs = []
running_labels = []
sequences = []
sigmoid = nn.Sigmoid()
with torch.no_grad():
    for seq, lbl in dataloader:
        sequences.extend(seq.numpy())
        seq = seq.to(device)
        out = model(seq)
        out = sigmoid(out.detach().cpu()) #for BCEWithLogits
        running_outputs.extend(out.numpy()) #for BCEWithLogits
        running_labels.extend(lbl.numpy())

running_labels = np.array(running_labels)
running_outputs = np.array(running_outputs)
sequences = np.array(sequences)

In [None]:
pred_full_round = np.round(running_outputs)

In [None]:
arr_comp = np.equal(pred_full_round, running_labels)
idx = np.argwhere(np.sum(arr_comp, axis=1) >= 50).squeeze() #43563

In [None]:
res2 = res[idx, :, :]
res_lab2 = res_lab[idx, :]

dataset = torch.utils.data.TensorDataset(res2, res_lab2)
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=100, shuffle=False,
                                                  num_workers=2)

In [None]:
# get first layer activations and predictions with leave-one-filter-out
predictions, activations = get_motifs(data_loader, motif_model, device)

In [None]:
output_file_path = "../data/motifs_for_multimodel.meme"

get_memes(activations, res2, res_lab2, output_file_path)

## Filter importance

In [None]:
output_folder = "../data/filter_importance/"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

In [None]:
compute_filter_importance(model, data_loader, target_labels, len(target_labels),
                          output_folder)

In [None]:
impacts = pd.read_csv("../data/filter_importance/average_impacts.csv", index_col=0)

In [None]:
#Heatmap of filter importance
fig = go.Figure(data=go.Heatmap(
                   z=impacts,
                   x=list(impacts),
                   y=impacts.index,
                   hoverongaps = False
    ))

fig.update_layout(title='Filter impacts',
                 font=dict(
                     family="Courier New, monospace",
                     size=10,
                     color="black"
                 ))

fig.update_layout(autosize=False,width=1000,height=1000)

fig.show()

In [None]:
tf_classes = {}

with open("../data/clusters.txt", "r") as f:
    for line in f:
        if line.startswith("#"):
            continue
        
        line_parts = line.strip().split()
        tf_class = line_parts[-1]
        tf_name = line_parts[0]

        tf_classes[tf_name.upper()] = tf_class
    
tf_classes = pd.Series(tf_classes)

tf_classes = tf_classes.sort_values(ascending=True)
tf_classes_df = tf_classes[list(impacts)].sort_values(ascending=True)
impacts_sorted = impacts[tf_classes_df.index]

In [None]:
tomtom_results = pd.read_csv("../data/tomtom_multimodel.tsv", sep="\t",comment="#")
filters_with_min_q = tomtom_results.groupby('Query_ID').min()["q-value"]
filters_with_min_q[filters_with_min_q < 0.01].index

In [None]:
#manually - using tomtom.html
filter_TFs = {"filter10":"CREB1", "filter13":"JUN",
              "filter14":"JUND", "filter15":"USF2", "filter17":"",
              "filter2":"JUN", "filter23":"",
              "filter28":"CTCF", "filter3":"USF2", "filter32":"",
              "filter43":"JUN", "filter45":"FOXA1", "filter49":"CTCF",
              "filter5":"GATA2", "filter50":"", "filter53":"", "filter54":"",
              "filter58":"CEBPB", "filter59":"USF2", "filter60":"NR2F2",
              "filter61":"JUND", "filter63":"", "filter65":"FOXA1", 
              "filter68":"CTCF", "filter70":"CTCF", "filter73":"JUND",
              "filter74":"", "filter75":"JUN", "filter76":"", "filter78":"USF2",
              "filter79":"JUN", "filter83":"", "filter86":"",
              "filter87":"CTCF", "filter92":"NR2F1", "filter94":"MYC",
              "filter95":"MAX", "filter98":"FOXA1", "filter99":"CTCF"}

filter_TFs = pd.Series(filter_TFs)

with open('../data/multimodel_filter_TFs.pickle', 'wb') as f:
    pickle.dump(filter_TFs, f)

In [None]:
filters_info = pd.DataFrame({"Q":filters_with_min_q, "TFs":filter_TFs})
filters_info = filters_info.fillna("")

In [None]:
#from Chendi Wang
with open('../data/motifs_for_multimodel.meme') as fp:
    line = fp.readline()
    motifs=[]
    motif_names=[]
    while line:
        #determine length of next motif
        if line.split(" ")[0]=='MOTIF':
            #add motif number to separate array
            motif_names.append(line.split(" ")[1])
            #get length of motif
            line2=fp.readline().split(" ")
            motif_length = int(float(line2[5]))
            #read in motif 
            current_motif=np.zeros((19, 4)) # Edited pad shorter ones with 0
            for i in range(motif_length):
                current_motif[i,:] = fp.readline().split("\t")
            motifs.append(current_motif)
        line = fp.readline()
    motifs = np.stack(motifs)  
    motif_names = np.stack(motif_names)

In [None]:
#calculate IC (from Chendi Wang)
#set background frequencies of nucleotides
bckgrnd = [0.25, 0.25, 0.25, 0.25]
#compute information content of each motif
info_content = []
position_ic = []
epsilon = 1e-11
for i in range(motifs.shape[0]): 
    length = motifs[i,:,:].shape[0]
    position_wise_ic = np.subtract(np.sum(np.multiply(motifs[i,:,:],np.log2(motifs[i,:,:] + epsilon)), axis=1),np.sum(np.multiply(bckgrnd,np.log2(bckgrnd))))                                    
    position_ic.append(position_wise_ic)
    ic = np.sum(position_wise_ic, axis=0)
    info_content.append(ic)
info_content = np.stack(info_content)

In [None]:
filter_content = pd.Series(info_content, index=motif_names)
filter_content = filter_content[filters_with_min_q.index]
filters_info["IC"] = filter_content

filter_impact = impacts.max(axis=1)
filter_impact = filter_impact[filters_with_min_q.index]
filters_info["Impact"] = filter_impact

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=filters_info["IC"],
    y=np.log10(filters_info["Impact"]),
    mode="markers+text",
    name="Lines, Markers and Text",
    text=filters_info["TFs"],
    textposition="top center",
    marker=dict(
        size=8,
        color=np.log10(filters_info["Q"])*-1, #set color equal to a variable
        colorscale='Bluered', # one of plotly colorscales
        showscale=True,
        colorbar=dict(
            title="-log10(TOMTOM q.value)",
            titleside="right",
            titlefont=dict(size=18)
        )
    )
    #marker_size=np.log10(filters_info["Q"])*-2.5,
    #marker_color="blue"
))

fig.update_layout(
                 font=dict(
                     family="Arial",
                     size=12,
                     color="black"
                 ))

fig.update_layout(yaxis_title='Filter influence (log10)',
                 xaxis_title='Information content',
                 plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
fig.update_xaxes(showline=True, linewidth=2, linecolor='black', titlefont=dict(size=18))
fig.update_yaxes(showline=True, linewidth=2, linecolor='black', titlefont=dict(size=18))

fig.show()

## Inspecting individual models

In [None]:
TF_name = "HNF4A"

data = h5py.File("../TRAIN_DATA_INDIV_INTER/"+TF_name +"/h5_files/" +
                 TF_name + "_tl.h5", 'r')


x = torch.Tensor(data['train_in'])
y = torch.Tensor(data['valid_in'])
z = torch.Tensor(data['test_in'])

x_lab = torch.Tensor(data['train_out'])
y_lab = torch.Tensor(data['valid_out'])
z_lab = torch.Tensor(data['test_out'])

res = torch.cat((x, y, z), dim=0)
res_lab = torch.cat((x_lab, y_lab, z_lab), dim=0)

all_dataset = torch.utils.data.TensorDataset(res, res_lab)
dataloader = torch.utils.data.DataLoader(all_dataset, 
                                                  batch_size=100, shuffle=False,
                                                  num_workers=0)

In [None]:
model = ConvNetDeep(1).to(device)

model.load_state_dict(torch.load("../MODEL_WEIGHTS_INDIV_INTER/" + TF_name +
                                 "_real_indiv_weights_TL/" + TF_name + "_tl_weights/" +
                                 "model_epoch_4_.pth"))
model.eval();

#copy trained model weights to motif extraction model
motif_model = motifCNN(model, 1).to(device)
motif_model.load_state_dict(model.state_dict())
motif_model.eval();

In [None]:
# run predictions with full model on all data
running_outputs = []
running_labels = []
sequences = []
sigmoid = nn.Sigmoid()
with torch.no_grad():
    for seq, lbl in dataloader:
        sequences.extend(seq.numpy())
        seq = seq.to(device)
        out = model(seq)
        out = sigmoid(out.detach().cpu()) #for BCEWithLogits
        running_outputs.extend(out.numpy()) #for BCEWithLogits
        running_labels.extend(lbl.numpy())

running_labels = np.array(running_labels)
running_outputs = np.array(running_outputs)
sequences = np.array(sequences)

In [None]:
pred_full_round = np.round(running_outputs)

In [None]:
arr_comp = np.equal(pred_full_round, running_labels)
idx = np.argwhere(np.sum(arr_comp, axis=1) >= 1).squeeze() #160819
idx.shape

In [None]:
res2 = res[idx, :, :]
res_lab2 = res_lab[idx, :]

dataset = torch.utils.data.TensorDataset(res2, res_lab2)
data_loader = torch.utils.data.DataLoader(dataset=dataset, 
                                          batch_size=100, shuffle=False,
                                                  num_workers=2)

output_folder = "../data/filter_importance/" + TF_name + "/"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
    
compute_filter_importance(model, data_loader, target_labels, len(target_labels),
                          output_folder)

In [None]:
predictions, activations = get_motifs(data_loader, motif_model, device)

In [None]:
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

#get_memes(activations, res2, res_lab2, output_folder + TF_name + "_noTL.meme")
get_memes(activations, res2, res_lab2, output_folder + TF_name + "_50_TL.meme")