In [None]:
import pytorch_lightning as L
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

import re
import os
import cv2
import copy
import math
import warnings
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import random
from torchmetrics.text import ROUGEScore
import spacy
spacy = spacy.load("en_core_web_sm")
import collections
from torch.nn.utils.rnn import pad_sequence

from torchvision.models import swin_t, Swin_T_Weights
try:
    from torchvision.transforms.v2 import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

# from swin_transformer import swin_t,swin_l,swin_b,swin_s
from torchvision.datasets.utils import download_and_extract_archive
from llama3_transformer_block import *
torch.set_float32_matmul_precision('high')
warnings.filterwarnings("ignore")

%matplotlib inline
plt.rcParams['axes.facecolor'] = 'lightgray'
plt.rcParams['mathtext.fontset'] = 'cm'
plt.rcParams['font.family'] = 'STIXGeneral'

In [31]:
x = swin_t()

In [32]:
x

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=96, out_features=288, bias=True)
          (proj): Linear(in_features=96, out_features=96, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=96, out_features=384, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=384, out_features=96, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
       

In [33]:
os.makedirs("experiment", exist_ok=True)
os.makedirs("experiment/training", exist_ok=True)
os.makedirs("experiment/dataset", exist_ok=True)
os.makedirs("experiment/model", exist_ok=True)
EXPERIMENT_DIR = "experiment/"

In [34]:
ANNOTATION_PATH = "experiment/dataset/Flickr8k.token_copy.txt"
IMAGE_PATH = "experiment/dataset/Flicker8k_Dataset"

In [35]:
METRIC_TO_MONITOR = "val_loss"
METRIC_MODE       = "min"

In [36]:
SEED = int(np.random.randint(2147483647))
print(f"Random seed: {SEED}")

Random seed: 298238261


In [37]:
START_TOKEN = "<sos>"
END_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
OOV_TOKEN = "<unk>"

In [38]:
MAX_SEQUENCE = 30
IMAGE_SIZE = 224

In [39]:
TEMPERATURE = 0.1
TOP_P = 0.9

In [40]:
NUM_HEAD = 32
NUM_KV_HEAD = 8
NUM_LAYER = 1
EMBED_DIM = 640
HEAD_DIM = EMBED_DIM // NUM_HEAD
ROPE_BASE = 10000
MLP_SCALE = 3.5
DROPOUT = math.sin(math.sqrt(math.e * math.pi))
EPS_NORM = 1e-5

In [41]:
MAX_EPOCH = 42
BATCH_SIZE = 32
LEARNING_RATE = 3.1e-4
REDUCE_LR_FACTOR = 0.69
Vocab_size = 5000

In [42]:
MILESTONES = 1. / math.sqrt(MAX_EPOCH) * (
    np.array(
        [m for m in range(1, int(math.sqrt(MAX_EPOCH)))]
    )
)

## **Dataset**

In [43]:
# DATASET_URL = {
#     "image" : (
#         "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip",
#         "Flickr8k_Dataset.zip",
#     ),
#     "text"   : (
#         "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip",
#         "Flickr8k_text.zip",
#     ),
# }

In [44]:
# for dat in DATASET_URL.values():
#     url, filename = dat
#     download_and_extract_archive(
#         url,
#         "experiment/dataset",
#         filename=filename,
#     )
#     os.remove(os.path.join("experiment/dataset", filename))

#### **Image Transform**

In [45]:
class ToRGB(object):
    def __call__(self, image):
        assert 'PIL' in str(type(image)), "Expected PIL Image"
        return image.convert("RGB")

In [46]:
TRANSFORM = Compose(
    [
        Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=BICUBIC),
        ToRGB(),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

TRANSFORM_AUGMENTATION = Compose(
    [
        Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=BICUBIC),
        ToRGB(),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        RandomHorizontalFlip(),
        RandomAutocontrast(p=0.25),
        RandomAffine(
            degrees=22.5,
            scale=(math.sqrt(0.5 * math.pi), math.sqrt(math.pi)),
            shear=5.,
        ),
    ]
)

#### **Tokenizer**

In [47]:
class Tokenizer(object):
    def __init__(self, freq_threshold=1):
        self.encoder = collections.defaultdict(lambda : 3)
        self.encoder[END_TOKEN] = 2
        self.encoder[START_TOKEN] = 1
        self.encoder[PAD_TOKEN] = 0

        self.freq_threshold = freq_threshold

        frequencies = dict()
        idx = len(self.encoder)

        with open(ANNOTATION_PATH) as captions:
            sentence_list = [
                line.rstrip("\n").split("\t")[-1].strip().lower()
                for line in captions.readlines()
            ]

        for sentence in sentence_list:
            tokenized_sentence = [
                tok.text.lower() for tok in spacy.tokenizer(sentence.strip())
            ]
            for word in tokenized_sentence:
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    idx += 1
                    self.encoder[word] = idx

        self.decoder = dict()
        self.decoder[0] = PAD_TOKEN
        self.decoder[1] = START_TOKEN
        self.decoder[2] = END_TOKEN
        self.decoder[3] = OOV_TOKEN
        for k, v in self.encoder.items():
            if v not in self.decoder:
                self.decoder[v] = k

    def __len__(self):
        assert len(self.encoder) == len(self.decoder)
        return len(self.encoder)

    def encode(self, text):
        numericalized_token = list()
        del_oov_token = False

        for token in spacy.tokenizer(text.strip()):
            if token.text.lower() not in self.encoder:
                del_oov_token = True

            numericalized_token.append(self.encoder[token.text.lower()])

            if del_oov_token:
                del self.encoder[token.text.lower()]
                del_oov_token = False

        return numericalized_token

    def decode(self, tokens):
        return " ".join([self.decoder[token] for token in tokens])

In [48]:
Tokenizer = Tokenizer()
print(f"Vocab size: {len(Tokenizer.decoder)}")

Vocab size: 1899


In [49]:
def tokenize(text):
    sos_token = Tokenizer.encoder[START_TOKEN]
    eos_token = Tokenizer.encoder[END_TOKEN]
    tokens = [sos_token] + Tokenizer.encode(text) + [eos_token]

    result = torch.full(
        (MAX_SEQUENCE,),
        Tokenizer.encoder[PAD_TOKEN],
        dtype=torch.long,
    )

    if len(tokens) > MAX_SEQUENCE:
        tokens = tokens[:MAX_SEQUENCE]
        tokens[-1] = eos_token

    result[:len(tokens)] = torch.tensor(tokens)

    return result.long()

#### **Image Captioning Dataset**

In [50]:
class Flickr8KDataset(data.Dataset):
    def __init__(self, split):
        assert split in ["train", "val", "test", "inference"]

        self.images = list()
        self.captions = list()

        with open(ANNOTATION_PATH) as caption_file:
            for line in caption_file.readlines():
                # Image name and captions are separated using a tab
                img_name, caption = line.rstrip("\n").split("\t")

                # Each image is repeated five times for the five different
                # captions. Each image name has a suffix `#(caption_number)`
                img_name = img_name.split("#")[0]
                img_name = os.path.join(IMAGE_PATH, img_name.strip())

                caption = caption.strip()

                # We will remove caption that are either too short to too long
                tokens = [tok.text.lower() for tok in spacy.tokenizer(caption)]

                if img_name.endswith("jpg"):
                    self.images.append(img_name)
                    self.captions.append(caption)

        random.seed(SEED)
        random.shuffle(self.images)

        random.seed(SEED)
        random.shuffle(self.captions)

        n_data = len(self.images)

        if split == "train":
            self.transform = TRANSFORM_AUGMENTATION
            start = 0
            end = int(0.8 * n_data)
        else:
            self.transform = TRANSFORM
            if split == "val":
                start = int(0.8 * n_data)
                end = int(0.9 * n_data)
            else:
                start = int(0.9 * n_data)
                end = n_data

        self.images = self.images[start:end]
        self.captions = self.captions[start:end]

        self.split = split

    def raw_image(self, index):
        assert self.split == "inference"
        return Image.open(self.images[index]).convert("RGB")

    def inference_data(self, index):
        assert self.split == "inference"
        image = self.transform(Image.open(self.images[index]))
        return image

    def __len__(self):
        assert len(self.images) == len(self.captions)
        return len(self.images)

    def __getitem__(self, index):
        image = self.transform(Image.open(self.images[index]))
        caption = tokenize(self.captions[index])
        return image, caption

### **Load**

In [51]:
class CollateFunction(object):
    def __init__(self, pad_idx=Tokenizer.encoder[PAD_TOKEN]):
        self.pad_idx = pad_idx

    def __call__(self, batch):
        images = [item[0].unsqueeze(0) for item in batch]
        images = torch.cat(images, dim=0)

        captions = [item[1] for item in batch] + [torch.randn(MAX_SEQUENCE + 1)]
        captions = pad_sequence(
            captions,
            batch_first=True,
            padding_value=self.pad_idx,
        )[:-1, ...]
        return images, captions.long()

In [52]:
TrainDataset = Flickr8KDataset('train')
ValDataset = Flickr8KDataset('val')
TestDataset = Flickr8KDataset('test')

## **Model**

In [53]:
class AvgMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.scores = []

    def update(self, val):
        self.scores.append(val)

    def show(self):
        if len(self.scores) == 0:
            return 0  
        return torch.tensor(self.scores, dtype=torch.float32).mean()


In [54]:
MODEL_NAME = "SwinLlama3"

In [55]:
class ImageCaptioning(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.batch_size = BATCH_SIZE
        self.max_epoch = MAX_EPOCH
        self.lr = LEARNING_RATE
        self.lr_now = self.lr * 1e3

        MLP = FeedForward(
            gate_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
            down_proj=nn.Linear(int(EMBED_DIM * MLP_SCALE), EMBED_DIM, bias=False),
            up_proj=nn.Linear(EMBED_DIM, int(EMBED_DIM * MLP_SCALE), bias=False),
        )

        FEATURE_EXTRACTOR = swin_t(weights=Swin_T_Weights.IMAGENET1K_V1)
        # FEATURE_EXTRACTOR = mobilenet_v3_large(
        #     weights=MobileNet_V3_Large_Weights.IMAGENET1K_V2
        # )
        # FEATURE_EXTRACTOR.classifier[2] = nn.Dropout(p=DROPOUT, inplace=True)
        FEATURE_EXTRACTOR.head = nn.Linear(
            in_features=768,
            out_features=EMBED_DIM,
            bias=False,
        )

        SELF_ATTENTION = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=RotaryPositionalEmbedding(
                dim=HEAD_DIM,
                max_seq_len=1,
                base=ROPE_BASE,
            ),
            max_seq_len=1,
            attn_dropout=DROPOUT,
        )
        ENCODER_LAYER = TransformerEncoderLayer(
            attn=SELF_ATTENTION,
            mlp=copy.deepcopy(MLP),
            sa_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            mlp_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
        )
        self.encoder = TransformerEncoder(
            feature_extractor=FEATURE_EXTRACTOR,
            layer=ENCODER_LAYER,
            num_layers=NUM_LAYER,
            max_seq_len=MAX_SEQUENCE,
            num_heads=NUM_HEAD,
            head_dim=HEAD_DIM,
            norm=RMSNorm(EMBED_DIM, eps=EPS_NORM),
        )

        # LLaMA 3
        TOKEN_EMBEDDING = nn.Embedding(Vocab_size, EMBED_DIM)
        ROPE = RotaryPositionalEmbedding(
            dim=HEAD_DIM,
            max_seq_len=MAX_SEQUENCE,
            base=ROPE_BASE,
        )
        SELF_ATTENTION_1 = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=ROPE,
            max_seq_len=MAX_SEQUENCE,
            attn_dropout=DROPOUT,
        )
        SELF_ATTENTION_2 = CausalSelfAttention(
            embed_dim=EMBED_DIM,
            num_heads=NUM_HEAD,
            num_kv_heads=NUM_KV_HEAD,
            head_dim=HEAD_DIM,
            q_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            k_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            v_proj=nn.Linear(EMBED_DIM, NUM_KV_HEAD * HEAD_DIM, bias=False),
            output_proj=nn.Linear(EMBED_DIM, EMBED_DIM, bias=False),
            pos_embeddings=ROPE,
            max_seq_len=MAX_SEQUENCE,
            attn_dropout=DROPOUT,
        )
        DECODER_LAYER = TransformerDecoderLayer(
            attn1=SELF_ATTENTION_1,
            attn2=SELF_ATTENTION_2,
            mlp=copy.deepcopy(MLP),
            sa_norm_x1=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            sa_norm_x2=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
            mlp_norm=RMSNorm(dim=EMBED_DIM, eps=EPS_NORM),
        )
        OUT_PROJECTION = nn.Linear(EMBED_DIM, Vocab_size, bias=False)
        self.decoder = TransformerDecoder(
            tok_embedding=TOKEN_EMBEDDING,
            layer=DECODER_LAYER,
            num_layers=NUM_LAYER,
            max_seq_len=MAX_SEQUENCE,
            num_heads=NUM_HEAD,
            head_dim=HEAD_DIM,
            norm=RMSNorm(EMBED_DIM, eps=EPS_NORM),
            output=OUT_PROJECTION,
        )

        self.automatic_optimization = False

        self.train_loss = list()
        self.val_loss = list()

        self.train_loss_recorder = AvgMeter()
        self.val_loss_recorder = AvgMeter()

        self.test_rogue = ROUGEScore()

        self.sanity_check_counter = 1

    def forward(self, image, caption):
        image_feature = self.encoder(image)
        return self.decoder(caption, image_feature)

    def captionize(self, image, temperature=TEMPERATURE, top_p=TOP_P):
        assert image.shape[0] == 1

        self.encoder.setup_caches(max_batch_size=1)
        encoder_feat = self.encoder(
            image,
            input_pos=torch.tensor([0], device=self.device),
        )
        self.encoder.clear_caches()

        self.decoder.setup_caches(max_batch_size=1)

        pred_token = Tokenizer.encoder[START_TOKEN]
        token = [pred_token] + [Tokenizer.encoder[PAD_TOKEN]] * (MAX_SEQUENCE)
        for index in range(MAX_SEQUENCE):
            caption = torch.LongTensor([pred_token]).unsqueeze(0).to(self.device)

            pred_token = self.decoder(
                caption,
                encoder_feat,
                input_pos=torch.tensor([index], device=self.device),
            )

            if temperature > 0:
                # Apply temperature -> make distribution softer (?)
                pred_token = (pred_token / temperature).softmax(-1)[0]
                # Sampling
                psort, pidx = torch.sort(pred_token, dim= -1, descending=True)
                psum = torch.cumsum(psort, dim=-1)
                psort[psum - psort > top_p] = 0.
                psort.div_(psort.sum(dim=-1, keepdim=True))
                pred_token = torch.multinomial(psort, num_samples=1)
                pred_token = torch.gather(pidx, -1, pred_token).transpose(0, 1)
            else:
                pred_token = pred_token.softmax(-1).argmax(2)

            pred_token = pred_token.item()
            token[index + 1] = pred_token

            if pred_token == Tokenizer.encoder[END_TOKEN]:
                break

        self.decoder.clear_caches()

        return self.postprocess_text(Tokenizer.decode(token))

    def postprocess_text(self, text):
        text = text.replace(START_TOKEN, "")
        text = text.replace(END_TOKEN, "")
        text = text.replace(PAD_TOKEN, "")
        text = re.sub(r'\s([,.!?])', r'\1', text)
        text = '. '.join(map(lambda s: s.strip().capitalize(), text.split('.')))
        return text

    def training_step(self, batch, batch_nb):
        image, caption = batch

        pred = self(image, caption[:, :-1])
        pred = pred.view(-1, pred.shape[-1])
        caption = caption[:, 1:].reshape(caption.shape[0] * (caption.shape[-1] - 1))

        loss = F.cross_entropy(pred, caption, ignore_index=Tokenizer.encoder[PAD_TOKEN])

        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(loss)
        torch.nn.utils.clip_grad_norm_(
            self.parameters(),
            math.log2(math.sqrt(math.e * math.tau) * math.pi),
        )
        opt.step()

        self.log("train_loss", loss, prog_bar=True)
        self.train_loss_recorder.update(loss.data)

    def on_train_epoch_end(self):
        sch = self.lr_schedulers()
        sch.step()
        self.train_loss.append(self.train_loss_recorder.show().data.cpu().numpy())
        self.train_loss_recorder = AvgMeter()

    def validation_step(self, batch, batch_nb):
        image, caption = batch

        pred = self(image, caption[:, :-1])
        pred = pred.view(-1, pred.shape[-1])
        caption = caption[:, 1:].reshape(caption.shape[0] * (caption.shape[-1] - 1))

        loss = F.cross_entropy(pred, caption, ignore_index=Tokenizer.encoder[PAD_TOKEN])

        if self.sanity_check_counter == 0:
            self.log("val_loss", loss, prog_bar=True)
            self.val_loss_recorder.update(loss.data)

    def on_validation_epoch_end(self):
        if self.sanity_check_counter == 0:
            loss = self.val_loss_recorder.show().data.cpu().numpy()
            lr_now_ = self.optimizers().param_groups[0]["lr"]
            if self.lr_now != lr_now_:
                self.lr_now = lr_now_
                str_report = f"[{MODEL_NAME}] Learning Rate Changed: {lr_now_}"
                str_report += f"- Epoch: {self.current_epoch}"
                print(str_report)
            self.val_loss.append(loss)
            self.val_loss_recorder = AvgMeter()
        else:
            self.sanity_check_counter -= 1

    def test_step(self, batch, batch_nb):
        image, caption = batch

        N_BATCH = image.shape[0]

        rogue1_fmeasure = list()

        for id in range(N_BATCH):
            pred = self.captionize(image[id].unsqueeze(0))
            target = self.postprocess_text(
                Tokenizer.decode(
                    caption[id].cpu()
                    .detach()
                    .numpy()
                    .tolist()
                )
            )
            rogue1_fmeasure.append(
                self.test_rogue(pred, target)['rouge1_fmeasure']
                .cpu()
                .detach()
                .numpy()
                .tolist()
            )

        rogue1_fmeasure = np.array(rogue1_fmeasure).mean()
        self.log("ROGUE-1 F-measure", rogue1_fmeasure, prog_bar=True, logger=True)

    def on_train_end(self):
        # Loss
        img_file = f"experiment/training/{MODEL_NAME}_loss_plot.png"
        plt.plot(self.train_loss, color="r", label="train")
        plt.plot(self.val_loss, color="b", label="validation")
        plt.title("Loss Curves")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid()
        plt.savefig(img_file)
        plt.clf()
        img = cv2.imread(img_file)
        # cv2_imshow(img)

    def train_dataloader(self):
        return data.DataLoader(
            TrainDataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=CollateFunction(),
            num_workers=0
        )

    def val_dataloader(self):
        return data.DataLoader(
            ValDataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=CollateFunction(),
            num_workers=0
        )

    def test_dataloader(self):
        return data.DataLoader(
            TestDataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=CollateFunction(),
            num_workers=0
        )

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), self.lr)

        lr_scheduler = {
            "scheduler": optim.lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[int(self.max_epoch * ms) for ms in MILESTONES],
                gamma=REDUCE_LR_FACTOR,
            ),
            "name": "lr_scheduler",
        }

        return [optimizer], [lr_scheduler]

In [56]:
MODEL_NAME = "SwinLlama3"
BEST_MODEL_PATH = os.path.join(
    EXPERIMENT_DIR,
    f"model/{MODEL_NAME}_best.ckpt",
)
LAST_MODEL_PATH = os.path.join(
    EXPERIMENT_DIR,
    "model/last.ckpt",
)

## **Training**

In [57]:
def _train_loop():
    seed_everything(SEED, workers=True)

    print(MODEL_NAME)
    model = ImageCaptioning()

    callbacks = list()

    checkpoint = ModelCheckpoint(
        monitor=METRIC_TO_MONITOR,
        dirpath=f"{EXPERIMENT_DIR}/model",
        mode=METRIC_MODE,
        filename=f"{MODEL_NAME}_best",
        save_last=True,
    )
    callbacks.append(checkpoint)

    # if os.path.exists(BEST_MODEL_PATH):
    #     ckpt_path = BEST_MODEL_PATH
    # else:
    ckpt_path = None

    trainer = Trainer(
        accelerator="auto",
        devices=1,
        max_epochs=MAX_EPOCH,
        logger=False,
        callbacks=callbacks,
        log_every_n_steps=1,
    )
    trainer.fit(model, ckpt_path=ckpt_path)

if __name__ == '__main__':
    _train_loop()

Seed set to 298238261


SwinLlama3


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | encoder    | TransformerEncoder | 33.3 M | train
1 | decoder    | TransformerDecoder | 12.8 M | train
2 | test_rogue | ROUGEScore         | 0      | train
----------------------------------------------------------
46.1 M    Trainable params
0         Non-trainable params
46.1 M    Total params
184.356   Total estimated model params size (MB)
223       Modules in train mode
0         Modules in eval mode


Epoch 0: 100%|██████████| 51/51 [00:40<00:00,  1.27it/s, train_loss=3.690] [SwinLlama3] Learning Rate Changed: 0.00031- Epoch: 0
Epoch 6: 100%|██████████| 51/51 [00:42<00:00,  1.20it/s, train_loss=2.410, val_loss=3.510][SwinLlama3] Learning Rate Changed: 0.0002139- Epoch: 6
Epoch 12: 100%|██████████| 51/51 [00:43<00:00,  1.17it/s, train_loss=1.290, val_loss=3.730][SwinLlama3] Learning Rate Changed: 0.00014759099999999998- Epoch: 12
Epoch 19: 100%|██████████| 51/51 [00:41<00:00,  1.22it/s, train_loss=1.050, val_loss=4.040][SwinLlama3] Learning Rate Changed: 0.00010183778999999998- Epoch: 19
Epoch 25: 100%|██████████| 51/51 [00:32<00:00,  1.55it/s, train_loss=0.644, val_loss=4.200][SwinLlama3] Learning Rate Changed: 7.026807509999998e-05- Epoch: 25
Epoch 32: 100%|██████████| 51/51 [00:27<00:00,  1.88it/s, train_loss=0.717, val_loss=4.320][SwinLlama3] Learning Rate Changed: 4.848497181899998e-05- Epoch: 32
Epoch 41: 100%|██████████| 51/51 [00:57<00:00,  0.88it/s, train_loss=0.753, val_los

`Trainer.fit` stopped: `max_epochs=42` reached.


Epoch 41: 100%|██████████| 51/51 [00:59<00:00,  0.85it/s, train_loss=0.753, val_loss=4.420]


<Figure size 640x480 with 0 Axes>

## **Testing**

In [58]:
# def _test_loop():
#     trainer = Trainer(accelerator='auto', logger=False)
#     model = ImageCaptioning()
#     trainer.test(
#         model=model,
#         ckpt_path=LAST_MODEL_PATH if os.path.exists(LAST_MODEL_PATH) else None,
#     )

# _test_loop()

## **Inference**

### **Utils**

In [64]:
INFERENCE_SAMPLE = 9
MAX_CHAR = 50

### **Initialize**

In [69]:
model = ImageCaptioning.load_from_checkpoint(LAST_MODEL_PATH)
model.eval()

InferenceDataset = Flickr8KDataset('inference')

#### **From Dataset**

In [None]:
plt.clf()
fig = plt.figure()
plt.subplots_adjust(
    left=0.1,
    bottom=0.1,
    right=math.sqrt(2),
    top=math.sqrt(3),
    wspace=0.4,
    hspace=0.4,
)

N_SAMPLE = len(InferenceDataset)
SELECTED_SAMPLE = [
    random.randint(0, N_SAMPLE - 1) for _ in range(INFERENCE_SAMPLE)
]

for index, sample_idx in enumerate(SELECTED_SAMPLE):
    image = InferenceDataset.inference_data(sample_idx)
    image = image.to(
        "cuda" if torch.cuda.is_available() else "cpu"
    ).unsqueeze(0)

    caption = model.captionize(image)
    image = np.array(InferenceDataset.raw_image(sample_idx))

    title = [
        f"{caption[(cline * MAX_CHAR):((cline + 1) * MAX_CHAR)]}\n"
        for cline in range(math.ceil(len(caption) / MAX_CHAR))
    ]
    title = "".join(title)

    ax = fig.add_subplot(
        int(math.sqrt(INFERENCE_SAMPLE)),
        int(math.sqrt(INFERENCE_SAMPLE)),
        index + 1,
    )
    ax.imshow(image.copy().astype(np.uint8))
    ax.set_title(title, fontsize=8)
    ax.set_axis_off()

plt.show()