# Train and evaluate the model on NN2H

In [0]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy
from scipy.special import softmax
import torch
import pandas as pd
from tqdm import trange
import random
from time import sleep
import gpytorch

In [None]:
PATH = ''
from evaluation import Evaluator
from datasets import NN2H

In [7]:
torch.cuda.get_device_properties(0).total_memory/1024/1024/1024

15.8992919921875

In [0]:
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

def lossPY(alphay, weighty, y, z, w, s, cov):
  Kz = maternz(z, z).evaluate()
  
  if cov == None:
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    _, Kw, _, Ks = cov
  
  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + w.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  
  f = weighty + lz + lw + ls

  return torch.sum((y - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)

def QZ(alphaq, weightq, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)
  
def lossQZ(alphaq, weightq, z, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov
  
  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  f = weightq + ly + lx + lw + ls
  
  return torch.sum((z - f.reshape(-1,1))**2) - 1e-3*torch.sum(f**2)

def lossPX(alphax, weightx, x, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return torch.sum((x - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def lossPW(alphaw, weightw, w, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
    
  lz = torch.sum(Kz*alphaw[:,0],dim=1)
  ls = torch.sum(Ks*alphaw[:,1],dim=1)
  
  f = weightw + lz + ls

  return loss_bce(f.reshape(-1,1),w) + 1e-3*torch.sum(f**2)

def lossPS(weights, s):
  return torch.sum((s - weights)**2)


def lossPZ(alphaz, weightz, z):
  Kz = maternz(z, z).evaluate()*lower_ones
  alphaz_ = torch.cat((alphaz[:,0], torch.tensor([1.0, 1.0])))
  f = weightz + torch.sum(Kz*alphaz_,dim=1)[:-1]
  return torch.sum((z[0] - weightz)**2) + torch.sum((z[1:] - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def trainModel(x, y, w, s, cov=None, n_iter=20000):
  loss_lst = []
  prog = trange(n_iter, desc='', leave=True)
  for t in prog:
      mean_z = QZ(alphaq, weightq, y, x, w, s, cov)
      z = mean_z + torch.randn(mean_z.shape)

      loss = lossPY(alphay, weighty, y, z, w, s, cov) + lossPX(alphax, weightx, x, z, s, cov) +\
              lossPW(alphaw, weightw, w, z, s, cov) + lossPS(weights, s) +\
                lossPZ(alphaz, weightz, z) - lossQZ(alphaq, weightq, z, y, x, w, s, cov)

      if t%50 == 0:
        prog.set_postfix_str("Iter {}:, Loss: {}".format(t,loss.item()))
        prog.refresh()
        
      loss_lst.append(loss.item())
      optimizer.zero_grad()
      loss.backward(retain_graph=True)
      optimizer.step()
  return np.asarray(loss_lst)

In [0]:
def predZ(alphaq, weightq, y, x, w, s, ynew, xnew, wnew, snew):
  Ky = materny(ynew, y).evaluate()
  Kx = maternx(xnew, x).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)


def evalLogPX(alphax, weightx, x, z, s, xpre, zpre, spre):
  Kz = maternz(zpre, z).evaluate()
  Ks = materns(spre, s).evaluate()
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return (-0.5*np.log(2*np.pi) - 0.5*(Xpre - f.reshape(-1,1))**2).detach().numpy().flatten()

def sampleY(z, w, s, znew, wnew, snew, alphay, weighty, weights_mixture, T, n):
  Y_samples = []
  
  Kz = maternz(znew, z).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  f = weighty + lz + lw + ls
  for i in range(n):
    noise = torch.from_numpy(np.random.normal(0,1,T)).float()

    Y = f + noise
    Y_samples.append(Y.detach().numpy())
  return np.asarray(Y_samples)

In [10]:
learn_hyper = False

dataset = NN2H(replications=10)

train_stats = []
test_stats = []
loss_lst = []
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
  W, Y, Y_cf, mu, X, S = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                          np.concatenate((train[1][1],train[1][2]),axis=1).T, train[0][0][:,0], train[0][0][:,1]
  T = len(Y)
  x = torch.from_numpy(X.reshape(T,-1)).float() 
  y = torch.from_numpy(Y.reshape(-1,1)).float()
  w = torch.from_numpy(W.reshape(-1,1)).float()
  s = torch.from_numpy(S.reshape(T,-1)).float()
  
  Wte, Yte, Y_cfte, mute, Xte, Ste = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                      np.concatenate((test[1][1],test[1][2]), axis=1).T, test[0][0][:,0], test[0][0][:,1]
  Tte = len(Yte)
  xte = torch.from_numpy(Xte.reshape(Tte,-1)).float() 
  yte = torch.from_numpy(Yte.reshape(-1,1)).float()
  wte = torch.from_numpy(Wte.reshape(-1,1)).float()
  ste = torch.from_numpy(Ste.reshape(Tte,-1)).float()

  lower_ones = torch.ones(len(y), len(y))
  lower_ones[np.triu_indices(len(y))] = 0
  
  # Compute kernel matrices
  maternx = gpytorch.kernels.RBFKernel()
  maternx.initialize(lengthscale=10.0)
  materns = gpytorch.kernels.RBFKernel()
  materns.initialize(lengthscale=10.0)
  materny = gpytorch.kernels.RBFKernel()
  materny.initialize(lengthscale=10.0)
  maternz = gpytorch.kernels.RBFKernel()
  maternz.initialize(lengthscale=10.0)
  maternw = gpytorch.kernels.RBFKernel()
  maternw.initialize(lengthscale=10.0)
  
  Ky = materny(y, y).evaluate()
  Kx = maternx(x, x).evaluate()
  Kw = maternw(w, w).evaluate()
  Ks = materns(s, s).evaluate()
  if learn_hyper == False:
    cov = (Ky, Kw, Kx, Ks)
  else:
    cov = None
  
  # Declare parameters
  alphay = torch.rand(len(y), 4, requires_grad=True)
  weighty = torch.rand(1, requires_grad=True)

  alphaq = torch.rand(len(y), 5, requires_grad=True)
  weightq = torch.rand(1, requires_grad=True)

  alphax = torch.rand(len(x), 2, requires_grad=True)
  weightx = torch.rand(1, requires_grad=True)

  alphaw = torch.rand(len(w), 2, requires_grad=True)
  weightw = torch.rand(1, requires_grad=True)

  weights = torch.rand(1, requires_grad=True)

  alphaz = torch.rand(len(y)-2, 1, requires_grad=True)
  weightz = torch.rand(1, requires_grad=True)  
  
  # Train
  params = [alphay, alphaq, alphax, alphaw, alphaz, weighty, weightq, weightx, weightw, weightz, weights]
  if learn_hyper == True:
    params = params + list(maternx.parameters()) + list(materns.parameters()) + list(maternw.parameters()) \
                    + list(materny.parameters()) + list(maternz.parameters())
  learning_rate = 1e-3
  optimizer = torch.optim.Adam(params, lr=learning_rate)
  loss_train = trainModel(x, y, w, s, cov=cov, n_iter=20000)
  loss_lst.append(loss_train)

  # Sample Ypred
  z_samples = predZ(alphaq, weightq, y, x, w, s,
                   torch.cat((y,yte),dim=0), torch.cat((x,xte),dim=0), torch.cat((w,wte),dim=0), torch.cat((s,ste),dim=0)).transpose(0,1)
  weights_mixture = 1.0

  Y_samples1 = sampleY(z_samples[0,:T].reshape(-1,1), torch.ones((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.ones((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Y_samples2 = sampleY(z_samples[0,:T].reshape(-1,1), torch.zeros((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.zeros((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Ypred1 = np.mean(Y_samples1,axis=0)
  Ypred0 = np.mean(Y_samples2,axis=0)
  
  # Evaluate
  evaluator_train = Evaluator(y=Y, t=W, y_cf=Y_cf, mu0=mu[0,:], mu1=mu[1,:])
  stat = evaluator_train.calc_stats(Ypred1[:T],Ypred0[:T])
  train_stats.append(stat)
  print('Train:', stat)

  evaluator_test = Evaluator(y=Yte, t=Wte, y_cf=Y_cfte, mu0=mute[0,:], mu1=mute[1,:])
  stat = evaluator_test.calc_stats(Ypred1[T:],Ypred0[T:])
  test_stats.append(stat)
  print('Test:', stat)
  
  sleep(0.5)

100%|██████████| 20000/20000 [03:07<00:00, 106.64it/s, Iter 19950:, Loss: 6113.04248046875]


Train: (1.1451759972396498, 0.16105867767334026, 0.16683698775649977)
Test: (1.5355199955745895, 0.17062498474121135, 0.17581192804876666)


100%|██████████| 20000/20000 [03:07<00:00, 106.92it/s, Iter 19950:, Loss: 10020.2099609375]


Train: (0.8873145241622713, 0.01750916696390803, 0.06284355880374291)
Test: (0.5648541378896034, 0.014356864929200519, 0.04710242566973514)


100%|██████████| 20000/20000 [03:03<00:00, 108.94it/s, Iter 19950:, Loss: 11351.8974609375]


Train: (1.297593339030988, 0.16137529754638713, 0.16849370280899148)
Test: (1.019039323065104, 0.16027952575683724, 0.16550924615297938)


100%|██████████| 20000/20000 [03:03<00:00, 108.92it/s, Iter 19950:, Loss: 7597.33642578125]


Train: (1.0384006278409774, 0.6736161575317379, 0.6749566830434143)
Test: (0.6963455109500061, 0.6713349685668941, 0.6729659403122986)


100%|██████████| 20000/20000 [03:08<00:00, 106.29it/s, Iter 19950:, Loss: 9186.7138671875]


Train: (1.2096537481598537, 0.26821731185913045, 0.27183353055666226)
Test: (1.2436061880115112, 0.26533292388915974, 0.26823624566355947)


100%|██████████| 20000/20000 [03:08<00:00, 105.84it/s, Iter 19950:, Loss: 8333.4814453125]


Train: (1.162995948706717, 0.303640617370605, 0.3068188380581396)
Test: (1.6424835058050746, 0.3105733528137202, 0.31385536911473844)


100%|██████████| 20000/20000 [03:08<00:00, 106.18it/s, Iter 19950:, Loss: 6738.220703125]


Train: (1.3263591871508937, 0.7950871124267582, 0.7963277496977148)
Test: (0.9162280124346264, 0.796870483398437, 0.7981511783215582)


100%|██████████| 20000/20000 [03:09<00:00, 105.58it/s, Iter 19950:, Loss: 10721.796875]


Train: (1.131660206819456, 0.28906179809570265, 0.29287294588159085)
Test: (1.708083423936851, 0.2836411132812504, 0.2857693911700937)


100%|██████████| 20000/20000 [03:09<00:00, 105.72it/s, Iter 19950:, Loss: 8161.3349609375]


Train: (1.1187271436083097, 0.5405309066772475, 0.5423123569292615)
Test: (1.1575896833974044, 0.5330884323120113, 0.5350571407996827)


100%|██████████| 20000/20000 [03:09<00:00, 105.77it/s, Iter 19950:, Loss: 8146.45166015625]


Train: (1.1749444894193555, 0.2188147803808338, 0.24945411405018628)
Test: (1.1731654097436448, 0.2009689941406263, 0.20798771376487357)


In [11]:
np.mean(train_stats,axis=0), np.std(train_stats,axis=0,ddof=1)/np.sqrt(10)

(array([1.14928252, 0.34289118, 0.35327505]),
 array([0.03944428, 0.07818891, 0.07542338]))

In [12]:
np.mean(test_stats,axis=0), np.std(test_stats,axis=0,ddof=1)/np.sqrt(10)

(array([1.16569152, 0.34070716, 0.34704466]),
 array([0.12154172, 0.07829925, 0.07657352]))