In [None]:
import torch
import numpy as np
import sys, copy, math, time, pdb
import os.path
import random
import pdb
import csv
import argparse
import itertools
from itertools import permutations, product
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torchdiffeq import odeint
import itertools
from torch.utils.data import Dataset, DataLoader
import rpy2.robjects as robjects
import h5py
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

N = 10
M = 10
device = torch.device("cpu")
t5 = torch.tensor([0, 2, 6, 14, 24], dtype=torch.float32)
t10 = torch.tensor([0, 1, 2, 6, 9, 12, 15, 18, 21, 24], dtype=torch.float32)
t15 =  torch.tensor([0, 1, 2, 3, 4, 5, 7, 9, 11, 13, 15, 17, 19, 21, 24], dtype=torch.float32)
t20 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 18, 20, 22, 24], dtype=torch.float32)
t25 = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], dtype=torch.float32)

class CustomDataset(Dataset):
    def __init__(self, tensor):
        self.tensor = tensor
    def __len__(self):
        return len(self.tensor)
    def __getitem__(self, idx):
        input_data = self.tensor[idx, 0, :]
        output_data = self.tensor[idx, :, :]
        return input_data, output_data

def loss_fn(pred_y, y):
    return torch.mean(torch.sum(torch.square(y - pred_y)))
        
class ODEFunc(torch.nn.Module):
    def __init__(self, N, M, mu_base):
        super(ODEFunc, self).__init__()

        self.N = N 
        self.M = M
        
        self.mu = torch.nn.Parameter(torch.rand(N, M, device=device))

        self.mu_base = mu_base
        self.mu_mask = (mu_base != 0).float()

        self.lambda_ = torch.nn.Parameter(torch.rand(1, device=device))
        self.m = torch.nn.Parameter(torch.rand(N, device=device))
        self.rho = torch.nn.Parameter(torch.rand(M, device=device))
        self.omega = torch.nn.Parameter(torch.rand(M, device=device))

        nonzero_byproduct=0.5
        self.l = torch.nn.Parameter(self.generate_byproduct_matrix(M, nonzero_byproduct, device))
    
    @staticmethod
    def generate_byproduct_matrix(M, nonzero_byproduct, device):
        num_elements = M * M
        num_nonzeros = int(num_elements * nonzero_byproduct)
        values = np.concatenate([np.random.uniform(0, 1, num_nonzeros), np.zeros(num_elements - num_nonzeros)])
        np.random.shuffle(values)
        D = torch.tensor(values.reshape(M, M), dtype=torch.float32, device=device)
        col_sums = torch.sum(D, axis=0)
        col_sums[col_sums == 0] = 1
        D_norm = D / col_sums
        return D_norm

    def forward(self, t, ys):
        dydts_list = []
        for y in ys: 
            N, M = self.N, self.M
            C, R = y[:N], y[N:]
            C = C.unsqueeze(0)
            R = R.unsqueeze(0)

            dCdt = C * (R @ self.mu.t() * (1 - self.lambda_)) - C * self.m
            dRdt = self.rho - R * self.omega - (C @ self.mu) * R + self.lambda_ * ((C @ self.mu) * R) @ self.l.t()
            dydt = torch.cat([dCdt.squeeze(0), dRdt.squeeze(0)])
            dydt = dydt * (y > 0).float()
            dydts_list.append(dydt)

        dydts = torch.stack(dydts_list)
        return dydts
        
    def constrained(self):
        with torch.no_grad():
            self.lambda_.data = torch.relu(self.lambda_.data)
            self.m.data = torch.relu(self.m.data)
            self.rho.data = torch.relu(self.rho.data)
            self.omega.data = torch.relu(self.omega.data)
            self.l.data = torch.relu(self.l.data)

def calculate_r_rmse(true_new_state, predict_new_state, S, N, M):
    results = []

    subjects = [f"Subject{i+1}" for i in range(S)]
    regulates = [f"Microbe{i+1}" for i in range(N)] 

    for k in range(S):
        for n in range(N): 
            true_state = true_new_state[k, n]
            predict_state = predict_new_state[k, n]
            C0 = true_state[:, :N]
            R0 = true_state[:, N:N+M]
            C1 = predict_state[:, :N]
            R1 = predict_state[:, N:N+M]
            C_r_rmse = np.sqrt(np.mean((C0 - C1) ** 2))/np.sqrt(np.mean((C0) ** 2))
            R_r_rmse = np.sqrt(np.mean((R0 - R1) ** 2))/np.sqrt(np.mean((R0) ** 2))

            results.append({
                'Subject': subjects[k],
                'Regulate': regulates[n],
                'C_r_rmse': C_r_rmse,
                'R_r_rmse': R_r_rmse
            })

    results_df = pd.DataFrame(results)
    return results_df

def calculate_L2_Distance(true_initial_state, predict_new_state, S, N, M):
    results = []

    for k in range(S):
        true_initial_state_del = true_initial_state[k]

        C0 = true_initial_state_del[:, :N]
        R0 = true_initial_state_del[:, N:(N+M)]

        predict_new_state_subject = predict_new_state[k]

        C_distance = np.zeros(N)
        R_distance = np.zeros(N)

        for n in range(N):
            predict_new_state_subject_del = predict_new_state_subject[n]
            C1 = predict_new_state_subject_del[:, :N]
            R1 = predict_new_state_subject_del[:, N:(N+M)]

            C_distance[n] = np.sqrt(np.sum((C0[:, np.arange(N) != n] - C1[:, np.arange(N) != n])**2))
            R_distance[n] = np.sqrt(np.sum((R0 - R1)**2))

        results.append({'C_distance': C_distance, 'R_distance': R_distance})
    return results

for s in ["s10","s15","s20"]:
    for t in ["t5", "t10", "t15", "t20", "t25"]:
        for iter_num in range(1, 11): 
            if t == "t5":
                batch_t = t5.to(device)
            elif t == "t10":
                batch_t = t10.to(device)
            elif t == "t15":
                batch_t = t15.to(device)
            elif t == "t20":
                batch_t = t20.to(device)
            elif t == "t25":
                batch_t = t25.to(device)

            input_file_path = f'~/eNODEconstr/simulation/generation_data/n{N}m{M}/{s}/{t}/iter{iter_num}'
            output_dir_path = f'~/eNODEconstr/simulation/eNODEconstr/n{N}m{M}/training/{s}/{t}/iter{iter_num}'
            pre_dir_path = f'~/eNODEconstr/simulation/eNODEconstr/n{N}m{M}/prediction/{s}/{t}/iter{iter_num}'
            
            os.makedirs(output_dir_path, exist_ok=True)
            os.makedirs(pre_dir_path, exist_ok=True)
            
            data_file = os.path.join(input_file_path, 'true_initial_state.h5')
            with h5py.File(data_file, 'r') as hdf:
                data = hdf['true_initial_state'][:]
                tensor_data = torch.tensor(data, dtype=torch.float32).permute(0, 2, 1)
            
            mu_file = os.path.join(input_file_path, 'mu_binary_matrix.csv')
            mu_base = pd.read_csv(mu_file, header=None, index_col=False)
            mu_base = torch.tensor(mu_base.values, dtype=torch.float32, device=device)

            dataset = CustomDataset(tensor_data)
            batch_s = int(int(s[1:])* 0.2)
            dataloader = DataLoader(dataset, batch_size= batch_s, shuffle=False)
            model = ODEFunc(N, M, mu_base).to(device)
            model.load_state_dict(torch.load(os.path.join(output_dir_path, 'best_model.pth')))
            model.eval()
            original_data = tensor_data 
            def set_column_to_zero_and_predict(data, column_index):
                new_data = data.clone()
                new_data[:, :, column_index] = 0
                predictions = []
                with torch.no_grad():
                    for i in range(len(new_data)):
                        try:
                            input_data = new_data[i, 0, :].float().unsqueeze(0).to(device)
                            pred_y = odeint(model, input_data, batch_t)  
                            predictions.append(pred_y.cpu().numpy())
                        except AssertionError as e:
                            print(f"Error with subject index {i}, removing microbe {column_index}: {e}")
                            placeholder = np.full_like(predictions[0], np.nan) if predictions else np.full((len(batch_t), 1, len(new_data[i, 0, :])), np.nan)
                            predictions.append(placeholder)
                            continue
                return predictions

            predict = []
            for i in tqdm(range(N)):
                predictions = set_column_to_zero_and_predict(original_data, i)
                predictions_list = predictions 
                predictions_array = np.array(predictions_list)
                predict.append(predictions_array)

            predict = np.array(predict)
            predict = predict.squeeze(3)
            predict = predict.transpose(1,0,2,3)
            predict[predict < 0] = 0  
            np.save(f'{pre_dir_path}/predict.npy', predict)  

            readRDS = robjects.r['readRDS']
            true_initial_state = np.array(readRDS(f'{input_file_path}/true_initial_state.rds'))
            true_new_state = np.array(readRDS(f'{input_file_path}/true_new_state.rds'))

            predict_new_state = predict

            def read_csv_no_header(filename):
                return pd.read_csv(f'{input_file_path}/{filename}', header=None, index_col=False)

            mu_true = read_csv_no_header('mu_true.csv')
            l_true = read_csv_no_header('l_true.csv')
            rho_true = read_csv_no_header('rho_true.csv')
            m_true = read_csv_no_header('m_true.csv')
            omega_true = read_csv_no_header('omega_true.csv')
            lambda_true = read_csv_no_header('lambda_true.csv')

            def calculate_and_save_rmse(model_param, param_true_df, param_name):
                param_pred_np = model_param.detach().numpy()
                param_pred_df = pd.DataFrame(param_pred_np)
                param_pred_df.to_csv(f'{pre_dir_path}/{param_name}_pre.txt', sep='\t', index=False, header=False)           
                param_pred_np = param_pred_df.to_numpy()
                param_true_np = param_true_df.to_numpy()
                rmse = np.sqrt(np.mean((param_pred_np - param_true_np) ** 2))
                r_rmse = rmse / np.sqrt(np.mean((param_true_np) ** 2))
                return rmse, r_rmse

            results = {}

            results['mu'] = calculate_and_save_rmse(model.mu, mu_true, 'mu')
            results['l'] = calculate_and_save_rmse(model.l, l_true, 'l')
            results['rho'] = calculate_and_save_rmse(model.rho, rho_true, 'rho')
            results['m'] = calculate_and_save_rmse(model.m, m_true, 'm')
            results['omega'] = calculate_and_save_rmse(model.omega, omega_true, 'omega')
            results['lambda'] = calculate_and_save_rmse(model.lambda_, lambda_true, 'lambda')

            results_df = pd.DataFrame(results, index=['rmse', 'r_rmse']).transpose()
            results_df.to_csv(f'{pre_dir_path}/pe_rmse.txt', sep='\t', header=True, index=True)

            S = int(s[1:])
            r_rmse_results = calculate_r_rmse(true_new_state, predict_new_state, S, N, M)
            r_rmse_results.to_csv(f'{pre_dir_path}/traj_r_rmse.txt', sep='\t', index=False)


            L2_Distance_results = calculate_L2_Distance(true_initial_state, predict_new_state, S, N, M)
            C_l2_d_avg = np.nanmean([result['C_distance'] for result in L2_Distance_results], axis=0)
            R_l2_d_avg = np.nanmean([result['R_distance'] for result in L2_Distance_results], axis=0)
            microbes = [f"Microbe{i+1}" for i in range(N)]
            C_l2_df = pd.DataFrame({'Regulate': microbes, 'Score': C_l2_d_avg, 'Group': 'Microbe'})
            C_l2_df['Score_normalized'] = C_l2_df['Score'] / (C_l2_df['Score'].sum())
            R_l2_df = pd.DataFrame({'Regulate': microbes, 'Score': R_l2_d_avg, 'Group': 'Metabolite'})
            R_l2_df['Score_normalized'] = R_l2_df['Score'] / (R_l2_df['Score'].sum())
            combined_df = pd.concat([C_l2_df, R_l2_df])
            combined_df.to_csv(f'{pre_dir_path}/score_pre_mean_l2.txt', sep='\t', index=False)

            results = calculate_L2_Distance(true_initial_state, predict_new_state, S, N, M)
            C_l2_d_list = pd.DataFrame([result['C_distance'] for result in results])
            microbes = ['Microbe' + str(i+1) for i in range(N)] 
            subjects = [f"Subject{i+1}" for i in range(S)]  
            row_sums = C_l2_d_list.sum(axis=1)
            C_l2_normalized = C_l2_d_list.div(row_sums, axis=0)
            C_l2_normalized.columns = microbes
            C_l2_normalized['Subject'] = subjects
            C_l2_normalized_long = pd.melt(C_l2_normalized, id_vars='Subject', var_name='Regulate', value_name='Score')
            
            C_l2 = C_l2_d_list
            C_l2.columns = microbes
            C_l2['Subject'] = subjects
            C_l2_long = pd.melt(C_l2, id_vars='Subject', var_name='Regulate', value_name='Score')

            R_l2_d_list = pd.DataFrame([result['R_distance'] for result in results])    
            row_sums = R_l2_d_list.sum(axis=1)
            R_l2_normalized = R_l2_d_list.div(row_sums, axis=0)
            R_l2_normalized.columns = microbes
            R_l2_normalized['Subject'] = subjects
            R_l2_normalized_long = pd.melt(R_l2_normalized, id_vars='Subject', var_name='Regulate', value_name='Score')
            
            R_l2 = R_l2_d_list
            R_l2.columns = microbes
            R_l2['Subject'] = subjects
            R_l2_long = pd.melt(R_l2, id_vars='Subject', var_name='Regulate', value_name='Score') 

            C_l2_long['Group'] = 'Microbe'
            R_l2_long['Group'] = 'Metabolite'
            combined_score_l2 = pd.concat([C_l2_long, R_l2_long])
            combined_score_l2.to_csv(f'{pre_dir_path}/score_pre_subject_l2.txt', index=False, sep='\t', header=True)

            C_l2_normalized_long['Group'] = 'Microbe'
            R_l2_normalized_long['Group'] = 'Metabolite'
            combined_score_normalized_l2 = pd.concat([C_l2_normalized_long, R_l2_normalized_long])
            combined_score_normalized_l2.to_csv(f'{pre_dir_path}/score_normalized_pre_subject_l2.txt', index=False, sep='\t', header=True)
