In [None]:
import torch
import numpy as np
import random
import warnings
import os
import json
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode

warnings.filterwarnings("ignore")
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
import kagglehub

path = kagglehub.dataset_download("nagasai524/mini-coco2014-dataset-for-image-captioning")
print(f"Path to dataset files: {path}")

In [None]:
dataset_path = path

json_file = os.path.join(dataset_path, "captions.json")

with open(json_file, "r") as f:
    data = json.load(f)

if isinstance(data, dict) and "annotations" in data:
    annotations = data["annotations"]
else:
    annotations = data

In [None]:
class KaggleCOCODataset:
    def __init__(self, dataset_path, annotations):
        self.dataset_path = dataset_path
        self.annotations = annotations
        
        self.image_dir = None
        for root, dirs, files in os.walk(dataset_path):
            if any(f.endswith((".jpg", ".png")) for f in files):
                self.image_dir = root
                break
        
        self.captions_by_image = {}
        for ann in annotations:
            image_id = ann.get("image_id")
            caption = ann.get("caption", "")
            if image_id not in self.captions_by_image:
                self.captions_by_image[image_id] = []
            self.captions_by_image[image_id].append(caption)
        
        self.image_ids = list(self.captions_by_image.keys())
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        image_filename = f"COCO_train2014_{image_id:012d}.jpg"
        image_path = os.path.join(self.image_dir, image_filename)
        
        if not os.path.exists(image_path):
            image_path = os.path.join(self.image_dir, f"{image_id}.jpg")
        
        if not os.path.exists(image_path):
            for f in os.listdir(self.image_dir):
                if str(image_id) in f:
                    image_path = os.path.join(self.image_dir, f)
                    break
        
        img = Image.open(image_path).convert("RGB")
        img = img.resize((384, 384))
        
        captions = self.captions_by_image[image_id]
        
        return img, captions

cap = KaggleCOCODataset(dataset_path, annotations)
print(f"Dataset initialized with {len(cap)} images")

In [None]:
def get_coco_dataset(id=3, for_attention=False):
    transform_totensor = transforms.ToTensor()
    img, target = cap[id]
    if for_attention:
        img = transform_totensor(img)
        return img
    else:
        return cap, target, img


def get_coco_dataset_for_sat(id, for_attention=False):
    pil_image, target = cap[id][0], cap[id][1]
    return cap, target, pil_image

In [None]:
def divide_list(lst, num_chunks=5):
    chunk_size = len(lst) // num_chunks
    remainder = len(lst) % num_chunks
    result = []
    start = 0
    for i in range(num_chunks):
        end = start + chunk_size
        if remainder > 0:
            end += 1
            remainder -= 1
        result.append(lst[start:end])
        start = end
    return result


def load_image(id, image_size=384, device=device, before=True, dataset="coco"):
    if dataset == "coco":
        _, _, raw_image = get_coco_dataset(id)
    else:
        raise ValueError(f"Dataset {dataset} not supported. Only 'coco' is available.")
    
    transform_pil_tensor = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.PILToTensor(),
    ])

    transform_tensor = transforms.Compose([
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073), 
            (0.26862954, 0.26130258, 0.27577711)
        )
    ])

    if before:
        image = transform_pil_tensor(raw_image).to(device) 
    else:
        image = transform_tensor(raw_image).unsqueeze(0).to(device) 

    return image

In [None]:
_, captions, img = get_coco_dataset(id=0)
print(f"Image type: {type(img)}, Size: {img.size}")
print(f"First caption: {captions[0]}")

test_list = list(range(25))
divided = divide_list(test_list, num_chunks=5)
print(f"divide_list test: {len(divided)} chunks")

tensor_img = load_image(id=0, image_size=384, device=device, before=True)
print(f"load_image output shape: {tensor_img.shape}")

tensor_img_normalized = load_image(id=0, image_size=384, device=device, before=False)
print(f"load_image normalized shape: {tensor_img_normalized.shape}")