In [2]:
import pandas as pd
import numpy as np
import pickle

In [3]:
# text data
df = pd.read_csv("train_df.tsv", sep="\t")

# img descriptions
with open("D_train.pkl", "rb") as f:
    image_descriptions = pickle.load(f)
print("Image Descriptions (First 5 Records):")
for key, value in list(image_descriptions.items())[:5]:
    print(f"{key}: {value}")

# detected objects
with open("O_train.pkl", "rb") as f:
    detected_objects = pickle.load(f)
print("\nDetected Objects (First 5 Records):")
for key, value in list(detected_objects.items())[:5]:
    print(f"{key}: {value}")

Image Descriptions (First 5 Records):
931874353976938497: people sitting on the floor in a large room with a wall
880425829246922752: two twee screens of donald trump and donald trump
690915881082343424: there are two shovels that are standing in the snow
915228456757059585: arafed view of a passenger plane with a flat screen tv
494194068998468686_25639236: cars are driving down the highway on a cloudy day

Detected Objects (First 5 Records):
931874353976938497: {'classes': ['person', 'backpack', 'handbag', 'backpack', 'backpack', 'cell phone', 'person', 'person', 'person', 'cup', 'chair', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'person', 'chair', 'person', 'person', 'person', 'backpack', 'backpack', 'person', 'person', 'person', 'backpack', 'person', 'person', 'person', 'person', 'person', 'person'], 'confidence_scores': [0.0945774, 0.0975185, 0.111666, 0.117207, 0.118647, 0.120145, 0.121325, 0.217946, 0.224531, 0.226831, 0.280922, 0.29603, 0.33

In [4]:
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import pandas as pd
import pickle
import os
from transformers import BartTokenizer


class MuSEDataset(Dataset):
    def __init__(self, text_file, image_desc_file, obj_file, image_folder, tokenizer, transform=None):
        ''' Initialize the MuSEDataset class. '''
        self.text_data = pd.read_csv(text_file, sep="\t")
        
        with open(image_desc_file, "rb") as f:
            self.image_descriptions = pickle.load(f)
        
        with open(obj_file, "rb") as f:
            self.detected_objects = pickle.load(f)
        
        self.image_folder = image_folder
        self.tokenizer = tokenizer
        
        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization - mean and std dev for RGB channels
        ])
    
    def __len__(self):
        ''' Return the length of the dataset. '''
        return len(self.text_data)
    
    def __getitem__(self, idx):
        ''' Retrieve the data for a given index. '''

        # Get data
        row = self.text_data.iloc[idx]
        text = row["text"]  
        image_name = str(row["pid"])  # Convert to string to match dictionary keys
        sarcasm_target = str(row["target_of_sarcasm"]) if pd.notna(row["target_of_sarcasm"]) else ""

        # Load and preprocess image
        image_path = os.path.join(self.image_folder, image_name)
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.transform(image)
        except FileNotFoundError:
            print(f"Warning: Image {image_path} not found, using blank image.")
            image = torch.zeros((3, 224, 224))

        # Get image description and detected objects (handle missing keys)
        img_desc = self.image_descriptions.get(image_name, "No description available")
        detected_objs = self.detected_objects.get(image_name, "No objects detected")

        # Tokenize text
        text_inputs = self.tokenizer(text, padding='max_length', truncation=True, return_tensors="pt")
        sarcasm_target_inputs = self.tokenizer(sarcasm_target, padding='max_length', truncation=True, return_tensors="pt")

        return {
            "text_input_ids": text_inputs["input_ids"].squeeze(0),
            "text_attention_mask": text_inputs["attention_mask"].squeeze(0),
            "image": image,
            "image_description": img_desc,
            "detected_objects": detected_objs,
            "sarcasm_target_input_ids": sarcasm_target_inputs["input_ids"].squeeze(0),
            "sarcasm_target_attention_mask": sarcasm_target_inputs["attention_mask"].squeeze(0),
        }


# Load the BART tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

# Create dataset instance
dataset = MuSEDataset("train_df.tsv", "D_train.pkl", "O_train.pkl", "images", tokenizer)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# from transformers import ViTModel, ViTFeatureExtractor

# # Create ViT class that will get image embeddings
# feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-base-patch16-224")
# vit_model = ViTModel.from_pretrained("facebook/vit-base-patch16-224")


In [6]:
from torch.utils.data import DataLoader

batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)


In [None]:
import torch.nn as nn
from torchvision.models import vit_b_16
import torch

# Load pretrained ViT model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit = vit_b_16(weights="IMAGENET1K_V1")  # Load pretrained weights
vit.heads = nn.Identity()  # Remove classification head
vit = vit.to(device)  # Move model to GPU

False


In [None]:
# Example: Process one batch
for batch in dataloader:
    batch["image"] = batch["image"].to(device)  # Move images to GPU
    image_embeddings = vit(batch["image"])  # Extract features
    print(image_embeddings.shape)  # Expected output: (batch_size, 768)
    break  # Stop after one batch