# Train and evaluate the model on NN4H

In [0]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy
from scipy.special import softmax
import torch
import pandas as pd
from tqdm import trange
import random
from time import sleep
import gpytorch

In [None]:
PATH = ''
from evaluation import Evaluator
from datasets import NN4H

In [6]:
torch.cuda.get_device_properties(0).total_memory/1024/1024/1024

15.8992919921875

In [0]:
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

def lossPY(alphay, weighty, y, z, w, s, cov):
  Kz = maternz(z, z).evaluate()
  
  if cov == None:
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    _, Kw, _, Ks = cov
  
  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + w.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  
  f = weighty + lz + lw + ls

  return torch.sum((y - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)

def QZ(alphaq, weightq, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)
  
def lossQZ(alphaq, weightq, z, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov
  
  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  f = weightq + ly + lx + lw + ls
  
  return torch.sum((z - f.reshape(-1,1))**2) - 1e-3*torch.sum(f**2)

def lossPX(alphax, weightx, x, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return torch.sum((x - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def lossPW(alphaw, weightw, w, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
    
  lz = torch.sum(Kz*alphaw[:,0],dim=1)
  ls = torch.sum(Ks*alphaw[:,1],dim=1)
  
  f = weightw + lz + ls

  return loss_bce(f.reshape(-1,1),w) + 1e-3*torch.sum(f**2)

def lossPS(weights, s):
  return torch.sum((s - weights)**2)


def lossPZ(alphaz, weightz, z):
  Kz = maternz(z, z).evaluate()*lower_ones
  alphaz_ = torch.cat((alphaz[:,0], torch.tensor([1.0, 1.0])))
  f = weightz + torch.sum(Kz*alphaz_,dim=1)[:-1]
  return torch.sum((z[0] - weightz)**2) + torch.sum((z[1:] - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def trainModel(x, y, w, s, cov=None, n_iter=20000):
  loss_lst = []
  prog = trange(n_iter, desc='', leave=True)
  for t in prog:
      mean_z = QZ(alphaq, weightq, y, x, w, s, cov)
      z = mean_z + torch.randn(mean_z.shape)

      loss = lossPY(alphay, weighty, y, z, w, s, cov) + lossPX(alphax, weightx, x, z, s, cov) +\
              lossPW(alphaw, weightw, w, z, s, cov) + lossPS(weights, s) +\
                lossPZ(alphaz, weightz, z) - lossQZ(alphaq, weightq, z, y, x, w, s, cov)

      if t%50 == 0:
        prog.set_postfix_str("Iter {}:, Loss: {}".format(t,loss.item()))
        prog.refresh()
        
      loss_lst.append(loss.item())
      optimizer.zero_grad()
      loss.backward(retain_graph=True)
      optimizer.step()
  return np.asarray(loss_lst)

In [0]:
def predZ(alphaq, weightq, y, x, w, s, ynew, xnew, wnew, snew):
  Ky = materny(ynew, y).evaluate()
  Kx = maternx(xnew, x).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)


def evalLogPX(alphax, weightx, x, z, s, xpre, zpre, spre):
  Kz = maternz(zpre, z).evaluate()
  Ks = materns(spre, s).evaluate()
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return (-0.5*np.log(2*np.pi) - 0.5*(Xpre - f.reshape(-1,1))**2).detach().numpy().flatten()

def sampleY(z, w, s, znew, wnew, snew, alphay, weighty, weights_mixture, T, n):
  Y_samples = []
  
  Kz = maternz(znew, z).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  f = weighty + lz + lw + ls
  for i in range(n):
    noise = torch.from_numpy(np.random.normal(0,1,T)).float()

    Y = f + noise
    Y_samples.append(Y.detach().numpy())
  return np.asarray(Y_samples)

In [10]:
learn_hyper = False

dataset = NN4H(replications=10)

train_stats = []
test_stats = []
loss_lst = []
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
  W, Y, Y_cf, mu, X, S = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                          np.concatenate((train[1][1],train[1][2]),axis=1).T, train[0][0][:,0], train[0][0][:,1]
  T = len(Y)
  x = torch.from_numpy(X.reshape(T,-1)).float() 
  y = torch.from_numpy(Y.reshape(-1,1)).float()
  w = torch.from_numpy(W.reshape(-1,1)).float()
  s = torch.from_numpy(S.reshape(T,-1)).float()
  
  Wte, Yte, Y_cfte, mute, Xte, Ste = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                      np.concatenate((test[1][1],test[1][2]), axis=1).T, test[0][0][:,0], test[0][0][:,1]
  Tte = len(Yte)
  xte = torch.from_numpy(Xte.reshape(Tte,-1)).float() 
  yte = torch.from_numpy(Yte.reshape(-1,1)).float()
  wte = torch.from_numpy(Wte.reshape(-1,1)).float()
  ste = torch.from_numpy(Ste.reshape(Tte,-1)).float()

  lower_ones = torch.ones(len(y), len(y))
  lower_ones[np.triu_indices(len(y))] = 0
  
  # Compute kernel matrices
  maternx = gpytorch.kernels.RBFKernel()
  maternx.initialize(lengthscale=10.0)
  materns = gpytorch.kernels.RBFKernel()
  materns.initialize(lengthscale=10.0)
  materny = gpytorch.kernels.RBFKernel()
  materny.initialize(lengthscale=10.0)
  maternz = gpytorch.kernels.RBFKernel()
  maternz.initialize(lengthscale=10.0)
  maternw = gpytorch.kernels.RBFKernel()
  maternw.initialize(lengthscale=10.0)
  
  Ky = materny(y, y).evaluate()
  Kx = maternx(x, x).evaluate()
  Kw = maternw(w, w).evaluate()
  Ks = materns(s, s).evaluate()
  if learn_hyper == False:
    cov = (Ky, Kw, Kx, Ks)
  else:
    cov = None
  
  # Declare parameters
  alphay = torch.rand(len(y), 4, requires_grad=True)
  weighty = torch.rand(1, requires_grad=True)

  alphaq = torch.rand(len(y), 5, requires_grad=True)
  weightq = torch.rand(1, requires_grad=True)

  alphax = torch.rand(len(x), 2, requires_grad=True)
  weightx = torch.rand(1, requires_grad=True)

  alphaw = torch.rand(len(w), 2, requires_grad=True)
  weightw = torch.rand(1, requires_grad=True)

  weights = torch.rand(1, requires_grad=True)

  alphaz = torch.rand(len(y)-2, 1, requires_grad=True)
  weightz = torch.rand(1, requires_grad=True)  
  
  # Train
  params = [alphay, alphaq, alphax, alphaw, alphaz, weighty, weightq, weightx, weightw, weightz, weights]
  if learn_hyper == True:
    params = params + list(maternx.parameters()) + list(materns.parameters()) + list(maternw.parameters()) \
                    + list(materny.parameters()) + list(maternz.parameters())
  learning_rate = 1e-3
  optimizer = torch.optim.Adam(params, lr=learning_rate)
  loss_train = trainModel(x, y, w, s, cov=cov, n_iter=20000)
  loss_lst.append(loss_train)

  # Sample Ypred
  z_samples = predZ(alphaq, weightq, y, x, w, s,
                   torch.cat((y,yte),dim=0), torch.cat((x,xte),dim=0), torch.cat((w,wte),dim=0), torch.cat((s,ste),dim=0)).transpose(0,1)
  weights_mixture = 1.0

  Y_samples1 = sampleY(z_samples[0,:T].reshape(-1,1), torch.ones((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.ones((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Y_samples2 = sampleY(z_samples[0,:T].reshape(-1,1), torch.zeros((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.zeros((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Ypred1 = np.mean(Y_samples1,axis=0)
  Ypred0 = np.mean(Y_samples2,axis=0)
  
  # Evaluate
  evaluator_train = Evaluator(y=Y, t=W, y_cf=Y_cf, mu0=mu[0,:], mu1=mu[1,:])
  stat = evaluator_train.calc_stats(Ypred1[:T],Ypred0[:T])
  train_stats.append(stat)
  print('Train:', stat)

  evaluator_test = Evaluator(y=Yte, t=Wte, y_cf=Y_cfte, mu0=mute[0,:], mu1=mute[1,:])
  stat = evaluator_test.calc_stats(Ypred1[T:],Ypred0[T:])
  test_stats.append(stat)
  print('Test:', stat)
  
  sleep(0.5)

100%|██████████| 20000/20000 [03:08<00:00, 106.23it/s, Iter 19950:, Loss: 5984.4853515625]


Train: (1.0470800399628617, 0.10789595611572267, 0.11634540805478527)
Test: (1.2859209989352334, 0.11746893890380683, 0.12488514337412471)


100%|██████████| 20000/20000 [03:05<00:00, 107.81it/s, Iter 19950:, Loss: 9914.228515625]


Train: (0.8737095459210746, 0.16566667406655267, 0.1734661080293692)
Test: (0.5709836664046991, 0.16363678939819248, 0.16967411789148346)


100%|██████████| 20000/20000 [03:04<00:00, 108.28it/s, Iter 19950:, Loss: 11184.146484375]


Train: (1.1311618214624748, 0.1224247074890128, 0.13166714420649395)
Test: (1.0116367820745888, 0.12133418090820136, 0.1281623263295995)


100%|██████████| 20000/20000 [03:01<00:00, 110.05it/s, Iter 19950:, Loss: 7459.76025390625]


Train: (0.9675927478808345, 0.5405239009094256, 0.5421937934469528)
Test: (0.6176014436114533, 0.5382379435730016, 0.5402696525070274)


100%|██████████| 20000/20000 [03:03<00:00, 109.20it/s, Iter 19950:, Loss: 9066.5498046875]


Train: (1.1673833155985047, 0.04908265106201348, 0.06605027150711114)
Test: (1.249498005446018, 0.04619301788330166, 0.06068615584282313)


100%|██████████| 20000/20000 [03:02<00:00, 109.53it/s, Iter 19950:, Loss: 8199.138671875]


Train: (1.0360925501926355, 0.19912634857177736, 0.20394010657272527)
Test: (1.2695356347375086, 0.2060557461547825, 0.2109712134247456)


100%|██████████| 20000/20000 [02:57<00:00, 112.44it/s, Iter 19950:, Loss: 6623.99560546875]


Train: (1.3388149918095251, 0.7831336117553711, 0.7843923981846673)
Test: (1.0056408587394314, 0.7849169827270508, 0.7862171930885274)


100%|██████████| 20000/20000 [03:01<00:00, 110.29it/s, Iter 19950:, Loss: 10597.43359375]


Train: (1.0265954401564272, 0.15534792907714667, 0.16232943518129192)
Test: (1.473331407927453, 0.1499320126342747, 0.15391723777615296)


100%|██████████| 20000/20000 [03:00<00:00, 110.82it/s, Iter 19950:, Loss: 8051.5068359375]


Train: (0.8891896850456132, 0.27850712768554864, 0.28194917889933524)
Test: (0.9679666872026551, 0.27106226913452325, 0.2749126015940061)


100%|██████████| 20000/20000 [02:57<00:00, 112.55it/s, Iter 19950:, Loss: 8047.927734375]


Train: (1.324777784004254, 0.504612147855644, 0.5130282747911616)
Test: (1.5289416605022386, 0.48943673141479493, 0.49235790359357295)


In [11]:
np.mean(train_stats,axis=0), np.std(train_stats,axis=0)

(array([1.08023979, 0.29063211, 0.29753621]),
 array([0.15317229, 0.22655285, 0.22359091]))

In [12]:
np.mean(test_stats,axis=0), np.std(test_stats,axis=0)

(array([1.09810571, 0.28882746, 0.29420535]),
 array([0.30914417, 0.22516108, 0.22251909]))