# "Multi-task Gaussian Processes For Multivariate Time Series Imputation"
> "Using Gpytorch Package"

- toc:true
- branch: master
- badges: true
- comments: true
- author: Zachary Barnes
- categories: [Gaussian Processes]

In [None]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import pyarrow.parquet as pq
import s3fs
from config import BUCKET
from joblib import Parallel, delayed
from utils import train_val_test_split, load_from_s3_to_df

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.RBFKernel()
        self.task_covar_module = gpytorch.kernels.IndexKernel(num_tasks=11, rank=1, var_constraint=gpytorch.constraints.Positive())


    def forward(self,x,i):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        covar_i = self.task_covar_module(i)
        covar = covar_x.mul(covar_i)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar)


def MTGP_impute(df, timesteps):
    impute_features = ['heart_rate', 'O2_saturation', 'temperature',\
    'systolic_blood_pressure', 'mean_arterial_pressure', \
     'diastolic_blood_pressure', 'respiratory_rate', 'serum_white_blood_count',\
    'serum_glucose', 'pulse_pressure', 'shock_index']
    
    missing_feats = df.columns[df.isna().all()].tolist()
    non_empty_feats = [f for f in set(impute_features) - set(missing_feats)]

    feats = []
    for f in non_empty_feats:
        feats.append(torch.tensor(df[f].values, dtype=torch.float))
    
    full_train_x = [torch.where(~torch.isnan(f))[0] for f in feats]
    if full_train_x == []:
        return df
    
    full_train_x = torch.cat(full_train_x)

    full_train_i = []
    for i, f in enumerate(feats):
        full_train_i.append(torch.full_like(f[~torch.isnan(f)], dtype=torch.long, fill_value=i))
    full_train_i = torch.cat(full_train_i)


    full_train_y = [f[~torch.isnan(f)] for f in feats]
    full_train_y = torch.cat(full_train_y)

    # Instantiate likelihood and model
    likelihood = gpytorch.likelihoods.GaussianLikelihood()

    model = MultitaskGPModel((full_train_x, full_train_i), full_train_y, likelihood)

    training_iterations = 50

    model.train()
    likelihood.train()

    optimizer = torch.optim.Adam([
        {'params': model.parameters()},  # Includes GaussianLikelihood parameters
    ], lr=0.1)

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    for i in range(training_iterations):
        optimizer.zero_grad()
        output = model(full_train_x, full_train_i)
        loss = -mll(output, full_train_y)
        loss.backward()
        optimizer.step()

    model.eval()
    likelihood.eval()

    test_x = torch.tensor(list(range(timesteps)), dtype=torch.long)
    
    test_i_tasks = []
    for i in range(len(non_empty_feats)):
        test_i_tasks.append(torch.full_like(test_x, dtype=torch.long, fill_value=i))
        
    observed_pred_ys = []
    with torch.no_grad(), gpytorch.settings.fast_pred_samples():
        for test_i_task in test_i_tasks:
            observed_pred_ys.append(likelihood(model(test_x, test_i_task)))
    
    for f, preds in zip(non_empty_feats, observed_pred_ys):
        df[f] = np.absolute(preds.mean.detach().numpy())
        


    return df


def chunk_train(train: pd.DataFrame):
    first_half_train = train.loc[:round(train.shape[0]/2)]
    second_half_train = train.loc[round(train.shape[0]/2)+1:]

    return first_half_train, second_half_train

s3 = s3fs.S3FileSystem()
path = "/data/interim/case_control_train.parquet"
filename = "case_control_train.parquet"
train = load_from_s3_to_df(path=path, filename=filename)

first_half_train, second_half_train = chunk_train(train)

impute_first_half_train = Parallel(n_jobs=12)(delayed(MTGP_impute)(df_group, len(df_group)) for patient, df_group in tqdm(first_half_train.groupby('encounter_id')))
impute_first_half_train = pd.concat(impute_first_half_train)