In [None]:
import torch
import torch.nn as nn
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import models, transforms
import json
from PIL import Image
import torch.nn.functional as F
import math
import os
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [None]:
IMAGES_FOLDER = '/kaggle/input/flickr8k/Images'

In [None]:
w2i = torch.load("/kaggle/input/w2i/pytorch/default/1/w2i.pt")
val_caption = torch.load("/kaggle/input/val_captions/pytorch/default/1/val_captions.pt")
train_caption = torch.load("/kaggle/input/train_captions/pytorch/default/1/train_captions.pt")
i2w = torch.load("/kaggle/input/i2w/pytorch/default/1/i2w.pt")
embedding_matrix = torch.load("/kaggle/input/embedding_matrix/pytorch/default/1/embedding_matrix.pt")

==================================================================================================
CREATE LOADER
===================================================================================================


In [None]:
w2i = {word: idx - 1 for word, idx in w2i.items()}

In [None]:
i2w = {idx: word for word, idx in w2i.items()}

In [None]:
def convert_sentences_to_index(caption):
    for cap in caption:
        encoded_caption = [w2i[word] for word in cap.split(' ') if word in w2i]
    return encoded_caption
        

In [None]:
def create_image_caption_pair(dataset, image_path):
    image_to_caption = []
    for image, captions in dataset.items():
        image = image_path + '/' + image
        encoded_caption = convert_sentences_to_index(captions)
        image_to_caption.append([image, encoded_caption])
    return image_to_caption

In [None]:
def get_max_len(data_set):
    max_len = 0
    for data in data_set:
        max_len = max(max_len, len(data[1]))
        
    return max_len

In [None]:
class FlickrDataset(Dataset):
    def __init__(self, data, training=True):
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        self.training = training
        self.max_len = get_max_len(data)
        self.samples = []
        for sample in data:
            image = sample[0]
            caption = sample[1]
            if training:
                for i in range(1, len(caption)):
                    in_seq, out_seq = caption[:i], caption[i]
                    self.samples.append((image, in_seq, out_seq))
            else:
                self.samples.append((image, caption))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.training:
            img_path, in_seq, out_token = self.samples[idx]
        else:
            img_path, caption = self.samples[idx]
    
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
    
        if self.training:
            in_seq = torch.tensor(in_seq)
            padded_seq = torch.zeros(self.max_len, dtype=torch.long)
            length = min(len(in_seq), self.max_len)
            padded_seq[-length:] = in_seq[-length:]
    
            out_token = torch.tensor(out_token)
            return image, padded_seq, out_token
        else:
            return image, caption

In [None]:
def get_dataset(training_set, validation_set, image_path):
    train_set = create_image_caption_pair(training_set, image_path)
    val_set = create_image_caption_pair(validation_set, image_path)
    
    train_data = FlickrDataset(train_set)
    val_data = FlickrDataset(val_set, training = False)
    train_loader = DataLoader(train_data, batch_size = 16, shuffle = False)
    val_loader = DataLoader(val_data, batch_size = 1, shuffle = False)
    return train_loader, val_loader



In [None]:
train_loader, val_loader = get_dataset(train_caption, val_caption, IMAGES_FOLDER)

print(len(train_loader), len(val_loader))

==================================================================================================
CREATE MODEL
===================================================================================================


In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        vgg              = models.vgg16(pretrained=True)
        self.features    = vgg.features
        self.avgpooling  = vgg.avgpool
        self.flatten     = nn.Flatten()

        for param in self.features.parameters():
            param.requires_grad = False
    def forward(self, x):
        out = self.features(x)
        out = self.avgpooling(out)
        out = self.flatten(out)

        return out
    


In [None]:
class KANLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.size(-1) == self.in_features
        original_shape = x.shape
        x = x.reshape(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        output = base_output + spline_output
        
        output = output.reshape(*original_shape[:-1], self.out_features)
        return output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )

In [None]:
embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float)

In [None]:
class VGGLSTM(nn.Module):
    def __init__(self, input_feature_size, vocab_size, embed_size, hidden_size):
        super(VGGLSTM, self).__init__()
        self.vgg = VGG()
        self.feature_projector = nn.Linear(input_feature_size, embed_size)
        self.embed = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, image, captions):
        features = self.vgg(image)
        projected_features = self.feature_projector(features)  # [batch, embed_size]
        projected_features = projected_features.unsqueeze(1)  # [batch, 1, embed_size]

        captions_embed = self.embed(captions)  # [batch, seq_len, embed_size]
        lstm_input = torch.cat((projected_features, captions_embed), dim=1)

        lstm_out, _ = self.lstm(lstm_input)
        out = self.fc(lstm_out[:, -1, :])  # Output from the last time step
        return out
model = VGGLSTM(
    input_feature_size=512 * 7 * 7,
    vocab_size=len(w2i),
    embed_size=200,
    hidden_size=512
)

# class VGGKANLSTM(nn.Module):
#     def __init__(self, layers_hidden, vocab_size, embed_size, hidden_size, grid_size=5, spline_order=3):
#         super(VGGKANLSTM, self).__init__()
#         self.vgg = VGG()  # VGG encoder
#         self.kan = KAN(layers_hidden, grid_size=grid_size, spline_order=spline_order)  # KAN layer
#         self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=True)  # LSTM for sequence generation
#         self.fc = nn.Linear(hidden_size, vocab_size)  # Output layer for word prediction
#         self.feature_projector = nn.Linear(layers_hidden[-1], embed_size)
#         self.embed = nn.Embedding.from_pretrained(embedding_matrix, freeze=True)
        
#     def forward(self, image, captions):

#         features = self.vgg(image) 
#         # print(f"after vgg: {features.shape}")
#         refined_features = self.kan(features) 
#         # print(f"after kan: {refined_features.shape}")

#         refined_features = self.feature_projector(refined_features)  # [1, 256]
#         refined_features = refined_features.unsqueeze(1)  # [1, 1, 256]

#         captions_embed = self.embed(captions) 
#         # print(f"captions embed: {captions_embed.shape}")

#         lstm_input = torch.cat((refined_features, captions_embed), dim=1)
#         # print(f"lstm input: {lstm_input.shape}")

#         lstm_out, (hn, cn) = self.lstm(lstm_input)  

#         out = self.fc(lstm_out[:, -1, :])
        
#         return out


# Example usage
# model = VGGLSTM(layers_hidden=[512 * 7 * 7, 1024, 512], vocab_size=len(w2i), embed_size=200, hidden_size=512)
model.to("cuda")

In [None]:
print('Number of network parameters:', sum(param.numel() for param in model.parameters()))

In [None]:
wandb.login(key = "5dd930565a80444a1b9c4c6613a2c773637a4b4c")

==================================================================================================
CREATE LOSS FUNCTION
===================================================================================================


In [None]:
learning_rate = 1e-3
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

==================================================================================================
CREATE TRAINER
===================================================================================================


In [None]:
import os
import torch
import wandb

def trainer(model, train_loader, val_loader, optimizer, loss_func,
            start_epoch, end_epoch, save_every=1, checkpoint_path=None,
            save_path="model.pt", wb_tracking=True):

    if wb_tracking: 
        wandb.init(
            entity="CS338",
            project="image_captioning",
            name="test",
            config={
                "epochs": end_epoch, 
                "optimizer": optimizer.__class__.__name__,
                "learning_rate": optimizer.param_groups[0]["lr"],
                "loss_func": loss_func.__class__.__name__,
            },
            settings=wandb.Settings(init_timeout=30)
        )

    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Resuming training from epoch {start_epoch}")
        
    for epoch in range(start_epoch, end_epoch):
        model.train()
        epoch_loss = 0
        print({f"Start training epoch {epoch}"})
        for idx, batch_data in enumerate(train_loader):
            image, pads, label = batch_data
            image = image.to(device)
            pads = pads.to(device)
            label = label.to(device)
    
            optimizer.zero_grad()
            outputs = model(image, pads)
            loss = criterion(outputs, label)
    
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            
            if wb_tracking:
                wandb.log({"train_step_loss": loss.item()})
        
        avg_loss = epoch_loss / len(train_loader)
        if wb_tracking:
            wandb.log({
                "train_epoch_loss": avg_loss,
                "epoch": epoch
            })

        if (epoch + 1) % save_every == 0 or epoch == end_epoch - 1:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, save_path)
            print(f"Checkpoint saved at epoch {epoch} to {save_path}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_dir = 'kaggle/working'
start = 0
end = 30
batch_size = 1
save_every = 1
checkpoint_path = ""
os.makedirs(save_dir, exist_ok = True)
save_path = os.path.join(save_dir, "best_metric_model.pth")
print(len(train_loader))
trainer(model, train_loader, val_loader, optimizer, criterion, start, end, save_every, checkpoint_path, save_path)

In [None]:
# def generate_caption_beam_search(model, image, i2w, w2i, transform=None, beam_width=3, max_len=20, device="cuda"):
#     model.to(device)
#     model.eval()

#     with torch.no_grad():
#         image_tensor = image.to(device)  # Không dùng transform
#         features = model.vgg(image_tensor)
#         refined_features = model.kan(features)
#         refined_features = model.feature_projector(refined_features).unsqueeze(1)

#         start_token = w2i["startseq"]
#         end_token = w2i["endseq"]
#         beam = [(0.0, [start_token], None)]

#         for _ in range(max_len):
#             new_beam = []
#             for log_prob, caption_idx, hidden in beam:
#                 if caption_idx[-1] == end_token:
#                     new_beam.append((log_prob, caption_idx, hidden))
#                     continue

#                 input_seq = torch.tensor([caption_idx[1:]], dtype=torch.long).to(device) if len(caption_idx) > 1 else torch.tensor([[start_token]], dtype=torch.long).to(device)
#                 caption_embed = model.embed(input_seq)
#                 lstm_input = torch.cat((refined_features, caption_embed), dim=1)

#                 lstm_out, hidden_out = model.lstm(lstm_input, hidden)
#                 output = model.fc(lstm_out[:, -1, :])
#                 log_probs = torch.nn.functional.log_softmax(output, dim=1).squeeze(0)

#                 top_log_probs, top_indices = torch.topk(log_probs, beam_width)

#                 for i in range(beam_width):
#                     next_idx = top_indices[i].item()
#                     total_log_prob = log_prob + top_log_probs[i].item()
#                     new_caption_idx = caption_idx + [next_idx]
#                     new_beam.append((total_log_prob, new_caption_idx, hidden_out))

#             beam = sorted(new_beam, key=lambda x: x[0], reverse=True)[:beam_width]

#         best_caption = beam[0][1]
#         caption_words = []
#         for idx in best_caption[1:]:
#             word = i2w[idx]
#             if word == "endseq":
#                 break
#             caption_words.append(word)

#         return " ".join(caption_words)


In [None]:
# from tqdm import tqdm
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# def evaluate_bleu(model, val_data, i2w, w2i, transform, device="cuda"):
#     smoothie = SmoothingFunction().method4
#     model.eval()
#     scores = []

#     for image_tensor, caption in tqdm(val_data, desc="Evaluating BLEU"):
#         image = image_tensor.unsqueeze(0).to(device)  # Thêm batch dimension
#         reference = [[i2w[idx] for idx in caption[1:]]]  # Bỏ <start>
#         prediction = generate_caption_beam_search(model, image, i2w, w2i, beam_width=3, device=device)
#         candidate = prediction.split()

#         bleu_score = sentence_bleu(reference, candidate, weights=(0.5, 0.5), smoothing_function=smoothie)
#         scores.append(bleu_score)

#     return sum(scores) / len(scores)



# val_bleu = evaluate_bleu(model, val_data=val_loader.dataset, i2w=i2w, w2i=w2i, transform=val_loader.dataset.transform)
# print("BLEU Score:", val_bleu)
