In [1]:
import os
import sys

sys.path.append(os.path.join("..", os.getcwd()))

import hydra
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from src.dataset import CSVDataset
from torch.utils.data import DataLoader

from src.utils import remove_duplicate_strings

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"


  Referenced from: <0F72FEF0-4DF1-3E8A-90BA-513122A1950F> /Users/rateria/anaconda3/envs/d2l/lib/python3.9/site-packages/torchvision/image.so
  warn(


In [2]:
bi_encoder_model_name = "pritamdeka/PubMedBERT-mnli-snli-scinli-stsb"
cross_encoder_model_path = "/content/drive/MyDrive/fine_tuned_cross_encoder"

class Retriever(nn.Module):
    """Given a list of evidences and a claim, this returns the top-k evidences"""
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.bi_encoder = SentenceTransformer(cfg.bi_encoder_model_name)
        self.k = cfg.k
        self.data = pd.read_csv(cfg.csv_file)
        self.evidence_pool = remove_duplicate_strings(self.data["Evidence"].dropna().tolist())
        self.evidence_embeddings = self.bi_encoder.encode(self.evidence_pool, convert_to_tensor=True)
        
    def tokenize_and_embed(self, data):
        # data -> [b]
        return self.bi_encoder.encode([data], convert_to_tensor=True)
    
    def set_encoder_training(self, mode):
        self.bi_encoder.train(mode)
    
    def forward(self, x):
        # x -> b, claims
        x = self.bi_encoder.encode(x, convert_to_tensor=True)
        # scores -> b, num_evidences, each row is the cosine similarity b/w the claim
        # and all the evidences
        cos_sim = torch.mm(x, self.evidence_embeddings.T)
        scores, indices = torch.topk(cos_sim, self.k, dim=1)
        evidences = [[self.evidence_pool[i] for i in row] for row in indices]
        return scores, indices, evidences
   

In [3]:
class Ranker(nn.Module):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cross_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.cross_encoder_model_name)
        self.cross_encoder.classifier = nn.Identity()  # remove the last classifier layer
        self.mlp = nn.Sequential(
            nn.Linear(self.cross_encoder.config.hidden_size, 128),
            nn.ReLU(),
            nn.Linear(128, 3)
        )
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.cross_encoder_model_name)
        
    def forward(self, x, evidence_pool):
        # x -> b, claims
        # evidence_pool -> list of evidence strings
        # create claim embedding pair
        pairs = [[f"{claim} [SEP] {evidence}" for evidence in evidence_pool] for claim in x]
        tokenized = self.tokenizer(*pairs, padding=True, truncation="longest_first", return_tensors="pt", max_length=100)
        h = self.cross_encoder(**tokenized, return_dict=False)
        s_hat = torch.softmax(self.mlp(h), dim=1)
        return s_hat


In [6]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("./config/config.yaml")

dataset = CSVDataset(file_path="/Users/rateria/Code/cs-5787-final-project/data/csv/generated_claim_triplets_with_topics.csv")
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
retriever = Retriever(cfg)
ranker = Ranker(cfg)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at pritamdeka/PubMedBERT-mnli-snli-scinli-stsb and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
batch = next(iter(data_loader))
x,y = batch

In [17]:
ranker.cross_encoder(x, dataset.data['Evidence'][:10])

AttributeError: 'tuple' object has no attribute 'size'

In [12]:
batch[0]

("There were no limitations that influenced the agent's ability to operate continuously.",)