In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.special
import scipy.stats
import scipy.ndimage
import sklearn.metrics
import pyfaidx
import pyBigWig
import tqdm
import project as proj
import importlib
tqdm.tqdm_notebook()

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  tqdm.tqdm_notebook()


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

|<bar/>| 0/? [00:00<?, ?it/s]

In [2]:
reads_path = './data/vcm_reads.bw'
open_regions_path = './data/vcm_peaks.bed'
chrom_sizes_path = './data/hg38.canon.chrom.sizes'
ref_fasta_path = './data/hg38.fasta'
reads = pyBigWig.open(reads_path)

In [3]:
def create_label_file(filepath, set_chroms, regions_df):
    set_regions = regions_df[regions_df['chrom'].isin(set_chroms)].copy().values
    print("saving array of shape " + str(set_regions.shape) + " in file: " + filepath)
    np.save(filepath, set_regions)
    
chrom_sizes = {}
with open(chrom_sizes_path, "r") as f:
    for line in f:
        chrom, size = line.strip().split()
        if len(chrom) > 5 or chrom in ("chrY", "chrM"):
            continue
        chrom_sizes[chrom] = int(size)
    
test_chroms = ['chr1']
val_chroms = ['chr2']
train_chroms = [chrom for chrom in chrom_sizes.keys() if chrom not in (test_chroms + val_chroms)]    
open_regions = pd.read_csv(open_regions_path, sep="\t", header=None, names=["chrom", "start", "end"])
train_regions_npy_path = "./data/train_regions.npy"
val_regions_npy_path = "./data/val_regions.npy"
test_regions_npy_path = "./data/test_regions.npy"  
create_label_file(train_regions_npy_path, train_chroms, open_regions)
create_label_file(val_regions_npy_path, val_chroms, open_regions)
create_label_file(test_regions_npy_path, test_chroms, open_regions)

saving array of shape (70577, 3) in file: ./data/train_regions.npy
saving array of shape (6424, 3) in file: ./data/val_regions.npy
saving array of shape (8385, 3) in file: ./data/test_regions.npy


In [11]:
# Functions adapted from https://github.com/amtseng/fourier_attribution_priors/blob/master/src/model/train_profile_model.py

def run_epoch(data_loader, mode, model, epoch_num, batch_size, att_prior_loss_weight=0,
              num_tasks=1, counts_loss_weight = 25, freq_limit = 200,
              limit_softness = 0.2, att_prior_grad_smooth_sigma = 3,
              input_length=1346, input_depth=4, profile_length=1000,optimizer=None, 
              return_data=False):
    
    
    assert mode in ("train", "eval")
    if mode == "train":
        assert optimizer is not None
    else:
        assert optimizer is None 
    
    data_loader.shuffle_data()
    
    num_batches = len(data_loader)
    t_iter = tqdm.tqdm(
        data_loader, total=num_batches, desc="\tLoss: ---"
    )
    
    
    if mode == "train":
        model.train()  # Switch to training mode
        torch.set_grad_enabled(True)
        
    batch_losses, corr_losses, att_losses = [], [], []
    prof_losses, count_losses = [], []
    
    input_seqs_array = []
    profiles_array = []
    for input_seqs, profiles in t_iter:
        input_seqs = proj.place_tensor(torch.tensor(input_seqs)).float()

        profiles = proj.place_tensor(torch.tensor(profiles)).float()
        profiles = profiles.view(profiles.shape[0],1,-1,1)
        
        input_seqs_array.append(input_seqs)
        profiles_array.append(profiles)
        
        if mode == "train":
            optimizer.zero_grad()
        elif att_prior_loss_weight > 0:
            # Not training mode, but we still need to zero out weights because
            # we are computing the input gradients
            model.zero_grad()
        
        if att_prior_loss_weight > 0:
            input_seqs.requires_grad=True
            logit_pred_profs, log_pred_counts = model(input_seqs)

            norm_logit_pred_profs = logit_pred_profs - torch.mean(logit_pred_profs, dim=2, keepdim=True) 
            pred_prof_probs = profile_logits_to_log_probs(logit_pred_profs).detach()
            weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

            input_grads, = torch.autograd.grad(
                weighted_norm_logits, input_seqs,
                grad_outputs=proj.place_tensor(
                    torch.ones(weighted_norm_logits.size())
                ),
                retain_graph=True, create_graph=True
            )
            
            input_grads = input_grads*input_seqs
            status = proj.place_tensor(torch.tensor(np.ones(input_grads.shape[0])))
            input_seqs.requires_grad = False
            
            corr_loss, prof_loss, count_loss = model.correctness_loss(
                profiles, logit_pred_profs, log_pred_counts, 
                counts_loss_weight, return_separate_losses=True
            )
            att_loss = model.fourier_att_prior_loss(
                status, input_grads, freq_limit,
                limit_softness, att_prior_grad_smooth_sigma
            )
            loss = corr_loss + att_prior_loss_weight*att_loss
            
        else:
            logit_pred_profs, log_pred_counts = model(input_seqs)
            status, input_grads = None, None
            loss, prof_loss, count_loss = model.correctness_loss(
                profiles, logit_pred_profs, log_pred_counts, 
                counts_loss_weight, return_separate_losses=True
            )
            corr_loss = loss
            att_loss = torch.zeros(1)
            
        if mode == "train":
            loss.backward()  # Compute gradient
            optimizer.step()  # Update weights through backprop
        
        batch_losses.append(loss.item())
        corr_losses.append(corr_loss.item())
        att_losses.append(att_loss.item())
        prof_losses.append(prof_loss.item())
        count_losses.append(count_loss.item())
        t_iter.set_description(
            "\tLoss: %6.4f" % loss.item()
        )
        
    return batch_losses, corr_losses, att_losses, prof_losses, count_losses, input_seqs_array, profiles_array
    

def train_model(train_loader, val_loader, model, batch_size, lr, num_epochs, charts_path,
               model_name, att_prior_loss_weight=0, early_stopping = False, early_stop_hist_len = 3, 
                early_stop_min_delta = .001, torch_seed = 3342):
    
    writer = SummaryWriter(charts_path)
    
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(device)
    model.to(device)
    
    torch.manual_seed(torch_seed)
    
    optimizer = torch.optim.Adam(model.parameters(), lr)
    
    if early_stopping:
        val_epoch_loss_hist = []
    
    best_val_epoch_loss, best_model_state = float("inf"), None
    
    for epoch in range(num_epochs):
        if torch.cuda.is_available:
            torch.cuda.empty_cache()
            
        t_batch_losses, t_corr_losses, t_att_losses, t_prof_losses, t_count_losses, input_seqs_array, profiles_array = run_epoch(
            train_loader, "train", model, epoch, batch_size, att_prior_loss_weight, optimizer=optimizer
        )
        
        train_epoch_loss = np.nanmean(t_batch_losses)
        print(
            "Train epoch %d: average loss = %6.10f" % (
                epoch + 1, train_epoch_loss
            )
        )
        
            
        v_batch_losses, v_corr_losses, v_att_losses, v_prof_losses, v_count_losses, input_seqs_array, profiles_array = run_epoch(
                val_loader, "eval", model, epoch, batch_size, att_prior_loss_weight
        )
        
        val_epoch_loss = np.nanmean(v_batch_losses)
        print(
            "Valid epoch %d: average loss = %6.10f" % (
                epoch + 1, val_epoch_loss
            )
        )
        
    
        writer.add_scalars("Loss", {"train": train_epoch_loss, "val": val_epoch_loss}, epoch)
        writer.add_scalars("Correctness_Loss", {"train": np.nanmean(t_corr_losses), "val": np.nanmean(v_corr_losses)}, epoch)
        writer.add_scalars("Attribution_Prior_Loss", {"train": np.nanmean(t_att_losses), "val": np.nanmean(v_att_losses)}, epoch)
        writer.add_scalars("Profile_Loss", {"train": np.nanmean(t_prof_losses), "val": np.nanmean(v_prof_losses)}, epoch)
        writer.add_scalars("Counts_Loss", {"train": np.nanmean(t_count_losses), "val": np.nanmean(v_count_losses)}, epoch)


        model_path = "%s_epoch_%d.pt" % (model_name, epoch + 1+19)
        save_model(model, model_path)
    
        if np.isnan(train_epoch_loss) and np.isnan(val_epoch_loss):
            print("Both NaN")
            break
    

In [None]:
batch_size = 64
learning_rate = 0.001
num_epochs = 25
charts_path = "runs/exp10"
model_path = "trained_models/exp10"
train_model(proj.DataLoader(train_regions_npy_path, batch_size=batch_size), proj.DataLoader(val_regions_npy_path,batch_size=batch_size),
            proj.ProfilePredictor(), batch_size, learning_rate, num_epochs, charts_path, model_path)

In [None]:
batch_size = 64
learning_rate = 0.001
num_epochs = 25
charts_path = "runs/exp10_prior"
model_path = "trained_models/exp10_prior"
att_loss_w = 280

train_model(proj.DataLoader(train_regions_npy_path, batch_size=batch_size), proj.DataLoader(val_regions_npy_path,batch_size=batch_size),
            proj.ProfilePredictor(), batch_size, learning_rate, num_epochs, charts_path, model_path, att_prior_loss_weight=att_loss_w)