In [1]:
import torch
import torch.nn as nn
import torchvision
import tqdm
import time
import math

In [2]:
from gpytorch.variational import VariationalStrategy
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP

In [3]:
import sys
import urllib.request
import os
from scipy.io import loadmat
from math import floor
import pandas as pd
import numpy as np

if torch.cuda.is_available():
    device = 'cuda:4' # change here to tune the device we use
else:
    device = 'cpu'
# device = 'cpu'

# load data and pre-processing
# regression: UCI dataset, classification: Image dataset
dataset = 'energy' # set here to determine to train which dataset
if dataset == 'elevators':
    # D=18, N=13768
    data = torch.Tensor(loadmat('data/elevators.mat')['data'])
    X = data[:, :-1]  # the last column is label
    y = data[:, -1]
elif dataset == 'energy':
    # D=8, N=768
    data = pd.read_excel('data/energy.xlsx').values
    X = torch.Tensor(data[:, :8]) # the last two column is label
    y1 = torch.Tensor(data[:, 8])
    y2 = torch.Tensor(data[:, 9])
    y = y1 # change here to run on another label
elif dataset == 'protein':
    # D=9, N=45730
    data = pd.read_csv('data/protein.csv').values
    X = torch.Tensor(data[:, 1:])
    y = torch.Tensor(data[:, 0]) # the first column is label
elif dataset == 'power':
    # D=4, N=9568
    sheet_num = 0 # power data has five sheet, change it from 0 to 4
    data = pd.read_excel('data/power.xlsx', sheet_name=sheet_num).values
    X = torch.Tensor(data[:, :4])
    y = torch.Tensor(data[:, 4])
elif dataset == 'concrete':
    # D=8, N=1030
    data = pd.read_excel('data/concrete.xls', header=0).values
    X = torch.Tensor(data[:, :8])
    y = torch.Tensor(data[:, 8])
elif dataset == 'year_msd':
    # D=90, N=515345
    # data = np.loadtxt('data/year_msd.txt', delimiter=',')
    raise NotImplementedError('It has not been implemented yet')
    data = pd.read_csv('data/year_msd.txt', header=None, delimiter=',').values
    X = torch.Tensor(data[:, 1:])
    y = torch.Tensor(data[:, 0]) # the first column is label
elif dataset == 'boston':
    # D=13, N=506
    from sklearn.datasets import load_boston    
    boston = load_boston()
    X = torch.Tensor(boston.data)
    y = torch.Tensor(boston.target)
elif dataset == 'kin8nm':
    # D=8, N=8192
    data = pd.read_csv('data/kin8nm.csv', header=None).values
    X = torch.Tensor(data[:, :8])
    y = torch.Tensor(data[:, 8])
elif dataset == 'yacht':
    # D=6, N=308
    data = pd.read_csv('data/yacht.csv', header=None).values
    X = torch.Tensor(data[:, :6])
    y = torch.Tensor(data[:, 6])
elif dataset == 'qsar':
    # D=8, N=546
    data = pd.read_csv('data/qsar.csv', header=None).values
    X = torch.Tensor(data[:, :8])
    y = torch.Tensor(data[:, 8])
    
# pre-processing
X = X - X.min(0)[0]  # X.min(0)[0]: min value of every feature
X = 2 * (X / X.max(0)[0]) - 1 # pre-preocess to [-1,1]
# X -= X.mean(0)
# X /= X.std(0)
y -= y.mean()
y /= y.std() # pre-process to N(0,1)

# split data
ratio = 0.9 # ratio to split train set
train_n = int(floor(ratio * len(X))) # split train set, ratio: 0.8
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()
        
# move data to cuda or cpu
train_x, train_y, test_x, test_y = train_x.to(device), train_y.to(device), test_x.to(device), test_y.to(device)

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

In [4]:
from torch.utils.data import TensorDataset, DataLoader
if dataset == 'elevators':
    loader_batch_size = 1024
elif dataset == 'energy':
    loader_batch_size = 64
elif dataset == 'protein':
    loader_batch_size = 2048
elif dataset == 'power':
    loader_batch_size = 1024
elif dataset == 'concrete':
    loader_batch_size = 128
    loader_batch_size = 512
elif dataset == 'year_msd':
    loader_batch_size = 8192
elif dataset == 'boston':
    loader_batch_size = 64
elif dataset == 'kin8nm':
    loader_batch_size = 1024
elif dataset == 'yacht':
    loader_batch_size = 128
elif dataset == 'qsar':
    loader_batch_size = 128

train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=loader_batch_size, shuffle=True)
test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=loader_batch_size)

In [5]:
print_time = False # set True to print module time
print_norm = False # set True to print norm of f_u (the perfect value of norm after training: 0)
print_metric = True # set True to print metric per epoch
print_param = True # set True to print parameter size of each network
print_loss = False # set True to print each loss during training
sample_times = 10 # Sample k times to evaluate the expectation over q(u)
concat_type = False # set True to concat e with z in generator G
expect_mean = False # Set True to compute mean of q(f) by average over q(u)

class ToyDeepGPHiddenLayer(DeepGPLayer):  # input_dims: size of feature dim
    def __init__(self, input_dims, output_dims, num_inducing=128, mean_type='constant', stein=False, noise_add=True,
                 noise_share=False, noise_dim=1, multi_head=False, vector_type=False):
        # TODO: adjust the size of inducing_points to [inducing_size, input_dim] for each layer
        if stein is False:
            if output_dims is None:
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True) # sample from gaussian dist
                batch_shape = torch.Size([])
            else: # inducing_points-dim correspond to: output_dims, num_inducing and input_dims
                inducing_points = torch.randn((output_dims, num_inducing, input_dims), requires_grad=True)
                batch_shape = torch.Size([output_dims])
        else:
            if output_dims is None:
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True) # sample from gaussian dist
                batch_shape = torch.Size([])
            else: # inducing_points-dim correspond to: [num_inducing, input_dims]
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True)
                batch_shape = torch.Size([output_dims])
        
        self.input_dim = input_dims
        self.output_dim = output_dims
        self.inducing_size = num_inducing

        self.noise_add = noise_add
        self.noise_share = noise_share
        self.noise_dim = noise_dim # get this value from the outer model
        if self.noise_dim is None:
            self.noise_dim = 32
        self.noise_add = noise_add
        if self.noise_add:
            self.concat_type = concat_type
        else:
            self.concat_type = False # only can set True when adding noise

        # self.share_noise = share_noise # noise shared by all layers
        self.multi_head = multi_head
        # self.transformed_noise = transformed_noise
        self.learn_inducing_locations = True # Set True to treat location of z as parameter
        
        self.hutch_times = 1 # number for hutchison estimation of trace
        self.bottleneck_trick = True # Set True to introduce decomposition of jacobian to reduce variance
        self.rademacher_type = False # Set True to sample from Rademecher dist to compute trace
        self.diagonal_type = True # Set True to use diagonal kernel matrix to approximate fully kernel matrix
        self.vector_type = vector_type # Set True to generate u using vector-based network
        
        self.print_time = print_time
        self.print_param = print_param
        self.device = device
        self.sample_times = sample_times # Sample k times to evaluate the expectation over q(u)
        self.expect_mean = expect_mean
        
        variational_distribution = None

        variational_strategy = VariationalStrategy( # transform q(U) to q(F)
            self,
            inducing_points,
            variational_distribution,
            stein_type=stein,
            batch_shape=batch_shape
        )

        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)        
        self.variational_strategy = variational_strategy

    def forward(self, x):
        return None
    
    # will be called when run the forward function of class DeepGP
    def __call__(self, x, train_type=True, disc_type=True, hyp_type=True, trace_layer_list=None, norm_layer_list=None, u_layer_list=None,
                 fiu_layer_list=None, glogpfu_layer_list=None, shared_list=None, transformed_list=None, *other_inputs,
                 **kwargs): 
        """
        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
        easily.
        """
        return super().__call__(x, are_samples=False, train_type=train_type, disc_type=disc_type,
                                hyp_type=hyp_type, trace_layer_list=trace_layer_list, norm_layer_list=norm_layer_list,
                                u_layer_list=u_layer_list, fiu_layer_list=fiu_layer_list, 
                                glogpfu_layer_list=glogpfu_layer_list, shared_list=shared_list,
                               transformed_list=transformed_list)

In [6]:
num_output_dims = 10  # number of output_dim in hidden layer
vector_type = True # Set True to generate u using vector-based network
learn_likelihood_covariance = True # Set True to treat the covariance of likelihood as parameter
task_dim = None
train_x_shape = train_x.shape[-1]
        
# we can fine-tune it to obtain a better result
if dataset == 'elevators':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'energy':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'protein':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'power':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'concrete':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'year_msd':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'boston':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'kin8nm':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'yacht':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'qsar':
    noise_dim = 32
    num_layer_inducing = 128
    
    
class DeepGP(DeepGP): # define the noise outside the layer
    def __init__(self, train_x_shape, stein=False, noise_add=True, noise_share=False, multi_head=False): # L=2
        if noise_add is False and noise_share:
            raise ValueError('Only can share noise across layer when noise is added')
        if noise_share is False and multi_head:
            raise ValueError('Only can use Multi-head Mechanism when noise is shared across layer')
        super().__init__()
        self.vector_type = vector_type
        self.noise_add = noise_add
        self.noise_share = noise_share # share the noise across layer or not
        # fine-tune this term or directly learn as a prior
        self.noise_dim = noise_dim
        self.multi_head = multi_head  # transform noise or not        
        self.back_bone = nn.Sequential(
            nn.Linear(in_features=self.noise_dim, out_features=32),
            nn.Sigmoid(),
            nn.Linear(in_features=32, out_features=self.noise_dim)
        )
        if self.multi_head and print_param:
            print('Generator backbone:', get_parameter_number(self.back_bone))
        
        lls_sigma = torch.Tensor([1.0]) # Set sigma to be a small number
        if learn_likelihood_covariance:
            self.register_parameter("lls_sigma", torch.nn.Parameter(lls_sigma))
        else:
            self.register_buffer("lls_sigma", lls_sigma)
        
        hidden_layer = ToyDeepGPHiddenLayer(
            input_dims=train_x_shape,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        
        hidden_layer2 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )   
        
        '''
        hidden_layer3 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer2.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        '''
        '''
        hidden_layer4 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer3.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        '''

        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer2.output_dim,
            output_dims=None,
            num_inducing=num_layer_inducing,
            mean_type='constant',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )

        self.hidden_layer = hidden_layer
        self.hidden_layer2 = hidden_layer2
        # self.hidden_layer3 = hidden_layer3
        # self.hidden_layer4 = hidden_layer4
        self.last_layer = last_layer
        
        # register all networks here!!!
        self.hidden_strategy_generator = self.hidden_layer.variational_strategy.generator
        self.hidden_strategy_discriminator = self.hidden_layer.variational_strategy.discriminator
        self.hidden_kernel_method = self.hidden_layer.variational_strategy.kernel_method
        self.hidden2_strategy_generator = self.hidden_layer2.variational_strategy.generator
        self.hidden2_strategy_discriminator = self.hidden_layer2.variational_strategy.discriminator
        self.hidden2_kernel_method = self.hidden_layer2.variational_strategy.kernel_method
        '''
        self.hidden3_strategy_generator = self.hidden_layer3.variational_strategy.generator
        self.hidden3_strategy_discriminator = self.hidden_layer3.variational_strategy.discriminator
        self.hidden3_kernel_method = self.hidden_layer3.variational_strategy.kernel_method
        
        self.hidden4_strategy_generator = self.hidden_layer4.variational_strategy.generator
        self.hidden4_strategy_discriminator = self.hidden_layer4.variational_strategy.discriminator
        self.hidden4_kernel_method = self.hidden_layer4.variational_strategy.kernel_method
        '''
        self.last_strategy_generator = self.last_layer.variational_strategy.generator
        self.last_strategy_discriminator = self.last_layer.variational_strategy.discriminator
        self.last_kernel_method = self.last_layer.variational_strategy.kernel_method

    def forward(self, inputs):
        # set train_type=False when evaluate, do not compute the loss to faster the process
        # Generate noise here when noise is shared or transformed
        if self.noise_add:
            if self.noise_share:
                if self.multi_head:
                    if self.vector_type:
                        shared_list = []
                        transformed_list = []
                        for i in range(0, sample_times):
                            shared_noise = torch.randn(self.noise_dim, requires_grad=True).to(device)
                            transformed_noise = self.back_bone(shared_noise)
                            shared_list.append(shared_noise)
                            transformed_list.append(transformed_noise)
                    else:
                        shared_list = []
                        transformed_list = []
                        for i in range(0, sample_times):
                            shared_noise = torch.randn((num_layer_inducing, self.noise_dim), requires_grad=True).to(device)
                            transformed_noise = self.back_bone(shared_noise)
                            shared_list.append(shared_noise)
                            transformed_list.append(transformed_noise)
                else:
                    if self.vector_type:
                        shared_list = []
                        transformed_list = None
                        for i in range(0, sample_times):
                            shared_noise = torch.randn(self.noise_dim, requires_grad=True).to(device)
                            shared_list.append(shared_noise)                        
                    else:
                        shared_list = []
                        transformed_list = None
                        for i in range(0, sample_times):
                            shared_noise = torch.randn((num_layer_inducing, self.noise_dim), requires_grad=True).to(device)
                            shared_list.append(shared_noise)  
            else:
                shared_list = None
                transformed_list = None
        else:
            shared_list = None
            transformed_list = None
        inputs, train_type, disc_type, hyp_type, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = inputs
        # determine to train backbone network and parameter sigma or not
        if train_type:
            if disc_type:
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(False)
            else:
                self.back_bone.requires_grad_(True)
                self.lls_sigma.requires_grad_(False)
        else:
            if hyp_type:
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(True)
            else: # use for test stage
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(False)
        
        hidden_rep1, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer(
                                                                            inputs, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        
        hidden_rep2, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer2(
                                                                            hidden_rep1, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        '''
        hidden_rep3, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer2(
                                                                            hidden_rep2, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        
        hidden_rep4, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer2(
                                                                            hidden_rep3, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        '''
        
        # get a prob distribution, not a deterministic vector value
        output = self.last_layer(hidden_rep2, train_type=train_type, disc_type=disc_type, hyp_type=hyp_type, 
                                 trace_layer_list=trace_layer_list, norm_layer_list=norm_layer_list, 
                                 u_layer_list=u_layer_list, fiu_layer_list=fiu_layer_list, 
                                 glogpfu_layer_list=glogpfu_layer_list, shared_list=shared_list, 
                                 transformed_list=transformed_list)
        return output

    def predict(self, test_loader):
        mus = []
        lls = []
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            output_batch, _, _, _, _, _ = self.forward((x_batch, False, False, False, [], [], [], [], []))
            mean_batch, covar_batch = output_batch
            f_batch = torch.distributions.MultivariateNormal(loc=mean_batch, covariance_matrix=covar_batch).rsample(
                    torch.Size([])) # sample f from q(f)
            f_batch = f_batch.squeeze(0) # size: [batch_size]
            preds = self.regression_likelihood(f_batch) # change here when dealing with classification
            mus.append(mean_batch.squeeze(0))
            lls.append(self.test_regression_lls(preds, y_batch))
                        
        return torch.cat(mus, dim=0), torch.cat(lls, dim=0)

    
    def regression_likelihood_loss(self, y_batch, output_batch, u_layer_list, fiu_layer_list, glogpfu_layer_list):
        # compute the likelihood loss without using the gpytorch framework
        # We now assume that output_batch == (mean_batch, covariance_batch)
        # We assume that it is single-task and final output_dim == 1
        # We now set S=1, so we just sample from q(f) once
        mean_batch, covariance_batch = output_batch
        f_batch = torch.distributions.MultivariateNormal(loc=mean_batch, covariance_matrix=covariance_batch).rsample(torch.Size([]))
        f_batch = torch.reshape(f_batch, (y_batch.shape[0],))
        distance_batch = y_batch - f_batch
        
        glogpyf_layer_list = [] # use to contain gradient of likelihood
        for l in range(0, len(u_layer_list)):
            glogpyf_l_list = []
            for k in range(0, sample_times):
                if expect_mean:
                    glogpyf_l = torch.autograd.grad(f_batch, u_layer_list[l][k], grad_outputs=distance_batch, 
                                                    retain_graph=True, create_graph=True)[0]
                else:
                    # Set expect_mean == False to generate q(f) only using the first element of u_layer_list
                    glogpyf_l = torch.autograd.grad(f_batch, u_layer_list[l][0], grad_outputs=distance_batch, 
                                                    retain_graph=True, create_graph=True)[0]
                glogpyf_l_list.append(glogpyf_l)
            glogpyf_layer_list.append(glogpyf_l_list)
        
        glogpyffu_layer_list = []
        # glogpyffu_layer_list: the second term in the first loss term in each layer
        for l in range(0, len(u_layer_list)):
            glogpyffu_l_list = []
            for k in range(0, sample_times):
                glogpyffu_l = torch.mm(glogpyf_layer_list[l][k].unsqueeze(0), fiu_layer_list[l][k].unsqueeze(1))
                glogpyffu_l_list.append(glogpyffu_l)
            glogpyffu_l_tensor = torch.stack(glogpyffu_l_list, dim=0)
            glogpyffu_l_mean = glogpyffu_l_tensor.mean(0) # average over sample_times
            glogpyffu_layer_list.append(glogpyffu_l_mean)
        glogpyffu_tensor = torch.stack(glogpyffu_layer_list, dim=0)
        glogpyffu_loss = glogpyffu_tensor.sum(0)
        
        sigma = self.lls_sigma.to(device)
        # sigma = torch.Tensor([1.0]).to(device)
        sigma2 = torch.pow(sigma, 2)
        glogpyffu_loss = torch.div(glogpyffu_loss, sigma2)
        
        return glogpyffu_loss
        
    # TODO: Still need to fix it 
    def test_regression_lls(self, y_predict, y_label): # use for regression
        # return the LL vector, size: [batch_size]
        # We now assume dim == 1
        # dim = y_label.shape[0]
        # size of y_label: [batch_size]
        sigma = self.lls_sigma.to(device)
        # sigma = torch.Tensor([1.0]).to(device)
        first_term = torch.log(torch.Tensor([2 * math.pi]))
        first_term = torch.mul(first_term, -0.5).to(device)
        second_term = torch.log(sigma)
        second_term = torch.mul(second_term, -1)
        res = y_label - y_predict
        res2 = torch.mul(res, res)
        # third_term = torch.mm(res.unsqueeze(0), res.unsqueeze(1))
        sigma2 = torch.pow(sigma, 2)
        third_factor_term = torch.div(torch.Tensor([-1]).to(device), torch.mul(sigma2, 2))
        third_term = torch.mul(res2, third_factor_term)
        # third_term = third_vector.mean(0) # average batch_size
        # test_ll = first_term + second_term + third_term
        test_ll = third_term.add(first_term).add(second_term)
        return test_ll
    
    def regression_likelihood(self, f_batch):
        # We assume that size of f_batch: [batch_size]
        # So we return y_predict with size: [batch_size]
        # We add a gaussian noise to f to form the final predict value y
        dim = f_batch.shape[0]
        sigma_batch = self.lls_sigma.to(device)
        noise_batch = torch.randn(dim).to(device)
        noise_final_batch = torch.mul(noise_batch, sigma_batch)
        y_predict = f_batch + noise_final_batch
        return y_predict

In [7]:
# NOTE: When strange error occur (especially wrong error line), restart the kernel
# If multi_head=True, Transform shared noise by a network and concat with layer-specified inducing points
model = DeepGP(train_x_shape, stein=True, noise_add=True, noise_share=True, multi_head=True)
print('Test whether you have cuda to run the process or not:', torch.cuda.is_available())
model = model.to(device)
print(get_parameter_number(model))
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

Generator backbone: {'Total': 2112, 'Trainable': 2112}
Generator: {'Total': 44320, 'Trainable': 44320}
Discriminator: {'Total': 42256, 'Trainable': 42256}
Kernel method: {'Total': 10, 'Trainable': 10}
Generator: {'Total': 44576, 'Trainable': 44576}
Discriminator: {'Total': 42256, 'Trainable': 42256}
Kernel method: {'Total': 12, 'Trainable': 12}
Generator: {'Total': 6560, 'Trainable': 6560}
Discriminator: {'Total': 4240, 'Trainable': 4240}
Kernel method: {'Total': 12, 'Trainable': 12}
Test whether you have cuda to run the process or not: True
{'Total': 186355, 'Trainable': 186355}


In [8]:
from gpytorch.settings import num_likelihood_samples

def train(num_epochs=3, num_disc=2, num_gen=2, num_hyp=10, num_samples=10, lamda=100, lamda_like=1):
    '''
    num_epochs = 3
    num_disc = 2 # use to train discriminator per epoch
    num_gen = 2 # use to train generator per epoch
    num_hyp = 5 # use to train hyper-parameter per epoch
    num_samples = 10 # use to evaluate likelihood
    lamda = 100 # hyperparameter for norm loss
    lamda_like = 1 # hyperparameter for likelihood loss (will be deleted soon)
    '''
    if num_epochs == 0:
        raise ValueError('At least one epoch need to be run!!!')
    
    epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
    for i in epochs_iter: # change here to add adversarial training
        print('New epoch!!!')
        if num_disc != 0:
            rmse_iter_list_disc = [] # list used to contain average rmse in each iteration
            lls_iter_list_disc = [] # list used to contain average log-likelihood in each iteration
            # fix hyperparameter v and generator G, update discriminator D num_nc times
            disc_iter = tqdm.notebook.tqdm(range(num_disc), desc='Train Discriminator', leave=False)
            for j in disc_iter:
                # Within each iteration, we will go over each minibatch of data
                minibatch_iter_disc = tqdm.notebook.tqdm(train_loader, desc="Minibatch-Disc", leave=False)
                rmse_batch_list_disc = [] # list used to store rmse in mini-batch
                lls_batch_list_disc = [] # list used to store log-likelihood in mini-batch
                for x_batch_disc, y_batch_disc in minibatch_iter_disc:
                    x_batch_disc, y_batch_disc = x_batch_disc.to(device), y_batch_disc.to(device)
                    with num_likelihood_samples(num_samples):
                        optimizer.zero_grad()
                        trace_layer_list_disc = []
                        norm_layer_list_disc = []
                        u_layer_list_disc = []
                        fiu_layer_list_disc = []
                        glogpfu_layer_list_disc = []
                        # trace_layer_list: contain the final trace in each layer
                        # norm_layer_list: contain the norm in each layer
                        # u_layer_list: contain u in each layer
                        # fiu_layer_list: contain f_u in each layer
                        # glogpu_layer_list: contain dot product of gradient of logp_u with f_u in each layer
                        # disc_type: determine update discriminator or generator
                        output_batch_disc, trace_layer_list_disc, norm_layer_list_disc, u_layer_list_disc, fiu_layer_list_disc, glogpfu_layer_list_disc = model.forward(
                            (x_batch_disc, True, True, False, trace_layer_list_disc, norm_layer_list_disc, u_layer_list_disc, fiu_layer_list_disc, glogpfu_layer_list_disc))
                        # size of output:([1, batch_size], [1, batch_size, batch_size])
                        # the second loss term
                        trace_layer_tensor_disc = torch.stack(trace_layer_list_disc, dim=0)
                        trace_loss_disc = trace_layer_tensor_disc.sum(0)  
                        # the third loss term
                        norm_layer_tensor_disc = torch.stack(norm_layer_list_disc, dim=0)
                        norm_loss_disc = norm_layer_tensor_disc.sum(0)
                        if print_norm:
                            print('DISC Norm of f_u:', norm_loss_disc.item())
                        norm_loss_disc = lamda * norm_loss_disc # need to negativate this loss 
                        # the first term in the first loss term 
                        glogpfu_tensor_disc = torch.stack(glogpfu_layer_list_disc, dim=0)
                        glogpu_loss_disc = glogpfu_tensor_disc.sum(0)
                        glogpu_loss_disc = -1 * glogpu_loss_disc
                        # the second term in the first loss term
                        # call the API instead of setting our own function
                        begin = time.time()
                        glogpyffu_loss_disc = lamda_like * model.regression_likelihood_loss(y_batch_disc, 
                                                output_batch_disc, u_layer_list_disc, fiu_layer_list_disc, 
                                                    glogpfu_layer_list_disc)                          
                        end = time.time()
                        if print_time:
                            print('Compute likelihood time:', str(end - begin), 's')
                        # the first loss term
                        score_loss_disc = glogpu_loss_disc + glogpyffu_loss_disc
                        # change here to form our own loss
                        original_loss_disc = torch.abs(score_loss_disc + trace_loss_disc)
                        loss_disc = original_loss_disc - norm_loss_disc
                        if print_loss:
                            print('DISC prior loss:', glogpu_loss_disc.item())
                            print('DISC likelihood loss:', glogpyffu_loss_disc.item())
                            print('DISC trace loss:', trace_loss_disc.item())
                            print('DISC original loss:', original_loss_disc.item())
                            print('DISC norm loss:', norm_loss_disc.item())
                            print('DISC total loss:', loss_disc.item())

                        torch.autograd.backward(-loss_disc) # nevigate it to max this loss
                        optimizer.step()
                        minibatch_iter_disc.set_postfix(loss=loss_disc.item())

                        mean_batch_disc, covar_batch_disc = output_batch_disc
                        '''
                        f_batch_list_disc = []
                        for t in range(0, covar_batch_disc.shape[0]):
                            f_line_disc = torch.distributions.MultivariateNormal(loc=mean_batch_disc[i, :], 
                                                                                  covariance_matrix=covar_batch_disc[i, :, :]).rsample(
                                torch.Size([])) # sample f from q(f)
                            f_batch_list_disc.append(f_line_disc)
                        f_batch_disc = torch.stack(f_batch_list_disc, dim=1)
                        '''
                        f_batch_disc = torch.distributions.MultivariateNormal(loc=mean_batch_disc, 
                                                                              covariance_matrix=covar_batch_disc).rsample(torch.Size([]))
                        f_batch_disc = f_batch_disc.squeeze(0) # size: [batch_size]
                        y_predict_disc = model.regression_likelihood(f_batch_disc)
                        rmse_batch_disc = torch.mean(torch.pow(y_predict_disc - y_batch_disc, 2)).sqrt() # rmse value in mini-batch
                        lls_batch_disc = model.test_regression_lls(y_predict_disc, y_batch_disc).mean(0)
                        rmse_batch_list_disc.append(rmse_batch_disc)
                        lls_batch_list_disc.append(lls_batch_disc)
                            
                            

                rmse_iter_disc = torch.stack(rmse_batch_list_disc, dim=0).mean(0) # rmse in one iteration
                lls_iter_disc = torch.stack(lls_batch_list_disc, dim=0).mean(0) # log-likelihood in one iteration
                if print_metric:
                    print(f"ITER DISC_RMSE: {rmse_iter_disc.item()}, ITER DISC_NLL: {-lls_iter_disc.item()}")
                rmse_iter_list_disc.append(rmse_iter_disc)
                lls_iter_list_disc.append(lls_iter_disc)
            

            rmse_epoch_disc = torch.stack(rmse_iter_list_disc, dim=0).mean(0) # average rmse in one epoch
            lls_epoch_disc = torch.stack(lls_iter_list_disc, dim=0).mean(0) # average log-likelihood in one epoch
            if print_metric:
                print(f"EPOCH DISC_RMSE: {rmse_epoch_disc.item()}, EPOCH DISC_NLL: {-lls_epoch_disc.item()}")
                    
            if print_loss:
                print('-----------------------------------')
        
        if num_gen != 0:
            rmse_iter_list_gen = [] # list used to contain average rmse in each iteration
            lls_iter_list_gen = [] # list used to contain average log-likelihood in each iteration
            # fix hyperparameter v and generator G, update discriminator D num_nc times
            gen_iter = tqdm.notebook.tqdm(range(num_gen), desc='Train Generator', leave=False)
            for j in gen_iter:
                # Within each iteration, we will go over each minibatch of data
                minibatch_iter_gen = tqdm.notebook.tqdm(train_loader, desc="Minibatch-Gen", leave=False)
                rmse_batch_list_gen = [] # list used to store rmse in mini-batch
                lls_batch_list_gen = [] # list used to store log-likelihood in mini-batch
                # fix discriminator D and train hyperparameter v with generator G
                for x_batch_gen, y_batch_gen in minibatch_iter_gen:
                    x_batch_gen, y_batch_gen = x_batch_gen.to(device), y_batch_gen.to(device)
                    with num_likelihood_samples(num_samples):
                        optimizer.zero_grad()
                        trace_layer_list_gen = []
                        norm_layer_list_gen = []
                        u_layer_list_gen = []
                        fiu_layer_list_gen = []
                        glogpfu_layer_list_gen = []
                        output_batch_gen, trace_layer_list_gen, norm_layer_list_gen, u_layer_list_gen, fiu_layer_list_gen, glogpfu_layer_list_gen = model.forward(
                                                                            (x_batch_gen, True, False, False, trace_layer_list_gen, 
                                                                            norm_layer_list_gen, u_layer_list_gen, fiu_layer_list_gen, 
                                                                            glogpfu_layer_list_gen))
                        trace_layer_tensor_gen = torch.stack(trace_layer_list_gen, dim=0)
                        trace_loss_gen = trace_layer_tensor_gen.sum(0)

                        norm_layer_tensor_gen = torch.stack(norm_layer_list_gen, dim=0)
                        norm_loss_gen = norm_layer_tensor_gen.sum(0)
                        norm_loss_gen = lamda * norm_loss_gen # need to negativate this loss

                        glogpfu_tensor_gen = torch.stack(glogpfu_layer_list_gen, dim=0)
                        glogpu_loss_gen = glogpfu_tensor_gen.sum(0)
                        glogpu_loss_gen = -1 * glogpu_loss_gen  # occur wrong when update generator G

                        # the second term in the first loss term
                        # call the API instead of setting our own function
                        begin = time.time()
                        glogpyffu_loss_gen = lamda_like * model.regression_likelihood_loss(y_batch_gen, 
                                                output_batch_gen, u_layer_list_gen, fiu_layer_list_gen, 
                                                    glogpfu_layer_list_gen)
                        end = time.time()
                        if print_time:
                            print('Compute likelihood time:', str(end - begin), 's')
                        # the first loss term
                        score_loss_gen = glogpu_loss_gen + glogpyffu_loss_gen
                        # change here to form our own loss
                        original_loss_gen = torch.abs(score_loss_gen + trace_loss_gen)
                        loss_gen = original_loss_gen - norm_loss_gen

                        if print_loss:
                            print('GEN prior loss:', glogpu_loss_gen.item())
                            print('GEN likelihood loss:', glogpyffu_loss_gen.item())
                            print('GEN trace loss:', trace_loss_gen.item())
                            print('GEN original loss:', original_loss_gen.item())
                            print('GEN norm loss:', norm_loss_gen.item())
                            print('GEN total loss:', loss_gen.item())

                        torch.autograd.backward(loss_gen)
                        optimizer.step()
                        minibatch_iter_gen.set_postfix(loss=loss_gen.item())

                        mean_batch_gen, covar_batch_gen = output_batch_gen
                        '''
                        f_batch_list_gen = []
                        for t in range(0, covar_batch_gen.shape[0]):
                            f_line_gen = torch.distributions.MultivariateNormal(loc=mean_batch_gen[i, :], 
                                                                                  covariance_matrix=covar_batch_gen[i, :, :]).rsample(
                                torch.Size([])) # sample f from q(f)
                            f_batch_list_gen.append(f_line_gen)
                        f_batch_gen = torch.stack(f_batch_list_gen, dim=1)
                        '''
                        f_batch_gen = torch.distributions.MultivariateNormal(loc=mean_batch_gen, 
                                                                              covariance_matrix=covar_batch_gen).rsample(torch.Size([]))
                        f_batch_gen = f_batch_gen.squeeze(0) # size: [batch_size]
                        
                        y_predict_gen = model.regression_likelihood(f_batch_gen)
                        rmse_batch_gen = torch.mean(torch.pow(y_predict_gen - y_batch_gen, 2)).sqrt() # rmse value in mini-batch
                        lls_batch_gen = model.test_regression_lls(y_predict_gen, y_batch_gen).mean(0)
                        rmse_batch_list_gen.append(rmse_batch_gen)
                        lls_batch_list_gen.append(lls_batch_gen)
                

                rmse_iter_gen = torch.stack(rmse_batch_list_gen, dim=0).mean(0) # rmse in one iteration
                lls_iter_gen = torch.stack(lls_batch_list_gen, dim=0).mean(0) # log-likelihood in one iteration
                if print_metric:
                    print(f"ITER GEN_RMSE: {rmse_iter_gen.item()}, ITER GEN_NLL: {-lls_iter_gen.item()}")
                rmse_iter_list_gen.append(rmse_iter_gen)
                lls_iter_list_gen.append(lls_iter_gen)
            

            rmse_epoch_gen = torch.stack(rmse_iter_list_gen, dim=0).mean(0) # average rmse in one epoch
            lls_epoch_gen = torch.stack(lls_iter_list_gen, dim=0).mean(0) # average log-likelihood in one epoch
            if print_metric:
                print(f"EPOCH GEN_RMSE: {rmse_epoch_gen.item()}, EPOCH GEN_NLL: {-lls_epoch_gen.item()}")            
                    
            if print_loss:
                print('-----------------------------------')

        if num_hyp != 0:
            rmse_iter_list_hyp = [] # list used to contain average rmse in each iteration
            lls_iter_list_hyp = [] # list used to contain average log-likelihood in each iteration
            # fix hyperparameter v and generator G, update discriminator D num_nc times
            hyp_iter = tqdm.notebook.tqdm(range(num_hyp), desc='Train Hyper-parameter', leave=False)
            for j in hyp_iter:
                # Within each iteration, we will go over each minibatch of data
                minibatch_iter_hyp = tqdm.notebook.tqdm(train_loader, desc="Minibatch-Hyper", leave=False)
                rmse_batch_list_hyp = [] # list used to store rmse in mini-batch
                lls_batch_list_hyp = [] # list used to store log-likelihood in mini-batch
                # fix discriminator D and train hyperparameter v with generator G
                for x_batch_hyp, y_batch_hyp in minibatch_iter_hyp:
                    x_batch_hyp, y_batch_hyp = x_batch_hyp.to(device), y_batch_hyp.to(device)
                    with num_likelihood_samples(num_samples):
                        optimizer.zero_grad()
                        output_batch_hyp, _, _, _, _, _ = model.forward((x_batch_hyp, False, False, True, [], [], [], [], []))
                        # call the API instead of setting our own function
                        mean_batch_hyp, covar_batch_hyp = output_batch_hyp
                        '''
                        f_batch_list_hyp = []
                        for t in range(0, covar_batch_hyp.shape[0]):
                            f_line_hyp = torch.distributions.MultivariateNormal(loc=mean_batch_hyp[i, :], 
                                                                                  covariance_matrix=covar_batch_hyp[i, :, :]).rsample(
                                torch.Size([])) # sample f from q(f)
                            f_batch_list_hyp.append(f_line_hyp)
                        f_batch_hyp = torch.stack(f_batch_list_hyp, dim=1)
                        '''
                        f_batch_hyp = torch.distributions.MultivariateNormal(loc=mean_batch_hyp, 
                                                                              covariance_matrix=covar_batch_hyp).rsample(torch.Size([]))
                        f_batch_hyp = f_batch_hyp.squeeze(0) # size: [batch_size]
                        
                        y_predict_hyp = model.regression_likelihood(f_batch_hyp)
                        loss_hyp = torch.mean(torch.pow(y_predict_hyp - y_batch_hyp, 2)).sqrt() # rmse value in mini-batch
                        
                        if print_loss:
                                print('HYP total loss:', loss_hyp.item())
                        torch.autograd.backward(loss_hyp)
                        optimizer.step()
                        minibatch_iter_hyp.set_postfix(loss=loss_hyp.item())
                        
                        rmse_batch_hyp = loss_hyp
                        lls_batch_hyp = model.test_regression_lls(y_predict_hyp, y_batch_hyp).mean(0)
                        rmse_batch_list_hyp.append(rmse_batch_hyp)
                        lls_batch_list_hyp.append(lls_batch_hyp)
                
                rmse_iter_hyp = torch.stack(rmse_batch_list_hyp, dim=0).mean(0) # rmse in one iteration
                lls_iter_hyp = torch.stack(lls_batch_list_hyp, dim=0).mean(0) # log-likelihood in one iteration
                if print_metric:
                    print(f"ITER HYP_RMSE: {rmse_iter_hyp.item()}, ITER HYP_NLL: {-lls_iter_hyp.item()}")
                rmse_iter_list_hyp.append(rmse_iter_hyp)
                lls_iter_list_hyp.append(lls_iter_hyp)

            rmse_epoch_hyp = torch.stack(rmse_iter_list_hyp, dim=0).mean(0) # average rmse in one epoch
            lls_epoch_hyp = torch.stack(lls_iter_list_hyp, dim=0).mean(0) # average log-likelihood in one epoch
            if print_metric:
                print(f"EPOCH HYP_RMSE: {rmse_epoch_hyp.item()}, EPOCH HYP_NLL: {-lls_epoch_hyp.item()}")
            
            if print_loss:
                print('-----------------------------------')

In [9]:
torch.autograd.set_detect_anomaly(True)

# training stage
print('----------Train Disc and Gen----------')
train(num_epochs=1, num_disc=1, num_gen=1, num_hyp=0, num_samples=10, lamda=100, lamda_like=1)
print('----------Train Hyp----------')
train(num_epochs=1, num_disc=0, num_gen=0, num_hyp=1, num_samples=10, lamda=100, lamda_like=1)
print('Training stage finish!!!')

----------Train Disc and Gen----------


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

New epoch!!!


Train Discriminator:   0%|          | 0/1 [00:00<?, ?it/s]

Minibatch-Disc:   0%|          | 0/11 [00:00<?, ?it/s]

ITER DISC_RMSE: 1.597905158996582, ITER DISC_NLL: 2.202610492706299
EPOCH DISC_RMSE: 1.597905158996582, EPOCH DISC_NLL: 2.202610492706299


Train Generator:   0%|          | 0/1 [00:00<?, ?it/s]

Minibatch-Gen:   0%|          | 0/11 [00:00<?, ?it/s]

ITER GEN_RMSE: 1.6173499822616577, ITER GEN_NLL: 2.236367702484131
EPOCH GEN_RMSE: 1.6173499822616577, EPOCH GEN_NLL: 2.236367702484131
----------Train Hyp----------


Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

New epoch!!!


Train Hyper-parameter:   0%|          | 0/1 [00:00<?, ?it/s]

Minibatch-Hyper:   0%|          | 0/11 [00:00<?, ?it/s]

ITER HYP_RMSE: 1.595822811126709, ITER HYP_NLL: 2.2999420166015625
EPOCH HYP_RMSE: 1.595822811126709, EPOCH HYP_NLL: 2.2999420166015625
Training stage finish!!!


In [10]:
test_itr = 10 # test iteration to get average metric and standard error

# evaluate stage
model.eval()

rmse_list = []
for r in range(0, test_itr):
    predictive_means, test_lls = model.predict(test_loader)

    rmse = torch.mean(torch.pow(predictive_means - test_y, 2)).sqrt()
    # Since NLL has different ways to compute, we only compare the RMSE metric
    print(f"RMSE: {rmse.item()}, NLL: {-test_lls.mean().item()}")
    rmse_list.append(rmse)

rmse_vector = torch.stack(rmse_list, dim=0)
rmse_mean = rmse_vector.mean(0)
rmse_std = rmse_vector.std(0)
print(f"RMSE Mean: {rmse_mean.item()}, Std: {rmse_std.item()}")

RMSE: 1.1482781171798706, NLL: 2.356001615524292
RMSE: 1.2388978004455566, NLL: 2.2036736011505127
RMSE: 1.2384344339370728, NLL: 2.247401714324951
RMSE: 1.2900049686431885, NLL: 2.410691499710083
RMSE: 1.0954948663711548, NLL: 2.610196113586426
RMSE: 1.2637896537780762, NLL: 2.4080896377563477
RMSE: 1.2467528581619263, NLL: 2.7060325145721436
RMSE: 1.2367781400680542, NLL: 2.2750627994537354
RMSE: 1.2568289041519165, NLL: 2.654531955718994
RMSE: 1.2341171503067017, NLL: 2.366696834564209
RMSE Mean: 1.2249376773834229, Std: 0.05818536505103111
