In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.special
import scipy.stats
import scipy.ndimage
import sklearn.metrics
import pyfaidx
import pyBigWig
import tqdm
import project as proj
import importlib
tqdm.tqdm_notebook()

In [None]:
importlib.reload(proj)

In [None]:
reads_path = './data/vcm_reads.bw'
open_regions_path = './data/vcm_peaks.bed'
chrom_sizes_path = './data/hg38.canon.chrom.sizes'
ref_fasta_path = './data/hg38.fasta'
reads = pyBigWig.open(reads_path)
train_regions_npy_path = "./data/train_regions.npy"
val_regions_npy_path = "./data/val_regions.npy"
test_regions_npy_path = "./data/test_regions.npy"

In [None]:
# Same function as in Training.ipynb
def run_epoch(data_loader, mode, model, epoch_num, batch_size, att_prior_loss_weight=0,
              num_tasks=1, counts_loss_weight = 25, freq_limit = 200,
              limit_softness = 0.2, att_prior_grad_smooth_sigma = 3,
              input_length=1346, input_depth=4, profile_length=1000,optimizer=None, 
              return_data=False, cuda=0, shuffle=True):
    
    assert mode in ("train", "eval")
    if mode == "train":
        assert optimizer is not None
    else:
        assert optimizer is None 
    
    if shuffle:
        data_loader.shuffle_data()
    
    num_batches = len(data_loader)
    t_iter = tqdm.tqdm(
        data_loader, total=num_batches, desc="\tLoss: ---"
    )
    
    
    if mode == "train":
        model.train()  # Switch to training mode
        torch.set_grad_enabled(True)
        
    batch_losses, corr_losses, att_losses = [], [], []
    prof_losses, count_losses = [], []
    
    input_seqs_array = []
    profiles_array = []
    for input_seqs, profiles in t_iter:
        input_seqs = proj.place_tensor(torch.tensor(input_seqs), cuda).float()

        profiles = proj.place_tensor(torch.tensor(profiles), cuda).float()
        profiles = profiles.view(profiles.shape[0],1,-1,1)
        
        input_seqs_array.append(input_seqs)
        profiles_array.append(profiles)
        
        if mode == "train":
            optimizer.zero_grad()
        elif att_prior_loss_weight > 0:
            # Not training mode, but we still need to zero out weights because
            # we are computing the input gradients
            model.zero_grad()
        
        if att_prior_loss_weight > 0:
            input_seqs.requires_grad=True
            logit_pred_profs, log_pred_counts = model(input_seqs)

            norm_logit_pred_profs = logit_pred_profs - torch.mean(logit_pred_profs, dim=2, keepdim=True) 
            pred_prof_probs = proj.profile_logits_to_log_probs(logit_pred_profs).detach()
            weighted_norm_logits = norm_logit_pred_profs * pred_prof_probs

            input_grads, = torch.autograd.grad(
                weighted_norm_logits, input_seqs,
                grad_outputs=proj.place_tensor(
                    torch.ones(weighted_norm_logits.size()), cuda
                ),
                retain_graph=True, create_graph=True
            )
            
            input_grads = input_grads*input_seqs
            status = proj.place_tensor(torch.tensor(np.ones(input_grads.shape[0])), cuda)
            input_seqs.requires_grad = False
            
            corr_loss, prof_loss, count_loss = model.correctness_loss(
                profiles, logit_pred_profs, log_pred_counts, 
                counts_loss_weight, return_separate_losses=True
            )
            att_loss = model.fourier_att_prior_loss(
                status, input_grads, freq_limit,
                limit_softness, att_prior_grad_smooth_sigma
            )
            loss = corr_loss + att_prior_loss_weight*att_loss
            
        else:
            logit_pred_profs, log_pred_counts = model(input_seqs)
            status, input_grads = None, None
            loss, prof_loss, count_loss = model.correctness_loss(
                profiles, logit_pred_profs, log_pred_counts, 
                counts_loss_weight, return_separate_losses=True
            )
            corr_loss = loss
            att_loss = torch.zeros(1)
            
        if mode == "train":
            loss.backward()  # Compute gradient
            optimizer.step()  # Update weights through backprop
        
        batch_losses.append(loss.item())
        corr_losses.append(corr_loss.item())
        att_losses.append(att_loss.item())
        prof_losses.append(prof_loss.item())
        count_losses.append(count_loss.item())
        t_iter.set_description(
            "\tLoss: %6.4f" % loss.item()
        )
        
    return batch_losses, corr_losses, att_losses, prof_losses, count_losses, input_seqs_array, profiles_array
    

In [None]:
## Restore model and run epoch
val_loader = proj.DataLoader(val_regions_npy_path, batch_size=1)
val_model = proj.restore_model("./trained_models/exp5_epoch_17.pt", 1)
val_model.cuda(1)
val_model.eval()

val_batch_losses, val_corr_losses, val_att_losses, val_prof_losses, val_count_losses, val_input_seqs_array, val_profiles_array = run_epoch(val_loader, 'eval', val_model, 0, 1, cuda=1, shuffle=False)
print("VALIDATION")
print("overall")
print(np.nanmean(val_batch_losses))
print("correctness")
print(np.nanmean(val_corr_losses))
print("attribution")
print(np.nanmean(val_att_losses))
print("profile")
print(np.nanmean(val_prof_losses))
print("count")
print(np.nanmean(val_count_losses))

print("------------------")

test_loader = proj.DataLoader(test_regions_npy_path,batch_size=1)
test_model = proj.restore_model("./trained_models/exp5_epoch_17.pt", 1)
test_model.cuda(1)
test_model.eval()
test_batch_losses, test_corr_losses, test_att_losses, test_prof_losses, test_count_losses, test_input_seqs_array, test_profiles_array = run_epoch(test_loader, 'eval', test_model, 0, 1, cuda=1, shuffle=False)
print("TESTING")
print("overall")
print(np.nanmean(test_batch_losses))
print("correctness")
print(np.nanmean(test_corr_losses))
print("attribution")
print(np.nanmean(test_att_losses))
print("profile")
print(np.nanmean(test_prof_losses))
print("count")
print(np.nanmean(test_count_losses))

In [None]:
## Restore model and run epoch
val_loader = proj.DataLoader(val_regions_npy_path, batch_size=1)
val_model = proj.restore_model("./trained_models/exp6_prior_epoch_20.pt", 1)
val_model.cuda(1)
val_model.eval()

print("WITH PRIOR")

val_batch_losses, val_corr_losses, val_att_losses, val_prof_losses, val_count_losses, val_input_seqs_array, val_profiles_array = run_epoch(val_loader, 'eval', val_model, 0, 1, cuda=1, shuffle=False, att_prior_loss_weight=230)
print("VALIDATION")
print("overall")
print(np.nanmean(val_batch_losses))
print("correctness")
print(np.nanmean(val_corr_losses))
print("attribution")
print(np.nanmean(val_att_losses))
print("profile")
print(np.nanmean(val_prof_losses))
print("count")
print(np.nanmean(val_count_losses))

print("------------------")

test_loader = proj.DataLoader(test_regions_npy_path,batch_size=1)
test_model = proj.restore_model("./trained_models/exp6_prior_epoch_20.pt", 1)
test_model.cuda(1)
test_model.eval()
test_batch_losses, test_corr_losses, test_att_losses, test_prof_losses, test_count_losses, test_input_seqs_array, test_profiles_array = run_epoch(test_loader, 'eval', test_model, 0, 1, cuda=1, shuffle=False, att_prior_loss_weight=230)
print("TESTING")
print("overall")
print(np.nanmean(test_batch_losses))
print("correctness")
print(np.nanmean(test_corr_losses))
print("attribution")
print(np.nanmean(test_att_losses))
print("profile")
print(np.nanmean(test_prof_losses))
print("count")
print(np.nanmean(test_count_losses))