In [None]:
import torch
import numpy as np
import random
import warnings
warnings.filterwarnings("ignore")

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
import kagglehub

path = kagglehub.dataset_download("nagasai524/mini-coco2014-dataset-for-image-captioning")
print("Path to dataset files:", path)

In [None]:
import os
import json
from PIL import Image
import torchvision.transforms as transforms

dataset_path = path

json_file = os.path.join(dataset_path, "captions.json")

with open(json_file, "r") as f:
    data = json.load(f)

if isinstance(data, dict) and "annotations" in data:
    annotations = data["annotations"]
else:
    annotations = data

In [None]:
def load_coco_sample(image_ids, dataset_path, annotations):
    """Load only the images needed for experiments"""
    
    # Group captions by image_id
    captions_by_image = {}
    for ann in annotations:
        img_id = ann['image_id']
        if img_id not in captions_by_image:
            captions_by_image[img_id] = []
        captions_by_image[img_id].append(ann['caption'])
    
    # Find image directory
    image_dir = None
    for root, dirs, files in os.walk(dataset_path):
        if any(f.endswith('.jpg') for f in files):
            image_dir = root
            break
    
    # Load only the images you want
    samples = []
    for img_id in image_ids[:30]:  # Take first 30
        img_path = os.path.join(image_dir, f"COCO_train2014_{img_id:012d}.jpg")
        img = Image.open(img_path).convert('RGB')
        img = img.resize((384, 384))  # BLIP size
        samples.append({
            'image_id': img_id,
            'image': img,
            'captions': captions_by_image[img_id]
        })
    
    return samples

In [None]:
# Functions from coco.py

def get_coco_dataset(id=3, for_attention=False):
    transform_totensor = transforms.ToTensor()
    img, target = cap[id]
    if for_attention:
        img = transform_totensor(img)
        return img
    else:
        return cap, target, img


def get_coco_dataset_for_sat(id, for_attention=False):
    pil_image, target = cap[id][0], cap[id][1]
    return cap, target, pil_image

In [None]:
from torchvision.transforms.functional import InterpolationMode

def divide_list(lst, num_chunks=5):
    chunk_size = len(lst) // num_chunks
    remainder = len(lst) % num_chunks
    result = []
    start = 0
    for i in range(num_chunks):
        end = start + chunk_size
        if remainder > 0:
            end += 1
            remainder -= 1
        result.append(lst[start:end])
        start = end
    return result

def load_image(id, image_size=384, device=device, for_blip=True, dataset="coco"):
    if dataset == "coco":
        _, _, raw_image = get_coco_dataset(id)
    elif dataset == "flicker8k":
        _, _, raw_image = flickr8k_dataset(id)
    
    if for_blip:
        # BLIP expects normalized images
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size), 
                            interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.48145466, 0.4578275, 0.40821073), 
                (0.26862954, 0.26130258, 0.27577711)
            )
        ])
        image = transform(raw_image).unsqueeze(0).to(device)
    else:
        # Raw tensor without normalization
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size),
                            interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
        ])
        image = transform(raw_image).to(device)
    
    return image

In [None]:
_, captions, img = get_coco_dataset(id=0)
print(f"Image: {type(img)}, Size: {img.size}")
print(f"Captions: {captions[0]}")

test_list = list(range(25))
divided = divide_list(test_list, num_chunks=5)
print(f"divide_list test: {len(divided)} chunks")

tensor_img = load_image(id=0, image_size=255, device=device, before=True)
print(f"load_image: {tensor_img.shape}")