In [24]:
# https://github.com/llSourcell/pytorch_in_5_minutes/blob/master/demo.py

import torch
from torch.autograd import Variable
import pandas as pd
from random import randint
from ast import literal_eval
import numpy as np
import torch.nn as nn
from torchviz import make_dot

In [13]:
from torch.autograd import Function
from torch.nn.modules.distance import PairwiseDistance

class TripletLoss(Function):
    
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin
        self.pdist  = PairwiseDistance(2)
        
    def forward(self, anchor, positive, negative):
        pos_dist   = self.pdist.forward(anchor, positive).pow(2)
        neg_dist   = self.pdist.forward(anchor, negative).pow(2)
        hinge_dist = torch.clamp(self.margin + pos_dist - neg_dist, min = 0.0)
        loss       = torch.mean(hinge_dist)
        return loss

In [15]:
%%time
#read in relevant data
trainData = torch.from_numpy(np.loadtxt('data/trainData.txt', dtype=np.float32))
queryData = torch.from_numpy(np.loadtxt('data/queryData.txt', dtype=np.float32))
df =  pd.read_pickle("./data/KNN.pkl")

CPU times: user 17.3 s, sys: 8.3 s, total: 25.6 s
Wall time: 26.7 s


In [238]:
# BATCH_SIZE is batch size; INPUT_D is input dimension; OUTPUT_D is output dimension; 
BATCH_SIZE, INPUT_D, HIDDEN_D, OUTPUT_D = 1000, 192, 128, 128
ALPHA = 0.5
LEARNING_RATE = 1e-3
K = 5

# Create random Tensor for trainable features, and wrap them in Variables.
# requires_grad=True indicates that we want to compute gradients wrt these Variables during the backward pass.
anchors = Variable(torch.randn(OUTPUT_D, INPUT_D).type(torch.FloatTensor), requires_grad=True)
# weights = Variable(torch.randn(HIDDEN_D, OUTPUT_D).type(torch.FloatTensor), requires_grad=True)

# set b0 to be mean value
aggregate = torch.zeros(OUTPUT_D)
for point in queryData:
    w0 = torch.norm(point.t() - anchors, 2, 1)
    aggregate += w0
    
biases = Variable((aggregate/queryData.shape[0]).reshape(-1, 1), requires_grad=True)

In [17]:
def generateTripplet(index):
    point = queryData[index].reshape(-1, 1)
    pos = trainData[df.iloc[index].KNN[randint(0,K)]].reshape(-1, 1)
    neg = trainData[df.iloc[index].KNN[randint(K, len(df)-1)]].reshape(-1, 1)
    return point, pos, neg

In [217]:
sigmoid = nn.Sigmoid()
def forward_pass(query):
    return sigmoid(torch.norm(query.t() - anchors, 2, 1).reshape(-1, 1) - biases)

In [239]:
#training
for epoch in range(100):
    
    #generate batch and compute collective loss for batch
    batch_indicies = np.random.choice(queryData.shape[0], BATCH_SIZE, replace=False)
    loss = 0
    for index in batch_indicies:
        query, pos, neg = generateTripplet(index)
        queryMapped, posMapped, negMapped = [forward_pass(x) for x in [query, pos, neg]]
        triplet_loss = TripletLoss(ALPHA).forward(queryMapped, posMapped, negMapped)
        loss += triplet_loss
    
    print(epoch, loss.data)

    # Use autograd to compute the backward pass. This call will compute the
    # gradient of loss with respect to all Variables with requires_grad=True.
    # After this we can call var.grad on variables
    loss.backward()

    # Update weights using gradient descent
    biases.data -= LEARNING_RATE * biases.grad.data
    anchors.data -= LEARNING_RATE * anchors.grad.data

    # Manually zero the gradients 
    biases.grad.data.zero_()
    anchors.grad.data.zero_()

0 tensor(507.2083)
1 tensor(521.0921)
2 tensor(520.4835)
3 tensor(517.2171)
4 tensor(511.0168)
5 tensor(510.6352)
6 tensor(533.2737)
7 tensor(522.0214)
8 tensor(519.0345)
9 tensor(527.9480)
10 tensor(528.4173)
11 tensor(506.4541)
12 tensor(513.2556)
13 tensor(515.3956)
14 tensor(525.0693)
15 tensor(510.0119)
16 tensor(525.0049)
17 tensor(518.9528)
18 tensor(512.3215)
19 tensor(518.0714)
20 tensor(518.8181)
21 tensor(512.0618)
22 tensor(514.3668)
23 tensor(508.6679)
24 tensor(518.6509)
25 tensor(517.7710)
26 tensor(517.2629)
27 tensor(516.8866)
28 tensor(506.8255)
29 tensor(519.1831)
30 tensor(530.7878)
31 tensor(518.4875)
32 tensor(516.9802)
33 tensor(519.2198)
34 tensor(510.8516)
35 tensor(522.6309)
36 tensor(530.8637)
37 tensor(510.6067)
38 tensor(519.1215)
39 tensor(522.8857)
40 tensor(521.7526)
41 tensor(526.3864)
42 tensor(505.8412)
43 tensor(522.1140)
44 tensor(515.7787)
45 tensor(518.8862)
46 tensor(526.0958)
47 tensor(519.3914)
48 tensor(503.7811)
49 tensor(504.5177)
50 tensor(