# Perform OOD evaluation of single-nucleotide variant effects using CAGI5 challenge data

In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, TensorDataset, DataLoader
import pandas as pd
import pytorch_lightning as pl
import h5py
from tqdm import tqdm
import glob
import os, pickle
import numpy as np
from argparse import Namespace
from einops import rearrange
from scipy.stats import pearsonr



from genomic_augmentations import models, supervised

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def batch_np(whole_dataset, batch_size):
    """
    batch a np array for passing to a model without running out of memory
    :param whole_dataset: np array dataset
    :param batch_size: batch size
    :return: generator of np batches
    """
    for i in range(0, whole_dataset.shape[0], batch_size):
        yield whole_dataset[i:i + batch_size]


In [4]:
vcf_data = 'CAGI_onehot.h5'
f =  h5py.File(vcf_data, "r")
alt_3k = f['alt'][()]
ref_3k = f['ref'][()]
f.close()

window_size = 600
L = alt_3k.shape[1]
alt_3k_crop = alt_3k[:,(L//2-window_size//2):(L//2+window_size//2), :]
ref_3k_crop = ref_3k[:,(L//2-window_size//2):(L//2+window_size//2), :]
ref_3k_crop = rearrange(ref_3k_crop, 'a b c -> a c b')
alt_3k_crop = rearrange(alt_3k_crop, 'a b c -> a c b')


### Load trained model (and its `Config` file)

In [5]:
models_dir = "/home/nick/Results/Basset/Models"
model_supervised = "Basset"
dataset = "Basset"
trials = 5

In [7]:
# list of all models
all_models = glob.glob(f'{models_dir}/{model_supervised}_{dataset}*_Model.ckpt')

h5f_output = h5py.File('predict_one_shot.h5', 'w')
# Determine paths to trained model checkpoint and its config file
for checkpoint_path in tqdm(all_models):
    all_ref_preds = []
    all_alt_preds = []
    config_path = checkpoint_path.replace('Finetune_', '').replace('_Model.ckpt', '_Config.p')
    model_label = checkpoint_path.replace('_Model.ckpt', '').split('/')[-1]
    # Load model config (from which loss function and augmentation type and hyperparameters are drawn)
    config_supervised_dict = pickle.load( open(config_path, "rb") )
    config_supervised = Namespace(**config_supervised_dict)
    # Set loss function
    losstype_lower = config_supervised.loss.lower()
    if losstype_lower == "bce": 
        loss = torch.nn.BCELoss()
    elif losstype_lower == "mse":
        loss = torch.nn.MSELoss()
    else:
        raise ValueError("unrecognized loss function type: %s" % loss)
    # Load model for inference using checkpoint
    numclasses_Basset = 164
    model_untrained = supervised.Basset(numclasses_Basset) # denotes model architecture on which to load checkpoint

    if "insert" in config_supervised.augs:
        print('insert')
        model_inference = models.SupervisedModelWithPadding.load_from_checkpoint(checkpoint_path=checkpoint_path, 
                                                                                 model_untrained=model_untrained, loss_criterion=loss,
                                                                                 insert_max=config_supervised.insert_max).to('cuda')
    else: # all other data augmentations do not affect expected input sequence length
        model_inference = models.SupervisedModel.load_from_checkpoint(checkpoint_path=checkpoint_path, 
                                                                      model_untrained=model_untrained, loss_criterion=loss).to('cuda')

    model_inference.eval();

    batch_size = 16
    for ref, alt in zip(batch_np(ref_3k_crop, batch_size), batch_np(alt_3k_crop, batch_size)):

        # reference allele predictions
        ref = torch.from_numpy(ref).float().to('cuda')
        ref_pred = model_inference(ref).to('cpu')
        all_ref_preds.append(ref_pred)
        
        # alternative allele predictions
        alt = torch.from_numpy(alt).float().to('cuda')
        alt_pred = model_inference(alt).to('cpu')
        all_alt_preds.append(alt_pred)
    all_ref_preds = np.concatenate(all_ref_preds)
    all_alt_preds = np.concatenate(all_alt_preds)

    # save outputs
    h5f_output.create_dataset('ref_'+model_label, data=all_ref_preds)
    h5f_output.create_dataset('alt_'+model_label, data=all_alt_preds)
h5f_output.close()

  rank_zero_deprecation(


insert


  0%|                                                    | 0/85 [00:04<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 12.00 MiB (GPU 0; 10.76 GiB total capacity; 9.61 GiB already allocated; 13.44 MiB free; 9.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [90]:
# experimental values
cagi_df = pd.read_csv('final_cagi_metadata.csv',
                      index_col=0).reset_index()
experimental_log_fold_change = cagi_df['6'].values

# predicted values
h5f_output = h5py.File('predict_one_shot.h5', 'r')
model_keys = list(h5f_output.keys())

for i in range(len(model_keys)):
    example_log_fold_change = h5f_output[model_keys[i]][:]
    print(i, pearsonr(experimental_log_fold_change, example_log_fold_change[:,0]))
    