In [1]:
import math
import torch
import gpytorch
import tqdm
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

import time

from gpytorch.likelihoods import _GaussianLikelihoodBase
from sklearn import metrics
import matplotlib.pyplot as plt

from torch import Tensor

from gpytorch.distributions import MultivariateNormal, base_distributions
from gpytorch.lazy import ZeroLazyTensor
from gpytorch.utils.warnings import GPInputWarning
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.likelihoods.noise_models import FixedGaussianNoise, HomoskedasticNoise, Noise
from typing import Any, Optional
from gpytorch.mlls._approximate_mll import _ApproximateMarginalLogLikelihood

from gpytorch.constraints import GreaterThan
from gpytorch.distributions import base_distributions
from gpytorch.functions import add_diag
from gpytorch.lazy import (
    BlockDiagLazyTensor,
    DiagLazyTensor,
    KroneckerProductLazyTensor,
    MatmulLazyTensor,
    RootLazyTensor,
    lazify,
)
from gpytorch.likelihoods import Likelihood, _GaussianLikelihoodBase
from gpytorch.utils.warnings import OldVersionWarning
from gpytorch.likelihoods.noise_models import MultitaskHomoskedasticNoise


In [2]:
train_T = 65000
test_T = 15000
N = 200
M = 25
batch_size = 1500

C_den = torch.zeros(5,5)
C_den[0,1:] = 1

sub_no = C_den.shape[0]
num_tasks = sub_no * 2

In [3]:
class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_tasks, M):
        # Let's use a different set of inducing points for each task
        inducing_points = torch.rand(num_tasks, M, 1)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )

        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
        )
        
    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class CustomLikelihood( _GaussianLikelihoodBase):
    def __init__(self, C_den, sub_no, N, num_tasks,
        rank=0,
        task_correlation_prior=None,
        batch_shape=torch.Size(),
        noise_prior=None,
        noise_constraint=None):
        if noise_constraint is None:
            noise_constraint = GreaterThan(1e-4)

        noise_covar = MultitaskHomoskedasticNoise(
            num_tasks=num_tasks, noise_prior=noise_prior, noise_constraint=noise_constraint, batch_shape=batch_shape
        )
        super().__init__(noise_covar=noise_covar)
        if rank != 0:
            if rank > num_tasks:
                raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})")
            tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long)
            self.tidcs = tidcs[:, 1:]  # (1, 1) must be 1.0, no need to parameterize this
            task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1))
            self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr))
            if task_correlation_prior is not None:
                self.register_prior(
                    "MultitaskErrorCorrelationPrior", task_correlation_prior, lambda: self._eval_corr_matrix
                )
        elif task_correlation_prior is not None:
            raise ValueError("Can only specify task_correlation_prior if rank>0")
        self.num_tasks = num_tasks
        self.rank = rank
                
        self.C_den = C_den
        self.sub_no = sub_no
        self.N = N
        
        # Between Subunit Parameters
        self.W_log = nn.Parameter(torch.randn(self.sub_no) , requires_grad=True) # POSITIVE

        ### Subunit Output Parameters ###
        self.V_o = nn.Parameter(torch.randn(1), requires_grad=True)
        self.Theta = nn.Parameter(torch.zeros(self.sub_no), requires_grad=True)
    
    @property
    def noise(self):
        return self.raw_noise_constraint.transform(self.raw_noise)

    @noise.setter
    def noise(self, value):
        self._set_noise(value)

    def _set_noise(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_noise)
        self.initialize(raw_noise=self.raw_noise_constraint.inverse_transform(value))

    def _shaped_noise_covar(self, base_shape, *params):
        if len(base_shape) >= 2:
            *batch_shape, n, _ = base_shape
        else:
            *batch_shape, n = base_shape

        # compute the noise covariance
        if len(params) > 0:
            shape = None
        else:
            shape = base_shape if len(base_shape) == 1 else base_shape[:-1]
        noise_covar = self.noise_covar(*params, shape=shape)

        if self.rank > 0:
            # if rank > 0, compute the task correlation matrix
            # TODO: This is inefficient, change repeat so it can repeat LazyTensors w/ multiple batch dimensions
            task_corr = self._eval_corr_matrix()
            exp_shape = torch.Size([*batch_shape, n]) + task_corr.shape[-2:]
            task_corr_exp = lazify(task_corr.unsqueeze(-3).expand(exp_shape))
            noise_sem = noise_covar.sqrt()
            task_covar_blocks = MatmulLazyTensor(MatmulLazyTensor(noise_sem, task_corr_exp), noise_sem)
        else:
            # otherwise tasks are uncorrelated
            task_covar_blocks = noise_covar

        if len(batch_shape) == 1:
            # TODO: Properly support general batch shapes in BlockDiagLazyTensor (no shape arithmetic)
            tcb_eval = task_covar_blocks.evaluate()
            task_covar = BlockDiagLazyTensor(lazify(tcb_eval), block_dim=-3)
        else:
            task_covar = BlockDiagLazyTensor(task_covar_blocks)

        return task_covar
        
    def expected_log_prob(self, target: Tensor, input: MultivariateNormal, S_e, S_i, *params: Any, **kwargs: Any) -> Tensor:
        #mean, variance = input.mean, input.variance
        #noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diag()
        # Potentially reshape the noise to deal with the multitask case
        #noise = noise.view(*noise.shape[:-1], *input.event_shape)
        
     
        
        all_F = input.mean.T + torch.sqrt(input.variance.T)
        all_F = all_F * 0.01
        T = S_e.shape[0]
        
        F_e = all_F[:self.sub_no].unsqueeze(1)
        F_i = all_F[self.sub_no:].unsqueeze(1)
        #flip_F_e = torch.flip(F_e, [2])
        #flip_F_i = torch.flip(F_i, [2])
        flip_F_e = F_e
        flip_F_i = F_i
        
        pad_S_e = torch.zeros(T + self.N-1, self.sub_no).cuda()
        pad_S_i = torch.zeros(T + self.N-1, self.sub_no).cuda()
        pad_S_e[-T:] = pad_S_e[-T:] + S_e
        pad_S_i[-T:] = pad_S_i[-T:] + S_i
        pad_S_e = pad_S_e.T.unsqueeze(0)
        pad_S_i = pad_S_i.T.unsqueeze(0)

        filtered_e = F.conv1d(pad_S_e, flip_F_e, padding=0, groups=self.sub_no).squeeze(0).T
        filtered_i = F.conv1d(pad_S_i, flip_F_i, padding=0, groups=self.sub_no).squeeze(0).T

        syn_in = filtered_e + filtered_i

        #----- Combine Subunits -----#

        sub_out = torch.zeros(T, self.sub_no).cuda()
        
        for s in range(self.sub_no):
            sub_idx = -s-1
            leaf_idx = torch.where(self.C_den[sub_idx] == 1)[0]

            if torch.numel(leaf_idx) == 0:
                nonlin_out = torch.tanh(syn_in[:,sub_idx] + self.Theta[sub_idx]) # (T_data,) 
                sub_out[:,sub_idx] = sub_out[:,sub_idx] + nonlin_out
            else:
                leaf_in = sub_out[:,leaf_idx] * torch.exp(self.W_log[leaf_idx]) # (T_data,)
                nonlin_in = syn_in[:,sub_idx] + torch.sum(leaf_in, 1) + self.Theta[sub_idx]# (T_data,)
                nonlin_out = torch.tanh(nonlin_in)
                sub_out[:,sub_idx] = sub_out[:,sub_idx] + nonlin_out
        
        final_voltage = sub_out[:,0]*torch.exp(self.W_log[0]) + self.V_o

        #res = (target - final_voltage) ** 2
        #res = res.mul(-0.5)
        res = torch.var(target - final_voltage)
        
        return res, final_voltage
    
class VariationalELBO(_ApproximateMarginalLogLikelihood):
    def _log_likelihood_term(self, variational_dist_f, target, S_e, S_i, **kwargs):
        error, pred = self.likelihood.expected_log_prob(target, variational_dist_f, S_e, S_i, **kwargs)
        
        return error.sum(-1), pred

    def forward(self, approximate_dist_f, target, S_e, S_i, **kwargs):
        r"""
        Computes the Variational ELBO given :math:`q(\mathbf f)` and `\mathbf y`.
        Calling this function will call the likelihood's `expected_log_prob` function.
        Args:
            :attr:`approximate_dist_f` (:obj:`gpytorch.distributions.MultivariateNormal`):
                :math:`q(\mathbf f)` the outputs of the latent function (the :obj:`gpytorch.models.ApproximateGP`)
            :attr:`target` (`torch.Tensor`):
                :math:`\mathbf y` The target values
            :attr:`**kwargs`:
                Additional arguments passed to the likelihood's `expected_log_prob` function.
        """
        # Get likelihood term and KL term
        num_batch = approximate_dist_f.event_shape[0]
        log_likelihood, pred = self._log_likelihood_term(approximate_dist_f, target, S_e, S_i,**kwargs)
        log_likelihood = log_likelihood.div(num_batch)
        
        kl_divergence = self.model.variational_strategy.kl_divergence().div(self.num_data / self.beta)

        # Add any additional registered loss terms
        added_loss = torch.zeros_like(log_likelihood)
        had_added_losses = False
        for added_loss_term in self.model.added_loss_terms():
            added_loss.add_(added_loss_term.loss())
            had_added_losses = True

        # Log prior term
        log_prior = torch.zeros_like(log_likelihood)
        for _, prior, closure, _ in self.named_priors():
            log_prior.add_(prior.log_prob(closure()).sum().div(self.num_data))

        if self.combine_terms:
            return log_likelihood - kl_divergence + log_prior - added_loss , pred
            #return log_likelihood , pred
        else:
            if had_added_losses:
                return log_likelihood, kl_divergence, log_prior.div(self.num_data), added_loss
            else:
                return log_likelihood, kl_divergence, log_prior.div(self.num_data)

In [4]:
Ensyn = torch.tensor([0, 106, 213, 211, 99])
Insyn = torch.tensor([1, 22, 36, 42, 19])

E_no = torch.sum(Ensyn)
I_no = torch.sum(Insyn)

C_syn_e = torch.zeros(sub_no, E_no)
C_syn_i = torch.zeros(sub_no, I_no)

E_count = 0
for s in range(sub_no):
    C_syn_e[s,E_count:E_count+Ensyn[s]] = 1
    E_count += Ensyn[s]

I_count = 0
for s in range(sub_no):
    C_syn_i[s,I_count:I_count+Insyn[s]] = 1
    I_count += Insyn[s]

In [5]:
V_ref = np.load("/media/hdd01/sklee/L23_inputs/vdata_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4.npy").flatten()

train_V_ref = V_ref[:train_T]
test_V_ref = V_ref[train_T:train_T+test_T]
test_V_ref = torch.from_numpy(test_V_ref).cuda()
train_V_ref = torch.from_numpy(train_V_ref).cuda()

raw_E_neural = np.load("/media/hdd01/sklee/L23_inputs/Espikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")
raw_I_neural = np.load("/media/hdd01/sklee/L23_inputs/Ispikes_NMDA_ApN0.5_13_Adend_r0_o2_i2_g_b4_neural.npy")

E_neural = torch.matmul(torch.from_numpy(raw_E_neural).double(), C_syn_e.T.double())
I_neural = torch.matmul(torch.from_numpy(raw_I_neural).double(), C_syn_i.T.double())

train_S_E = E_neural[:train_T].cuda()
train_S_I = I_neural[:train_T].cuda()
test_S_E = E_neural[train_T:train_T+test_T].double().cuda()
test_S_I = I_neural[train_T:train_T+test_T].double().cuda()

repeat_no = 1
batch_no = (train_V_ref.shape[0] - batch_size) * repeat_no
train_idx = np.empty((repeat_no, train_V_ref.shape[0] - batch_size))
for i in range(repeat_no):
    part_idx = np.arange(train_V_ref.shape[0] - batch_size)
    np.random.shuffle(part_idx)
    train_idx[i] = part_idx
train_idx = train_idx.flatten()
train_idx = torch.from_numpy(train_idx)



In [6]:
model = MultitaskGPModel(num_tasks, M)
likelihood = CustomLikelihood(C_den.cuda(), sub_no, N, num_tasks)

num_epochs = 10000
model.cuda().train()
likelihood.cuda().train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr = 0.005)

#lr = 0.00004

train_x = torch.arange(N).cuda()

In [None]:
mll = VariationalELBO(likelihood, model, num_data=train_V_ref.shape[0])
#mll = VariationalELBO(likelihood, model, num_data=N)
#epochs_iter = tqdm.tqdm_notebook(range(num_epochs), desc="Epoch")

count = 0
while True:
    model.train()
    likelihood.train()
    # Within each iteration, we will go over each minibatch of data
    optimizer.zero_grad()
    output = model(train_x)
    loss, pred = mll(output, train_V_ref, train_S_E, train_S_I)
    #epochs_iter.set_postfix(loss=loss.item())
    loss.backward()
    optimizer.step()
    
    if count%200 == 0:
        model.eval()
        likelihood.eval()
        test_output = model(train_x)
        test_loss, test_pred = mll(test_output, test_V_ref, test_S_E, test_S_I)
        
        
        test_score = metrics.explained_variance_score(y_true=test_V_ref.cpu().detach().numpy(),
                                                      y_pred=test_pred.cpu().detach().numpy(),
                                                      multioutput='uniform_average')
        print(count, test_score, time.time() - s)
        s = time.time()
    count += 1

0 8.118211741070525e-05 1607544391.9688473


In [None]:
plt.plot(test_V_ref.cpu().detach().numpy()[1000:4000])
plt.plot(test_pred.cpu().detach().numpy()[1000:4000]-61)

In [None]:
plt.figure(figsize = (10,5))
F = test_output.mean.T + torch.sqrt(test_output.variance.T)
plt.plot(F[6].cpu().detach().numpy())