In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.utils as utils

import pandas as pd
import numpy as np
import gzip
import os
from torchsummaryX import summary
import pytorch_warmup as warmup
from pathlib import Path
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import wandb

from dataloader import *
from model import *

In [8]:
paths = [
 '../iDeepS/datasets/clip/10_PARCLIP_ELAVL1A_hg19',
 # '../iDeepS/datasets/clip/11_CLIPSEQ_ELAVL1_hg19',
 # '../iDeepS/datasets/clip/12_PARCLIP_EWSR1_hg19',
 # '../iDeepS/datasets/clip/13_PARCLIP_FUS_hg19',
 # '../iDeepS/datasets/clip/14_PARCLIP_FUS_mut_hg19',
 # '../iDeepS/datasets/clip/15_PARCLIP_IGF2BP123_hg19',
 # '../iDeepS/datasets/clip/16_ICLIP_hnRNPC_Hela_iCLIP_all_clusters',
 # '../iDeepS/datasets/clip/18_ICLIP_hnRNPL_Hela_group_3975_all-hnRNPL-Hela-hg19_sum_G_hg19--ensembl59_from_2337-2339-741_bedGraph-cDNA-hits-in-genome',
 # '../iDeepS/datasets/clip/19_ICLIP_hnRNPL_U266_group_3986_all-hnRNPL-U266-hg19_sum_G_hg19--ensembl59_from_2485_bedGraph-cDNA-hits-in-genome',
 # '../iDeepS/datasets/clip/17_ICLIP_HNRNPC_hg19',
 # '../iDeepS/datasets/clip/1_PARCLIP_AGO1234_hg19',
 # '../iDeepS/datasets/clip/20_ICLIP_hnRNPlike_U266_group_4000_all-hnRNPLlike-U266-hg19_sum_G_hg19--ensembl59_from_2342-2486_bedGraph-cDNA-hits-in-genome',
 # '../iDeepS/datasets/clip/21_PARCLIP_MOV10_Sievers_hg19',
 # '../iDeepS/datasets/clip/22_ICLIP_NSUN2_293_group_4007_all-NSUN2-293-hg19_sum_G_hg19--ensembl59_from_3137-3202_bedGraph-cDNA-hits-in-genome',
 # '../iDeepS/datasets/clip/23_PARCLIP_PUM2_hg19',
 # '../iDeepS/datasets/clip/24_PARCLIP_QKI_hg19',
 # '../iDeepS/datasets/clip/25_CLIPSEQ_SFRS1_hg19',
 '../iDeepS/datasets/clip/26_PARCLIP_TAF15_hg19',
 # '../iDeepS/datasets/clip/27_ICLIP_TDP43_hg19',
 # '../iDeepS/datasets/clip/28_ICLIP_TIA1_hg19',
 # '../iDeepS/datasets/clip/29_ICLIP_TIAL1_hg19',
 # '../iDeepS/datasets/clip/2_PARCLIP_AGO2MNASE_hg19',
 # '../iDeepS/datasets/clip/30_ICLIP_U2AF65_Hela_iCLIP_ctrl_all_clusters',
 # '../iDeepS/datasets/clip/31_ICLIP_U2AF65_Hela_iCLIP_ctrl+kd_all_clusters',
 # '../iDeepS/datasets/clip/3_HITSCLIP_Ago2_binding_clusters',
 # '../iDeepS/datasets/clip/4_HITSCLIP_Ago2_binding_clusters_2',
 # '../iDeepS/datasets/clip/5_CLIPSEQ_AGO2_hg19',
 # '../iDeepS/datasets/clip/6_CLIP-seq-eIF4AIII_1',
 # '../iDeepS/datasets/clip/7_CLIP-seq-eIF4AIII_2',
 # '../iDeepS/datasets/clip/8_PARCLIP_ELAVL1_hg19',
 # '../iDeepS/datasets/clip/9_PARCLIP_ELAVL1MNASE_hg19'
]

instance_length = 40
instance_stride = 5

batch_size = 1
epochs = 10
lr = 5e-4
weight_decay = 1e-5

In [9]:
def train(path):
    print(f"start running {path}")
    
    wandb.init(project="02710", entity="fanfanwu9898",name = "WEAKRMLSTM" + p.split("/")[-1].split("_")[0])
    wandb.config = {
      "learning_rate": lr,
      "epochs": epochs,
      "batch_size": batch_size,
      "instance_length" : instance_length,
      "instance_stride" : instance_stride,
      "decay": weight_decay
    }
    

    train_data_path = path + "/30000/training_sample_0/sequences.fa.gz"
    valid_data_path = path + "/30000/test_sample_0/sequences.fa.gz"
    train_data = LibriSamples(train_data_path)
    valid_data = LibriSamples(valid_data_path)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size = batch_size, shuffle=True)
    
    model = WeakRMLSTM().cuda()
    # model = WSCNN().cuda()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.BCELoss(weight=torch.tensor([0.8, 0.2])).cuda()
    num_steps = len(train_loader) * epochs
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
    warmup_scheduler = warmup.RAdamWarmup(optimizer)

    for epoch in range(1, epochs + 1):
        # training
        model.train()
        num_correct = 0
        total_loss = 0
        for i, (x, y) in enumerate(train_loader):
            optimizer.zero_grad()

            x = x.float().cuda()
            y = y.float().cuda()

            outputs, _ = model(x)

            loss = criterion(outputs[0], y[0])

            outputs = torch.argmax(outputs)

            num_correct += int((outputs == torch.argmax(y)).sum())
            total_loss += loss

            loss.backward()
            optimizer.step()
            with warmup_scheduler.dampening():
                scheduler.step()

#         print_content ="Epoch {}/{}: Train Acc {:.04f}%, Train Loss {:.04f}, Learning Rate {:.04f}".format(
#             epoch,
#             epochs,
#             100 * num_correct / (len(train_loader) * batch_size),
#             float(total_loss / len(train_loader)),
#             float(optimizer.param_groups[0]['lr'])
#         )

#         print(print_content)
        train_acc =  100 * num_correct / (len(train_loader) * batch_size)
        train_loss = float(total_loss / len(train_loader))
        learn_rate = float(optimizer.param_groups[0]['lr'])

        # eval
        model.eval()

        num_correct = 0
        total_loss = 0
        predictions =[]
        labels = []
        for i, (x, y) in enumerate(valid_loader):
            x = x.float().cuda()
            y = y.float().cuda()

            with torch.no_grad():
                outputs_probs,_ = model(x)

            outputs = torch.argmax(outputs_probs)

            num_correct += int((outputs == torch.argmax(y)).sum())
            total_loss += loss

            predictions.append(outputs_probs.detach().cpu().numpy()[0])
            labels.append(torch.argmax(y).detach().cpu().numpy())

        dev_acc = 100 * num_correct / len(valid_loader)
        dev_loss = total_loss / len(valid_loader)

        auc_score = roc_auc_score(np.array(labels).flatten(), np.array(predictions)[:, 1])
        # print("Validation Accuracy {:.04f}%, Auc Score {:.04f}%".format(dev_acc, auc_score))
        
        wandb.log({"Train Acc:":train_acc, "Train loss:": train_loss, "Test Acc ":dev_acc,
                   "Test loss":dev_loss, "auROC":auc_score, "lr":lr})

In [10]:
for p in paths:
    train(p)

start running ../iDeepS/datasets/clip/10_PARCLIP_ELAVL1A_hg19





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test Acc,▁▁▁▁▁▁▁▁▁▁
Test loss,▂██▂▂▁▁▁▁▁
Train Acc:,▁█████████
Train loss:,█▆▆▅▄▃▃▂▁▁
auROC,▁▂▄▄▇▇████
lr,▁▁▁▁▁▁▁▁▁▁

0,1
Test Acc,80.0
Test loss,0.10578
Train Acc:,80.0
Train loss:,0.24374
auROC,0.58706
lr,0.0005


  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])


start running ../iDeepS/datasets/clip/26_PARCLIP_TAF15_hg19





VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Test Acc,▅▇▇██▅▂▃▂▁
Test loss,▂▁█▁▂▁▁▁▁▂
Train Acc:,▁▃▄▄▅▆▆▇██
Train loss:,█▆▆▅▅▄▃▂▁▁
auROC,▇████▆▃▃▁▁
lr,▁▁▁▁▁▁▁▁▁▁

0,1
Test Acc,88.06
Test loss,0.05529
Train Acc:,92.24333
Train loss:,0.10563
auROC,0.87683
lr,0.0005


  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
  gated_attention = nn.Softmax()(gated_attention)  # torch.Size([1, 13])
