In [94]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.insert(1, '/Users/derekhuang/Documents/Research/sequence_similarity_search/classes')
sys.path.insert(1, '/Users/derekhuang/Documents/Research/fast-soft-sort/fast_soft_sort')
from data_classes import BERTDataset
from dist_perm import DistPerm
import utils
from pytorch_ops import soft_rank
from sklearn.neighbors import NearestNeighbors
import numpy as np

class AnchorNet(nn.Module):
    def __init__(self, num_anchs, d, k):
        super(AnchorNet, self).__init__()
        self.anchors = nn.Linear(d, num_anchs)
        self.softmax = nn.Softmax(dim=1)
        self.k = k

    def forward(self, data, query):
        data_rank = torch.clamp(soft_rank(self.anchors(data), direction="DESCENDING", regularization_strength=.001), max=k)
        query_rank = torch.clamp(soft_rank(self.anchors(query), direction="DESCENDING", regularization_strength=.001), max=k)
        out = self.softmax(torch.matmul(query_rank, data_rank.T))
        return out



In [101]:
n = 100000
D = 128
num_queries = 3200
num_anchors = 128
R = 100
k = 8

In [102]:
file_q = '../datasets/Q.pt'
file_db = '../datasets/D.pt'
data_source = BERTDataset(file_q, file_db, n)
db = data_source.generate_db()
data = np.array(db).astype(np.float32)
queries = data_source.generate_queries(num_queries)
quers = np.array(queries).astype(np.float32)

torch.Size([685571, 128])
torch.Size([100, 32, 128])


In [103]:
import torch, itertools
from torch.utils.data import TensorDataset, random_split, DataLoader
from sklearn.model_selection import train_test_split

def dataset_split(dataset, train_frac):
    length = len(dataset)
   
    train_length = int(length * train_frac)
    valid_length = int((length - train_length) / 2)
    test_length  = length - train_length - valid_length

    sets = random_split(dataset, (train_length, valid_length, test_length))
    dataset = {name: set for name, set in zip(('train', 'valid', 'test'), sets)}
    return dataset


In [104]:
# Fits fine in mem
batch_size=3200
train_split = .8 

query_datasets = dataset_split(quers, train_split)
doc_datasets = dataset_split(db, train_split)

# The fixed train queries
query_data_loader = torch.utils.data.DataLoader(dataset=query_datasets['train'], 
                                           batch_size=batch_size, 
                                           shuffle=False)

# The fixed test queries 
query_data_test_loader = torch.utils.data.DataLoader(dataset=query_datasets['test'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

# The train docs
docs_loader = torch.utils.data.DataLoader(dataset=doc_datasets['train'], 
                                           batch_size=10000, 
                                           shuffle=True)

# The test docs
docs_test_loader = torch.utils.data.DataLoader(dataset=doc_datasets['test'], 
                                           batch_size=10000, 
                                           shuffle=False)

model = AnchorNet(num_anchors, D, k)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.001)  


In [105]:
# Generate the nearest neighbors for each batch of docs 
# d = docs
# ret = train/test
def return_loader(d, ret):
    index_l2 = NearestNeighbors()
    index_l2.fit(d)
    q = None
#   Get the right data
    if ret=='query':
        q = next(iter(query_data_loader))
    elif ret=='test':
        q = next(iter(query_data_test_loader))
#   Get the true nearest neighbors
    true = torch.tensor(index_l2.kneighbors(q, n_neighbors=1)[1])
    query_data = []
    for i in range(q.shape[0]):
        query_data.append([q[i], true[i]])
    test_set = dataset_split(query_data, 1)
#   Return data as a dataloader
    data = torch.utils.data.DataLoader(dataset=test_set['train'], 
                                           batch_size=batch_size, 
                                           shuffle=False)
    return data

In [106]:
with torch.no_grad():
    correct = 0
    total = 0
    for d in docs_test_loader:
        test_loader = return_loader(d, ret='test')
        for q, l in test_loader:
#             print(q)
            outputs = model(d,q)
            test_loss = criterion(outputs, l.squeeze())
#           Recall: TP / (TP + TN)
            _, predicted = torch.max(outputs.data, 1)
            total += l.size(0)
            correct += (predicted == l.flatten()).sum().item()

        print ('Test Loss: {:.4f}, recall_test: {:.4f}'
               .format(test_loss.item(), correct / total))

Test Loss: 8.9974, recall_test: 0.2156


In [55]:
for epoch in range(1000):
#   Train step: get batch of train documents, then generate the dataloader containing 
#   the train queries and the correct data labels
    for step, d in enumerate(docs_loader):
        query_loader = return_loader(d, ret='query')
        for i, (q, l) in enumerate(query_loader):  
            outputs = model(d,q)
            loss = criterion(outputs, l.squeeze())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Epoch [{}] Step [{}] Loss: {:.4f}'.format(epoch, step, loss.item()))
        
#   Eval step: repeat but with the test documents and test queries      
    if epoch % 1 == 0:
        with torch.no_grad():
            correct = 0
            total = 0
            for d in docs_test_loader:
                test_loader = return_loader(d, ret='test')
                for q, l in test_loader:
                    outputs = model(d,q)
                    test_loss = criterion(outputs, l.squeeze())
#                   Recall: TP / (TP + TN)
                    _, predicted = torch.max(outputs.data, 1)
                    total += l.size(0)
                    correct += (predicted == l.flatten()).sum().item()

                print ('Epoch [{}], Loss: {:.4f}, Test Loss: {:.4f}, recall_test: {:.4f}'
                       .format(epoch+1, loss.item(), test_loss.item(), correct / total))

# # Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

tensor([[-1.1450e-01,  4.7607e-02,  7.6660e-02,  ..., -1.4648e-03,
          6.3293e-02,  6.4758e-02],
        [-1.3733e-04,  2.6465e-01, -5.5969e-02,  ..., -2.3499e-02,
          5.6030e-02, -1.1096e-01],
        [ 1.0323e-02, -3.3813e-02,  1.1823e-01,  ..., -1.1328e-01,
          6.9885e-03,  5.8105e-02],
        ...,
        [ 8.0688e-02, -1.6266e-02, -1.6089e-01,  ..., -8.8257e-02,
         -5.4016e-03,  4.3030e-02],
        [-9.3689e-03, -1.9211e-02, -1.8402e-02,  ...,  8.1482e-02,
         -8.1299e-02, -1.7395e-01],
        [ 6.7200e-02, -2.7328e-02,  7.6294e-03,  ..., -1.0999e-01,
         -6.7993e-02, -3.2787e-03]])
Epoch [0] Step [0] Loss: 8.6280
tensor([[-1.1450e-01,  4.7607e-02,  7.6660e-02,  ..., -1.4648e-03,
          6.3293e-02,  6.4758e-02],
        [-1.3733e-04,  2.6465e-01, -5.5969e-02,  ..., -2.3499e-02,
          5.6030e-02, -1.1096e-01],
        [ 1.0323e-02, -3.3813e-02,  1.1823e-01,  ..., -1.1328e-01,
          6.9885e-03,  5.8105e-02],
        ...,
        [ 8.06

KeyboardInterrupt: 