# Prototype for Training NN to Invert ODE

## Imports / Installation

In [None]:
# %load_ext autoreload
# %autoreload 2

import os
import scipy
import numpy as np
import pandas as pd
import sys
import scanpy as sc
import scvelo as scv
import matplotlib.pyplot as plt
import math
import random
import time

In [None]:
# Use this for progress bar help: https://stackoverflow.com/questions/3160699/python-progress-bar
from time import sleep
from tqdm.notebook import tqdm

In [None]:
#import pytorch
import torch
from torch import nn

In [None]:
torch.manual_seed(42)
np.random.seed(42)

In [None]:
%load_ext autoreload
%autoreload 2

# import function that makes our data
from training_tools import *

## Construct the Matrix Used for Our Training Data

In [None]:
adj_factor = 21

In [None]:
# direction of the data we're fitting
dire = 2

# model 1 repression genes are notoriously hard to train,
# raise this flag if training them and some extra help
# will be added
dir_2_model_1 = True

In [None]:
if dir_2_model_1 and not dire == 2:
    raise Exception("Don't set dir_2_model_1 to True if you aren't training repression genes!")

In [None]:
if dire == 0:
    suffix_name = "dir0"
elif dire == 1:
    suffix_name = "dir1"
elif dire == 2:
    if dir_2_model_1:
        suffix_name = "dir2_m1"
    else:
        suffix_name = "dir2_m2"

In [None]:
read_folder = "./data/simulated_data/" + suffix_name

In [None]:
X, t = X_from_file(read_folder, dire)

In [None]:
t = (t + 1) / adj_factor
# t = (t / adj_factor)

In [None]:
print(X.shape)
print(t.shape)

## Prepare Batches

In [None]:
batches = 800

## Generate Validation Set

In [None]:
# you will need to supply your own validation data for this to work, either that or
# remove the validation code entirely
read_folder = 

val_X, val_t = X_from_file(read_folder, dire)

In [None]:
print(np.any(np.isnan(X)))
print(np.any(np.isnan(t)))

In [None]:
val_t = (val_t + 1) / adj_factor

In [None]:
val_X_ten = torch.tensor(val_X, dtype=torch.float, requires_grad=True).reshape(-1, val_X.shape[1])
val_t_ten = torch.tensor(val_t, dtype=torch.float, requires_grad=True).reshape(-1, 1)

## Define Model

In [None]:
if dire == 0:
    # DIR 0:
    base_n = 75
    
    ode_model = nn.Sequential(
                nn.Linear(21, int(2*base_n)),
                nn.ReLU(),
                nn.Linear(int(2*base_n), int(1.5*base_n)),
                nn.ReLU(),
                nn.Linear(int(1.5*base_n), int(1*base_n)),
                nn.ReLU(),
                nn.Linear(int(1.0*base_n), 1),
                nn.Sigmoid()
    )

elif dire == 1:
    # DIR 1:
    base_n = 32
    
    ode_model = nn.Sequential(
                nn.Linear(16, int(2*base_n)),
                nn.ReLU(),
                nn.Linear(int(2*base_n), int(1.5*base_n)),
                nn.ReLU(),
                nn.Linear(int(1.5*base_n), int(1*base_n)),
                nn.ReLU(),
                nn.Linear(int(1.0*base_n), 1),
                nn.Sigmoid()
    )

elif dire == 2:
    if dir_2_model_1:
        # DIR 2 M1:
        base_n = 110
        
        ode_model = nn.Sequential(
                    nn.Linear(18, int(2*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(2*base_n), int(1.5*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(1.5*base_n), int(1*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(1.0*base_n), 1),
                    nn.Sigmoid()
        )
        
    else:
        # DIR 2 M2:
        base_n = 75
        
        ode_model = nn.Sequential(
                    nn.Linear(16, int(2*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(2*base_n), int(1.5*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(1.5*base_n), int(1*base_n)),
                    nn.ReLU(),
                    nn.Linear(int(1.0*base_n), 1),
                    nn.Sigmoid()
        )


In [None]:
# Used this: https://stackoverflow.com/questions/49433936/how-do-i-initialize-weights-in-pytorch

def init_weights(m):
    if dire == 0:
        if isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)
    else:
        if isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

In [None]:
ode_model.apply(init_weights)

## Train

I closely followed this tutorial: https://pytorch.org/tutorials/beginner/examples_nn/polynomial_nn.html?highlight=mse

In [None]:
# Number of max epochs for each neural network.
# Each N value was the epoch at which the network ceased
# to improve performance on the developer's validation data
# set for five epochs.
if dire == 0:
    N = 51
elif dire == 1:
    N = 67
elif dire == 2:
    if dir_2_model_1:
        N = 43
    else:
        N = 51

# Use MSE as our loss criterion
criterion = torch.nn.MSELoss(reduction="mean")
val_criterion = torch.nn.MSELoss(reduction="mean")

# learning rate for adam
if dir_2_model_1:
# if dir_2_model_1:
    adam_lr = 1e-3
elif dire == 0:
    adam_lr = 3e-4
else:
    adam_lr = 1e-4

# keep track of the minimum validation loss
min_loss = float('inf')

epoch_loss = []
val_loss_list = []
val_epoch = []

# progress bar keeping track of 
# how many training epochs have
# transpired
epoch_bar = tqdm(total=N)

# progress bar keeping track of
# how many batches within the
# epoch have transpired
batch_bar = tqdm(total=batches)

# mean training loss over each batch
# of the epoch
mean_loss = 0

# set Adam as the training optimizer
optimizer = torch.optim.Adam(ode_model.parameters(), lr=adam_lr)

# print out some hyperparameters
print("Running with lr =", adam_lr, "and base_n =", base_n)

ode_model.train()

# loop thru each epoch
for i in range(N):

    # reset the epoch mean loss
    mean_loss = 0
    mean_context_loss = 0
    
    batch_bar.reset()
    
    sleep(0.001)

    # randomly select data points and assign them to batches
    x_tens, t_tens = make_batches_random(batches, X, t)

    # loop thru each batch
    for j in range(batches):
        
        ode_model.float()
        
        torch.autograd.set_detect_anomaly(True)
        
        # do a forward pass
        t_pred = ode_model(x_tens[j])

        # for loss values, set t_pred and t to 
        # the original time scale
        scaled_t_pred = torch.mul(t_pred, adj_factor)
        scaled_t = torch.mul(t_tens[j], adj_factor)
        
        # check loss
        loss = criterion(scaled_t, scaled_t_pred)
    
        # reset gradients
        optimizer.zero_grad()
    
        # do a backward pass
        loss.backward()

        # step throught the optimizer
        optimizer.step()
        
        batch_bar.update()

        # add the loss of this batch
        # to the running mean batch loss
        mean_loss += loss.item()

    # calculate the mean batch loss
    mean_loss /= batches

    # add loss values to lists for graphing later
    epoch_loss.append(mean_loss)

    # periodically check our progress
    # (can adjust value after modulo in case
    # you want to check progress less regularly)
    if i % 1 == 0:

        ode_model.eval()
        
        # do a forward pass on validation data
        val_t_pred = ode_model(val_X_ten)
        
        # check loss on validation data
        val_loss = val_criterion(torch.mul(val_t_pred, adj_factor), torch.mul(val_t_ten, adj_factor))
        val_loss_list.append(val_loss.item())
        val_epoch.append(i)
        
        print("Epoch", i, "loss:", mean_loss)
        print("Validation loss:", val_loss.item())
    
        # find the relative increase between the two previous
        # loss values for our validation set
        rel_diff = 0
    
        new_val_N = len(val_loss_list)

        # as long as we have at least two items to compare,
        # compare the relative difference in successive
        # validation loss values
        if new_val_N >= 2:
            rel_diff = (val_loss_list[-1] - min_loss) / min_loss
            print(str(rel_diff*100) + "%")

        # # keep track of the lowest validation loss    
        if val_loss_list[-1] < min_loss:
            min_loss = val_loss_list[-1]
            print("New min!")

        print()
        
        ode_model.train()
        
    epoch_bar.update()
    
ode_model.eval()

In [None]:
torch.save(ode_model.state_dict(), "../src/multivelo/neural_nets/" + suffix_name + ".pt")

In [None]:
sub = 10

# make a graph of the loss per epoch
plt.scatter(range(len(epoch_loss))[sub:], epoch_loss[sub:], s=2, label="Training", color="blue")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

## Visualize Neural Network

For this section we visualize how well the neural network performs by passing in noiseless c/u/s data and graphing the results.

In [None]:
read_folder = "./data/simulated_data/" + suffix_name + "_noiseless"

In [None]:
genes_to_graph = [1,12]

In [None]:
graph_X, graph_t, subset_gene = X_from_file(read_folder, dire, subset_gene=genes_to_graph)

In [None]:
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 16))

perf_line = np.arange(0, 20)
ax1.plot(perf_line, perf_line, label="perfect fit", linewidth=1)
ax1.set_xlabel("Real")
ax1.set_ylabel("Pred")

ax2.set_ylabel("Time")
ax2.set_xlabel("C")

ax3.set_ylabel("Time")
ax3.set_xlabel("U")

ax4.set_ylabel("Time")
ax4.set_xlabel("S")

epsilon = 1e-5

for i in range(len(genes_to_graph)):

    subset_X = graph_X[subset_gene[i]:subset_gene[i+1]]
    subset_t = graph_t[subset_gene[i]:subset_gene[i+1]]

    alpha_c = str(np.round(np.exp(subset_X[0,3]) - epsilon, 4))
    alpha = str(np.round(np.exp(subset_X[0,4]) - epsilon, 4))
    beta = str(np.round(np.exp(subset_X[0,5]) - epsilon, 4))
    gamma = str(np.round(np.exp(subset_X[0,6]) - epsilon, 4))

    c = subset_X[:,0]
    u = np.exp(subset_X[:,1] - epsilon)
    s = np.exp(subset_X[:,2] - epsilon)

    t_pred = ode_model(torch.tensor(subset_X).reshape(-1, subset_X.shape[1]))
    t_pred = (t_pred.detach().numpy().reshape(-1) * adj_factor) - 1

    ax1.plot(subset_t, t_pred, linewidth=2, label="alpha_c: " + alpha_c)

    clabel = " t - alpha_c: " + alpha_c
    ax2.plot(c, subset_t, label="real" + clabel, linewidth=1)
    ax2.plot(c, t_pred, label="pred" + clabel, linewidth=1)

    ulabel = " t - alpha: " + alpha + " and beta: " + beta
    ax3.plot(u, subset_t, label="real" + ulabel, linewidth=1)
    ax3.plot(u, t_pred, label="pred" + ulabel, linewidth=1)

    slabel = " t - beta: " + beta + " and gamma: " + gamma
    ax4.plot(s, subset_t, label="real" + slabel, linewidth=1)
    ax4.plot(s, t_pred, label="pred" + slabel, linewidth=1)
    

ax1.legend()
ax2.legend()
ax3.legend()
ax4.legend()

## Visualize Results

In [None]:
raise Exception ("On large datasets, the next steps can crash the kernel! Only proceed if you have enough RAM")

In [None]:
# pass all the data through the final model
final_pred_t_alldata = ode_model(torch.tensor(X, dtype=torch.float, requires_grad=True))

# calculate final loss and print it
final_train_loss = val_criterion(final_pred_t_alldata, \
                             torch.tensor(t, dtype=torch.float, requires_grad=True).reshape(-1, 1))

print(final_train_loss)

In [None]:
# calculate final loss with the original time scale
final_train_context_loss = val_criterion(final_pred_t_alldata*adj_factor, \
                             torch.tensor(t*adj_factor, dtype=torch.float, requires_grad=True).reshape(-1, 1))

print(final_train_context_loss)

In [None]:
# convert predicted time to numpy
t_pred_for_graph = final_pred_t_alldata.detach().numpy().reshape(-1)

In [None]:
# graph the final results of training the original data
graph_results(t, t_pred_for_graph, X, "test")

In [None]:
# a heatmap showing data points of predicted time vs true time
# (a well-trained model will show the most points along the x=y line)
fig, ax = plt.subplots()
h = ax.hist2d(t, t_pred_for_graph, bins=50, cmap="PuOr")
fig.colorbar(h[3], ax=ax)
ax.set_ylabel("Predicted")
ax.set_xlabel("True")

## More Validation

In [None]:
# pass the full HSPC validation set through the model to get predicted time
pred_val_t = ode_model(torch.tensor(val_X, dtype=torch.float, requires_grad=True))

In [None]:
# calculate the loss
loss = criterion(pred_val_t, torch.tensor(val_t, dtype=torch.float, requires_grad=True).reshape(-1, 1))

In [None]:
# print the loss
print(loss.item())

In [None]:
# graph the validation results
graph_results(val_t, pred_val_t.detach().numpy(), val_X, "test")