In [1]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.insert(1, '/Users/derekhuang/Documents/Research/sequence_similarity_search/classes')
sys.path.insert(1, '/Users/derekhuang/Documents/Research/fast-soft-sort/fast_soft_sort')
from data_classes import BERTDataset
from dist_perm import DistPerm
import utils
import pytorch_ops
# import torchsort
from sklearn.neighbors import NearestNeighbors
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

def rank(array):
    if not torch.cuda.is_available():
        return pytorch_ops.soft_rank(array.cpu(), direction="DESCENDING", regularization_strength=.00001)
    # else:
        # return torchsort.soft_rank(-1 * array, regularization_strength=.00001)

class AnchorNet(nn.Module):
    def __init__(self, num_anchs, d, k):
        super(AnchorNet, self).__init__()
        self.anchors = nn.Linear(d, num_anchs)
        self.softmax = nn.Softmax(dim=1)
        self.k = k

    def forward(self, data, query):
        data_rank = torch.clamp(rank(self.anchors(data)), max=k)
        query_rank = torch.clamp(rank(self.anchors(query)), max=k)
        out = self.softmax(torch.matmul(query_rank, data_rank.T))
        return out

    def evaluate(self, data, query):
        q_dist = self.anchors(query)
        d_dist = self.anchors(data)

        query_ranks = self.k*torch.ones(q_dist.shape, dtype=torch.float)
        data_ranks = self.k*torch.ones(d_dist.shape, dtype=torch.float)

        query_ids = torch.argsort(q_dist, dim=1)[:, :self.k]
        data_ids = torch.argsort(d_dist, dim=1)[:, :self.k]

        q_ids = torch.arange(query.shape[0])[:,None]
        d_ids = torch.arange(data.shape[0])[:,None]
        query_ranks[q_ids, query_ids] = torch.arange(self.k, dtype=torch.float)
        data_ranks[d_ids, data_ids] = torch.arange(self.k, dtype=torch.float)

        db_dists = torch.cdist(data_ranks, query_ranks, p=1).float()
        closest_idx = torch.topk(db_dists, 10, dim=0, largest=False)
        return closest_idx[1].transpose(0,1), query_ranks, closest_idx[0]



cpu


In [2]:
# Hyperparams

n = 100000
D = 128
num_queries = 3200
num_anchors = 128
k = 12

In [3]:
# Load data

file_q = './Q.pt'
file_db = './D.pt'
data_source = BERTDataset(file_q, file_db, n)
db = data_source.generate_db()
data = np.array(db).astype(np.float32)
queries = data_source.generate_queries(num_queries)
quers = np.array(queries).astype(np.float32)

In [4]:
# Helper function for splitting into datasets 

import torch, itertools
from torch.utils.data import TensorDataset, random_split, DataLoader
from sklearn.model_selection import train_test_split

def dataset_split(dataset, train_frac):
    length = len(dataset)
   
    train_length = int(length * train_frac)
    valid_length = int((length - train_length) / 2)
    test_length  = length - train_length - valid_length

    sets = random_split(dataset, (train_length, valid_length, test_length))
    dataset = {name: set for name, set in zip(('train', 'val', 'test'), sets)}
    return dataset


In [5]:
# Generate the nearest neighbors for each batch of docs. Then combine with the queries 
# to produce the dataloader for queries containing data and label
# d = docs
# ret = train/test
def return_loader(d, ret):
    index_l2 = NearestNeighbors()
    index_l2.fit(d)
    q = None
#   Get the right data
    if ret=='query':
        q = next(iter(query_data_loader))
    elif ret=='test':
        q = next(iter(query_data_test_loader))
    elif ret=='val':
        q = next(iter(query_data_val_loader))
#   Get the true nearest neighbors
    true = torch.tensor(index_l2.kneighbors(q, n_neighbors=1)[1])
    query_data = []
    for i in range(q.shape[0]):
        query_data.append([q[i], true[i]])
    test_set = dataset_split(query_data, 1)
#   Return data as a dataloader
    data = torch.utils.data.DataLoader(dataset=test_set['train'], 
                                           batch_size=320, 
                                           shuffle=False)
    return data

In [6]:
# Fits fine in mem
batch_size=3200
train_split = .8

query_datasets = dataset_split(quers, train_split)
# doc_datasets = dataset_split(db, train_split)
doc_datasets = dataset_split(db, train_split)

# The fixed train queries
query_data_loader = torch.utils.data.DataLoader(dataset=query_datasets['train'], 
                                           batch_size=batch_size, 
                                           shuffle=False)

# The fixed test queries 
query_data_test_loader = torch.utils.data.DataLoader(dataset=query_datasets['test'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

# The fixed val queries 
query_data_val_loader = torch.utils.data.DataLoader(dataset=query_datasets['val'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

# The train docs
docs_loader = torch.utils.data.DataLoader(dataset=doc_datasets['train'], 
                                           batch_size=5000, 
                                           shuffle=True)

# The test docs
docs_test_loader = torch.utils.data.DataLoader(dataset=doc_datasets['test'], 
                                           batch_size=5000, 
                                           shuffle=False)

docs_val_loader = torch.utils.data.DataLoader(dataset=doc_datasets['val'], 
                                           batch_size=5000, 
                                           shuffle=False)



In [7]:
model = AnchorNet(num_anchors, D, k).to(device)
criterion = nn.CrossEntropyLoss()
lr=.0005
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)  

In [8]:
with torch.no_grad():
    correct = 0
    total = 0
    for d in docs_test_loader:
    # for d in docs_loader:
        test_loader = return_loader(d, ret='test')
        d = d.to(device)
        for q, l in test_loader:
#             print(q)
            q = q.to(device)
            l = l.to(device)
            outputs = model.evaluate(d,q)
#           Recall: TP / (TP + TN)
            predicted = outputs[0][:,0]
            total += l.size(0)
            correct += (predicted.to(device) == l.flatten()).sum().item()

    print ('recall_test: {:.4f}'
            .format(correct / total))

Test Loss: 8.1256, recall_test: 0.3547


In [9]:
# train
for epoch in range(1000):
#   Train step: get batch of train documents, then generate the dataloader containing 
#   the train queries and the correct data labels
    train_correct = 0
    train_total = 0
    for step, d in enumerate(docs_loader):
        query_loader = return_loader(d, ret='query')
        d = d.to(device)
        for i, (q, l) in enumerate(query_loader):  
            q = q.to(device)
            l = l.to(device)
            outputs = model(d,q)
            loss = criterion(outputs, l.squeeze())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            train_total += l.size(0)
            train_correct += (predicted == l.flatten()).sum().item()
        # print('Epoch [{}] Step [{}] Loss: {:.4f}'.format(epoch, step, loss.item()))
        
#   Eval step: repeat but with the test documents and test queries      
    if epoch % 5 == 0:
        with torch.no_grad():
            correct = 0
            total = 0
            correct_val = 0
            total_val = 0
            for d in docs_test_loader:
            # for d in docs_loader:
                test_loader = return_loader(d, ret='test')
                d = d.to(device)
                for q, l in test_loader:
                    q = q.to(device)
                    l = l.to(device)
                    lol = model(d,q)
                    test_loss = criterion(lol, l.squeeze())
                    outputs = model.evaluate(d,q)
        #           Recall: TP / (TP + TN)
                    predicted = outputs[0][:,0]
                    total += l.size(0)
                    correct += (predicted.to(device) == l.flatten()).sum().item()


            for d in docs_val_loader:
            # for d in docs_loader:
                test_loader = return_loader(d, ret='val')
                d = d.to(device)
                for q, l in test_loader:
                    q = q.to(device)
                    l = l.to(device)
                    outputs = model.evaluate(d,q)
        #           Recall: TP / (TP + TN)
                    predicted = outputs[0][:,0]
                    total_val += l.size(0)
                    correct_val += (predicted.to(device) == l.flatten()).sum().item()

            print ('Epoch [{}], Loss: {:.4f}, Test Loss: {:.4f}, recall_train: {:.4f}, recall_test: {:.4f} recall_val: {:.4f}'
                    .format(epoch+1, loss.item(), test_loss.item(), train_correct / train_total, correct / total, correct_val / total_val))

# # Save the model checkpoint
    # if epoch % 20 == 0:
    #   torch.save(model.state_dict(), "/content/gdrive/MyDrive/ckpt/model_epoch_{}_anchor_{}_k_{}_lr_{}.pt".format(epoch, num_anchors, k, lr))

Epoch [1], Loss: 8.1558, Test Loss: 8.1436, recall_train: 0.3626, recall_test: 0.3391 recall_val: 0.3828
Epoch [2], Loss: 8.0732, Test Loss: 8.1298, recall_train: 0.3819, recall_test: 0.3578 recall_val: 0.4484
Epoch [3], Loss: 8.0337, Test Loss: 8.1129, recall_train: 0.4188, recall_test: 0.3766 recall_val: 0.4359
Epoch [4], Loss: 8.1161, Test Loss: 8.1130, recall_train: 0.4191, recall_test: 0.3781 recall_val: 0.4688
Epoch [5], Loss: 8.0201, Test Loss: 8.0792, recall_train: 0.4360, recall_test: 0.4250 recall_val: 0.4547
Epoch [6], Loss: 8.0584, Test Loss: 8.1136, recall_train: 0.4333, recall_test: 0.4031 recall_val: 0.4313
Epoch [7], Loss: 8.0330, Test Loss: 8.0777, recall_train: 0.4420, recall_test: 0.4172 recall_val: 0.4500
Epoch [8], Loss: 8.0429, Test Loss: 8.1260, recall_train: 0.4581, recall_test: 0.3891 recall_val: 0.4625
Epoch [9], Loss: 8.0847, Test Loss: 8.1178, recall_train: 0.4570, recall_test: 0.3844 recall_val: 0.4469
Epoch [10], Loss: 8.0816, Test Loss: 8.1045, recall_tra

Epoch [79], Loss: 7.9396, Test Loss: 7.9996, recall_train: 0.5534, recall_test: 0.5062 recall_val: 0.5469
Epoch [80], Loss: 7.9441, Test Loss: 7.9979, recall_train: 0.5596, recall_test: 0.5125 recall_val: 0.5641
Epoch [81], Loss: 7.9477, Test Loss: 8.0076, recall_train: 0.5512, recall_test: 0.4953 recall_val: 0.5437
Epoch [82], Loss: 7.9541, Test Loss: 8.0082, recall_train: 0.5445, recall_test: 0.4813 recall_val: 0.5484
Epoch [83], Loss: 8.0131, Test Loss: 7.9815, recall_train: 0.5444, recall_test: 0.5141 recall_val: 0.5625
Epoch [84], Loss: 7.9484, Test Loss: 7.9990, recall_train: 0.5497, recall_test: 0.5344 recall_val: 0.5453
Epoch [85], Loss: 7.9461, Test Loss: 8.0126, recall_train: 0.5481, recall_test: 0.5172 recall_val: 0.5422
Epoch [86], Loss: 7.9178, Test Loss: 7.9996, recall_train: 0.5481, recall_test: 0.5312 recall_val: 0.5328
Epoch [87], Loss: 7.9803, Test Loss: 8.0156, recall_train: 0.5527, recall_test: 0.5266 recall_val: 0.5578
Epoch [88], Loss: 7.9304, Test Loss: 8.0261, r

KeyboardInterrupt: 