In [11]:
# https://github.com/llSourcell/pytorch_in_5_minutes/blob/master/demo.py

import torch
from torch.autograd import Variable
import pandas as pd
from random import randint
from ast import literal_eval
import numpy as np
import torch.nn as nn
from torchviz import make_dot

In [20]:
from torch.autograd import Function
from torch.nn.modules.distance import PairwiseDistance

class TripletLoss(Function):
    
    def __init__(self, alpha):
        super(TripletLoss, self).__init__()
        self.alpha = alpha
        self.pdist  = PairwiseDistance(2)
        
    def forward(self, anchor, positive, negative):
        pos_dist   = self.pdist.forward(anchor, positive).pow(2)
        neg_dist   = self.pdist.forward(anchor, negative).pow(2)
        hinge_dist = torch.clamp(self.alpha + pos_dist - neg_dist, min = 0.0)
        loss       = torch.mean(hinge_dist)
        return loss

In [21]:
%%time
#read in relevant data
trainData = torch.from_numpy(np.loadtxt('data/trainData.txt', dtype=np.float32))
queryData = torch.from_numpy(np.loadtxt('data/queryData.txt', dtype=np.float32))
df =  pd.read_pickle("./data/KNN.pkl")

CPU times: user 26 s, sys: 30.7 s, total: 56.7 s
Wall time: 1min


In [22]:
def generateTripplet(index):
    point = queryData[index].reshape(-1, 1)
#     pos = trainData[df.iloc[index].KNN[randint(0,K)]].reshape(-1, 1) # pos fom KNN
#     negIndicies = list(range(K,K + 10)) + list(range(df.shape[0]-20, df.shape[0]-1))
#     neg = trainData[df.iloc[index].KNN[np.random.choice(negIndicies)]].reshape(-1, 1)
#     neg = trainData[df.iloc[index].KNN[randint(K, df.shape[0]-1)]].reshape(-1, 1)
    pos = trainData[df.iloc[index].KNN[0]].reshape(-1, 1) # pos fom KNN
    neg = trainData[df.iloc[index].KNN[5004]].reshape(-1, 1)
    return point, pos, neg

In [23]:
sigmoid = nn.Sigmoid()
def forward_pass(query):
    return sigmoid(torch.norm(query.t() - anchors, 2, 1).reshape(-1, 1) - biases)

In [60]:
# BATCH_SIZE is batch size; INPUT_D is input dimension; OUTPUT_D is output dimension; 
BATCH_SIZE, INPUT_D, HIDDEN_D, OUTPUT_D = 10, 192, 128, 128
ALPHA = 0.5
LEARNING_RATE = 1e2
K = 5

def init_model():
    print("--- Initialising Model Params --- ")
    
    anchors = Variable(torch.randn(OUTPUT_D, INPUT_D).type(torch.FloatTensor), requires_grad=True)
    # weights = Variable(torch.randn(HIDDEN_D, OUTPUT_D).type(torch.FloatTensor), requires_grad=True)

    # set biases to be mean value
    aggregate = torch.zeros(OUTPUT_D)
    for point in queryData:
        w0 = torch.norm(point.t() - anchors, 2, 1)
        aggregate += w0

    biases = Variable((aggregate/queryData.shape[0]).reshape(-1, 1), requires_grad=True)
    print("--- Done. Begining training ---")
    return anchors, biases

In [61]:
#training
anchors, biases = init_model()

for epoch in range(100):

    #generate batch and compute collective loss for batch
    # UNSTABLE LEARNING WHEN SAMPLE SIZE > BATCH SIZE
    batch_indicies = np.random.choice(queryData.shape[0], BATCH_SIZE, replace=False) 
    loss = 0
    for index in batch_indicies:
        query, pos, neg = generateTripplet(index)
        queryMapped, posMapped, negMapped = [forward_pass(x) for x in [query, pos, neg]]
        triplet_loss = TripletLoss(ALPHA).forward(queryMapped, posMapped, negMapped)
        loss += triplet_loss
    
    loss /= BATCH_SIZE # computes mean so learning rate remains the same
    
    print(epoch, loss.data)

        
    loss.backward()

    # Update params using gradient descent
    biases.data -= LEARNING_RATE * biases.grad.data
    anchors.data -= LEARNING_RATE * anchors.grad.data

    # Manually zero the gradients 
    biases.grad.data.zero_()
    anchors.grad.data.zero_()

--- Initialising Model Params --- 
--- Done. Begining training ---
0 tensor(0.2355)
1 tensor(0.2663)
2 tensor(0.3368)
3 tensor(0.3510)
4 tensor(0.4320)
5 tensor(0.4403)
6 tensor(0.3730)
7 tensor(0.4323)
8 tensor(0.3196)
9 tensor(0.3957)
10 tensor(0.2928)
11 tensor(0.3656)
12 tensor(0.3768)
13 tensor(0.3730)
14 tensor(0.3298)
15 tensor(0.4228)
16 tensor(0.3350)
17 tensor(0.3694)
18 tensor(0.2834)
19 tensor(0.3947)
20 tensor(0.3734)
21 tensor(0.3021)
22 tensor(0.4551)
23 tensor(0.4408)
24 tensor(0.3470)
25 tensor(0.4468)
26 tensor(0.3623)
27 tensor(0.3312)
28 tensor(0.3105)
29 tensor(0.3379)
30 tensor(0.3874)
31 tensor(0.3624)
32 tensor(0.3141)
33 tensor(0.3365)
34 tensor(0.2891)
35 tensor(0.3846)
36 tensor(0.2998)
37 tensor(0.3924)
38 tensor(0.2779)
39 tensor(0.3754)
40 tensor(0.3768)
41 tensor(0.3728)
42 tensor(0.3337)
43 tensor(0.2802)
44 tensor(0.3923)
45 tensor(0.2336)
46 tensor(0.4020)
47 tensor(0.3124)
48 tensor(0.2705)
49 tensor(0.4131)
50 tensor(0.3653)
51 tensor(0.3607)
52 tens