# Project: Semantic Search with Transformers

## Task 1: Import the Libraries

In [1]:
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn import preprocessing
import faiss
import numpy as np
import json

import os

  from .autonotebook import tqdm as notebook_tqdm


## Task 2: Load the Data

In [2]:
from torch.utils.data import Dataset, DataLoader

In [3]:
with open("arxivData.json", 'r') as f:
    data = json.load(f)

In [4]:
class ArxivDataset(Dataset):
    
    def __init__(self, fpath, transform=None):
        super().__init__()
        self.transform = transform
        with open(fpath, 'r') as f:
            self.data = json.load(f)
        self.id2idx = preprocessing.LabelEncoder()
        self.id2idx.fit_transform([self.data[idx]['id'] for idx in range(len(self.data))])
    
    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        arxiv_id = item['id'] 
        author = item['author']
        year, month, day = item['year'], item['month'], item['day']
        link = item['link']
        summary = item['summary']
        tag = item['tag']

        if self.transform:
            inputs = self.transform(summary)
        else:
            inputs = summary

        return inputs, arxiv_id, author, year, month, day, link, tag

In [5]:
dataset = ArxivDataset(fpath="arxivData.json")
data_loader = DataLoader(dataset, batch_size=16)

## Task 3: Retrieve the Model

In [6]:
model = SentenceTransformer('sentence-transformers/distilbert-base-nli-stsb-mean-tokens')

In [7]:
if torch.cuda.is_available():
    model = model.to(torch.device("cuda"))

print(model.device)

mps:0


## Task 4: Generate or Load the Embeddings

In [8]:
embeddings = model.encode(next(iter(data_loader))[0][0])

In [9]:
embeddings

array([-1.09980330e-01,  1.64144143e-01,  6.77780747e-01,  5.53460531e-02,
       -5.35669886e-02,  3.31018567e-01,  3.55745703e-01, -4.42227423e-01,
       -1.04353681e-01, -1.33925772e+00, -6.32905960e-02,  9.93281960e-01,
       -4.75986868e-01,  2.11562842e-01,  2.54529744e-01,  2.63086885e-01,
        1.14939737e+00, -1.39734000e-01, -1.43783480e-01,  8.60705897e-02,
        9.56531882e-01,  1.09346434e-01, -1.67078167e-01,  6.77421212e-01,
       -6.76153228e-02, -4.19904813e-02,  5.63696623e-01,  9.84705031e-01,
        4.20293927e-01, -1.94894642e-01,  2.00043604e-01, -8.29738498e-01,
       -2.93100387e-01, -1.27244920e-01,  3.93142968e-01,  8.00250292e-01,
       -3.63039196e-01, -4.19262350e-01, -3.02979589e-01, -7.83954263e-01,
        2.23565817e-01,  3.50907475e-01, -2.59092506e-02,  1.58727154e-01,
       -6.72377467e-01,  3.67021143e-01, -3.11773300e-01,  7.09333122e-01,
       -8.87612760e-01, -5.24797678e-01, -4.38391060e-01, -2.48128951e-01,
       -2.39898294e-01, -

## Task 5: Data Preparation and Helper Methods

In [10]:
dataset.id2idx.transform(next(iter(data_loader))[1])

array([36693, 18198, 19318, 27779, 31468, 32183, 36310, 21232, 21534,
       27752, 13830, 37032, 18446, 18514, 18715, 23572])

In [None]:
nth_b = 0
embed_list = []
for b in data_loader:
    
    nth_b += 1
    embed_list.append(model.encode(b[0]))

    if nth_b % 5000 == 0:
        print(f"current batch number = {nth_b}")
        break

In [12]:
xb = np.concatenate(embed_list)

## Task 6: Set up the Index

In [13]:
d = embeddings.shape[0]
nlists = 100
metric = faiss.METRIC_INNER_PRODUCT

quantizer = faiss.IndexFlatIP(d)
faiss_index = faiss.IndexIVFFlat(quantizer, d, nlists, metric)

In [14]:
assert not faiss_index.is_trained

In [15]:
faiss_index.train(xb)

: 

In [None]:
assert faiss_index.is_trained

## Task 7: Search with a Summary

## Task 8: Search with a Prompt


# End