In [86]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.insert(1, '/Users/derekhuang/Documents/Research/sequence_similarity_search/classes')
sys.path.insert(1, '/Users/derekhuang/Documents/Research/fast-soft-sort/fast_soft_sort')
from data_classes import BERTDataset
from dist_perm import DistPerm
import utils
import pytorch_ops
# import torchsort
from sklearn.neighbors import NearestNeighbors
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

def rank(array):
    if torch.cuda.is_available():
        return pytorch_ops.soft_rank(array.cpu(), direction="DESCENDING", regularization_strength=.001).cuda()
#     else:
#         return torchsort.soft_rank(-1 * array, regularization_strength=.001)

class AnchorNet(nn.Module):
    def __init__(self, num_anchs, d, k):
        super(AnchorNet, self).__init__()
        self.anchors = nn.Linear(d, num_anchs)
        self.softmax = nn.Softmax(dim=1)
        self.k = k

    def forward(self, data, query):
        data_rank = torch.clamp(rank(self.anchors(data)), max=k)
        query_rank = torch.clamp(rank(self.anchors(query)), max=k)
        out = self.softmax(torch.matmul(query_rank, data_rank.T))
        return out



cpu


In [87]:
# Hyperparams

n = 100000
D = 128
num_queries = 3200
num_anchors = 128
k = 12

In [88]:
# Load data

file_q = './Q.pt'
file_db = './D.pt'
data_source = BERTDataset(file_q, file_db, n)
db = data_source.generate_db()
data = np.array(db).astype(np.float32)
queries = data_source.generate_queries(num_queries)
quers = np.array(queries).astype(np.float32)

In [89]:
# Helper function for splitting into datasets 

import torch, itertools
from torch.utils.data import TensorDataset, random_split, DataLoader
from sklearn.model_selection import train_test_split

def dataset_split(dataset, train_frac):
    length = len(dataset)
   
    train_length = int(length * train_frac)
    valid_length = int((length - train_length) / 2)
    test_length  = length - train_length - valid_length

    sets = random_split(dataset, (train_length, valid_length, test_length))
    dataset = {name: set for name, set in zip(('train', 'val', 'test'), sets)}
    return dataset


In [90]:
# Generate the nearest neighbors for each batch of docs. Then combine with the queries 
# to produce the dataloader for queries containing data and label
# d = docs
# ret = train/test
def return_loader(d, ret):
    index_l2 = NearestNeighbors()
    index_l2.fit(d)
    q = None
#   Get the right data
    if ret=='query':
        q = next(iter(query_data_loader))
    elif ret=='test':
        q = next(iter(query_data_test_loader))
    elif ret=='val':
        q = next(iter(query_data_val_loader))
#   Get the true nearest neighbors
    true = torch.tensor(index_l2.kneighbors(q, n_neighbors=1)[1])
    query_data = []
    for i in range(q.shape[0]):
        query_data.append([q[i], true[i]])
    test_set = dataset_split(query_data, 1)
#   Return data as a dataloader
    data = torch.utils.data.DataLoader(dataset=test_set['train'], 
                                           batch_size=800, 
                                           shuffle=False)
    return data

In [91]:
# Fits fine in mem
batch_size=3200
train_split = .8

query_datasets = dataset_split(quers, train_split)
# doc_datasets = dataset_split(db, train_split)
doc_datasets = dataset_split(db, train_split)

# The fixed train queries
query_data_loader = torch.utils.data.DataLoader(dataset=query_datasets['train'], 
                                           batch_size=batch_size, 
                                           shuffle=False)

# The fixed test queries 
query_data_test_loader = torch.utils.data.DataLoader(dataset=query_datasets['test'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

# The fixed val queries 
query_data_val_loader = torch.utils.data.DataLoader(dataset=query_datasets['val'], 
                                          batch_size=batch_size, 
                                          shuffle=False)

# The train docs
docs_loader = torch.utils.data.DataLoader(dataset=doc_datasets['train'], 
                                           batch_size=5000, 
                                           shuffle=False)

# The test docs
docs_test_loader = torch.utils.data.DataLoader(dataset=doc_datasets['test'], 
                                           batch_size=5000, 
                                           shuffle=False)

docs_val_loader = torch.utils.data.DataLoader(dataset=doc_datasets['val'], 
                                           batch_size=5000, 
                                           shuffle=False)



In [92]:
model = AnchorNet(num_anchors, D, k).to(device)
criterion = nn.CrossEntropyLoss()
lr=.0005
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)  

In [93]:
# Get the initial performance of random anchors 

with torch.no_grad():
    correct = 0
    total = 0
    for d in docs_val_loader:
    # for d in docs_loader:
        test_loader = return_loader(d, ret='val')
        d = d.to(device)
        for q, l in test_loader:
#             print(q)
            q = q.to(device)
            l = l.to(device)
            outputs = model(d,q)
            test_loss = criterion(outputs, l.squeeze())
#           Recall: TP / (TP + TN)
            _, predicted = torch.max(outputs.data, 1)
            total += l.size(0)
            correct += (predicted == l.flatten()).sum().item()

    print ('Test Loss: {:.4f}, recall_test: {:.4f}'
            .format(test_loss.item(), correct / total))

AssertionError: Torch not compiled with CUDA enabled

In [25]:
# train
for epoch in range(300):
#   Train step: get batch of train documents, then generate the dataloader containing 
#   the train queries and the correct data labels
    train_correct = 0
    train_total = 0
    for step, d in enumerate(docs_loader):
        query_loader = return_loader(d, ret='query')
        d = d.to(device)
        for i, (q, l) in enumerate(query_loader):  
            q = q.to(device)
            l = l.to(device)
            outputs = model(d,q)
            loss = criterion(outputs, l.squeeze())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            train_total += l.size(0)
            train_correct += (predicted == l.flatten()).sum().item()
        # print('Epoch [{}] Step [{}] Loss: {:.4f}'.format(epoch, step, loss.item()))
        
#   Eval step: repeat but with the test documents and test queries      
    if epoch % 1 == 0:
        with torch.no_grad():
            correct = 0
            total = 0
            for d in docs_test_loader:
            # for d in docs_loader:
                test_loader = return_loader(d, ret='test')
                d = d.to(device)
                for q, l in test_loader:
                    q = q.to(device)
                    l = l.to(device)
                    outputs = model(d,q)
                    test_loss = criterion(outputs, l.squeeze())
#                   Recall: TP / (TP + TN)
                    _, predicted = torch.max(outputs.data, 1)
                    total += l.size(0)
                    correct += (predicted == l.flatten()).sum().item()

            print ('Epoch [{}], Loss: {:.4f}, Test Loss: {:.4f}, recall_train: {:.4f}, recall_test: {:.4f}'
                    .format(epoch+1, loss.item(), test_loss.item(), train_correct / train_total, correct / total))

# # Save the model checkpoint
#     if epoch % 20 == 0:
#       torch.save(model.state_dict(), "/content/gdrive/MyDrive/ckpt/model_epoch_{}_anchor_{}_k_{}_lr_{}.pt".format(epoch, num_anchors, k, lr))

KeyboardInterrupt: 