In [24]:
# https://github.com/llSourcell/pytorch_in_5_minutes/blob/master/demo.py

import torch
from torch.autograd import Variable
import pandas as pd
from random import randint
from ast import literal_eval
import numpy as np
import torch.nn as nn
from torchviz import make_dot

In [279]:
from torch.autograd import Function
from torch.nn.modules.distance import PairwiseDistance

class TripletLoss(Function):
    
    def __init__(self, alpha):
        super(TripletLoss, self).__init__()
        self.alpha = alpha
        self.pdist  = PairwiseDistance(2)
        
    def forward(self, anchor, positive, negative):
        pos_dist   = self.pdist.forward(anchor, positive).pow(2)
        neg_dist   = self.pdist.forward(anchor, negative).pow(2)
        hinge_dist = torch.clamp(self.alpha + pos_dist - neg_dist, min = 0.0)
        loss       = torch.mean(hinge_dist)
        return loss

In [15]:
%%time
#read in relevant data
trainData = torch.from_numpy(np.loadtxt('data/trainData.txt', dtype=np.float32))
queryData = torch.from_numpy(np.loadtxt('data/queryData.txt', dtype=np.float32))
df =  pd.read_pickle("./data/KNN.pkl")

CPU times: user 17.3 s, sys: 8.3 s, total: 25.6 s
Wall time: 26.7 s


In [300]:
# BATCH_SIZE is batch size; INPUT_D is input dimension; OUTPUT_D is output dimension; 
BATCH_SIZE, INPUT_D, HIDDEN_D, OUTPUT_D = 100, 192, 128, 128
ALPHA = 0.5
LEARNING_RATE = 1
K = 5

def init_model():
    print("--- Initialising Model Params --- ")
    # Create random Tensor for trainable features, and wrap them in Variables.
    # requires_grad=True indicates that we want to compute gradients wrt these Variables during the backward pass.
    anchors = Variable(torch.randn(OUTPUT_D, INPUT_D).type(torch.FloatTensor), requires_grad=True)
    # weights = Variable(torch.randn(HIDDEN_D, OUTPUT_D).type(torch.FloatTensor), requires_grad=True)

    # set b0 to be mean value
    aggregate = torch.zeros(OUTPUT_D)
    for point in queryData:
        w0 = torch.norm(point.t() - anchors, 2, 1)
        aggregate += w0

    biases = Variable((aggregate/queryData.shape[0]).reshape(-1, 1), requires_grad=True)
    print("--- Done. Begining training ---")
    return anchors, biases

In [302]:
def generateTripplet(index):
    point = queryData[index].reshape(-1, 1)
    pos = trainData[df.iloc[index].KNN[randint(0,K)]].reshape(-1, 1)
    neg = trainData[df.iloc[index].KNN[randint(K, len(df)-1)]].reshape(-1, 1)
    return point, pos, neg

In [288]:
sigmoid = nn.Sigmoid()
def forward_pass(query):
    return sigmoid(torch.norm(query.t() - anchors, 2, 1).reshape(-1, 1) - biases)

In [301]:
#training
anchors, biases = init_model()

for epoch in range(100):
    
    #generate batch and compute collective loss for batch
    batch_indicies = np.random.choice(queryData.shape[0], BATCH_SIZE, replace=False)
    loss = 0
    for index in batch_indicies:
        query, pos, neg = generateTripplet(1)
        queryMapped, posMapped, negMapped = [forward_pass(x) for x in [query, pos, neg]]
        triplet_loss = TripletLoss(ALPHA).forward(queryMapped, posMapped, negMapped)
        loss += triplet_loss
    
    print(epoch, loss.data)

        
    loss.backward()

    # Update params using gradient descent
    biases.data -= LEARNING_RATE * biases.grad.data
    anchors.data -= LEARNING_RATE * anchors.grad.data

    # Manually zero the gradients 
    biases.grad.data.zero_()
    anchors.grad.data.zero_()

0 tensor(41.4934)
1 tensor(46.4123)
2 tensor(41.4779)
3 tensor(46.9142)
4 tensor(36.1964)
5 tensor(48.4893)
6 tensor(44.7107)
7 tensor(43.6143)
8 tensor(39.9611)
9 tensor(40.5532)
10 tensor(37.3945)
11 tensor(38.8403)
12 tensor(44.3418)
13 tensor(35.8196)
14 tensor(43.3604)
15 tensor(44.6226)
16 tensor(41.1356)
17 tensor(44.8081)
18 tensor(43.1640)
19 tensor(48.7209)
20 tensor(40.4168)
21 tensor(37.3136)
22 tensor(42.7325)
23 tensor(42.7834)
24 tensor(42.8964)
25 tensor(43.2495)
26 tensor(40.1771)
27 tensor(45.3841)
28 tensor(42.8846)
29 tensor(40.0212)
30 tensor(40.4052)
31 tensor(38.5106)
32 tensor(44.9346)
33 tensor(35.2529)
34 tensor(43.5309)
35 tensor(42.1153)
36 tensor(44.6548)
37 tensor(35.9885)
38 tensor(41.5562)
39 tensor(41.4608)
40 tensor(39.7521)
41 tensor(39.3504)
42 tensor(40.2714)
43 tensor(42.0676)
44 tensor(33.2751)
45 tensor(36.6382)
46 tensor(35.3337)
47 tensor(40.9315)
48 tensor(32.6623)
49 tensor(40.0415)
50 tensor(37.0852)
51 tensor(42.2910)
52 tensor(35.8593)
53 

In [286]:
forward_pass(query) - forward_pass(neg)

tensor([[-9.7731e-01],
        [-9.8997e-01],
        [-9.9090e-01],
        [-9.8967e-01],
        [-9.9262e-01],
        [-8.8068e-01],
        [-9.9247e-01],
        [-8.0227e-01],
        [-9.4777e-01],
        [-9.8801e-01],
        [-8.6084e-01],
        [-9.9064e-01],
        [-8.1740e-01],
        [-9.7572e-01],
        [-8.5580e-01],
        [-9.6548e-01],
        [-9.8604e-01],
        [-9.8229e-01],
        [-7.0966e-01],
        [-9.8856e-01],
        [-8.3321e-01],
        [-7.5851e-01],
        [-9.9034e-01],
        [-8.9640e-01],
        [-7.1946e-01],
        [-9.7243e-01],
        [-9.5354e-01],
        [-9.9274e-01],
        [ 9.6548e-01],
        [-9.7705e-01],
        [-9.9076e-01],
        [-7.7688e-01],
        [-9.8993e-01],
        [-9.8569e-01],
        [-9.1570e-01],
        [-1.3076e-03],
        [-9.0260e-01],
        [-9.0607e-01],
        [-9.8849e-01],
        [ 9.9072e-05],
        [-5.2665e-06],
        [-9.7986e-01],
        [-9.8623e-01],
        [-9