In [13]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("./config/config.yaml")

In [16]:
import os
import sys


sys.path.append(cfg.base_path)
os.environ['TOKENIZERS_PARALLELISM'] = 'false'


In [17]:
# import hydra
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
# from omegaconf import DictConfig
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.dataset import CSVDataset
from src.multihead_attn import TransformerEncoder
from src.utils import remove_duplicate_strings

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")
# device = torch.device("cpu")

Using device mps


In [18]:
model_name = "pritamdeka/PubMedBERT-mnli-snli-scinli-stsb"

In [19]:
class Retriever(nn.Module):
    """Given a list of evidences and a claim, this returns the top-k evidences"""
    def __init__(self,train_evidences=[], val_evidences=[], test_evidences=[]):
        super().__init__()
        self.bi_encoder = SentenceTransformer(model_name, cache_folder="../cache")
        # self.bi_encoder.eval()
        self.bi_encoder.requires_grad_(False)
        self.k = 10
        # self.data = pd.read_csv()
        self.train_evidence_pool = train_evidences
        self.train_evidence_embeddings = self.bi_encoder.encode(train_evidences, convert_to_tensor=True)
        self.val_evidence_pool = val_evidences
        self.val_evidence_embeddings = self.bi_encoder.encode(val_evidences, convert_to_tensor=True)
        self.test_evidence_pool = test_evidences
        self.test_evidence_embeddings = self.bi_encoder.encode(test_evidences, convert_to_tensor=True)
    def tokenize_and_embed(self, data):
        # data -> [b]
        return self.bi_encoder.encode([data], convert_to_tensor=True)

    def set_encoder_training(self, mode):
        self.bi_encoder.train(mode)

    def forward(self, x, mode="train"):
        # x -> b, claims
        x = self.bi_encoder.encode(x, convert_to_tensor=True)
        # scores -> b, num_evidences, each row is the cosine similarity b/w the claim
        # and all the evidences
        if mode == "train":
            self.evidence_embeddings = self.train_evidence_embeddings
            self.evidence_pool = self.train_evidence_pool
        elif mode == "val":
            self.evidence_embeddings = self.val_evidence_embeddings
            self.evidence_pool = self.val_evidence_pool
        elif mode == "test":
            self.evidence_embeddings = self.test_evidence_embeddings
            self.evidence_pool = self.test_evidence_pool
        cos_sim = torch.mm(x, self.evidence_embeddings.T)
        scores, indices = torch.topk(cos_sim, self.k, dim=1)
        evidences = [[self.evidence_pool[i] for i in row] for row in indices]
        return scores, indices, evidences

In [20]:
class Ranker(nn.Module):
    def __init__(self):
        super().__init__()
        # self.cross_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.cross_encoder_model_name, cache_dir="../cache")
        self.cross_encoder = SentenceTransformer(model_name, cache_folder="../cache")
        self.cross_encoder.train(True)
        # self.cross_encoder.classifier = nn.Identity()  # remove the last classifier layer
        # self.tokenizer = AutoTokenizer.from_pretrained(cfg.cross_encoder_model_name)
        hidden_size = 768
        self.scorer = nn.Linear(hidden_size, 1)

    def forward(self, x, evidence_pool):
        # x -> b, claim
        # evidence_pool -> list of evidence strings
        # create claim embedding pair
        embeddings = []

        for i, claim in enumerate(x):
            evidences = evidence_pool[i]
            claim_pairs = ["CLS"] + [f"[CLS] {claim} [SEP] {evidence} [SEP]" for evidence in evidences]
            encoded = self.cross_encoder.encode(claim_pairs, convert_to_tensor=True)
            embeddings.append(encoded)
        # Convert list of tensors to a single tensor
        embeddings_tensor = torch.stack(embeddings)
        # print(f"{type(embeddings_tensor)}")
        return embeddings_tensor

In [21]:
class Classifier(nn.Module):
    def __init__(self, input_dim, out_classes=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, out_classes)
        )
    def forward(self, x):
        # x = torch.mean(x, dim=1)
        # print(f"classifier after mean {x.shape}")
        x = self.mlp(x)
        # print("classifier", x, x.shape)
        return x, F.softmax(x, dim=-1)

In [22]:
# batch_size = 128
# num_workers = 11

# base_path = "/content/drive/My Drive/cs-5787-final-scripts/"

train_dataset = CSVDataset(file_path=os.path.join(cfg.base_path, "data/csv/train.csv"))
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)

val_dataset = CSVDataset(file_path=os.path.join(cfg.base_path, "data/csv/val.csv"))
val_dataloader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

test_dataset = CSVDataset(file_path=os.path.join(cfg.base_path, "data/csv/test.csv"))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

In [9]:
retriever = Retriever(train_evidences=train_dataset.get_evidences(), val_evidences=val_dataset.get_evidences(), test_evidences=test_dataset.get_evidences())

In [10]:
ranker = Ranker()
attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)

In [11]:
classifier = Classifier(768)

In [23]:
# sanity check
for batch in train_dataloader:
    print(batch)
    x, y = batch
    # rav = RAV()
    # print(rav)
    scores, indices, evidences = retriever(x)
    print(scores.shape, indices.shape, len(evidences))
    # print(f"\n{evidences}\n")
    # print("initial device", x.device)
    h = ranker(x, evidences)
    h = h.to("cpu")
    print("ranker", h.shape, h.device)
    enc_out = attn(h)
    print(enc_out.shape)
    logits, out = classifier(enc_out[:, 0, :])
    print(out, out.shape)
    loss = nn.CrossEntropyLoss()(logits, y)
    print(f"loss: {loss}")
    break

[('Levetiracetam has no known side effects and is completely safe for all patients.', 'The need for a standardized approach may be important, yet some clinicians argue that flexibility in practice can also be beneficial.', 'Lidocaine and Prilocaine are not used in pediatric procedures due to safety concerns.', 'Fluconazole is frequently used in antifungal therapy, but its efficacy can vary depending on the specific type of fungal infection being treated.', "The necessity of addressing both risk assessment and modification can vary depending on the specific surgical procedure and the patient's overall health.", 'Postoperative complications such as infections and delirium are uncommon in the elderly population following hip fracture surgery.', 'Clinical trials show that ACEIs and ARBs do not improve left ventricular function after a myocardial infarction.', 'Traditional methods like bare-metal stents are more effective than ICRT in preventing in-stent restenosis.', 'Although the summary 

In [24]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from sklearn.metrics import precision_score, recall_score, accuracy_score

class MyLightningModule(pl.LightningModule):
    def __init__(self, retriever, ranker, attn, classifier, learning_rate=1e-3):
        """
        Args:
            retriever: Module for retrieving evidence.
            ranker: Module for ranking evidence.
            attn: Attention module.
            classifier: Classification module.
            learning_rate: Learning rate for the optimizer.
        """
        super().__init__()
        self.retriever = retriever
        self.ranker = ranker
        self.attn = attn
        self.classifier = classifier
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

        # freeze the retriever
        for param in self.retriever.parameters():
            param.requires_grad = False

        self.retriever.eval()

        # freeze the ranker
        # for param in self.ranker.parameters():
        #     param.requires_grad = False

        # self.ranker.eval()

        self.save_hyperparameters()

    def forward(self, x):
        # Forward pass to process the input
        scores, indices, evidences = self.retriever(x)
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        out = self.classifier(enc_out)
        return out

    def training_step(self, batch, batch_idx):
        # Single training step
        x, y = batch
        scores, indices, evidences = self.retriever(x, mode="train")
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out[:, 0, :])

        # Compute loss
        loss = self.loss_fn(logits, y)
        print(f"loss: {loss}")
        self.log("train_loss", loss)  # Log the loss for monitoring
        return loss

    def configure_optimizers(self):
        # Optimizer configuration
        # return torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def validation_step(self, batch, batch_idx):
        # Validation step
        x, y = batch
        scores, indices, evidences = self.retriever(x, mode="val")
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out[:, 0, :])

        # Compute loss
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss)  # Log the validation loss

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Convert to CPU for metrics calculation
        y_cpu = y.cpu().numpy()
        preds_cpu = preds.cpu().numpy()

        # Calculate accuracy, precision, and recall
        accuracy = accuracy_score(y_cpu, preds_cpu)
        precision = precision_score(y_cpu, preds_cpu, average='weighted', zero_division=0)
        recall = recall_score(y_cpu, preds_cpu, average='weighted', zero_division=0)

        # Log metrics
        self.log("val_accuracy", accuracy, prog_bar=True)
        self.log("val_precision", precision, prog_bar=True)
        self.log("val_recall", recall, prog_bar=True)

        res = {"val_loss": loss, "val_accuracy": accuracy, "val_precision": precision, "val_recall": recall}
        print(res)
        return loss


In [25]:
# retriever = Retriever(cfg)
# ranker = Ranker(cfg)
# attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)
# classifier = Classifier(768)

In [26]:
# import wandb
# from lightning.pytorch.loggers import WandbLogger
# wandb.init(project="rav")
# wandb_logger = WandbLogger(project="rav")

In [27]:
model = MyLightningModule(retriever, ranker, attn, classifier, learning_rate=cfg.lr)

/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'retriever' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['retriever'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ranker' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ranker'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'attn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['attn'])`.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'classifier' is an instance of `nn.Module` and is alrea

In [28]:
# trainer = pl.Trainer(max_epochs=100, accelerator="gpu", devices=1)
trainer = pl.Trainer(max_epochs=100, accelerator="mps", devices=1)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [29]:
trainer.fit(model, train_dataloader, val_dataloader)


  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | retriever  | Retriever          | 109 M  | eval 
1 | ranker     | Ranker             | 109 M  | train
2 | attn       | TransformerEncoder | 10.3 M | train
3 | classifier | Classifier         | 98.8 K | train
4 | loss_fn    | CrossEntropyLoss   | 0      | train
----------------------------------------------------------
119 M     Trainable params
109 M     Non-trainable params
229 M     Total params
917.263   Total estimated model params size (MB)
58        Modules in train mode
463       Modules in eval mode
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/core/saving.py:363: Skipping 'retriever' parameter because it is not possible to safely dump to YAML.
/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/core/saving.py:363: Skipping 'ranker' parameter because it is not possible to safely dump to YAML.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]{'val_loss': tensor(1.0975, device='mps:0'), 'val_accuracy': 0.34375, 'val_precision': np.float64(0.2373046875), 'val_recall': np.float64(0.34375)}
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:16<00:16,  0.06it/s]{'val_loss': tensor(1.1035, device='mps:0'), 'val_accuracy': 0.3359375, 'val_precision': np.float64(0.23308185034305318), 'val_recall': np.float64(0.3359375)}
                                                                           

/Users/rateria/miniconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   0%|          | 0/230 [00:00<?, ?it/s] loss: 1.092132806777954
Epoch 0:   0%|          | 1/230 [00:12<48:53,  0.08it/s, v_num=5]loss: 1.0531566143035889
Epoch 0:   1%|          | 2/230 [00:24<45:53,  0.08it/s, v_num=5]loss: 1.0986642837524414
Epoch 0:   1%|▏         | 3/230 [00:35<44:53,  0.08it/s, v_num=5]

### Rough Stuff Here

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
loss(torch.tensor([1, 0, 1], dtype=torch.float32), torch.tensor([0, 1, 0], dtype=torch.float32))

tensor(1.8620)

In [9]:
import torch
import torch.nn as nn

# Define CrossEntropyLoss
criterion = nn.CrossEntropyLoss()

# Predictions (logits): Shape (N, C)
logits = torch.tensor([[2.50, 1.0, 0.1]])  # Batch of size 1, 3 classes

# Target labels: Shape (N)
targets = torch.tensor([0])  # Class 0 is the correct label

# Compute loss
loss = criterion(logits, targets)
print(f"Loss: {loss.item()}")

Loss: 0.2729603350162506
