# Project: Semantic Search with Transformers

## Task 1: Import the Libraries

In [1]:
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn import preprocessing
import faiss
import numpy as np
import pickle
import json

import os
try:
    os.chdir('usercode/')
except:
    pass

  from .autonotebook import tqdm as notebook_tqdm


## Task 2: Load the Data

In [2]:
from torch.utils.data import Dataset, DataLoader

In [22]:
with open("arxivData.json", 'r') as f:
    data = json.load(f)

In [3]:
class ArxivDataset(Dataset):
    
    def __init__(self, fpath, transform=None):
        super().__init__()
        self.transform = transform
        with open(fpath, 'r') as f:
            self.data = json.load(f)
        self.id2idx = preprocessing.LabelEncoder()
        self.id2idx.fit_transform([self.data[idx]['id'] for idx in range(len(self.data))])
    
    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        
        arxiv_id = item['id'] 
        author = item['author']
        year, month, day = item['year'], item['month'], item['day']
        link = item['link']
        summary = item['summary']
        tag = item['tag']

        if self.transform:
            inputs = self.transform(summary)
        else:
            inputs = summary

        return inputs, arxiv_id, author, year, month, day, link, tag

In [29]:
dataset = ArxivDataset(fpath="arxivData.json")
data_loader = DataLoader(dataset, batch_size=16)

## Task 3: Retrieve the Model

In [6]:
model = SentenceTransformer('sentence-transformers/distilbert-base-nli-stsb-mean-tokens')

Downloading model_quint8_avx2.onnx:   0%|          | 0.00/67.0M [00:00<?, ?B/s]Downloading model_quint8_avx2.onnx:  16%|█▌        | 10.5M/67.0M [00:00<00:00, 80.0MB/s]Downloading model_quint8_avx2.onnx:  47%|████▋     | 31.5M/67.0M [00:00<00:00, 132MB/s] Downloading model_quint8_avx2.onnx:  78%|███████▊  | 52.4M/67.0M [00:00<00:00, 155MB/s]Downloading model_quint8_avx2.onnx: 100%|██████████| 67.0M/67.0M [00:00<00:00, 146MB/s]
Downloading openvino_model.bin:   0%|          | 0.00/265M [00:00<?, ?B/s]Downloading openvino_model.bin:   4%|▍         | 10.5M/265M [00:00<00:03, 81.9MB/s]Downloading openvino_model.bin:  12%|█▏        | 31.5M/265M [00:00<00:01, 135MB/s] Downloading openvino_model.bin:  20%|█▉        | 52.4M/265M [00:00<00:01, 156MB/s]Downloading openvino_model.bin:  28%|██▊       | 73.4M/265M [00:00<00:01, 168MB/s]Downloading openvino_model.bin:  36%|███▌      | 94.4M/265M [00:00<00:00, 174MB/s]Downloading openvino_model.bin:  43%|████▎     | 115M/265M [00:00<00:00, 

In [7]:
if torch.cuda.is_available():
    model = model.to(torch.device("cuda"))

## Task 4: Generate or Load the Embeddings

In [30]:
embeddings = model.encode(next(iter(data_loader))[0][0])

In [31]:
embeddings

array([-1.09982237e-01,  1.64143652e-01,  6.77781641e-01,  5.53463846e-02,
       -5.35667986e-02,  3.31018150e-01,  3.55745554e-01, -4.42226768e-01,
       -1.04354627e-01, -1.33925748e+00, -6.32903874e-02,  9.93281603e-01,
       -4.75987703e-01,  2.11563453e-01,  2.54530936e-01,  2.63086587e-01,
        1.14939797e+00, -1.39734372e-01, -1.43785283e-01,  8.60709101e-02,
        9.56532657e-01,  1.09347209e-01, -1.67078406e-01,  6.77422166e-01,
       -6.76144660e-02, -4.19904701e-02,  5.63697338e-01,  9.84705448e-01,
        4.20292884e-01, -1.94895357e-01,  2.00043291e-01, -8.29737782e-01,
       -2.93100476e-01, -1.27245232e-01,  3.93142849e-01,  8.00250173e-01,
       -3.63038838e-01, -4.19262141e-01, -3.02977204e-01, -7.83953965e-01,
        2.23564565e-01,  3.50907326e-01, -2.59092115e-02,  1.58728272e-01,
       -6.72377169e-01,  3.67020547e-01, -3.11774343e-01,  7.09333539e-01,
       -8.87611032e-01, -5.24797916e-01, -4.38391596e-01, -2.48129398e-01,
       -2.39898100e-01, -

## Task 5: Data Preparation and Helper Methods

In [32]:
dataset.id2idx.transform(next(iter(data_loader))[1])

array([36693, 18198, 19318, 27779, 31468, 32183, 36310, 21232, 21534,
       27752, 13830, 37032, 18446, 18514, 18715, 23572])

## Task 6: Set up the Index

In [33]:
d = embeddings.shape[0]
nlists = 100
metric = faiss.METRIC_INNER_PRODUCT

quantizer = faiss.IndexFlatIP(d)
faiss_index = faiss.IndexIVFFlat(quantizer, d, nlists, metric)

In [34]:
nth_b = 0
embed_list = []
for b in data_loader:
    
    nth_b += 1
    if nth_b % 1000 == 0:
        print(f"current batch number = {nth_b}")

    embed_list.append(model.encode(b[0]))

In [None]:
faiss_index.train()

## Task 7: Search with a Summary

## Task 8: Search with a Prompt


# End