In [1]:
import numpy as np, os, sys
sys.path.append("..")

import torch
from torch import nn
from torch.nn import functional as F 

from src.utils.synthetic_seqdata import download_data, load_data, sequence_string_to_one_hot
from src.models import DeepBindCNN
from src.trainer import Trainer
from src.utils.datasets import DNASequenceDataset
from sklearn.metrics import roc_auc_score, roc_curve
from src.utils import metrics
from src.explain import Explainer

import matplotlib as mpl 
from matplotlib import pyplot as plt 
%matplotlib inline

In [2]:
# get data 
savedir = "./data"
# _=download_data(savedir)
Xs, Ys = load_data(savedir=savedir)
len(Xs['train']), Ys, len(Ys['train']), len(Ys['valid']), len(Ys['test'])

(14000,
 {'train': array([1., 1., 0., ..., 1., 0., 1.], dtype=float32),
  'valid': array([0., 1., 1., ..., 0., 0., 1.], dtype=float32),
  'test': array([0., 1., 0., ..., 1., 1., 1.], dtype=float32)},
 14000,
 2000,
 4000)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
config = {
    "batch_size": 32,
    "learning_rate": 0.001,
    "architecture": "deepbind",
    "dataset": "synthetic data",
    "epochs": 35,
    "patience": 3,
    }

In [5]:
# set up datasets
datasets = {}
for k in Xs:
    datasets[k] = DNASequenceDataset(sequences=Xs[k], labels=Ys[k], alphabet="ACGT")

# set up dataloaders 
loaders = {}
for k, dataset in datasets.items():
    if k == 'train':
        loaders[k] = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
    else:
        loaders[k] = torch.utils.data.DataLoader(dataset, batch_size=config['batch_size'])
    
# set up the model, lossfn, optimizer, trainer 
model = DeepBindCNN(input_size=4, output_size=1, kernel_size=3)
lossfn = nn.CrossEntropyLoss()

model.load_state_dict(torch.load('best_model.pt'))
print(model)


DeepBindCNN(
  (conv1): Conv1d(4, 16, kernel_size=(3,), stride=(1,))
  (relu): ReLU()
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=16, out_features=1, bias=True)
)


In [20]:
explainer = Explainer(model)

# Select a random sample from the test dataset
sample_index = np.random.randint(len(datasets['test']))
input_sequence, target_label = datasets['test'][sample_index]
saliency_scores = explainer.saliency_map(input_sequence.unsqueeze(0))
saliency_scores

tensor([[[ 0.0000,  0.0000, -0.0583, -0.0000, -0.0000,  0.0000,  0.0504,
          -0.1874,  0.0000,  0.0000, -0.0000, -0.0963, -0.0000, -0.0000,
          -0.0000, -0.0376,  0.0516, -0.0602, -0.0000,  0.0000,  0.0000,
          -0.0000, -0.1013,  0.0000, -0.0000,  0.0000,  0.0000, -0.0504,
          -0.1115, -0.0607,  0.0000, -0.0000,  0.0000,  0.0000, -0.1307,
          -0.0000,  0.0000, -0.0722, -0.0000,  0.0000,  0.0000, -0.1384,
          -0.0000,  0.0000,  0.0000, -0.1384, -0.0000,  0.0000, -0.0682,
          -0.0000, -0.0000, -0.0000,  0.0000,  0.0178, -0.1629,  0.0000,
          -0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000,
          -0.0673,  0.0000,  0.0000, -0.0000, -0.1339, -0.0000, -0.0000,
           0.0000, -0.0000, -0.0000, -0.0963,  0.0000, -0.0000, -0.0000,
          -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000, -0.0000,
          -0.0000, -0.0000,  0.0000,  0.0000,  0.0000, -0.0000,  0.0000,
          -0.0272, -0.0000,  0.0000,  0.0000, -0.00

In [19]:
explainer = Explainer(model)
saliency_scores = []
for inputs, labels in loaders['test']:
    inputs, labels = inputs.to(device), labels.to(device)
    saliency_scores.append(explainer.saliency_map(inputs))

In [16]:
len(saliency_scores)
saliency_scores[0]

tensor([[[ 0.0000, -0.0000, -0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [ 0.0359, -0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
         [-0.0000,  0.0137, -0.0000,  ..., -0.0000,  0.0317, -0.0000],
         [-0.0000,  0.0000, -0.0284,  ..., -0.0861,  0.0000, -0.0388]],

        [[ 0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0449, -0.0907],
         [ 0.0000, -0.0083, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0277, -0.0000, -0.0022,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0000,  0.0000, -0.0000,  ..., -0.1491,  0.0000,  0.0000]],

        [[-0.0000, -0.0000,  0.0000,  ...,  0.0000, -0.0219, -0.0000],
         [ 0.0000, -0.0000, -0.0000,  ..., -0.0000, -0.0000,  0.0000],
         [-0.0882,  0.0235,  0.0000,  ...,  0.0000, -0.0000,  0.0008],
         [-0.0000,  0.0000,  0.0008,  ..., -0.0830,  0.0000, -0.0000]],

        ...,

        [[-0.0086,  0.0000, -0.0000,  ..., -0.0838, -0.0098,  0.0000],
         [-0.0000,  0.0000, -0.0000,  ..., -0.0000,  0.00