In [1]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.insert(1, '/Users/derekhuang/Documents/Research/sequence_similarity_search/classes')
sys.path.insert(1, '/Users/derekhuang/Documents/Research/fast-soft-sort/fast_soft_sort')
from data_classes import BERTDataset
from dist_perm import DistPerm
import utils
from pytorch_ops import soft_rank
from sklearn.neighbors import NearestNeighbors
import numpy as np

class AnchorNet(nn.Module):
    def __init__(self, num_anchs, d, k):
        super(AnchorNet, self).__init__()
        self.anchors = nn.Linear(d, num_anchs)
        self.softmax = nn.Softmax(dim=1)
        self.k = k

    def forward(self, data, query):
        data_rank = torch.clamp(soft_rank(self.anchors(data), direction="DESCENDING", regularization_strength=.001), max=k)
#         data_rank = soft_rank(self.anchors(data), direction="DESCENDING")
        
        query_rank = torch.clamp(soft_rank(self.anchors(query), direction="DESCENDING", regularization_strength=.001), max=k)
#         query_rank = soft_rank(self.anchors(query), direction="DESCENDING")
        
        out = self.softmax(torch.matmul(query_rank, data_rank.T))
        return out



In [2]:
n = 20000
D = 128
num_queries = 3200
num_anchors = 128
R = 100
k = 32

In [3]:
file_q = '../datasets/Q.pt'
file_db = '../datasets/D.pt'
data_source = BERTDataset(file_q, file_db, n)
db = data_source.generate_db()
data = np.array(db).astype(np.float32)
queries = data_source.generate_queries(num_queries)
quers = np.array(queries).astype(np.float32)

torch.Size([685571, 128])
torch.Size([100, 32, 128])


In [4]:
index_l2 = NearestNeighbors()
index_l2.fit(data)
true = torch.tensor(index_l2.kneighbors(quers, n_neighbors=1)[1])
truth = torch.nn.functional.one_hot(true.squeeze(), n)

In [5]:
import torch, itertools
from torch.utils.data import TensorDataset, random_split, DataLoader

def dataset_split(dataset, train_frac):
    length = len(dataset)
   
    train_length = int(length * train_frac)
    valid_length = int((length - train_length) / 2)
    test_length  = length - train_length - valid_length

    sets = random_split(dataset, (train_length, valid_length, test_length))
    dataset = {name: set for name, set in zip(('train', 'valid', 'test'), sets)}
    return dataset

In [6]:
# Fits fine in mem
batch_size=3200
train_split = .8 

query_data = []
for i in range(queries.shape[0]):
    query_data.append([queries[i], true[i]])

query_datasets = dataset_split(query_data, train_split)
query_loader = torch.utils.data.DataLoader(dataset=query_datasets['train'], 
                                           batch_size=batch_size, 
                                           shuffle=True)


test_loader = torch.utils.data.DataLoader(dataset=query_datasets['test'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

model = AnchorNet(num_anchors, D, k)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001)  


In [None]:
for epoch in range(1000):
    for i, (q, l) in enumerate(query_loader):  
        outputs = model(db,queries)
        loss = criterion(outputs, true.squeeze())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#     print('Loss: {:.4f}'.format(loss.item()))
    if epoch % 1 == 0:
        with torch.no_grad():
            correct = 0
            total = 0
            for q, l in test_loader:
                outputs = model(db,q)
                _, predicted = torch.max(outputs.data, 1)
#                 print(predicted.shape)
                total += l.size(0)
#                 print(l.size(0))
                correct += (predicted == l.flatten()).sum().item()

            print ('Epoch [{}], Loss: {:.4f}, recall_test: {:.4f}'
                   .format(epoch+1, loss.item(), correct / total))

# # Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

Epoch [1], Loss: 9.4589, recall_test: 0.4656
Epoch [2], Loss: 9.4561, recall_test: 0.4656
Epoch [3], Loss: 9.4539, recall_test: 0.4906
Epoch [4], Loss: 9.4478, recall_test: 0.4844
Epoch [5], Loss: 9.4424, recall_test: 0.4844
Epoch [6], Loss: 9.4368, recall_test: 0.4906
Epoch [7], Loss: 9.4317, recall_test: 0.4906
Epoch [8], Loss: 9.4301, recall_test: 0.5000
Epoch [9], Loss: 9.4298, recall_test: 0.5031
Epoch [10], Loss: 9.4259, recall_test: 0.5062
Epoch [11], Loss: 9.4219, recall_test: 0.5188
Epoch [12], Loss: 9.4209, recall_test: 0.5281
Epoch [13], Loss: 9.4141, recall_test: 0.5375
Epoch [14], Loss: 9.4087, recall_test: 0.5375
Epoch [15], Loss: 9.4069, recall_test: 0.5344
Epoch [16], Loss: 9.4076, recall_test: 0.5250
Epoch [17], Loss: 9.4085, recall_test: 0.5281
Epoch [18], Loss: 9.4002, recall_test: 0.5344
Epoch [19], Loss: 9.3956, recall_test: 0.5344
Epoch [20], Loss: 9.3935, recall_test: 0.5375
Epoch [21], Loss: 9.3879, recall_test: 0.5312
Epoch [22], Loss: 9.3884, recall_test: 0.53