In [None]:
import torch
import faiss
import numpy as np
from transformers import AutoTokenizer
from model import SimCSEModel
from dataset import SimCSEDataset

In [None]:
MODEL_NAME = "bert-base-uncased"
MAX_LEN = 32
CHECKPOINT_PATH = './checkpoint/model_epoch_3.pth'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
# Prepare dataset and dataloader
dataset = SimCSEDataset('data/simple_corpus.txt', tokenizer, max_len=MAX_LEN)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
# Load model
model = SimCSEModel(MODEL_NAME).cuda()
model.load_state_dict(torch.load(CHECKPOINT_PATH))
model.eval()

In [None]:
def get_embeddings(dataloader):
    embeddings = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask = batch
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            emb = model(input_ids, attention_mask=attention_mask)
            embeddings.append(emb.cpu().numpy())
    return np.vstack(embeddings)

In [None]:
# Get all embeddings
embeddings = get_embeddings(dataloader)

In [None]:
# Faiss index creation
dim = embeddings.shape[1]  # dimension of the embeddings
index = faiss.IndexFlatL2(dim)  # Use L2 distance for similarity
index.add(embeddings)  # Add embeddings to the index


In [None]:
def search(query, k=5):
    query_tokens = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN)
    query_input_ids = query_tokens['input_ids'].cuda()
    query_attention_mask = query_tokens['attention_mask'].cuda()

    query_embedding = model(query_input_ids, attention_mask=query_attention_mask).cpu().numpy()

    # Search in the Faiss index
    distances, indices = index.search(query_embedding, k)
    return distances, indices

In [None]:
query = "I love you"
distances, indices = search(query)

for i, idx in enumerate(indices[0]):
    print(f"Rank {i + 1}: {dataset[idx]} | Distance: {distances[0][i]:.4f}")