In [14]:
import os
import sys

sys.path.append(os.path.join("..", os.getcwd()))

import hydra
import pandas as pd
import pytorch_lightning as pl
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import classification_report
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from src.dataset import CSVDataset
from src.multihead_attn import TransformerEncoder
from torch.utils.data import DataLoader

from src.utils import remove_duplicate_strings

# device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device("cpu")


In [2]:
bi_encoder_model_name = "pritamdeka/PubMedBERT-mnli-snli-scinli-stsb"

In [36]:
class Retriever(nn.Module):
    """Given a list of evidences and a claim, this returns the top-k evidences"""
    def __init__(self, cfg: DictConfig,train_evidences=[], val_evidences=[], test_evidences=[]):
        super().__init__()
        self.bi_encoder = SentenceTransformer(cfg.bi_encoder_model_name, cache_folder="../cache")
        # self.bi_encoder.eval()
        self.bi_encoder.requires_grad_(False)
        self.k = cfg.k
        self.data = pd.read_csv(cfg.csv_file)
        self.train_evidence_pool = train_evidences
        self.train_evidence_embeddings = self.bi_encoder.encode(train_evidences, convert_to_tensor=True)
        self.val_evidence_pool = val_evidences
        self.val_evidence_embeddings = self.bi_encoder.encode(val_evidences, convert_to_tensor=True)
        self.test_evidence_pool = test_evidences
        self.test_evidence_embeddings = self.bi_encoder.encode(test_evidences, convert_to_tensor=True)
    def tokenize_and_embed(self, data):
        # data -> [b]
        return self.bi_encoder.encode([data], convert_to_tensor=True)
    
    def set_encoder_training(self, mode):
        self.bi_encoder.train(mode)
    
    def forward(self, x, mode="train"):
        # x -> b, claims
        x = self.bi_encoder.encode(x, convert_to_tensor=True)
        # scores -> b, num_evidences, each row is the cosine similarity b/w the claim
        # and all the evidences
        if mode == "train":
            self.evidence_embeddings = self.train_evidence_embeddings
            self.evidence_pool = self.train_evidence_pool
        elif mode == "val":
            self.evidence_embeddings = self.val_evidence_embeddings
            self.evidence_pool = self.val_evidence_pool
        elif mode == "test":
            self.evidence_embeddings = self.test_evidence_embeddings
            self.evidence_pool = self.test_evidence_pool
        cos_sim = torch.mm(x, self.evidence_embeddings.T)
        scores, indices = torch.topk(cos_sim, self.k, dim=1)
        evidences = [[self.evidence_pool[i] for i in row] for row in indices]
        return scores, indices, evidences

In [37]:
class Ranker(nn.Module):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        # self.cross_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.cross_encoder_model_name, cache_dir="../cache")
        self.cross_encoder = SentenceTransformer(cfg.bi_encoder_model_name, cache_folder="../cache", device="cpu")
        self.cross_encoder.train(True)
        # self.cross_encoder.classifier = nn.Identity()  # remove the last classifier layer
        # self.tokenizer = AutoTokenizer.from_pretrained(cfg.cross_encoder_model_name)
        hidden_size = 768
        self.scorer = nn.Linear(hidden_size, 1)
        
    def forward(self, x, evidence_pool):
        # x -> b, claim
        # evidence_pool -> list of evidence strings
        # create claim embedding pair
        embeddings = []
        for i, claim in enumerate(x):
            evidences = evidence_pool[i]
            claim_pairs = [f"[CLS] {claim} [SEP] {evidence} [SEP]" for evidence in evidences]
            encoded = self.cross_encoder.encode(claim_pairs, convert_to_tensor=True)
            embeddings.append(encoded)
        # Convert list of tensors to a single tensor
        embeddings_tensor = torch.stack(embeddings)
        # print(f"{type(embeddings_tensor)}")
        return embeddings_tensor

In [38]:
class Classifier(nn.Module):
    def __init__(self, input_dim, out_classes=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, out_classes)
        )
    def forward(self, x):
        x = torch.mean(x, dim=1)
        print(f"classifier after mean {x.shape}")
        x = self.mlp(x)
        # print("classifier", x, x.shape)
        return x, F.softmax(x, dim=-1)

In [97]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("./config/config.yaml")

train_dataset = CSVDataset(file_path="/Users/rateria/Code/cs-5787-final-project/data/csv/train.csv")
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)

val_dataset = CSVDataset(file_path="/Users/rateria/Code/cs-5787-final-project/data/csv/val.csv")
val_dataloader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False)

test_dataset = CSVDataset(file_path="/Users/rateria/Code/cs-5787-final-project/data/csv/test.csv")
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

In [35]:
retriever = Retriever(cfg, train_evidences=train_dataset.get_evidences(), val_evidences=val_dataset.get_evidences(), test_evidences=test_dataset.get_evidences())
ranker = Ranker(cfg)
attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)

KeyboardInterrupt: 

In [39]:
classifier = Classifier(768)

In [30]:
for batch in train_dataloader:
    print(batch)
    x, y = batch
    # rav = RAV()
    # print(rav)
    scores, indices, evidences = retriever(x)
    print(scores.shape, indices.shape, len(evidences))
    # print(f"\n{evidences}\n")
    # print("initial device", x.device)
    h = ranker(x, evidences)
    h = h.to("cpu")
    print(h.shape, h.device)
    enc_out = attn(h)
    print(enc_out.shape)
    logits, out = classifier(enc_out)
    print(out, out.shape)
    loss = nn.CrossEntropyLoss()(logits, y)
    print(f"loss: {loss}")
    break

[('Nateglinide is particularly effective for controlling postprandial hyperglycemia due to its short duration of action.', 'While screening tests like CH50 and AH50 are used to evaluate complement function, their effectiveness in predicting clinical outcomes may vary.', 'The combination of Lansoprazole, Amoxicillin, and Clarithromycin is an effective treatment for Helicobacter pylori infections and related duodenal ulcers.', 'The vaccine plays a crucial role in public health efforts to control infectious diseases.', 'Awareness of tinea nigra is crucial for proper management and treatment of the condition.', 'Diagnosis of acute pericarditis typically involves analyzing ECG changes and performing echocardiography.', 'While accurate diagnosis might help some patients, it is unclear how much it influences overall mental health outcomes.', 'The management of asthma during pregnancy does not require careful consideration of pharmacological interventions.', 'Patient history is not a reliable 

RuntimeError: Tensor for argument #1 'mat1' is on CPU, but expected it to be on GPU (while checking arguments for mm)

In [116]:
import pytorch_lightning as pl
import torch
import torch.nn as nn

class MyLightningModule(pl.LightningModule):
    def __init__(self, retriever, ranker, attn, classifier, learning_rate=1e-3):
        """
        Args:
            retriever: Module for retrieving evidence.
            ranker: Module for ranking evidence.
            attn: Attention module.
            classifier: Classification module.
            learning_rate: Learning rate for the optimizer.
        """
        super().__init__()
        self.retriever = retriever
        self.ranker = ranker
        self.attn = attn
        self.classifier = classifier
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate
        
        # Freeze the retriever
        for param in self.retriever.parameters():
            param.requires_grad = False
        
        self.retriever.eval()
        
        # for param in self.ranker.parameters():
        #     param.requires_grad = False
        
        # self.ranker.eval()
        
        self.save_hyperparameters()

    def forward(self, x):
        # Forward pass to process the input
        scores, indices, evidences = self.retriever(x)
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        out = self.classifier(enc_out)
        return out

    def training_step(self, batch, batch_idx):
        # Single training step
        x, y = batch
        scores, indices, evidences = self.retriever(x, mode="train")
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out)

        # Compute loss
        loss = self.loss_fn(logits, y)
        print(f"loss: {loss}")
        self.log("train_loss", loss)  # Log the loss for monitoring
        return loss

    def configure_optimizers(self):
        # Optimizer configuration
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def validation_step(self, batch, batch_idx):
        # Validation step (optional, similar to training_step)
        x, y = batch
        scores, indices, evidences = self.retriever(x, mode="val")
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out)

        # Compute loss
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss)  # Log the validation loss
        return loss


In [117]:
# retriever = Retriever(cfg)
# ranker = Ranker(cfg)
# attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)
# classifier = Classifier(768)

In [118]:
import wandb
from lightning.pytorch.loggers import WandbLogger
wandb.init(project="rav")
wandb_logger = WandbLogger(project="rav")

In [119]:
model = MyLightningModule(retriever, ranker, attn, classifier, learning_rate=1e-3)

/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'retriever' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['retriever'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ranker' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ranker'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'attn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['attn'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'classifier' is an instance of `nn.Module` and is alrea

In [122]:
trainer = pl.Trainer(max_epochs=10, logger=wandb_logger, accelerator="mps", devices=1)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [123]:
trainer.fit(model, train_dataloader)

/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory ./rav/gw1s3g9a/checkpoints exists and is not empty.

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | retriever  | Retriever          | 109 M  | eval 
1 | ranker     | Ranker             | 109 M  | eval 
2 | attn       | TransformerEncoder | 10.3 M | train
3 | classifier | Classifier         | 98.8 K | train
4 | loss_fn    | CrossEntropyLoss   | 0      | train
----------------------------------------------------------
10.4 M    Trainable params
218 M     Non-trainable params
229 M     Total params
917.263   Total estimated model params size (MB)
56        Modules in train mode
465       Modules in eval mode


Epoch 0:   0%|          | 0/59 [00:00<?, ?it/s] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

### Rough Stuff Here

In [11]:
loss = nn.CrossEntropyLoss()

In [13]:
loss(torch.tensor([1, 0, 1], dtype=torch.float32), torch.tensor([0, 1, 0], dtype=torch.float32))

tensor(1.8620)

In [74]:
import torch
import torch.nn as nn

# Define CrossEntropyLoss
criterion = nn.CrossEntropyLoss()

# Predictions (logits): Shape (N, C)
logits = torch.tensor([[2.0, 1.0, 0.1]])  # Batch of size 1, 3 classes

# Target labels: Shape (N)
targets = torch.tensor([2])  # Class 0 is the correct label

# Compute loss
loss = criterion(logits, targets)
print(f"Loss: {loss.item()}")

Loss: 2.3170299530029297
