In [2]:
# prompt: mount drive

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!pip install omegaconf
!pip install lightning



In [4]:
from omegaconf import OmegaConf
cfg = OmegaConf.load("/content/drive/My Drive/cs-5787-final-scripts/config/config.yaml")

In [5]:
import os
import sys


sys.path.append("/content/drive/My Drive/cs-5787-final-scripts/")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'


In [6]:
# import hydra
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
# from omegaconf import DictConfig
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from tqdm import tqdm
from pytorch_lightning.callbacks import ModelCheckpoint

from src.dataset import CSVDataset
from src.multihead_attn import TransformerEncoder
from src.utils import remove_duplicate_strings

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")
# device = torch.device("cpu")

Using device cuda


In [7]:
model_name = "pritamdeka/PubMedBERT-mnli-snli-scinli-stsb"

In [8]:
class Retriever(nn.Module):
    """Given a list of evidences and a claim, this returns the top-k evidences"""
    def __init__(self, evidences=[]):
        super().__init__()
        self.bi_encoder = SentenceTransformer(model_name, cache_folder="../cache")
        # self.bi_encoder.eval()
        self.bi_encoder.requires_grad_(False)
        self.k = 10
        # self.data = pd.read_csv()
        self.evidence_pool = evidences
        self.evidence_embeddings = self.bi_encoder.encode(evidences, convert_to_tensor=True)

    def tokenize_and_embed(self, data):
        # data -> [b]
        return self.bi_encoder.encode([data], convert_to_tensor=True)

    def set_encoder_training(self, mode):
        self.bi_encoder.train(mode)

    def forward(self, x):
        # x -> b, claims
        x = self.bi_encoder.encode(x, convert_to_tensor=True)
        # scores -> b, num_evidences, each row is the cosine similarity b/w the claim
        # and all the evidences

        cos_sim = torch.mm(x, self.evidence_embeddings.T)
        scores, indices = torch.topk(cos_sim, self.k, dim=1)
        evidences = [[self.evidence_pool[i] for i in row] for row in indices]
        return scores, indices, evidences

In [9]:
class Ranker(nn.Module):
    def __init__(self):
        super().__init__()
        # self.cross_encoder = AutoModelForSequenceClassification.from_pretrained(cfg.cross_encoder_model_name, cache_dir="../cache")
        self.cross_encoder = SentenceTransformer(model_name, cache_folder="../cache")
        self.cross_encoder.train(True)
        # self.cross_encoder.classifier = nn.Identity()  # remove the last classifier layer
        # self.tokenizer = AutoTokenizer.from_pretrained(cfg.cross_encoder_model_name)
        hidden_size = 768
        self.scorer = nn.Linear(hidden_size, 1)

    def forward(self, x, evidence_pool):
        # x -> b, claim
        # evidence_pool -> list of evidence strings
        # create claim embedding pair
        embeddings = []

        for i, claim in enumerate(x):
            evidences = evidence_pool[i]
            claim_pairs = ["CLS"] + [f"[CLS] {claim} [SEP] {evidence} [SEP]" for evidence in evidences]
            encoded = self.cross_encoder.encode(claim_pairs, convert_to_tensor=True)
            embeddings.append(encoded)
        # Convert list of tensors to a single tensor
        embeddings_tensor = torch.stack(embeddings)
        # print(f"{type(embeddings_tensor)}")
        return embeddings_tensor

In [10]:
class Classifier(nn.Module):
    def __init__(self, input_dim, out_classes=3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, out_classes)
        )
    def forward(self, x):
        # x = torch.mean(x, dim=1)
        # print(f"classifier after mean {x.shape}")
        x = self.mlp(x)
        # print("classifier", x, x.shape)
        return x, F.softmax(x, dim=-1)

In [11]:
# batch_size = 128
# num_workers = 11

base_path = "/content/drive/My Drive/cs-5787-final-scripts/"

base_file = os.path.join(base_path, "csv/generated_claim_triplets_with_topics.csv")
df = pd.read_csv(base_file)

df = df[df['Evidence'] != "Agent stopped due to iteration limit or time limit."]
evidences = remove_duplicate_strings(df['Evidence'].to_list())

train_dataset = CSVDataset(file_path=os.path.join(base_path, "csv/train.csv"))
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)

val_dataset = CSVDataset(file_path=os.path.join(base_path, "csv/val.csv"))
val_dataloader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

test_dataset = CSVDataset(file_path=os.path.join(base_path, "csv/test.csv"))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

In [12]:
retriever = Retriever(evidences=evidences)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [13]:
ranker = Ranker()
attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)

In [14]:
classifier = Classifier(768)

In [15]:
# run a forward pass to sanity check the outputs, shapes
for batch in train_dataloader:
    print(batch)
    x, y = batch
    # rav = RAV()
    # print(rav)
    # x = claim, return 10 evidences pertaining to this claim
    scores, indices, evidences = retriever(x)
    print(scores.shape, indices.shape, len(evidences))
    # print(f"\n{evidences}\n")
    # print("initial device", x.device)

    # takes claims and evidences -> concats embeddings for claim to each of the 10 evidences for this claim
    h = ranker(x, evidences)
    h=h.to("cpu")
    print("ranker", h.shape, h.device)
    enc_out = attn(h)
    print(enc_out.shape)
    logits, out = classifier(enc_out[:, 0, :])
    print(out, out.shape)
    loss = nn.CrossEntropyLoss()(logits, y)
    print(f"loss: {loss}")
    break

In [16]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from sklearn.metrics import precision_score, recall_score, accuracy_score

class MyLightningModule(pl.LightningModule):
    def __init__(self, retriever, ranker, attn, classifier, learning_rate=1e-3):
        """
        Args:
            retriever: Module for retrieving evidence.
            ranker: Module for ranking evidence.
            attn: Attention module.
            classifier: Classification module.
            learning_rate: Learning rate for the optimizer.
        """
        super().__init__()
        self.retriever = retriever
        self.ranker = ranker
        self.attn = attn
        self.classifier = classifier
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

        # freeze the retriever
        for param in self.retriever.parameters():
            param.requires_grad = False

        self.retriever.eval()

        # freeze the ranker
        # for param in self.ranker.parameters():
        #     param.requires_grad = False

        # self.ranker.eval()

        self.save_hyperparameters(ignore='retriever')

    def forward(self, x):
        # Forward pass to process the input
        scores, indices, evidences = self.retriever(x)
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out)
        return out, evidences

    def training_step(self, batch, batch_idx):
        # Single training step
        x, y = batch
        scores, indices, evidences = self.retriever(x)
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out[:, 0, :])

        # Compute loss
        loss = self.loss_fn(logits, y)
        self.log("train_loss", loss)  # Log the loss for monitoring
        return loss

    def configure_optimizers(self):
        # Optimizer configuration
        # return torch.optim.SGD(self.parameters(), lr=self.learning_rate)
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def validation_step(self, batch, batch_idx):
        # Validation step
        x, y = batch
        scores, indices, evidences = self.retriever(x)
        h = self.ranker(x, evidences)
        enc_out = self.attn(h)
        logits, out = self.classifier(enc_out[:, 0, :])

        # Compute loss
        loss = self.loss_fn(logits, y)
        self.log("val_loss", loss)  # Log the validation loss

        # Compute predictions
        preds = torch.argmax(logits, dim=1)

        # Convert to CPU for metrics calculation
        y_cpu = y.cpu().numpy()
        preds_cpu = preds.cpu().numpy()

        # Calculate accuracy, precision, and recall
        accuracy = accuracy_score(y_cpu, preds_cpu)
        precision = precision_score(y_cpu, preds_cpu, average='weighted', zero_division=0)
        recall = recall_score(y_cpu, preds_cpu, average='weighted', zero_division=0)

        # Log metrics
        # self.log("val_accuracy", accuracy, prog_bar=True)
        # self.log("val_precision", precision, prog_bar=True)
        # self.log("val_recall", recall, prog_bar=True)
        # self.log("val_loss", loss, prog_bar=True)
        self.log_dict({"val_accuracy": accuracy, "val_precision": precision, "val_recall": recall, "val_loss": loss})

        res = {"val_loss": loss, "val_accuracy": accuracy, "val_precision": precision, "val_recall": recall}
        # print(res)
        return loss


In [17]:
retriever = Retriever(cfg)
ranker = Ranker(cfg)
attn = TransformerEncoder(num_layers=4, input_dim=768, num_heads=1, dim_feedforward=128)
classifier = Classifier(768)

In [18]:
import wandb
from lightning.pytorch.loggers import WandbLogger
wandb.init(project="rav")
wandb_logger = WandbLogger(project="rav")

In [19]:
model = MyLightningModule(retriever, ranker, attn, classifier, learning_rate=cfg.lr)

In [20]:
# Define the ModelCheckpoint callback to save a checkpoint after every epoch
checkpoint_callback = ModelCheckpoint(
    dirpath=base_path,         # Directory to save checkpoints
    filename="epoch-{epoch}-ddg",      # Filename format (e.g., "epoch-1.ckpt")
    every_n_epochs=1,              # Save after every epoch
)

In [21]:
trainer = pl.Trainer(max_epochs=10, accelerator="gpu", devices=1, default_root_dir=base_path, log_every_n_steps=1, logger=wandb_logger, callbacks=[checkpoint_callback])

In [22]:
trainer.fit(model, train_dataloader, val_dataloader)

## Testing & Inference

In [23]:
model = MyLightningModule.load_from_checkpoint("/content/drive/My Drive/cs-5787-final-scripts/epoch-epoch=8-ddg.ckpt", retriever=retriever)

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'ranker' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['ranker'])`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'attn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['attn'])`.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'classifier' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['classifier'])`.


In [24]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, precision_score, recall_score, accuracy_score

def custom_testing_loop(model, test_loader, device):
    """
    Custom testing loop for evaluating the model and printing a classification report.

    Args:
        model (MyLightningModule): The trained Lightning module.
        test_loader (DataLoader): DataLoader for the test dataset.
        device (torch.device): The device to run the test on.

    Returns:
        dict: A dictionary containing test metrics (precision, recall, accuracy, and loss).
    """
    model.eval()
    model.to(device)

    all_preds = []
    all_targets = []
    total_loss = 0.0

    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            # x, y = x.to(device), y.to(device)

            # Forward pass
            scores, indices, evidences = model.retriever(x)
            h = model.ranker(x, evidences)
            enc_out = model.attn(h)
            logits, out = model.classifier(enc_out[:, 0, :])

            # Compute loss
            loss = model.loss_fn(logits.to("cpu"), y.to("cpu"))
            total_loss += loss.item()

            # Compute predictions
            preds = torch.argmax(logits, dim=1)

            # Store predictions and targets for metric calculation
            all_preds.append(preds.cpu())
            all_targets.append(y.cpu())

    # Concatenate all predictions and targets
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()

    # Compute metrics
    accuracy = accuracy_score(all_targets, all_preds)
    precision = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
    avg_loss = total_loss / len(test_loader)

    # Generate and print classification report
    target_names = [f"Class {i}" for i in range(len(set(all_targets)))]
    class_report = classification_report(all_targets, all_preds, target_names=target_names, zero_division=0)
    print("Classification Report:")
    print(class_report)

    # Return results
    metrics = {
        "test_loss": avg_loss,
        "test_accuracy": accuracy,
        "test_precision": precision,
        "test_recall": recall,
    }

    return metrics

In [63]:
custom_testing_loop(model, test_dataloader, device)

Classification Report:
              precision    recall  f1-score   support

     Class 0       0.95      0.92      0.94      1781
     Class 1       0.94      0.95      0.94      1781
     Class 2       0.97      0.99      0.98      1781

    accuracy                           0.95      5343
   macro avg       0.95      0.95      0.95      5343
weighted avg       0.95      0.95      0.95      5343



{'test_loss': 0.14173173256928012,
 'test_accuracy': 0.9522740033688939,
 'test_precision': 0.9521615178609821,
 'test_recall': 0.9522740033688939}

In [25]:
custom_testing_loop(model, val_dataloader, device)

Classification Report:
              precision    recall  f1-score   support

     Class 0       0.95      0.94      0.95      1603
     Class 1       0.95      0.96      0.95      1603
     Class 2       0.97      0.98      0.98      1603

    accuracy                           0.96      4809
   macro avg       0.96      0.96      0.96      4809
weighted avg       0.96      0.96      0.96      4809



{'test_loss': 0.14110470602386876,
 'test_accuracy': 0.9596589727594095,
 'test_precision': 0.9595803762302402,
 'test_recall': 0.9596589727594095}

In [None]:
import torch
import torch.nn as nn

def get_attention_map(model, claim, retriever, ranker, attn, classifier):
    """
    Function to retrieve the attention map for a given input claim.

    Args:
        model: The custom loaded model containing all components.
        claim (str): The input claim for which attention map is generated.
        retriever: The retriever module to fetch relevant evidences.
        ranker: The ranker module to rank and encode evidence.
        attn: The attention encoder module to process ranked evidence.
        classifier: The classifier module for final prediction.

    Returns:
        attention_map: The attention map showing attention paid to evidence.
        logits: The classifier logits.
        evidences: The list of evidence snippets retrieved for the input claim.
    """
    # Convert the input claim to tensor if necessary
    if not isinstance(claim, list):
        claim = [claim]  # Ensure claim is in list format

    # Retrieve top-k evidences using the retriever
    scores, indices, evidences = retriever(claim)
    print(f"Retrieved Evidences: {evidences}")

    # Generate embeddings using the ranker
    h = ranker(claim, evidences)
    print(f"Ranker Output Shape: {h.shape}, Device: {h.device}")

    # Encode embeddings using the attention encoder
    enc_out = attn(h)
    attention_map = attn.get_attention_maps(h)
    print(f"Attention Encoder Output Shape: {enc_out.shape}")

    # Perform classification
    logits, out = classifier(enc_out[:, 0, :])
    print(f"Classifier Output: {out}, Shape: {out.shape}")
    pred = int(np.argmax(logits.detach().cpu()))
    # Return the attention map, logits, and evidences
    return attention_map, logits, pred, evidences


In [None]:
# claim for evidence retrieved but not a high score given - "Highly active antiretroviral therapy (HAART) has changed HIV from a fatal disease to a manageable chronic condition for adolescents."
# supporting claim -
# contradicting claim -
# ambiguous claim -

### Give a claim to the model

In [None]:
# Example usage:
# Assume retriever, ranker, attn, and classifier are initialized
claim = "Highly active antiretroviral therapy (HAART) has changed HIV from a fatal disease to a manageable chronic condition for adolescents."
attention_map, logits, pred, evidences = get_attention_map(model, claim, model.retriever, model.ranker, model.attn, model.classifier)

Retrieved Evidences: [['HIV infection, primarily caused by HIV-1 and HIV-2, leads to the progressive destruction of CD4+ T lymphocytes, resulting in significant immunosuppression and increased susceptibility to opportunistic infections and malignancies. The epidemiology of HIV has changed dramatically due to the global response, with antiretroviral therapy (ART) now standard for all diagnosed individuals, transforming HIV from a fatal disease into a manageable chronic condition. As of recent estimates, approximately 37.9 million people live with HIV, with many receiving ART, which has reduced mortality but increased the prevalence of non-communicable diseases, including HIV-associated lipodystrophy. This condition, characterized by abnormal fat distribution, can manifest as lipoatrophy or lipohypertrophy and is often linked to both the virus and the side effects of certain antiretroviral medications. Clinically, it can lead to significant psychosocial distress and complicate HIV manage

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def draw_attention_map(attention_matrix, labels=None, title="Attention Correlation Map"):
    """
    Draws an attention correlation map.

    Args:
        attention_matrix (np.ndarray): The attention matrix of shape (n, n).
        labels (list): Optional, list of labels for the rows and columns.
        title (str): Title of the plot.
    """
    if attention_matrix.shape[0] != attention_matrix.shape[1]:
        raise ValueError("Attention matrix must be square (n x n).")

    n = attention_matrix.shape[0]

    # Create the plot
    plt.figure(figsize=(8, 6))
    plt.imshow(attention_matrix, cmap="viridis", interpolation="nearest")

    # Add colorbar
    plt.colorbar(label="Attention Weight")

    # Add labels to rows and columns
    if labels is not None:
        if len(labels) != n:
            raise ValueError("Labels must have the same length as the dimensions of the matrix.")
        plt.xticks(ticks=np.arange(n), labels=labels, rotation=90)
        plt.yticks(ticks=np.arange(n), labels=labels)
    else:
        plt.xticks(ticks=np.arange(n))
        plt.yticks(ticks=np.arange(n))

    # Add title and adjust layout
    plt.title(title)
    plt.xlabel("Query/Key Positions")
    plt.ylabel("Query/Key Positions")
    plt.tight_layout()
    plt.show()

In [None]:
pred, logits, evidences

(0,
 tensor([[ 3.4453,  2.4228, -4.1190]], device='cuda:0',
        grad_fn=<AddmmBackward0>),
 [['HIV infection, primarily caused by HIV-1 and HIV-2, leads to the progressive destruction of CD4+ T lymphocytes, resulting in significant immunosuppression and increased susceptibility to opportunistic infections and malignancies. The epidemiology of HIV has changed dramatically due to the global response, with antiretroviral therapy (ART) now standard for all diagnosed individuals, transforming HIV from a fatal disease into a manageable chronic condition. As of recent estimates, approximately 37.9 million people live with HIV, with many receiving ART, which has reduced mortality but increased the prevalence of non-communicable diseases, including HIV-associated lipodystrophy. This condition, characterized by abnormal fat distribution, can manifest as lipoatrophy or lipohypertrophy and is often linked to both the virus and the side effects of certain antiretroviral medications. Clinically,

### Sort evidences by attention scores

In [None]:
def sort_evidences_by_scores(scores, evidences):
    """
    Sorts a list of evidences based on their corresponding scores in descending order.

    Args:
        scores (list of float): A list of numeric scores.
        evidences (list of str): A list of evidences corresponding to the scores.

    Returns:
        list of tuple: A sorted list of tuples where each tuple contains (evidence, score),
                       sorted by scores in descending order.
    """
    if len(scores) != len(evidences):
        raise ValueError("The number of scores must match the number of evidences.")

    # Combine scores and evidences into tuples and sort by scores in descending order
    sorted_evidences = sorted(zip(scores, evidences), key=lambda x: x[0], reverse=True)

    return [(evidence, score) for score, evidence in sorted_evidences]

In [None]:
sort_evidences_by_scores(attention_map[0][0][0].detach().cpu().numpy()[1:], evidences[0])

[('HIV infection, primarily caused by HIV-1 and HIV-2, leads to the progressive destruction of CD4+ T lymphocytes, resulting in significant immunosuppression and increased susceptibility to opportunistic infections and malignancies. The epidemiology of HIV has changed dramatically due to the global response, with antiretroviral therapy (ART) now standard for all diagnosed individuals, transforming HIV from a fatal disease into a manageable chronic condition. As of recent estimates, approximately 37.9 million people live with HIV, with many receiving ART, which has reduced mortality but increased the prevalence of non-communicable diseases, including HIV-associated lipodystrophy. This condition, characterized by abnormal fat distribution, can manifest as lipoatrophy or lipohypertrophy and is often linked to both the virus and the side effects of certain antiretroviral medications. Clinically, it can lead to significant psychosocial distress and complicate HIV management due to its assoc

### Rough Work Here

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
loss(torch.tensor([1, 0, 1], dtype=torch.float32), torch.tensor([0, 1, 0], dtype=torch.float32))

tensor(1.8620)

In [None]:
import torch
import torch.nn as nn

# Define CrossEntropyLoss
criterion = nn.CrossEntropyLoss()

# Predictions (logits): Shape (N, C)
logits = torch.tensor([[2.50, 1.0, 0.1]])  # Batch of size 1, 3 classes

# Target labels: Shape (N)
targets = torch.tensor([0])  # Class 0 is the correct label

# Compute loss
loss = criterion(logits, targets)
print(f"Loss: {loss.item()}")

Loss: 0.2729603350162506
