# Train and evaluate the model on NN6HIID

In [0]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import scipy
from scipy.special import softmax
import torch
import pandas as pd
from tqdm import trange
import random
from time import sleep
import gpytorch

In [None]:
PATH = ''
from evaluation import Evaluator
from datasets import NN6HIID

In [0]:
torch.cuda.get_device_properties(0).total_memory/1024/1024/1024

11.17303466796875

In [0]:
loss_bce = torch.nn.BCEWithLogitsLoss(reduction='sum')

def lossPY(alphay, weighty, y, z, w, s, cov):
  Kz = maternz(z, z).evaluate()
  
  if cov == None:
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    _, Kw, _, Ks = cov
  
  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + w.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  
  f = weighty + lz + lw + ls

  return torch.sum((y - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)

def QZ(alphaq, weightq, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)
  
def lossQZ(alphaq, weightq, z, y, x, w, s, cov):
  if cov == None:
    Ky = materny(y, y).evaluate()
    Kx = maternx(x, x).evaluate()
    Kw = maternw(w, w).evaluate()
    Ks = materns(s, s).evaluate()
  else:
    Ky, Kw, Kx, Ks = cov
  
  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-w.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + w.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  f = weightq + ly + lx + lw + ls
  
  return torch.sum((z - f.reshape(-1,1))**2) - 1e-3*torch.sum(f**2)

def lossPX(alphax, weightx, x, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return torch.sum((x - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def lossPW(alphaw, weightw, w, z, s, cov):
  Kz = maternz(z, z).evaluate()
  if cov == None:
    Ks = materns(s, s).evaluate()
  else:
    _, _, _, Ks = cov
    
  lz = torch.sum(Kz*alphaw[:,0],dim=1)
  ls = torch.sum(Ks*alphaw[:,1],dim=1)
  
  f = weightw + lz + ls

  return loss_bce(f.reshape(-1,1),w) + 1e-3*torch.sum(f**2)

def lossPS(weights, s):
  return torch.sum((s - weights)**2)


def lossPZ(alphaz, weightz, z):
  Kz = maternz(z, z).evaluate()*lower_ones
  alphaz_ = torch.cat((alphaz[:,0], torch.tensor([1.0, 1.0])))
  f = weightz + torch.sum(Kz*alphaz_,dim=1)[:-1]
  return torch.sum((z[0] - weightz)**2) + torch.sum((z[1:] - f.reshape(-1,1))**2) + 1e-3*torch.sum(f**2)


def trainModel(x, y, w, s, cov=None, n_iter=20000):
  loss_lst = []
  prog = trange(n_iter, desc='', leave=True)
  for t in prog:
      mean_z = QZ(alphaq, weightq, y, x, w, s, cov)
      z = mean_z + torch.randn(mean_z.shape)

      loss = lossPY(alphay, weighty, y, z, w, s, cov) + lossPX(alphax, weightx, x, z, s, cov) +\
              lossPW(alphaw, weightw, w, z, s, cov) + lossPS(weights, s) +\
                lossPZ(alphaz, weightz, z) - lossQZ(alphaq, weightq, z, y, x, w, s, cov)

      if t%50 == 0:
        prog.set_postfix_str("Iter {}:, Loss: {}".format(t,loss.item()))
        prog.refresh()
        
      loss_lst.append(loss.item())
      optimizer.zero_grad()
      loss.backward(retain_graph=True)
      optimizer.step()
  return np.asarray(loss_lst)

In [0]:
def predZ(alphaq, weightq, y, x, w, s, ynew, xnew, wnew, snew):
  Ky = materny(ynew, y).evaluate()
  Kx = maternx(xnew, x).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  ly = torch.sum(Ky*alphaq[:,0],dim=1)
  lx = torch.sum(Kx*alphaq[:,1],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphaq[:,2],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphaq[:,3],dim=1)
  ls = torch.sum(Ks*alphaq[:,4],dim=1)
  
  z = weightq + ly + lx + lw + ls
  return z.reshape(-1,1)


def evalLogPX(alphax, weightx, x, z, s, xpre, zpre, spre):
  Kz = maternz(zpre, z).evaluate()
  Ks = materns(spre, s).evaluate()
  
  lz = torch.sum(Kz*alphax[:,0],dim=1)
  ls = torch.sum(Ks*alphax[:,1],dim=1)
  
  f = weightx + lz + ls
  
  return (-0.5*np.log(2*np.pi) - 0.5*(Xpre - f.reshape(-1,1))**2).detach().numpy().flatten()

def sampleY(z, w, s, znew, wnew, snew, alphay, weighty, weights_mixture, T, n):
  Y_samples = []
  
  Kz = maternz(znew, z).evaluate()
  Kw = maternw(wnew, w).evaluate()
  Ks = materns(snew, s).evaluate()

  lz = torch.sum(Kz*alphay[:,0],dim=1)
  lw = (1-wnew.reshape(-1))*torch.sum(Kw*alphay[:,1],dim=1) + wnew.reshape(-1)*torch.sum(Kw*alphay[:,2],dim=1)
  ls = torch.sum(Ks*alphay[:,3],dim=1)
  f = weighty + lz + lw + ls
  for i in range(n):
    # choice = np.random.choice(range(len(weights_mixture)), p=weights_mixture)
    # noise = torch.randn(T)
    noise = torch.from_numpy(np.random.normal(0,1,T)).float()

    Y = f + noise
    Y_samples.append(Y.detach().numpy())
  return np.asarray(Y_samples)

In [9]:
learn_hyper = False

dataset = NN6HIID(replications=10)

train_stats = []
test_stats = []
loss_lst = []
for i, (train, valid, test, contfeats, binfeats) in enumerate(dataset.get_train_valid_test()):
  W, Y, Y_cf, mu, X, S = train[0][1].reshape(-1), train[0][2].reshape(-1), train[1][0].reshape(-1),\
                          np.concatenate((train[1][1],train[1][2]),axis=1).T, train[0][0][:,0], train[0][0][:,1]
  T = len(Y)
  x = torch.from_numpy(X.reshape(T,-1)).float() 
  y = torch.from_numpy(Y.reshape(-1,1)).float()
  w = torch.from_numpy(W.reshape(-1,1)).float()
  s = torch.from_numpy(S.reshape(T,-1)).float()
  
  Wte, Yte, Y_cfte, mute, Xte, Ste = test[0][1].reshape(-1), test[0][2].reshape(-1), test[1][0].reshape(-1), \
                                      np.concatenate((test[1][1],test[1][2]), axis=1).T, test[0][0][:,0], test[0][0][:,1]
  Tte = len(Yte)
  xte = torch.from_numpy(Xte.reshape(Tte,-1)).float() 
  yte = torch.from_numpy(Yte.reshape(-1,1)).float()
  wte = torch.from_numpy(Wte.reshape(-1,1)).float()
  ste = torch.from_numpy(Ste.reshape(Tte,-1)).float()
  
  lower_ones = torch.ones(len(y), len(y))
  lower_ones[np.triu_indices(len(y))] = 0
  
  # Compute kernel matrices
  maternx = gpytorch.kernels.RQKernel()
  maternx.initialize(lengthscale=10.0)
  maternx.initialize(alpha=2.0)
  materns = gpytorch.kernels.RQKernel()
  materns.initialize(lengthscale=10.0)
  materns.initialize(alpha=2.0)
  materny = gpytorch.kernels.RQKernel()
  materny.initialize(lengthscale=10.0)
  materny.initialize(alpha=2.0)
  maternz = gpytorch.kernels.RQKernel()
  maternz.initialize(lengthscale=10.0)
  maternz.initialize(alpha=2.0)
  maternw = gpytorch.kernels.RQKernel()
  maternw.initialize(lengthscale=10.0)
  maternw.initialize(alpha=2.0)
  
  Ky = materny(y, y).evaluate()
  Kx = maternx(x, x).evaluate()
  Kw = maternw(w, w).evaluate()
  Ks = materns(s, s).evaluate()
  if learn_hyper == False:
    cov = (Ky, Kw, Kx, Ks)
  else:
    cov = None
  
  # Declare parameters
  alphay = torch.rand(len(y), 4, requires_grad=True)
  weighty = torch.rand(1, requires_grad=True)

  alphaq = torch.rand(len(y), 5, requires_grad=True)
  weightq = torch.rand(1, requires_grad=True)

  alphax = torch.rand(len(x), 2, requires_grad=True)
  weightx = torch.rand(1, requires_grad=True)

  alphaw = torch.rand(len(w), 2, requires_grad=True)
  weightw = torch.rand(1, requires_grad=True)

  weights = torch.rand(1, requires_grad=True)

  alphaz = torch.rand(len(y)-2, 1, requires_grad=True)
  weightz = torch.rand(1, requires_grad=True)  
  
  # Train
  params = [alphay, alphaq, alphax, alphaw, alphaz, weighty, weightq, weightx, weightw, weightz, weights]
  if learn_hyper == True:
    params = params + list(maternx.parameters()) + list(materns.parameters()) + list(maternw.parameters()) \
                    + list(materny.parameters()) + list(maternz.parameters())
  learning_rate = 1e-3
  optimizer = torch.optim.Adam(params, lr=learning_rate)
  loss_train = trainModel(x, y, w, s, cov=cov, n_iter=20000)
  loss_lst.append(loss_train)

  # Sample Ypred
  z_samples = predZ(alphaq, weightq, y, x, w, s,
                   torch.cat((y,yte),dim=0), torch.cat((x,xte),dim=0), torch.cat((w,wte),dim=0), torch.cat((s,ste),dim=0)).transpose(0,1)
  weights_mixture = 1.0

  Y_samples1 = sampleY(z_samples[0,:T].reshape(-1,1), torch.ones((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.ones((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Y_samples2 = sampleY(z_samples[0,:T].reshape(-1,1), torch.zeros((T,1)), s.reshape(-1,1),
                       z_samples[0,:].reshape(-1,1), torch.zeros((T+Tte,1)), torch.cat((s,ste),dim=0).reshape(-1,1),
                       alphay, weighty, weights_mixture, T=T+Tte, n=1000)
  Ypred1 = np.mean(Y_samples1,axis=0)
  Ypred0 = np.mean(Y_samples2,axis=0)
  
  # Evaluate
  evaluator_train = Evaluator(y=Y, t=W, y_cf=Y_cf, mu0=mu[0,:], mu1=mu[1,:])
  stat = evaluator_train.calc_stats(Ypred1[:T],Ypred0[:T])
  train_stats.append(stat)
  print('Train:', stat)

  evaluator_test = Evaluator(y=Yte, t=Wte, y_cf=Y_cfte, mu0=mute[0,:], mu1=mute[1,:])
  stat = evaluator_test.calc_stats(Ypred1[T:],Ypred0[T:])
  test_stats.append(stat)
  print('Test:', stat)
  
  sleep(0.5)

100%|██████████| 20000/20000 [04:24<00:00, 75.73it/s, Iter 19950:, Loss: 1407.8463134765625]


Train: (2.323383229664898, 0.1320915263342859, 0.6088513710403669)
Test: (2.5050190238652235, 0.03350303235670182, 0.2567642730933663)


100%|██████████| 20000/20000 [04:26<00:00, 75.12it/s, Iter 19950:, Loss: 1301.3209228515625]


Train: (2.1773865226991664, 0.2199615740572689, 0.5036189509806217)
Test: (2.3415740739447948, 0.3448604466969263, 0.6865329231343948)


100%|██████████| 20000/20000 [04:25<00:00, 75.24it/s, Iter 19950:, Loss: 1452.56689453125]


Train: (2.364851920388114, 0.20399270378233947, 0.47185563136410646)
Test: (2.4066058502518253, 0.1706924891442707, 0.3423272982601154)


100%|██████████| 20000/20000 [04:33<00:00, 73.14it/s, Iter 19950:, Loss: 1457.5726318359375]


Train: (2.333249764198014, 0.04197031052749267, 0.6318489143376229)
Test: (2.1847971516864604, 0.14001830368597545, 0.1728394869453863)


100%|██████████| 20000/20000 [04:42<00:00, 70.72it/s, Iter 19950:, Loss: 1480.5194091796875]


Train: (2.3713664767837703, 0.4019741825867156, 0.7749635880860521)
Test: (2.540163252837788, 0.2907958495545824, 0.7987696169435032)


100%|██████████| 20000/20000 [04:39<00:00, 71.52it/s, Iter 19950:, Loss: 1464.310302734375]


Train: (2.6030505593833793, 0.7677775567572387, 0.9888397132285065)
Test: (2.2300418098730113, 0.7477251336137067, 0.9244985885546868)


100%|██████████| 20000/20000 [04:40<00:00, 71.33it/s, Iter 19950:, Loss: 1411.7667236328125]


Train: (2.362110262211889, 0.5036525549905022, 0.71512696652692)
Test: (1.91966687290752, 0.3968604063858958, 0.4436047438856305)


100%|██████████| 20000/20000 [04:40<00:00, 71.33it/s, Iter 19950:, Loss: 1411.7451171875]


Train: (2.2793734884022823, 0.2341042811465357, 0.5549098200220691)
Test: (2.9293511836896493, 0.11302481658935548, 0.11826341991375572)


100%|██████████| 20000/20000 [04:38<00:00, 71.80it/s, Iter 19950:, Loss: 1400.3580322265625]


Train: (2.275703141254723, 0.5384806565804312, 0.732646219292078)
Test: (2.5395521273418207, 0.5570094081319956, 0.6466054019950281)


100%|██████████| 20000/20000 [04:35<00:00, 72.72it/s, Iter 19950:, Loss: 1430.8450927734375]


Train: (2.2868208803480394, 0.6866567985320975, 0.7674055757476916)
Test: (3.2033631580388056, 0.6273337126276806, 0.6702267149936533)


In [10]:
np.mean(train_stats,axis=0), np.std(train_stats,axis=0,ddof=1)/np.sqrt(10)

(array([2.33772962, 0.37306621, 0.67500668]),
 array([0.03474301, 0.07734605, 0.04867791]))

In [11]:
np.mean(test_stats,axis=0), np.std(test_stats,axis=0,ddof=1)/np.sqrt(10)

(array([2.48001345, 0.34218236, 0.50604325]),
 array([0.1164812 , 0.07569232, 0.0878334 ]))