In [None]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
%matplotlib inline
import os
import torch
print(torch.cuda.is_available())
from torch.utils.data import DataLoader, Dataset
from scipy import stats

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
# config

MODEL_PATH =  'model/changenet_trained_weights.pth' 
DATA_PATH = 'data/toy_dataset_site.hdf5'
SITE_LIST = ['CCR5_s8', 'LAG3_s9', 'TRAC_s1', 'CTLA4_s9', 'AAVS1_s14']

BATCH_SIZE = 1024
NUM_EPOCHS = 300
gRNA = {'AAVS1_s14': 'GGGGCCACTAGGGACAGGATTGG',
        'CTLA4_s9': 'GGACTGAGGGCCATGGACACGGG',
        'TRAC_s1': 'GTCAGGGTTCTGGATATCTGTGG',
        'LAG3_s9': 'GAAGGCTGAGATCCTGGAGGGGG',
        'CXCR4_s8': 'GTCCCCTGAGCCCATTTCCTCGG',
        'CCR5_s8': 'GGACAGTAAGAAGGAAAAACAGG'}

## Loading dataset

In [None]:
# Loading dataset
class SeqData(Dataset):
    def __init__(self, X, seq=None):
        self.X = X
        self.seq = seq
        self.length = len(self.X)


    def __getitem__(self, i):
        X = torch.tensor(self.X[i], dtype=torch.float32)
        return X


    def __len__(self):
        return self.length


print('Loading dataset...')
X = []
seq = []
mismatches = []
len_list = []
source_site = []

for site in SITE_LIST:
    data_path = DATA_PATH.replace('site', site)
    with h5py.File(data_path, 'r') as f:
        for num_mismatches in range(0, 7):
            X_s = np.array(f[str(num_mismatches)]['X']).astype(np.float32)
            X_s[:,20,:4] = 0.25 # Set the first base of PAM to N
            seq_s = np.array(f[str(num_mismatches)]['seq']).astype(str)

            X.append(X_s)
            seq.append(seq_s)
            mismatches.append([num_mismatches] * len(X_s))
            source_site.append([site] * len(X_s))
            print(num_mismatches, X_s.shape)
            len_list.append(len(X_s))
            del X_s, seq_s

X = np.concatenate(X)
seq = np.concatenate(seq)
mismatches = np.concatenate(mismatches)
source_site = np.concatenate(source_site)

print(X.shape, mismatches.shape, source_site.shape)


In [None]:
dts = SeqData(X, seq)

loader_test = DataLoader(dataset = dts, \
                        batch_size = 8192,\
                        pin_memory=True,\
                        num_workers = 0,\
                        shuffle = False)


## Loading trained model

In [None]:
def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A5 GPU|grep Free > ./tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))

# id = get_free_gpu()
# device = torch.device("cuda:%d" % id)
device = 'cpu'

class CHANGENET(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        dropout = 0.2
        hidden_dim = 128

        self.seq_length = 23
        self.layers.append(nn.Conv1d(in_channels = 8, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 3, padding = 1))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(hidden_dim, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Conv1d(in_channels = hidden_dim, out_channels = 10, kernel_size = 5, padding = 2))
        self.layers.append(nn.BatchNorm1d(10, track_running_stats = True))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Dropout(dropout))

        self.layers.append(nn.Flatten())

        self.layers.append(nn.Linear(self.seq_length * 10, 128))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(128, 32))
        self.layers.append(nn.LeakyReLU())
        self.layers.append(nn.Linear(32, 1))


    def forward(self, x):
        out = x
        for layer in self.layers:
            out = layer(out)
        return out


model = CHANGENET()
model.to(device)
mseloss = nn.MSELoss()
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint['model_state_dict'])

## Running inference

In [None]:
model.eval()

pred_list = []

for i, X in enumerate(loader_test):
    X  = X.to(device)
    X = torch.transpose(X, 1, 2)
    output = model(X)
    pred_list.append(output.cpu().detach().numpy().reshape(-1,))

pred_list = np.concatenate(pred_list)

d = {'seq': seq,
     'gRNA': source_site,
     'mismatches': mismatches,
     'gRNA_seq': [gRNA[s] for s in source_site],
     'log2FC_pred': pred_list
}

df = pd.DataFrame(data=d)
if not os.path.exists('results/'):
    os.mkdir('results/')
df.to_csv('results/prediction_results_toy_dataset.csv')



