In [1]:
import torch.nn.functional as F
import torch
import torch.nn as nn
import sys
sys.path.insert(1, '/Users/derekhuang/Documents/Research/sequence_similarity_search/classes')
sys.path.insert(1, '/Users/derekhuang/Documents/Research/fast-soft-sort/fast_soft_sort')
from data_classes import BERTDataset
from dist_perm import DistPerm
import utils
from pytorch_ops import soft_rank
from sklearn.neighbors import NearestNeighbors
import numpy as np

class AnchorNet(nn.Module):
    def __init__(self, num_anchs, d):
        super(AnchorNet, self).__init__()
        self.anchors = nn.Linear(d, num_anchs)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, data, query):

        data_rank = soft_rank(self.anchors(data), direction="DESCENDING")
        query_rank = soft_rank(self.anchors(query), direction="DESCENDING")
        out = self.softmax(torch.matmul(query_rank, data_rank.T))
        return out



In [5]:

n = 10000
D = 128
num_queries = 500
num_anchors = 128
R = 100
k = 64

file_q = '../datasets/Q.pt'
file_db = '../datasets/D.pt'
data_source = BERTDataset(file_q, file_db, n)
db = data_source.generate_db()
data = np.array(db).astype(np.float32)
queries = data_source.generate_queries(num_queries)
quers = np.array(queries).astype(np.float32)
index_l2 = NearestNeighbors()
index_l2.fit(data)
true = torch.tensor(index_l2.kneighbors(quers, n_neighbors=1)[1])
truth = torch.nn.functional.one_hot(true.squeeze(), n)


In [6]:
model = AnchorNet(num_anchors, D)
print(model.forward(db,queries).shape)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=.1)  

index_dp = DistPerm(num_anchors, k=k, dist='dot')
index_dp.fit(db)
dp_hashes = index_dp.add(db)
found_dp = index_dp.search(queries, R)
found_dp = found_dp.numpy()

torch.Size([500, 10000])


In [10]:
for epoch in range(1000):
    outputs = model(db,queries)
    loss = criterion(outputs, true.squeeze())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    score = utils.mean_avg_precision(np.argmax(outputs.detach().numpy()[...,np.newaxis], axis=1), true)[0]
    print ('Epoch [{}], Loss: {:.4f}, MAP: {:.4f}'.format(epoch+1, loss.item(), score))

# # Save the model checkpoint
# torch.save(model.state_dict(), 'model.ckpt')

Epoch [1], Loss: 8.3656, MAP: 0.8440
Epoch [2], Loss: 8.3656, MAP: 0.8440
Epoch [3], Loss: 8.3656, MAP: 0.8440
Epoch [4], Loss: 8.3656, MAP: 0.8440
Epoch [5], Loss: 8.3656, MAP: 0.8440
Epoch [6], Loss: 8.3656, MAP: 0.8440
Epoch [7], Loss: 8.3656, MAP: 0.8440
Epoch [8], Loss: 8.3656, MAP: 0.8440
Epoch [9], Loss: 8.3656, MAP: 0.8440
Epoch [10], Loss: 8.3656, MAP: 0.8440
Epoch [11], Loss: 8.3656, MAP: 0.8440


KeyboardInterrupt: 

In [None]:
lin_weights = model.anchors.state_dict()['weight']
print(lin_weights.shape)