In [1]:
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import torch
from torch.autograd import Variable
from torch.autograd import grad
from nash_advreg import *

In [2]:
def learner_cost_flatten(X, y, w, params):
    X = X.view( -1, w.shape[1]-1 )
    weights = w[0,:-1].view(1,-1)
    bias = w[0,-1]
    return torch.sum( (X @ weights.t() + bias - y)**2 ) +  params["lmb"] * weights @ weights.t()

def attacker_cost_flatten(X, X_clean, c_d, w, params):
    z = params["z_train"]
    weights = w[0,:-1].view(1,-1)
    bias = w[0,-1]
    X = X.view(X_clean.shape[0],-1)
    ##
    diff = X_clean - X
    return  torch.sum( c_d*(X @ weights.t() + bias - z)**2 )  +  torch.sum(diff**2)

def compute_backwar_derivative(X_clean, y, w, c_d, params):
    ##
    S = params["inner_epochs"]
    ilr = params["inner_lr"]
    ##
    gm = lambda w, X: attacker_cost_flatten(X, X_clean, c_d, w, params)
    fm = lambda w, X: learner_cost_flatten(X, y, w, params)
    X = torch.randn(X_clean.shape[0]*X_clean.shape[1], requires_grad=True)
    ##
    Xt = torch.zeros(int(S), X.shape[0])
    ## Inner loop
    for j in range(S):
        grad_X = torch.autograd.grad( gm(w,X) , X, create_graph=True )[0]
        new_X = X - ilr*grad_X
        X = Variable(new_X, requires_grad=True)
        Xt[j] = X ## Store for later usage
    ########
    alpha = -torch.autograd.grad( fm(w,X), X, retain_graph=True )[0]
    gr = torch.zeros_like(w)
    ########
    for j in range(S-1,-1,-1):
        X_tmp = Variable(Xt[j], requires_grad=True)
        grad_X, = torch.autograd.grad( gm(w, X_tmp), X_tmp, create_graph=True )
        loss = X_tmp - ilr*grad_X
        loss = loss@alpha ## To compute Hessian Vector Product
        aux1 = torch.autograd.grad(loss, w, retain_graph=True)[0]
        aux2 = torch.autograd.grad(loss, X_tmp)[0]
        gr -= aux1
        alpha = aux2

    grad_w = torch.autograd.grad( fm(w, X), w )[0]
    ##
    return grad_w + gr

In [3]:
data = pd.read_csv("../data/winequality-white.csv", sep = ";")
X = data.loc[:, data.columns != "quality"]
y = data.quality
##
pca = PCA(n_components=X.shape[1], svd_solver='full')
pca.fit(X)
X = pca.fit_transform(X)
##

## Bayesian Case

In [4]:
MEAN = 0.5
VAR = 0.01
m = torch.distributions.Gamma(torch.tensor([MEAN**2/VAR]), torch.tensor([MEAN/VAR])) ## shape, rate

In [5]:
X_train, y_train, X_test, y_test = create_train_test(X,y)
params = {
                "epochs_rr"    : 1000,
                "lr_rr"        : 0.01,
                "lmb"          : 0.0,
                "c_d_train"    : torch.ones([len(y_train), 1]) * MEAN,
                "z_train"      : torch.zeros([len(y_train),1]),
                "c_d_test"     : torch.ones([len(y_test), 1]) * MEAN,
                "z_test"       : torch.zeros([len(y_test),1]),
                "outer_lr"     : 10e-6,
                "inner_lr"     : 0.01,
                "outer_epochs" : 200,
                "inner_epochs" : 200,
                "n_samples"    : 10,
                "prior"        : m  
            }

In [6]:
c_d_train = params["prior"].sample(torch.Size([params["n_samples"], len(y_train)]))

In [17]:
def train_nash_rr_test(X_clean, y, c_d_train, params, verbose = False):
    lr = params["outer_lr"]
    T = params["outer_epochs"]
    n_samples = params["n_samples"]
    w = torch.randn(1, X_clean.shape[1] + 1, requires_grad=True)
    
    #X = torch.randn(X_clean.shape[0]*X_clean.shape[1], requires_grad=True)
    #fm = lambda w, X: learner_cost_flatten(X, y, w, params)
    
    for i in range(T):
        grad = torch.zeros(1, X_clean.shape[1] + 1)
        for j in range(n_samples):
            c_d = c_d_train[j].t()[0]
            grad += compute_backwar_derivative(X_clean, y, w, c_d, params)
 

        ####     
        grad /= n_samples
        w = w - lr*grad
        if verbose:
            if i%1 == 0:
                print( 'epoch {}'.format(i) )
    return w

In [18]:
train_nash_rr_test(X_train, y_train, c_d_train, params, verbose = True)

epoch 0


KeyboardInterrupt: 