### Reading in data from pkl files

In [25]:
import pandas as pd
import numpy as np

x_split1 = pd.read_pickle('data/dbpedia_train_x_split1.pkl')
y_split1 = pd.read_pickle('data/dbpedia_train_y_split1.pkl')

In [26]:
x_split1

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,290,291,292,293,294,295,296,297,298,299
38018,-0.005324,0.205206,0.008683,-0.309079,-0.109366,-0.258832,0.276842,0.086448,0.055326,0.220757,...,-0.169310,0.138412,-0.108493,-0.044117,0.055317,0.187381,-0.098321,-0.217412,-0.220616,0.182205
10389,0.255242,0.188798,-0.135879,-0.207213,-0.214796,-0.272929,0.200540,0.154624,-0.109058,0.141196,...,-0.181943,-0.035217,-0.184698,0.020663,0.115205,0.156353,-0.224457,-0.114330,-0.198991,-0.007543
37016,0.113450,-0.009364,-0.131690,-0.256048,-0.153683,-0.168570,0.108866,0.078384,-0.041223,0.151630,...,-0.084727,0.006616,-0.070600,-0.026646,0.144471,0.141350,-0.141484,-0.089492,-0.131767,0.109087
12221,0.104002,0.051275,-0.143594,-0.222141,-0.112486,-0.314900,0.167336,0.121801,0.031601,0.242447,...,-0.085333,0.088219,-0.251650,0.030851,0.090480,0.315288,-0.136301,-0.181215,-0.329499,0.237604
6690,0.156203,0.004777,-0.010680,-0.233068,-0.241265,-0.173814,0.221157,0.192151,-0.029584,0.185003,...,-0.111471,0.133750,-0.208221,0.016590,0.082757,0.259925,-0.131668,-0.198523,-0.185508,0.178424
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
521492,0.321894,0.364019,-0.028605,-0.211936,0.001405,-0.210255,0.225507,0.304995,-0.010479,0.030767,...,0.055579,0.000221,-0.190833,0.020330,0.090201,0.270010,-0.156809,-0.136904,-0.194771,0.115096
536501,0.202290,-0.006832,0.077930,-0.336992,-0.316965,-0.079169,0.241228,0.240695,0.034314,0.152834,...,-0.028027,0.117409,-0.205209,0.070434,0.028975,0.287230,-0.089631,-0.252135,-0.178748,0.224248
551031,0.167308,0.015951,-0.031282,-0.265090,-0.227614,-0.096121,0.032917,0.175551,0.083510,0.168240,...,-0.084468,0.134981,-0.147401,0.013449,0.119402,0.338470,-0.111547,-0.412538,-0.243889,0.096841
530237,0.238980,0.103001,-0.042302,-0.272961,-0.263882,-0.077211,0.211424,0.158734,0.089401,0.133178,...,-0.069092,0.073285,-0.195081,0.074679,0.113065,0.228045,-0.052555,-0.251632,-0.215212,0.143994


In [27]:
y_split1

Unnamed: 0,class
38018,1
10389,1
37016,1
12221,1
6690,1
...,...
521492,14
536501,14
551031,14
530237,14


In [28]:
# A torch implementation of VAT.
import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

@contextlib.contextmanager
def _disable_tracking_bn_stats(model):
    def switch_attr(m):
        if hasattr(m, 'track_running_stats'):
            m.track_running_stats ^= True
    model.apply(switch_attr)
    yield
    model.apply(switch_attr)
    
def normalize(d):
    d /= (torch.sqrt(torch.sum(d**2, axis=1)).view(-1,1)+1e-16)
    return d

def _l2_normalize(d):
    d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2)))
    d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-16
    return d

def _kl_div(p,q):
    '''
    D_KL(p||q) = Sum(p log p - p log q)
    '''
    logp = torch.nn.functional.log_softmax(p,dim=1)
    logq = torch.nn.functional.log_softmax(q,dim=1)
    p = torch.exp(logp)
    return (p*(logp-logq)).sum(dim=1,keepdim=True).mean()

class VATLoss(nn.Module):
    def __init__(self, xi = .0001, eps = 0.1, ip = 2):
        """
        :xi: hyperparameter: small float for finite difference threshold 
        :eps: hyperparameter: value for how much to deviate from original X.
        :ip: value of power iteration for approximation of r_vadv.
        """
        super(VATLoss, self).__init__()
        self.xi = xi
        self.eps = eps
        self.ip = ip
        
    def forward(self, model, x):
        with torch.no_grad():
            pred = model(x)
        
        # random unit tensor for perturbation
        d = torch.randn(x.shape)
        d = _l2_normalize(d)
        #calculating adversarial direction

        for _ in range(self.ip):
            d.requires_grad_()
            pred_hat = model(x + self.xi * d)
            adv_distance = _kl_div(pred_hat, pred)
            adv_distance.backward()
            d = _l2_normalize(d.grad.data)
            model.zero_grad()
        
        r_adv = d*self.eps
        pred_hat = model(x+r_adv)
        lds = _kl_div(pred_hat, pred)
        return lds

In [29]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.linear1 = nn.Linear(300,100)
        self.linear2 = nn.Linear(100,100)
        self.linear3 = nn.Linear(100,100)
        self.linear4 = nn.Linear(100,2)
        
    def forward(self,x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = self.linear4(x)
        return x

In [30]:
def train(u):
    vat_loss=VATLoss(ip=2,xi=0.5,eps=0.5)
    cross_entropy = nn.CrossEntropyLoss()
    lds = vat_loss(model,u.float())
    output = model(torch.tensor(Xl_new).float())
    classification_loss = cross_entropy(output, torch.tensor(yl_new))
    loss = classification_loss + 4.*lds
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return classification_loss, lds

def valid(win=None):
    model.eval()
    z_pred = torch.nn.functional.softmax(model(torch.from_numpy(X_all).float()),dim=1).data.numpy()[:,1]
    val = np.c_[X_all,z_pred.T]
    
    fig = plt.figure(figsize=[10,10])
    ax = fig.add_subplot(111)
    h = ax.scatter(X_all[:,0],X_all[:,1], c=z_pred, vmin=0, vmax=1,cmap='seismic')
    ax.scatter(X0l[:,0],X0l[:,1],c='C0',marker='s',s=100)
    ax.scatter(X1l[:,0],X1l[:,1],c='C1',marker='v',s=100)
    fig.colorbar(h)
    return win,val

In [None]:
i_total_step = 0
model = Net()
optimizer = optim.Adam(model.parameters(), lr=1e-3, eps=1e-3)
gamma = 4e1

var = torch.autograd.Variable
ftn = torch.FloatTensor
ltn = torch.LongTensor
dataset = torch.utils.data.TensorDataset(ftn(X_all), ltn(y_all))


for i in range(150):#epoch
    data_loader = torch.utils.data.DataLoader(dataset, 121, shuffle=True) #batch size 
    for u, _ in data_loader:
        i_total_step += 1
        vat_loss = VATLoss()
        cross_entropy = nn.CrossEntropyLoss()
        lds = vat_loss(model, torch.tensor(u).float())
        output = model(torch.tensor(Xl).float())
        classification_loss = cross_entropy(output, torch.tensor(yl))
        loss = classification_loss + gamma * lds
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        accuracy = accuracy_score(yl,np.argmax(output.data.numpy(),axis=1))
        ce_losses = classification_loss.item()
        vat_losses = gamma*lds.item()
        if i_total_step % 200 == 0:
            print("CrossEntropyLoss %f:" % (ce_losses))
            print("VATLoss %f:" % (vat_losses))
            print("Accuracy %f:" % (accuracy)) 
            print("---------------------------------")

In [None]:
z_pred = torch.nn.functional.softmax(model(var(ftn(X_all))), dim=1).data.numpy()[:, 1]
z_pred = np.where(z_pred > 0.5, 1, 0)
print("{0} vat acc".format(accuracy_score(y_all, z_pred)))