# Train and evaluate the model on NN6HIID

In [0]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy
from scipy.special import softmax
import torch
import pandas as pd
from tqdm import trange
import random
from time import sleep
import gpytorch

In [None]:
PATH = ''
from evaluation import Evaluator
from datasets import NN6HIID

In [0]:
torch.cuda.get_device_properties(0).total_memory/1024/1024/1024

11.17303466796875

In [0]:
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

def lossPY(alphay, weighty, y, z, w, s, cov):
  Kz = maternz(z, z).evaluate()
  
  if cov == None:
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    _, Kw, _, Ks = cov
  
  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + w.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  
  f = weighty + lz + lw + ls

  return torch.sum((y - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)

def QZ(alphaq, weightq, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)
  
def lossQZ(alphaq, weightq, z, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov
  
  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  f = weightq + ly + lx + lw + ls
  
  return torch.sum((z - f.reshape(-1,1))**2) - 1e-3*torch.sum(f**2)

def lossPX(alphax, weightx, x, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return torch.sum((x - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def lossPW(alphaw, weightw, w, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
    
  lz = torch.sum(Kz*alphaw[:,0],dim=1)
  ls = torch.sum(Ks*alphaw[:,1],dim=1)
  
  f = weightw + lz + ls

  return loss_bce(f.reshape(-1,1),w) + 1e-3*torch.sum(f**2)

def lossPS(weights, s):
  return torch.sum((s - weights)**2)


def lossPZ(alphaz, weightz, z):
  Kz = maternz(z, z).evaluate()*lower_ones
  alphaz_ = torch.cat((alphaz[:,0], torch.tensor([1.0, 1.0])))
  f = weightz + torch.sum(Kz*alphaz_,dim=1)[:-1]
  return torch.sum((z[0] - weightz)**2) + torch.sum((z[1:] - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def trainModel(x, y, w, s, cov=None, n_iter=20000):
  loss_lst = []
  prog = trange(n_iter, desc='', leave=True)
  for t in prog:
      mean_z = QZ(alphaq, weightq, y, x, w, s, cov)
      z = mean_z + torch.randn(mean_z.shape)

      loss = lossPY(alphay, weighty, y, z, w, s, cov) + lossPX(alphax, weightx, x, z, s, cov) +\
              lossPW(alphaw, weightw, w, z, s, cov) + lossPS(weights, s) +\
                lossPZ(alphaz, weightz, z) - lossQZ(alphaq, weightq, z, y, x, w, s, cov)

      if t%50 == 0:
        prog.set_postfix_str("Iter {}:, Loss: {}".format(t,loss.item()))
        prog.refresh()
        
      loss_lst.append(loss.item())
      optimizer.zero_grad()
      loss.backward(retain_graph=True)
      optimizer.step()
  return np.asarray(loss_lst)

In [0]:
def predZ(alphaq, weightq, y, x, w, s, ynew, xnew, wnew, snew):
  Ky = materny(ynew, y).evaluate()
  Kx = maternx(xnew, x).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)


def evalLogPX(alphax, weightx, x, z, s, xpre, zpre, spre):
  Kz = maternz(zpre, z).evaluate()
  Ks = materns(spre, s).evaluate()
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return (-0.5*np.log(2*np.pi) - 0.5*(Xpre - f.reshape(-1,1))**2).detach().numpy().flatten()

def sampleY(z, w, s, znew, wnew, snew, alphay, weighty, weights_mixture, T, n):
  Y_samples = []
  
  Kz = maternz(znew, z).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  f = weighty + lz + lw + ls
  for i in range(n):
    noise = torch.from_numpy(np.random.normal(0,1,T)).float()

    Y = f + noise
    Y_samples.append(Y.detach().numpy())
  return np.asarray(Y_samples)

In [9]:
learn_hyper = False

dataset = NN6HIID(replications=10)

train_stats = []
test_stats = []
loss_lst = []
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
  W, Y, Y_cf, mu, X, S = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                          np.concatenate((train[1][1],train[1][2]),axis=1).T, train[0][0][:,0], train[0][0][:,1]
  T = len(Y)
  x = torch.from_numpy(X.reshape(T,-1)).float() 
  y = torch.from_numpy(Y.reshape(-1,1)).float()
  w = torch.from_numpy(W.reshape(-1,1)).float()
  s = torch.from_numpy(S.reshape(T,-1)).float()
  
  Wte, Yte, Y_cfte, mute, Xte, Ste = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                      np.concatenate((test[1][1],test[1][2]), axis=1).T, test[0][0][:,0], test[0][0][:,1]
  Tte = len(Yte)
  xte = torch.from_numpy(Xte.reshape(Tte,-1)).float() 
  yte = torch.from_numpy(Yte.reshape(-1,1)).float()
  wte = torch.from_numpy(Wte.reshape(-1,1)).float()
  ste = torch.from_numpy(Ste.reshape(Tte,-1)).float()
  
  lower_ones = torch.ones(len(y), len(y))
  lower_ones[np.triu_indices(len(y))] = 0
  
  # Compute kernel matrices
  maternx = gpytorch.kernels.RBFKernel()
  maternx.initialize(lengthscale=10.0)
  materns = gpytorch.kernels.RBFKernel()
  materns.initialize(lengthscale=10.0)
  materny = gpytorch.kernels.RBFKernel()
  materny.initialize(lengthscale=10.0)
  maternz = gpytorch.kernels.RBFKernel()
  maternz.initialize(lengthscale=10.0)
  maternw = gpytorch.kernels.RBFKernel()
  maternw.initialize(lengthscale=10.0)
  
  Ky = materny(y, y).evaluate()
  Kx = maternx(x, x).evaluate()
  Kw = maternw(w, w).evaluate()
  Ks = materns(s, s).evaluate()
  if learn_hyper == False:
    cov = (Ky, Kw, Kx, Ks)
  else:
    cov = None
  
  # Declare parameters
  alphay = torch.rand(len(y), 4, requires_grad=True)
  weighty = torch.rand(1, requires_grad=True)

  alphaq = torch.rand(len(y), 5, requires_grad=True)
  weightq = torch.rand(1, requires_grad=True)

  alphax = torch.rand(len(x), 2, requires_grad=True)
  weightx = torch.rand(1, requires_grad=True)

  alphaw = torch.rand(len(w), 2, requires_grad=True)
  weightw = torch.rand(1, requires_grad=True)

  weights = torch.rand(1, requires_grad=True)

  alphaz = torch.rand(len(y)-2, 1, requires_grad=True)
  weightz = torch.rand(1, requires_grad=True)  
  
  # Train
  params = [alphay, alphaq, alphax, alphaw, alphaz, weighty, weightq, weightx, weightw, weightz, weights]
  if learn_hyper == True:
    params = params + list(maternx.parameters()) + list(materns.parameters()) + list(maternw.parameters()) \
                    + list(materny.parameters()) + list(maternz.parameters())
  learning_rate = 1e-3
  optimizer = torch.optim.Adam(params, lr=learning_rate)
  loss_train = trainModel(x, y, w, s, cov=cov, n_iter=20000)
  loss_lst.append(loss_train)

  # Sample Ypred
  z_samples = predZ(alphaq, weightq, y, x, w, s,
                   torch.cat((y,yte),dim=0), torch.cat((x,xte),dim=0), torch.cat((w,wte),dim=0), torch.cat((s,ste),dim=0)).transpose(0,1)

  weights_mixture = 1.0

  Y_samples1 = sampleY(z_samples[0,:T].reshape(-1,1), torch.ones((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.ones((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Y_samples2 = sampleY(z_samples[0,:T].reshape(-1,1), torch.zeros((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.zeros((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Ypred1 = np.mean(Y_samples1,axis=0)
  Ypred0 = np.mean(Y_samples2,axis=0)
  
  # Evaluate
  evaluator_train = Evaluator(y=Y, t=W, y_cf=Y_cf, mu0=mu[0,:], mu1=mu[1,:])
  stat = evaluator_train.calc_stats(Ypred1[:T],Ypred0[:T])
  train_stats.append(stat)
  print('Train:', stat)

  evaluator_test = Evaluator(y=Yte, t=Wte, y_cf=Y_cfte, mu0=mute[0,:], mu1=mute[1,:])
  stat = evaluator_test.calc_stats(Ypred1[T:],Ypred0[T:])
  test_stats.append(stat)
  print('Test:', stat)
  
  sleep(0.5)

100%|██████████| 20000/20000 [02:46<00:00, 119.91it/s, Iter 19950:, Loss: 1407.383544921875]


Train: (2.32205140020506, 0.13687134200811402, 0.6099071237834433)
Test: (2.5080620624540706, 0.03828618589063737, 0.25742936930190413)


100%|██████████| 20000/20000 [02:47<00:00, 119.68it/s, Iter 19950:, Loss: 1301.2186279296875]


Train: (2.1769023404860586, 0.21766751048915367, 0.5026229659643505)
Test: (2.3362419294960115, 0.34256972098891847, 0.6853854767931379)


100%|██████████| 20000/20000 [02:46<00:00, 119.95it/s, Iter 19950:, Loss: 1452.2835693359375]


Train: (2.3642882763977737, 0.20087323509337462, 0.47051581290948047)
Test: (2.4006114744070683, 0.16756968259519844, 0.34078091642459596)


100%|██████████| 20000/20000 [02:46<00:00, 119.79it/s, Iter 19950:, Loss: 1457.6578369140625]


Train: (2.3352562681913436, 0.04263382943313232, 0.6318924788062724)
Test: (2.185096527781718, 0.1393521621759657, 0.17230172851643635)


100%|██████████| 20000/20000 [02:46<00:00, 120.15it/s, Iter 19950:, Loss: 1480.9921875]


Train: (2.3697974572742115, 0.4096216969299773, 0.7789572219756145)
Test: (2.542828014465751, 0.2984414565492113, 0.8015826169342647)


100%|██████████| 20000/20000 [02:46<00:00, 120.03it/s, Iter 19950:, Loss: 1465.07958984375]


Train: (2.602244027488276, 0.7623659318487914, 0.9846436815435424)
Test: (2.22536649211966, 0.7423142239609968, 0.9201294333323565)


100%|██████████| 20000/20000 [02:47<00:00, 119.57it/s, Iter 19950:, Loss: 1181.9786376953125]


Train: (1.6548823362473517, 1.0489315810219964, 1.165331786327093)
Test: (1.2400924915379998, 0.9421399092545482, 0.962764743228176)


100%|██████████| 20000/20000 [02:47<00:00, 119.60it/s, Iter 19950:, Loss: 1411.203857421875]


Train: (2.2797659298389545, 0.23476613112212164, 0.555189471040765)
Test: (2.930573787796994, 0.11368762023925783, 0.1188971770452119)


100%|██████████| 20000/20000 [02:46<00:00, 120.15it/s, Iter 19950:, Loss: 1397.4659423828125]


Train: (2.272269226067076, 0.5367287568611929, 0.7313590919565229)
Test: (2.535890714839683, 0.5552563163198618, 0.6450948445115718)


100%|██████████| 20000/20000 [02:46<00:00, 119.88it/s, Iter 19950:, Loss: 1430.377197265625]


Train: (2.286082742868457, 0.688167418649285, 0.7687578836561558)
Test: (3.2014691702462916, 0.6288471937678173, 0.6716435027047826)


In [10]:
np.mean(train_stats,axis=0), np.std(train_stats,axis=0,ddof=1)/np.sqrt(10)

(array([2.266354  , 0.42786274, 0.71991775]),
 array([0.07626466, 0.1023953 , 0.06919216]))

In [11]:
np.mean(test_stats,axis=0), np.std(test_stats,axis=0,ddof=1)/np.sqrt(10)

(array([2.41062327, 0.39684645, 0.55760098]),
 array([0.16327916, 0.09637058, 0.09832666]))