In [1]:
# Imports

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import utils
import khammash_repro

In [2]:
# Load Training Data

path = "/home/smalani/controlledLearning/training_data"
training_L = np.load(path + "/training_L.npy").astype(np.float32)
training_sp = np.load(path + "/training_sp.npy").astype(np.float32)
training_t = np.load(path + "/training_t.npy").astype(np.float32)
training_y = np.load(path + "/training_y.npy").astype(np.float32)

In [3]:
# Define Dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, training_t, training_y, training_L, time_forecast=1):
        t0, t1, L0, L1, y0, y1 = self.split_data(training_t, training_L, training_y, time_forecast)
        self.t0 = t0
        self.t1 = t1
        self.L0 = L0
        self.L1 = L1
        self.y0 = y0
        self.y1 = y1

    def split_data(self, t, L, y, time_forecast):
        total_time_length = t.shape[1]

        t0, t1, L0, L1, y0, y1 = [], [], [], [], [], []

        for j in range(t.shape[0]):
            time_forecast = np.clip(time_forecast, 1, total_time_length-1)
            for i in range(total_time_length - time_forecast):
                t0.append(t[j, i])


                t1_add, L0_add, L1_add, y1_add, y0_add = [], [], [], [], []
                for k in range(time_forecast):
                    t1_add.append(t[j, i + k+1])
                    L0_add.append(L[j, i + k])
                    y0_add.append(y[j, i+k])
                    L1_add.append(L[j, i + k+1])
                    y1_add.append(y[j, i + k+1])

                t1.append(t1_add)
                L1.append(L1_add)
                L0.append(L0_add)
                y1.append(y1_add)
                y0.append(y0_add)


        t0 = np.array(t0)
        t1 = np.array(t1)
        L0 = np.array(L0)
        L1 = np.array(L1)
        y0 = np.array(y0)
        y1 = np.array(y1)

        return t0, t1, L0, L1, y0, y1

    def __len__(self):
        return self.t0.shape[0]

    def __getitem__(self, idx):
        t0 = torch.tensor(self.t0[idx]).unsqueeze(0)
        t1 = torch.tensor(self.t1[idx])#.unsqueeze(0)
        L0 = torch.tensor(self.L0[idx])#.unsqueeze(0)
        L1 = torch.tensor(self.L1[idx]).unsqueeze(0)
        y0 = torch.tensor(self.y0[idx])
        y1 = torch.tensor(self.y1[idx])

        return t0, t1, L0, L1, y0, y1

In [4]:
# Define the model
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, device, max_step, box='black'):
        super(Model, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        # self.hidden_network = torch.nn.Sequential(
        #     torch.nn.Linear(5, 64),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(64, 64),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(64, 3)
        # )
        self.hidden_fc1 = torch.nn.Linear(5, 64)
        self.hidden_fc2 = torch.nn.Linear(64, 64)
        self.hidden_fc3 = torch.nn.Linear(64, 3)

        self.growth_network = torch.nn.Sequential(
            torch.nn.Linear(4, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1)
        )

        self.device = device
        self.max_step = max_step

        self.initial_vector = None
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

    def hidden_network(self, x):
        x = self.hidden_fc1(x)
        x = self.relu(x)
        x = self.hidden_fc2(x)
        x = self.relu(x)
        x = self.hidden_fc3(x)
        x = self.sigmoid(x)
        return x

    def initialize_initial_conditions(self, y0):
        self.initial_vector = torch.zeros((y0.shape[0], 3), device=self.device, requires_grad=True)
        self.lambda_c = torch.tensor(1.0, device=self.device, requires_grad=True)
    
    def ODEs(self, hidden, phi, L0):
        hidden_derivative = self.hidden_network(torch.cat((hidden, phi, L0/800), dim=1))
        lambda_p = self.growth_network(torch.cat((hidden, phi), dim=1))
        dphi = (lambda_p - self.lambda_c) * phi * (1-phi)

        dydt = torch.cat((hidden_derivative, dphi), dim=1)

        return dydt
    
    def RK4(self, hidden, phi, L0, dt):
        k1 = self.ODEs(hidden, phi, L0)
        k2 = self.ODEs(hidden + dt/2 * k1[...,:3], phi + dt/2 * k1[...,[3]], L0)
        k3 = self.ODEs(hidden + dt/2 * k2[...,:3], phi + dt/2 * k2[...,[3]], L0)
        k4 = self.ODEs(hidden + dt * k3[...,:3], phi + dt * k3[...,[3]], L0)

        hidden_next = hidden + dt/6 * (k1[...,:3] + 2*k2[...,:3] + 2*k3[...,:3] + k4[...,:3])
        phi_next = phi + dt/6 * (k1[...,[3]] + 2*k2[...,[3]] + 2*k3[...,[3]] + k4[...,[3]])

        return hidden_next, phi_next

    def move_to_device(self, t0, t1, L0, y0):
        t0 = t0.to(self.device)
        t1 = t1.to(self.device)
        L0 = L0.to(self.device)
        y0 = y0.to(self.device)

        return t0, t1, L0, y0
    
    def forward(self, t0, t1, L0, y0, autoreg=1):
        if self.initial_vector is None:
            self.initialize_initial_conditions(y0)
        autoreg = np.clip(autoreg, 0, 1)
        t0, t1, L0, y0 = self.move_to_device(t0, t1, L0, y0)
        past_time = t0[:, [0]]

        y_out = []
        hidden_out = []

        y_out.append(y0[:, 0])
        hidden_out.append(self.initial_vector)

        for i in range(t1.shape[1]):
            dt = t1[:, [i]] - past_time
            past_time = t1[:, [i]]
            dt_split = torch.rand(dt.shape).to(self.device)
            dt1 = dt * dt_split
            dt2 = dt - dt1

            L_in = L0[:, [i]]

            hidden_in = hidden_out[-1]
            if int(autoreg * t1.shape[1]) == 0 or i % int(autoreg * t1.shape[1]) == 0:
                y_in = y0[:, i]
            else:
                y_in = y_out[-1]

            hidden_in, y_in = self.RK4(hidden_in, y_in, L_in, dt1)
            hidden_sol, y_sol = self.RK4(hidden_in, y_in, L_in, dt2)

            y_out.append(y_sol)
            hidden_out.append(hidden_sol)
        y_out = torch.stack(y_out[1:], dim=1)
        return hidden_out, y_out

In [5]:
# Create Dataset
time_forecast = 10000
train_dataset = Dataset(training_t[:3,:], training_y[:3,:,:], training_L[:3,:], time_forecast=time_forecast)
val_dataset = Dataset(training_t[[3],:], training_y[[3],:,:], training_L[[3],:], time_forecast=time_forecast)


# Min/Max Normalization


# Train Validation Split
# train_size = int(1.0 * len(dataset))
# val_size = len(dataset) - train_size
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Define the dataloader
batch_size = 10000
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# Create the model
input_size = 5
hidden_size = 64
output_size = 4

max_step = 0.04384

model = Model(input_size, hidden_size, output_size, device, max_step).to(device)

# Define the loss function
loss_fn = torch.nn.MSELoss()

# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr= 1e-2)

# Define the lr scheduler
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, verbose=True)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-6)

In [6]:
# Define the training loop

def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler=None, autoreg=0):
    size = len(dataloader.dataset)
    train_loss = 0
    model.train()
    for batch, (t0, t1, L0, L1, y0, y1) in enumerate(dataloader):
        # Compute prediction and loss
        _, pred = model(t0, t1, L0, y0[...,[3]], autoreg)

        # Scale the predictions
        y1 = y1[...,[3]].to(device)
        loss_ar = loss_fn(pred, y1)

        # Compute prediction and loss
        _, pred = model(t0, t1, L0, y0[...,[3]], 0)

        # Scale the predictions
        loss_tf = loss_fn(pred, y1)
        loss = torch.log(loss_ar) + torch.log(loss_tf)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # if batch % 100 == 0:
        #     loss, current = loss.item(), batch * len(t0)
        #     print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    if lr_scheduler is not None:
        lr_scheduler.step(train_loss)

    lr = optimizer.param_groups[0]['lr']

    return train_loss , lr


def val_loop(dataloader, model, loss_fn, autoreg=0):
    size = len(dataloader.dataset)
    val_loss = 0
    model.eval()
    with torch.no_grad():
        for t0, t1, L0, L1, y0, y1 in dataloader:
        # Compute prediction and loss
            _, pred = model(t0, t1, L0, y0[...,[3]], autoreg)

            # Scale the predictions
            y1 = y1[...,[3]].to(device)
            
            val_loss += torch.log(loss_fn(pred, y1)).item()
    
    return val_loss 

In [7]:
# Train the model
tf_epochs = 100
transition_epochs = 100
autoreg_epochs = 1000
epochs = tf_epochs + transition_epochs + autoreg_epochs
train_loss_list = []
val_loss_list = []
lr_list = []

pbar = tqdm(range(epochs))
lr_scheduler = None

for t in pbar:
    if t < tf_epochs:
        autoreg = 0
    elif t < tf_epochs + transition_epochs:
        autoreg = (t - tf_epochs) / transition_epochs
    else:
        autoreg = 1
        if lr_scheduler is None:
            lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50, verbose=True)
            # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-6)
            for g in optimizer.param_groups:
                g['lr'] = 1e-2

    train_loss, lr = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, autoreg)
    # val_loss = val_loop(val_dataloader, model, loss_fn, 1)

    train_loss_list.append(train_loss)
    # val_loss_list.append(val_loss)
    lr_list.append(lr)

    pbar.set_description(f"Epoch {t+1}")
    pbar.set_postfix(train_loss=train_loss, lr=lr, autoreg=autoreg)#, val_loss=val_loss)

    if t % 100 == 0:
        path = "/home/smalani/controlledLearning/trained_models/"
        torch.save(model, path + "modeltrain_nocontroller_longtraj.pt")
        
    # pbar.set_postfix(train_loss=train_loss, val_loss=val_loss, lr=lr)

  0%|          | 0/1200 [00:00<?, ?it/s]

IndexError: index 3 is out of bounds for dimension 0 with size 1

In [None]:
# # Plot the loss

# fig = plt.figure(figsize=(5, 4))
# ax = fig.add_subplot(111)
# ax.semilogy(train_loss_list, label='Train Loss')
# ax.semilogy(val_loss_list, label='Validation Loss')
# ax.set_xlabel('Epochs')
# ax.set_ylabel('Loss')
# ax.set_title('Loss vs Epochs')

In [None]:
device

In [None]:
#  Performance on training data

def learned_model_ode(t, y0, L0):
    L0 = torch.tensor(L0).to(device)
    y0 = torch.tensor(y0).to(device)
    if len(y0.shape) == 1:
        y0 = y0.unsqueeze(0)
    while len(y0.shape) > len(L0.shape):
        L0 = L0.unsqueeze(-1)
    dydt = model.ODEs(y0, L0)

    return dydt.cpu().detach().numpy().squeeze()
    

true_dydt = khammash_repro.ode_fun(0, dataset.y0.T, dataset.L0[:,0]).T * 60
learned_dydt = learned_model_ode(0, dataset.y0, dataset.L0[:,0])

fig = plt.figure(figsize=(5, 4))
ax = fig.add_subplot(221)
ax.scatter(true_dydt[:,0], learned_dydt[:,0], s=1)
ax.plot([np.min(true_dydt[:,0]), np.max(true_dydt[:,0])], [np.min(true_dydt[:,0]), np.max(true_dydt[:,0])], 'k--')

ax = fig.add_subplot(222)
ax.scatter(true_dydt[:,1], learned_dydt[:,1], s=1)
ax.plot([np.min(true_dydt[:,1]), np.max(true_dydt[:,1])], [np.min(true_dydt[:,1]), np.max(true_dydt[:,1])], 'k--')

ax = fig.add_subplot(223)
ax.scatter(true_dydt[:,2], learned_dydt[:,2], s=1)
ax.plot([np.min(true_dydt[:,2]), np.max(true_dydt[:,2])], [np.min(true_dydt[:,2]), np.max(true_dydt[:,2])], 'k--')

ax = fig.add_subplot(224)
ax.scatter(true_dydt[:,3], learned_dydt[:,3], s=1)
ax.plot([np.min(true_dydt[:,3]), np.max(true_dydt[:,3])], [np.min(true_dydt[:,3]), np.max(true_dydt[:,3])], 'k--')


In [None]:
t0 = torch.tensor(dataset.t0).to(device).unsqueeze(-1)
t1 = torch.tensor(dataset.t1).to(device)#.unsqueeze(-1)
L0 = torch.tensor(dataset.L0).to(device)#.unsqueeze(-1)
y0 = torch.tensor(dataset.y0).to(device)

y1_pred = model.forward(t0, t1, L0, y0).cpu().detach().numpy()

fig = plt.figure(figsize=(5, 4))
ax = fig.add_subplot(221)
ax.scatter(dataset.y1[:,0], y1_pred[:,0], s=1)
ax.plot([np.min([dataset.y1[:,0]]), np.max([dataset.y1[:,0]])], [np.min([dataset.y1[:,0]]), np.max([dataset.y1[:,0]])], 'k--')

ax = fig.add_subplot(222)
ax.scatter(dataset.y1[:,1], y1_pred[:,1], s=1)
ax.plot([np.min([dataset.y1[:,1]]), np.max([dataset.y1[:,1]])], [np.min([dataset.y1[:,1]]), np.max([dataset.y1[:,1]])], 'k--')

ax = fig.add_subplot(223)
ax.scatter(dataset.y1[:,2], y1_pred[:,2], s=1)
ax.plot([np.min([dataset.y1[:,2]]), np.max([dataset.y1[:,2]])], [np.min([dataset.y1[:,2]]), np.max([dataset.y1[:,2]])], 'k--')

ax = fig.add_subplot(224)
ax.scatter(dataset.y1[:,3], y1_pred[:,3], s=1)
ax.plot([np.min([dataset.y1[:,3]]), np.max([dataset.y1[:,3]])], [np.min([dataset.y1[:,3]]), np.max([dataset.y1[:,3]])], 'k--')


In [None]:
# path = "/home/smalani/controlledLearning/trained_models/"
# torch.save(model, path + "modeltrain_nocontroller.pt")

In [None]:
# Load Model and test

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

path = "/home/smalani/controlledLearning/trained_models/"
model = torch.load(path + "modeltrain_nocontroller.pt", map_location="cpu")
model.device = "cpu"

In [None]:
from scipy.integrate import solve_ivp

def learned_model_ode(t, y0, L0):
    L0 = torch.tensor(L0).to(device)
    y0 = torch.tensor(y0).to(device)
    if len(y0.shape) == 1:
        y0 = y0.unsqueeze(0)
    while len(y0.shape) > len(L0.shape):
        L0 = L0.unsqueeze(-1)
    dydt = model.ODEs(y0, L0)

    return dydt.cpu().detach().numpy().squeeze()

x_init = khammash_repro.get_init_cond(L0=200)
x_init[-1] = 0.4
x_init = dataset.y0[0,:]
L0 = 0
t_span = [0, 100]
t_eval = np.linspace(t_span[0], t_span[1], 1000)

sol_true = solve_ivp(khammash_repro.ode_fun, t_span, x_init, args=(L0,), t_eval=t_eval, method='BDF', rtol=1e-10, atol=1e-10, first_step=1e-10)
sol_learned = solve_ivp(learned_model_ode, t_span, x_init, args=(L0,), t_eval=t_eval, method='BDF', rtol=1e-10, atol=1e-10, first_step=1e-10)

print(learned_model_ode(0, x_init, L0))
print(khammash_repro.ode_fun(0, x_init, L0))

In [None]:
fig = plt.figure(figsize=(5, 4))
for i in range(4):
    ax = fig.add_subplot(2,2,i+1)
    ax.plot(sol_true.t, sol_true.y[i,:], 'k-', label='True')
    ax.plot(sol_learned.t, sol_learned.y[i,:], 'r--', label='Learned')
    ax.set_xlabel('Time')
    ax.set_ylabel('Species {}'.format(i+1))
    ax.legend()