In [1]:
%load_ext autoreload
%autoreload 2
import torch
from networks import AnchorNet, EmbeddingNet, TripletNet
from datasets import TripletAudio
import numpy as np
import pandas as pd

#### checks the KNN's rankings of datapoints is correct. i.e. anchor-i dist < anchor-j dist for i < j

In [2]:
def dist(x, y):
    return (x-y).pow(2).sum()

In [3]:
train_data = torch.from_numpy(np.loadtxt('data/trainData.txt', dtype=np.float32))
train_KNN = pd.read_csv('data/trainKNN.csv', index_col=0)
# check num elems are same in DF and dataset
assert(train_data.shape[0] == train_KNN.shape[0])
# #randomly pick 10 rows to check
row_indicies = list(range(0,train_KNN.shape[0]))
selected_indicies = np.random.choice(row_indicies, size=10, replace=False)
# ensure that the ordering of neighbours is correct for that row
for index in selected_indicies:
    row = train_KNN.iloc[index]
    anchor = train_data[index]
    for prev, cur in zip(row, row[1:]):
        prev_dist = dist(anchor, train_data[prev])
        cur_dist = dist(anchor, train_data[cur])
        assert(prev_dist < cur_dist)
print("All Good!")

All Good!


#### generates triplets and ensures that the anchor-pos distance < anchor-neg distance

In [4]:
from datasets import TripletAudio
dataset = TripletAudio(False, 5, 5, 5)
for i, triplet in enumerate(dataset):
    anchor, pos, neg = [x.reshape(-1,1) for x in triplet]
    assert( ((anchor-pos)**2).sum() - ((anchor-neg)**2).sum() < 0 )
print("All Good!")

All Good!


#### Runs two checks for the embedding process:

In [5]:
#test 1
train_data = torch.tensor([[2.],[3.]])
anchor = AnchorNet(train_data, 2, 1, True)
emb = EmbeddingNet(anchor)
expected_output = torch.tensor(5**(1/2) - 1)
output = anchor.forward(train_data)[0][0]
assert(torch.equal(output, expected_output))
#test 2
train_data = torch.tensor([[12.],[8.]])
anchor = AnchorNet(train_data, 2, 1, True)
emb = EmbeddingNet(anchor)
expected_output = torch.tensor(170**(1/2) - 1)
output = anchor.forward(train_data)[0][0]
assert(torch.equal(output, expected_output))
print("All Good!")

All Good!
