In [4]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.utils as utils

import gzip
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

from torchsummaryX import summary
import pytorch_warmup as warmup

from tqdm import tqdm
from sklearn.metrics import roc_auc_score

from dataloader import *
from model import *

In [5]:
path = '../iDeepS/datasets/clip/11_CLIPSEQ_ELAVL1_hg19'

instance_length = 40
instance_stride = 5

batch_size = 1
epochs = 20
lr = 5e-4
weight_decay = 1e-5

train_data_path = path + "/30000/training_sample_0/sequences.fa.gz"
valid_data_path = path + "/30000/test_sample_0/sequences.fa.gz"
train_structure_path = path + "/30000/training_sample_0/sequence_structures_forgi.out"
validate_structure_path = path + "/30000/test_sample_0/sequence_structures_forgi.out"

train_data = LibriSamplesWithStructure(train_data_path, train_structure_path)
valid_data = LibriSamplesWithStructure(valid_data_path, validate_structure_path)

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size, shuffle=True)

for x, y, structure in train_loader:
    print(x.shape, y.shape, structure.shape)
    break

torch.Size([1, 13, 40, 4]) torch.Size([1, 2]) torch.Size([1, 13, 40, 6])


In [6]:
attention_to_plot = [333, 3447, 6235, 6235, 9390, 1661, 7096, 1406, 124, 5254]
attention_data_dict = {}
for i in attention_to_plot:
    attention_data_dict[i] = []

model = WeakRMwithStructure().cuda()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.005)
criterion = nn.BCELoss(weight=torch.tensor([0.8, 0.2])).cuda()
num_steps = len(train_loader) * epochs
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
warmup_scheduler = warmup.RAdamWarmup(optimizer)


for epoch in range(1, epochs + 1):
        # training
        model.train()
        num_correct = 0
        total_loss = 0
        for i, (x, y, structure) in enumerate(train_loader):
            # print(i)
            optimizer.zero_grad()

            x = x.float().cuda()
            y = y.float().cuda()
            structure = structure.float().cuda()
            
            outputs, _ = model(x, structure[:, :13, :, :])

            loss = criterion(outputs[0], y[0])

            outputs = torch.argmax(outputs)

            num_correct += int((outputs == torch.argmax(y)).sum())
            total_loss += loss

            loss.backward()
            optimizer.step()
            with warmup_scheduler.dampening():
                scheduler.step()

        print_content ="Epoch {}/{}: Train Acc {:.04f}%, Train Loss {:.04f}, Learning Rate {:.04f}".format(
            epoch,
            epochs,
            100 * num_correct / (len(train_loader) * batch_size),
            float(total_loss / len(train_loader)),
            float(optimizer.param_groups[0]['lr'])
        )

#         print(print_content)
        train_acc =  100 * num_correct / (len(train_loader) * batch_size)
        train_loss = float(total_loss / len(train_loader))
        learn_rate = float(optimizer.param_groups[0]['lr'])

        # eval
        model.eval()

        num_correct = 0
        total_loss = 0
        predictions =[]
        labels = []
        for i, (x, y, structure) in enumerate(valid_loader):
            x = x.float().cuda()
            y = y.float().cuda()
            structure = structure.float().cuda()
            
            with torch.no_grad():
                outputs_probs, attention = model(x, structure[:, :13, :, :])
                
            outputs = torch.argmax(outputs_probs)

            num_correct += int((outputs == torch.argmax(y)).sum())
            total_loss += loss

            predictions.append(outputs_probs.detach().cpu().numpy()[0])
            labels.append(torch.argmax(y).detach().cpu().numpy())
            
            if i in attention_to_plot:
                attention_data_dict[i].append(attention)
                
        dev_acc = 100 * num_correct / len(valid_loader)
        dev_loss = total_loss / len(valid_loader)

        auc_score = roc_auc_score(np.array(labels).flatten(), np.array(predictions)[:, 1])
        print("Validation Accuracy {:.04f}%, Auc Score {:.04f}%".format(dev_acc, auc_score))
        
        torch.save(model.state_dict(), f"epoch{epoch}.m")


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.6600%, Auc Score 0.9040%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.6400%, Auc Score 0.9036%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.9400%, Auc Score 0.9047%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.8100%, Auc Score 0.9039%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.9200%, Auc Score 0.9043%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.8400%, Auc Score 0.9040%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.4900%, Auc Score 0.9050%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.3200%, Auc Score 0.9048%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.7300%, Auc Score 0.9048%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.9400%, Auc Score 0.9053%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.9500%, Auc Score 0.9055%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.8900%, Auc Score 0.9049%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.0200%, Auc Score 0.9054%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 87.8300%, Auc Score 0.9055%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.0400%, Auc Score 0.9057%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.1800%, Auc Score 0.9065%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.1800%, Auc Score 0.9064%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.1600%, Auc Score 0.9064%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.1400%, Auc Score 0.9061%


  gated_attention = self.softmax(gated_attention)  # torch.Size([1, 13])


Validation Accuracy 88.1200%, Auc Score 0.9061%


In [7]:
attention_data_dict.keys()

dict_keys([333, 3447, 6235, 9390, 1661, 7096, 1406, 124, 5254])

In [13]:
for k, v in attention_data_dict.items():
    new_v = []
    for i in v:
        new_v.append(i.cpu().numpy())
    attention_data_dict[k] = np.array(new_v)

In [26]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(rc = {'figure.figsize':(13, 20)})

In [27]:
for i in attention_data_dict.keys():
    fig = sns.heatmap(attention_data_dict[i].squeeze()).get_figure()
    fig.savefig(f"{i}.png")
    plt.clf()

<Figure size 936x1440 with 0 Axes>