In [None]:
import json
import lightning as L
import numpy as np
import torch
import easyocr
from pathlib import Path
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torchmetrics.functional import f1_score
from lightning.pytorch import callbacks as L_callbacks
from transformers import AutoModel, AutoImageProcessor
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm
from PIL import Image

DATA_DIR = "/kaggle/input/vimmsd-uit2024"
CLASS_NAMES = ["not-sarcasm", "image-sarcasm", "text-sarcasm", "multi-sarcasm"]
IMAGE_PROCESSOR = AutoImageProcessor.from_pretrained(
    "google/vit-base-patch16-224", use_fast=True
)
OCR = easyocr.Reader(["vi", "en"])


class DscTrainDataset(Dataset):
    def __init__(self, data_dir: str = DATA_DIR, class_names: list[str] = CLASS_NAMES):
        super().__init__()

        self.data_dir = Path(data_dir)
        images_dir = self.data_dir.joinpath("training-images", "train-images")
        data_file = self.data_dir.joinpath("vimmsd-train.json")

        label2id = dict([(class_name, id) for id, class_name in enumerate(class_names)])
        self.label2id = label2id

        with open(data_file, "r") as f:
            data = json.load(f)

        mapped_data = []

        for key in tqdm(
            data, desc="Mapping image file name to file path", total=len(data)
        ):
            data_i = data[key]

            image_path = images_dir.joinpath(data_i["image"]).as_posix()
            caption = data_i["caption"]
            label = data_i["label"]

            mapped_data.append(
                {"image": image_path, "caption": caption, "label": label2id[label]}
            )

        self.data = mapped_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        item["image"] = Image.open(item["image"]).convert("RGB")
        return item


class DscPredictDataset(Dataset):
    def __init__(self, data_dir: str = DATA_DIR):
        super().__init__()

        self.data_dir = Path(data_dir)
        images_dir = self.data_dir.joinpath("public-test-images", "dev-images")
        data_file = self.data_dir.joinpath("vimmsd-public-test.json")

        with open(data_file, "r") as f:
            data = json.load(f)

        mapped_data = []

        for key in tqdm(
            data, desc="Mapping image file name to file path", total=len(data)
        ):
            data_i = data[key]

            image_path = images_dir.joinpath(data_i["image"]).as_posix()
            caption = data_i["caption"]

            mapped_data.append({"image": image_path, "caption": caption})

        self.data = mapped_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        item["image"] = Image.open(item["image"]).convert("RGB")
        return item


def collate_fn(batch: list[dict[str, any]], image_processor=IMAGE_PROCESSOR, ocr=OCR):
    images = []
    image_texts = []
    captions = []
    labels = []

    for item in batch:
        image = np.array(item["image"])
        images.append(image)

        image_text = "\n".join(
            list(lambda bounding_box, text, confident: text, ocr.readtext(image))
        )

        image_texts.append(image_text)
        captions.append(item["caption"])
        labels.append(item.get("label"))

    images = image_processor(images, return_tensors="pt")

    return (
        {"images": images, "image_texts": image_texts, "captions": captions},
        labels,
    )


class CombinedSarcasmClassifier(L.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        text_encoder = SentenceTransformer(
            "jinaai/jina-embeddings-v3", trust_remote_code=True
        )
        text_encoder.train()
        self.text_encoder = text_encoder

        image_encoder = AutoModel.from_pretrained("google/vit-base-patch16-224")
        image_encoder.train()
        self.image_encoder = image_encoder

        self.fc = nn.LazyLinear(4)

    def setup(self, stage=None):
        self.text_encoder.to(self.device, self.dtype)
        self.image_encoder.to(self.device, self.dtype)

    def forward(self, images, image_texts, captions):
        # with torch.no_grad():
        images = self.image_encoder(**images).last_hidden_state[:, 0]
        images = images.view(images.shape[0], -1)

        image_texts = self.text_encoder.encode(
            image_texts, convert_to_tensor=True, show_progress_bar=False
        )
        image_texts = image_texts.view(image_texts.shape[0], -1)

        captions = self.text_encoder.encode(
            captions, convert_to_tensor=True, show_progress_bar=False
        )
        captions = captions.view(captions.shape[0], -1)

        embeddings = torch.cat([images, image_texts, captions], dim=1)
        logits = self.fc(embeddings)

        return logits

    def training_step(self, batch, _):
        features, targets = batch
        logits = self.forward(**features)
        targets = torch.tensor(targets, device=self.device)
        loss = F.cross_entropy(logits, targets)
        f1 = f1_score(
            F.softmax(logits, dim=1),
            targets,
            task="multiclass",
            num_classes=len(CLASS_NAMES),
        )
        self.log_dict(
            {"train_loss": loss, "train_f1": f1},
            prog_bar=True,
            batch_size=logits.shape[0],
        )
        return loss

    def validation_step(self, batch, _):
        features, targets = batch
        targets = torch.tensor(targets, device=self.device)
        logits = self.forward(**features)
        loss = F.cross_entropy(logits, targets)
        f1 = f1_score(
            F.softmax(logits, dim=1),
            targets,
            task="multiclass",
            num_classes=len(CLASS_NAMES),
        )
        self.log_dict(
            {"val_loss": loss, "val_f1": f1},
            prog_bar=True,
            batch_size=logits.shape[0],
        )

    def predict_step(self, batch, _):
        features, _ = batch
        logits = self.forward(**features)
        predictions = F.softmax(logits, dim=1).argmax(dim=1)
        return predictions

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), 5e-5)
        return optimizer


train_ds = DscTrainDataset()
train_ds, val_ds = torch.utils.data.random_split(train_ds, [0.85, 0.15])
train_dataloader = DataLoader(
    train_ds, collate_fn=collate_fn, batch_size=128, num_workers=3
)
val_dataloader = DataLoader(
    val_ds, collate_fn=collate_fn, batch_size=128, num_workers=3
)

model = CombinedSarcasmClassifier().to(torch.bfloat16)

callbacks = [L_callbacks.EarlyStopping(monitor="val_loss", mode="min", min_delta=0.001)]
trainer = L.Trainer(max_epochs=-1, callbacks=callbacks)
trainer.fit(model, train_dataloader, val_dataloader)

predict_ds = DscPredictDataset()
predict_dataloader = DataLoader(
    predict_ds, collate_fn=collate_fn, batch_size=128, num_workers=3
)
predictions = trainer.predict(dataloaders=predict_dataloader, ckpt_path="best")

results = {}
for step, batch in (prog_bar := tqdm(enumerate(predictions), desc="Predicting")):
    for index, prediction in enumerate(batch):
        index = str((step * batch.shape[0]) + index)
        prediction = CLASS_NAMES[prediction]
        results.update({index: prediction})
        prog_bar.set_postfix({"label": prediction})
        prog_bar.refresh()
results = {"results": results, "phase": "dev"}

with open("/kaggle/working/results.json", "w") as f:
    json.dump(results, f, indent=2)

: 