In [31]:
import os
import pickle
import argparse
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from transformers import BartTokenizer, BartForConditionalGeneration, ViTModel, ViTFeatureExtractor, logging
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput

# For evaluation metrics
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from bert_score import score as bert_score

import nltk
nltk.download('wordnet')

# random_seed = 42
# torch.manual_seed(random_seed)
# np.random.seed(random_seed)

[nltk_data] Downloading package wordnet to
[nltk_data]     /home/ritika22408/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [32]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# Silence transformers warnings
logging.set_verbosity_error()

In [33]:
class MultimodalSarcasmDataset(Dataset):
    def __init__(self, df_path, desc_pickle, obj_pickle, image_dir, tokenizer, max_length=256, transform=None):
        """
        df_path: path to TSV file with columns [pid, text, explanation, sarcasm_target]
        desc_pickle: pickle file containing image descriptions (dictionary: pid -> description string)
        obj_pickle: pickle file containing detected objects (dictionary: pid -> object string or list)
        image_dir: directory with images named by pid (e.g. pid.jpg or png)
        tokenizer: BART tokenizer
        max_length: maximum token length for input sequence
        transform: torchvision transforms for image preprocessing
        """
        self.df = pd.read_csv(df_path, sep="\t")
        with open(desc_pickle, "rb") as f:
            self.desc_dict = pickle.load(f)
        with open(obj_pickle, "rb") as f:
            self.obj_dict = pickle.load(f)
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        # if no transform provided, define one for ViT (assumes 224x224 images)
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get the sample row
        row = self.df.iloc[idx]
        pid = str(row["pid"])
        text = str(row["text"])
        explanation = str(row["explanation"])
        sarcasm_target = str(row["target_of_sarcasm"])
        # Get image description and detected objects from pickles
        image_desc = self.desc_dict.get(pid, "")
        detected_obj = self.obj_dict.get(pid, "")
        # Ensure all components are strings and concatenate them
        input_text = sarcasm_target + " " + text + " " + str(image_desc) + " " + str(detected_obj)

        # Load image (assuming image file named <pid>.jpg; adjust extension if needed)
        image_path = os.path.join(self.image_dir, pid + ".jpg")
        if not os.path.exists(image_path):
            image_path = os.path.join(self.image_dir, pid + ".png")
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            raise FileNotFoundError(f"Image for pid {pid} not found: {e}")
        image = self.transform(image)

        return {
            "input_text": input_text,
            "target_text": explanation,
            "image": image
        }

In [34]:
# ==============================
# Custom Collate Function
# ==============================
def collate_fn(batch, tokenizer, max_length=256, target_max_length=64):
    input_texts = [item["input_text"] for item in batch]
    target_texts = [item["target_text"] for item in batch]
    images = torch.stack([item["image"] for item in batch])

    # Tokenize inputs and targets
    inputs = tokenizer(input_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
    targets = tokenizer(target_texts, padding=True, truncation=True, max_length=target_max_length, return_tensors="pt")
    # Replace pad tokens in targets with -100 for loss computation
    targets_input_ids = targets.input_ids.masked_fill(targets.input_ids == tokenizer.pad_token_id, -100)

    batch_dict = {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "decoder_input_ids": targets.input_ids,  # for teacher forcing
        "labels": targets_input_ids,
        "pixel_values": images
    }
    return batch_dict

In [35]:
# ==============================
# Model Definition
# ==============================
class MultimodalSarcasmExplanationModel(nn.Module):
    def __init__(self, bart_model_name="facebook/bart-base", vit_model_name="google/vit-base-patch16-224"):
        super(MultimodalSarcasmExplanationModel, self).__init__()
        self.bart = BartForConditionalGeneration.from_pretrained(bart_model_name)
        self.vit = ViTModel.from_pretrained(vit_model_name)
        # The fusion layer projects the ViT output to BART's d_model if needed.
        self.fusion_layer = nn.Linear(self.vit.config.hidden_size, self.bart.config.d_model)

    def forward(self, input_ids, attention_mask, decoder_input_ids, pixel_values, labels=None):
        # Encode the text input with BART encoder
        encoder_outputs = self.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        text_features = encoder_outputs.last_hidden_state  # shape: (batch, seq_len, d_model)

        # Extract image features using ViT
        vit_outputs = self.vit(pixel_values=pixel_values)
        image_feature = vit_outputs.last_hidden_state[:, 0, :]  # global image representation
        image_feature_proj = self.fusion_layer(image_feature)   # project to d_model

        # Shared Fusion: add the projected image feature to each token embedding of text features.
        fused_features = text_features + image_feature_proj.unsqueeze(1)

        # Prepare custom encoder output to pass to the decoder
        fused_encoder_outputs = BaseModelOutput(last_hidden_state=fused_features)

        # Decode using BART decoder with the fused encoder outputs.
        outputs = self.bart.model.decoder(
            input_ids=decoder_input_ids,
            encoder_hidden_states=fused_encoder_outputs.last_hidden_state,
            encoder_attention_mask=attention_mask,
            return_dict=True
        )
        sequence_output = outputs.last_hidden_state
        # Compute logits using the shared language modeling head of BART
        lm_logits = self.bart.lm_head(sequence_output) + self.bart.final_logits_bias

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, self.bart.config.vocab_size), labels.view(-1))

        return Seq2SeqLMOutput(loss=loss, logits=lm_logits, encoder_last_hidden_state=fused_encoder_outputs.last_hidden_state)


In [36]:
# ==============================
# Training Function
# ==============================
def train(model, dataloader, optimizer, tokenizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        decoder_input_ids = batch["decoder_input_ids"].to(device)
        labels = batch["labels"].to(device)
        pixel_values = batch["pixel_values"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                        decoder_input_ids=decoder_input_ids, pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [37]:
# ==============================
# Validation & Evaluation Functions
# ==============================
def generate_explanations(model, dataloader, tokenizer, device, max_length=64):
    model.eval()
    generated_texts = []
    ground_truths = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            pixel_values = batch["pixel_values"].to(device)
            # Encode text and fuse image features
            encoder_outputs = model.bart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
            text_features = encoder_outputs.last_hidden_state

            vit_outputs = model.vit(pixel_values=pixel_values)
            image_feature = vit_outputs.last_hidden_state[:, 0, :]
            image_feature_proj = model.fusion_layer(image_feature)

            fused_features = text_features + image_feature_proj.unsqueeze(1)
            fused_encoder_outputs = BaseModelOutput(last_hidden_state=fused_features)

            generated_ids = model.bart.generate(
                encoder_outputs=fused_encoder_outputs,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=4,                # using beam search with a beam width of 4
                repetition_penalty=2.5,     # increases penalty for repeated tokens
                no_repeat_ngram_size=3,     # prevents any 3-gram from repeating
                do_sample=True,             # enables sampling (stochastic generation)
                top_k=50,                   # limits the next-token choices to the top 50 tokens
                top_p=0.9,                  # alternatively, you could use nucleus sampling
                temperature=1.25,            # a temperature above 1 can add diversity
                early_stopping=True
            )

            decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            generated_texts.extend(decoded_preds)

            # Replace -100 with the pad token id before decoding ground truth labels
            labels = batch["labels"].to(device).clone()
            labels[labels == -100] = tokenizer.pad_token_id
            targets = tokenizer.batch_decode(labels, skip_special_tokens=True)
            ground_truths.extend(targets)
    return generated_texts, ground_truths

In [38]:
def compute_metrics(preds, targets):
    rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge1_scores, rouge2_scores, rougeL_scores = [], [], []
    bleu_scores, meteor_scores = [], []
    for pred, target in zip(preds, targets):
        scores = rouge.score(target, pred)
        rouge1_scores.append(scores['rouge1'].fmeasure)
        rouge2_scores.append(scores['rouge2'].fmeasure)
        rougeL_scores.append(scores['rougeL'].fmeasure)

        smoothing = SmoothingFunction().method1
        bleu = sentence_bleu([target.split()], pred.split(), smoothing_function=smoothing)
        bleu_scores.append(bleu)

        # Split both prediction and target into tokens for meteor_score
        meteor = meteor_score([target.split()], pred.split())
        meteor_scores.append(meteor)

    P, R, F1 = bert_score(preds, targets, lang="en", verbose=False)

    metrics = {
        "ROUGE-1": np.mean(rouge1_scores),
        "ROUGE-2": np.mean(rouge2_scores),
        "ROUGE-L": np.mean(rougeL_scores),
        "BLEU": np.mean(bleu_scores),
        "METEOR": np.mean(meteor_scores),
        "BERTScore_F1": F1.mean().item()
    }
    return metrics

In [39]:
# ==============================
# Inference Function
# ==============================
def inference(model, test_df_path, desc_pickle, obj_pickle, image_dir, tokenizer, device, output_file="sarcasm_explanations.txt"):
    test_dataset = MultimodalSarcasmDataset(test_df_path, desc_pickle, obj_pickle, image_dir, tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,
                             collate_fn=lambda batch: collate_fn(batch, tokenizer))
    generated_texts, _ = generate_explanations(model, test_loader, tokenizer, device)
    with open(output_file, "w", encoding="utf-8") as f:
        for line in generated_texts:
            f.write(line + "\n")
    print(f"Generated sarcasm explanations saved to {output_file}")

In [40]:
# ==============================
# Main Function
# ==============================
def main(args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    # Create training and validation datasets
    train_dataset = MultimodalSarcasmDataset(
        df_path=os.path.join(args.data_dir, "train_df.tsv"),
        desc_pickle=os.path.join(args.data_dir, "D_train.pkl"),
        obj_pickle=os.path.join(args.data_dir, "O_train.pkl"),
        image_dir=os.path.join(args.data_dir, "images"),
        tokenizer=tokenizer
    )
    val_dataset = MultimodalSarcasmDataset(
        df_path=os.path.join(args.data_dir, "val_df.tsv"),
        desc_pickle=os.path.join(args.data_dir, "D_val.pkl"),
        obj_pickle=os.path.join(args.data_dir, "O_val.pkl"),
        image_dir=os.path.join(args.data_dir, "images"),
        tokenizer=tokenizer
    )
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                              collate_fn=lambda batch: collate_fn(batch, tokenizer))
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            collate_fn=lambda batch: collate_fn(batch, tokenizer))

    model = MultimodalSarcasmExplanationModel()
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)

    best_val_loss = float("inf")
    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        train_loss = train(model, train_loader, optimizer, tokenizer, device)
        print(f"Training Loss: {train_loss:.4f}")

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                decoder_input_ids = batch["decoder_input_ids"].to(device)
                labels = batch["labels"].to(device)
                pixel_values = batch["pixel_values"].to(device)
                outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                                decoder_input_ids=decoder_input_ids, pixel_values=pixel_values, labels=labels)
                total_val_loss += outputs.loss.item()
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Validation Loss: {avg_val_loss:.4f}")

        # Generate sample explanations for monitoring
        gen_texts, gt_texts = generate_explanations(model, val_loader, tokenizer, device)
        print("\nSample Generated Explanations (Validation):")
        for i in range(min(3, len(gen_texts))):
            print(f"GT: {gt_texts[i]}")
            print(f"Pred: {gen_texts[i]}")
            print("-" * 40)

        metrics = compute_metrics(gen_texts, gt_texts)
        print("\nEvaluation Metrics on Validation Set:")
        for k, v in metrics.items():
            print(f"{k}: {v:.4f}")

        # checkpoint_path = os.path.join(args.checkpoint_dir, f"model_epoch_{epoch+1}.pt")
        # os.makedirs(args.checkpoint_dir, exist_ok=True)
        # torch.save(model.state_dict(), checkpoint_path)
        # print(f"Model checkpoint saved to {checkpoint_path}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            # best_model_path = os.path.join(args.checkpoint_dir, "best_model.pt")
            # torch.save(model.state_dict(), best_model_path)
            # print("Best model updated.")

    if args.mode == "test":
        test_df_path = os.path.join(args.data_dir, "val_df.tsv")
        inference(model, test_df_path,
                  desc_pickle=os.path.join(args.data_dir, "D_val.pkl"),
                  obj_pickle=os.path.join(args.data_dir, "O_val.pkl"),
                  image_dir=os.path.join(args.data_dir, "images"),
                  tokenizer=tokenizer,
                  device=device,
                  output_file=args.output_file)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Multimodal Sarcasm Explanation (MuSE) Task")
    parser.add_argument("--data_dir", type=str, default="MORE-PLUS-DATASET", help="Path to the dataset folder")
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save model checkpoints")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--mode", type=str, choices=["train", "test"], default="train", help="Mode: train or test")
    parser.add_argument("--output_file", type=str, default="sarcasm_explanations.txt", help="Output file for test predictions")

    args, unknown = parser.parse_known_args()
    main(args)

Using device: cuda:0



Epoch 1/5
Training Loss: 2.8785
Validation Loss: 0.1758

Sample Generated Explanations (Validation):
GT: the author is pissed at <user> for not getting network in malad.
Pred: this this this this these these thesethesethesethese These These ThesethesetheseThesethesethesethethethethemthemthemmmm m m m M M MM M M & & &&&&^^^ ^^^### # # ### "# "# "# # #
----------------------------------------
GT: nothing worst than waiting for an hour on the tarmac for a gate to come open in snowy, windy chicago.
Pred: aaa a a a an an an An An AnAn An AnanananANAN ANANANANANIANIANIanianianiatiatiatiifiifiifiFiFiFifififiFiFi FiFiFiwiFiFiWiFiFiFIFiFi fiFiFiBiFi
----------------------------------------
GT: nobody likes getting one hour of their life sucked away.
Pred: springspringspring spring spring spring springs springs springs sprung sprung sprung sprang sprang sprang popped popped popped pop pop pop Pop Pop PopPop Pop Poppop Pop Pop POP POP POPOPOPOPopopop op op op Opopopopsopopppopoppopopumpopoppoopo