# Train and evaluate the model on NNH6

In [0]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy
from scipy.special import softmax
import torch
import pandas as pd
from tqdm import trange
import random
from time import sleep
import gpytorch

In [None]:
PATH = ''
from evaluation import Evaluator
from datasets import NN6H

In [0]:
torch.cuda.get_device_properties(0).total_memory/1024/1024/1024

11.17303466796875

In [0]:
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

def lossPY(alphay, weighty, y, z, w, s, cov):
  Kz = maternz(z, z).evaluate()
  
  if cov == None:
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    _, Kw, _, Ks = cov
  
  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + w.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  
  f = weighty + lz + lw + ls

  return torch.sum((y - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)

def QZ(alphaq, weightq, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)
  
def lossQZ(alphaq, weightq, z, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov
  
  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  f = weightq + ly + lx + lw + ls
  
  return torch.sum((z - f.reshape(-1,1))**2) - 1e-3*torch.sum(f**2)

def lossPX(alphax, weightx, x, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return torch.sum((x - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def lossPW(alphaw, weightw, w, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
    
  lz = torch.sum(Kz*alphaw[:,0],dim=1)
  ls = torch.sum(Ks*alphaw[:,1],dim=1)
  
  f = weightw + lz + ls

  return loss_bce(f.reshape(-1,1),w) + 1e-3*torch.sum(f**2)

def lossPS(weights, s):
  return torch.sum((s - weights)**2)


def lossPZ(alphaz, weightz, z):
  Kz = maternz(z, z).evaluate()*lower_ones
  alphaz_ = torch.cat((alphaz[:,0], torch.tensor([1.0, 1.0])))
  f = weightz + torch.sum(Kz*alphaz_,dim=1)[:-1]
  return torch.sum((z[0] - weightz)**2) + torch.sum((z[1:] - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def trainModel(x, y, w, s, cov=None, n_iter=20000):
  loss_lst = []
  prog = trange(n_iter, desc='', leave=True)
  for t in prog:
      mean_z = QZ(alphaq, weightq, y, x, w, s, cov)
      z = mean_z + torch.randn(mean_z.shape)

      loss = lossPY(alphay, weighty, y, z, w, s, cov) + lossPX(alphax, weightx, x, z, s, cov) +\
              lossPW(alphaw, weightw, w, z, s, cov) + lossPS(weights, s) +\
                lossPZ(alphaz, weightz, z) - lossQZ(alphaq, weightq, z, y, x, w, s, cov)

      if t%50 == 0:
        prog.set_postfix_str("Iter {}:, Loss: {}".format(t,loss.item()))
        prog.refresh()
        
      loss_lst.append(loss.item())
      optimizer.zero_grad()
      loss.backward(retain_graph=True)
      optimizer.step()
  return np.asarray(loss_lst)

In [0]:
def predZ(alphaq, weightq, y, x, w, s, ynew, xnew, wnew, snew):
  Ky = materny(ynew, y).evaluate()
  Kx = maternx(xnew, x).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)


def evalLogPX(alphax, weightx, x, z, s, xpre, zpre, spre):
  Kz = maternz(zpre, z).evaluate()
  Ks = materns(spre, s).evaluate()
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return (-0.5*np.log(2*np.pi) - 0.5*(Xpre - f.reshape(-1,1))**2).detach().numpy().flatten()

def sampleY(z, w, s, znew, wnew, snew, alphay, weighty, weights_mixture, T, n):
  Y_samples = []
  
  Kz = maternz(znew, z).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  f = weighty + lz + lw + ls
  for i in range(n):
    noise = torch.from_numpy(np.random.normal(0,1,T)).float()

    Y = f + noise
    Y_samples.append(Y.detach().numpy())
  return np.asarray(Y_samples)

In [8]:
learn_hyper = False

dataset = NN6H(replications=10)

train_stats = []
test_stats = []
loss_lst = []
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
  W, Y, Y_cf, mu, X, S = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                          np.concatenate((train[1][1],train[1][2]),axis=1).T, train[0][0][:,0], train[0][0][:,1]
  T = len(Y)
  x = torch.from_numpy(X.reshape(T,-1)).float() 
  y = torch.from_numpy(Y.reshape(-1,1)).float()
  w = torch.from_numpy(W.reshape(-1,1)).float()
  s = torch.from_numpy(S.reshape(T,-1)).float()
  
  Wte, Yte, Y_cfte, mute, Xte, Ste = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                      np.concatenate((test[1][1],test[1][2]), axis=1).T, test[0][0][:,0], test[0][0][:,1]
  Tte = len(Yte)
  xte = torch.from_numpy(Xte.reshape(Tte,-1)).float() 
  yte = torch.from_numpy(Yte.reshape(-1,1)).float()
  wte = torch.from_numpy(Wte.reshape(-1,1)).float()
  ste = torch.from_numpy(Ste.reshape(Tte,-1)).float()
  
  lower_ones = torch.ones(len(y), len(y))
  lower_ones[np.triu_indices(len(y))] = 0
  
  # Compute kernel matrices
  maternx = gpytorch.kernels.RBFKernel()
  maternx.initialize(lengthscale=10.0)
  materns = gpytorch.kernels.RBFKernel()
  materns.initialize(lengthscale=10.0)
  materny = gpytorch.kernels.RBFKernel()
  materny.initialize(lengthscale=10.0)
  maternz = gpytorch.kernels.RBFKernel()
  maternz.initialize(lengthscale=10.0)
  maternw = gpytorch.kernels.RBFKernel()
  maternw.initialize(lengthscale=10.0)
  
  Ky = materny(y, y).evaluate()
  Kx = maternx(x, x).evaluate()
  Kw = maternw(w, w).evaluate()
  Ks = materns(s, s).evaluate()
  if learn_hyper == False:
    cov = (Ky, Kw, Kx, Ks)
  else:
    cov = None
  
  # Declare parameters
  alphay = torch.rand(len(y), 4, requires_grad=True)
  weighty = torch.rand(1, requires_grad=True)

  alphaq = torch.rand(len(y), 5, requires_grad=True)
  weightq = torch.rand(1, requires_grad=True)

  alphax = torch.rand(len(x), 2, requires_grad=True)
  weightx = torch.rand(1, requires_grad=True)

  alphaw = torch.rand(len(w), 2, requires_grad=True)
  weightw = torch.rand(1, requires_grad=True)

  weights = torch.rand(1, requires_grad=True)

  alphaz = torch.rand(len(y)-2, 1, requires_grad=True)
  weightz = torch.rand(1, requires_grad=True)  
  
  # Train
  params = [alphay, alphaq, alphax, alphaw, alphaz, weighty, weightq, weightx, weightw, weightz, weights]
  if learn_hyper == True:
    params = params + list(maternx.parameters()) + list(materns.parameters()) + list(maternw.parameters()) \
                    + list(materny.parameters()) + list(maternz.parameters())
  learning_rate = 1e-3
  optimizer = torch.optim.Adam(params, lr=learning_rate)
  loss_train = trainModel(x, y, w, s, cov=cov, n_iter=20000)
  loss_lst.append(loss_train)

  # Sample Ypred
  z_samples = predZ(alphaq, weightq, y, x, w, s,
                   torch.cat((y,yte),dim=0), torch.cat((x,xte),dim=0), torch.cat((w,wte),dim=0), torch.cat((s,ste),dim=0)).transpose(0,1)
  weights_mixture = 1.0

  Y_samples1 = sampleY(z_samples[0,:T].reshape(-1,1), torch.ones((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.ones((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Y_samples2 = sampleY(z_samples[0,:T].reshape(-1,1), torch.zeros((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.zeros((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Ypred1 = np.mean(Y_samples1,axis=0)
  Ypred0 = np.mean(Y_samples2,axis=0)
  
  # Evaluate
  evaluator_train = Evaluator(y=Y, t=W, y_cf=Y_cf, mu0=mu[0,:], mu1=mu[1,:])
  stat = evaluator_train.calc_stats(Ypred1[:T],Ypred0[:T])
  train_stats.append(stat)
  print('Train:', stat)

  evaluator_test = Evaluator(y=Yte, t=Wte, y_cf=Y_cfte, mu0=mute[0,:], mu1=mute[1,:])
  stat = evaluator_test.calc_stats(Ypred1[T:],Ypred0[T:])
  test_stats.append(stat)
  print('Test:', stat)
  
  sleep(0.5)

100%|██████████| 20000/20000 [02:56<00:00, 113.50it/s, Iter 19950:, Loss: 5992.83251953125]


Train: (1.0587999056848756, 0.1823316670227051, 0.18745555954551035)
Test: (1.3066916514791855, 0.19189368255615058, 0.19652102990343343)


100%|██████████| 20000/20000 [02:56<00:00, 113.38it/s, Iter 19950:, Loss: 9922.2470703125]


Train: (0.8696838029789734, 0.10425863115883693, 0.11625597568008196)
Test: (0.5734634348047454, 0.10222683914184483, 0.1116357236360923)


100%|██████████| 20000/20000 [02:56<00:00, 113.48it/s, Iter 19950:, Loss: 11191.6923828125]


Train: (1.151884114814651, 0.21192465789794745, 0.21739472877064908)
Test: (1.0501331278048278, 0.21083746917724433, 0.21484059069597353)


100%|██████████| 20000/20000 [02:56<00:00, 113.05it/s, Iter 19950:, Loss: 7467.16015625]


Train: (0.9564222749088469, 0.5136689089965838, 0.5154258090066108)
Test: (0.6179454356054885, 0.5113915347290074, 0.513529969938156)


100%|██████████| 20000/20000 [02:57<00:00, 112.46it/s, Iter 19950:, Loss: 9068.7138671875]


Train: (1.1542243741347713, 0.1009203337860125, 0.11017525811126898)
Test: (1.2240790174576168, 0.09803046218872158, 0.10563695751161684)


100%|██████████| 20000/20000 [02:57<00:00, 112.88it/s, Iter 19950:, Loss: 8206.6015625]


Train: (1.0129690418334443, 0.09983407028198243, 0.10912020519950548)
Test: (1.2614269352030745, 0.10676871307372693, 0.1159710758140374)


100%|██████████| 20000/20000 [02:57<00:00, 112.89it/s, Iter 19950:, Loss: 6631.01806640625]


Train: (1.2990844947457612, 0.7398310757446289, 0.7411637128290431)
Test: (0.9816370656505978, 0.7416187382507307, 0.7429945017184215)


100%|██████████| 20000/20000 [02:55<00:00, 113.90it/s, Iter 19950:, Loss: 10605.2919921875]


Train: (1.0347511266105083, 0.18378840454101386, 0.18972639873904582)
Test: (1.4932997706001334, 0.1783748722839329, 0.1817394003019568)


100%|██████████| 20000/20000 [02:56<00:00, 113.13it/s, Iter 19950:, Loss: 8057.81201171875]


Train: (0.9278925671668757, 0.3700829886627215, 0.37268082515134315)
Test: (0.9830194069851704, 0.3626350306701678, 0.36552224075270223)


100%|██████████| 20000/20000 [02:55<00:00, 113.72it/s, Iter 19950:, Loss: 8060.18994140625]


Train: (1.281922270470204, 0.3828484416483686, 0.3938759077325166)
Test: (1.5176142657744616, 0.36767588623046876, 0.3715577918522637)


In [9]:
np.mean(train_stats,axis=0), np.std(test_stats,axis=0,ddof=1)/np.sqrt(10)

(array([1.0747634 , 0.28894892, 0.29532744]),
 array([0.10290512, 0.06651968, 0.06580125]))

In [10]:
np.mean(test_stats,axis=0), np.std(test_stats,axis=0,ddof=1)/np.sqrt(10)

(array([1.10093101, 0.28714532, 0.29199493]),
 array([0.10290512, 0.06651968, 0.06580125]))