In [3]:
import os
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image

import torchvision.models as models
import torchvision.transforms as transforms
from transformers import BertTokenizer, BertModel
from pycocotools.coco import COCO

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [5]:
# Load pretrained ResNet50
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))  # remove final FC layer
resnet.eval().to(device)

# Load pretrained BERT
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")
bert.eval().to(device)

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [6]:
def extract_image_features(img_path):
    image = Image.open(img_path).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = resnet(img_tensor)  # [1, 2048, 1, 1]
    return features.flatten(1).cpu().numpy()  # [1, 2048]

def extract_text_features(caption):
    inputs = tokenizer(caption, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = bert(**inputs)
    return outputs.last_hidden_state[:,0,:].cpu().numpy()  # [1, 768]


In [7]:
def extract_all_features_npy(coco, img_dir, save_dir="features"):
    os.makedirs(save_dir, exist_ok=True)

    img_ids = coco.getImgIds()
    print(f"Extracting features for {len(img_ids)} images...")

    all_image_features = []
    all_caption_features = []

    for img_id in tqdm(img_ids):
        img_info = coco.loadImgs(img_id)[0]
        img_path = f"{img_dir}/{img_info['file_name']}"

        try:
            # ---- IMAGE FEATURES ----
            img_features = extract_image_features(img_path)
            all_image_features.append(img_features.squeeze(0))

            # ---- CAPTION FEATURES ----
            ann_ids = coco.getAnnIds(imgIds=img_id)
            anns = coco.loadAnns(ann_ids)
            captions = [ann['caption'] for ann in anns if "caption" in ann]

            cap_features = []
            for cap in captions:
                cap_feat = extract_text_features(cap)
                cap_features.append(cap_feat.squeeze(0))
            all_caption_features.append(np.array(cap_features))

        except Exception as e:
            print(f"Skipping {img_id} due to error: {e}")

    all_image_features = np.array(all_image_features)

    np.save(os.path.join(save_dir, "images.npy"), all_image_features)
    np.save(os.path.join(save_dir, "captions.npy"), np.array(all_caption_features, dtype=object))

    print(f"Saved image features to {save_dir}/images.npy")
    print(f"Saved caption features to {save_dir}/captions.npy")

In [8]:
# Example paths (change these to your setup)
annFile = "/home/BTECH_7TH_SEM/MS-COCO/annotations_trainval2017/annotations/captions_val2017.json"
img_dir = "/home/BTECH_7TH_SEM/MS-COCO/val2017"
save_dir = "/home/BTECH_7TH_SEM/Desktop/MML-RL-and-NLP/MML/coco_features"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

coco = COCO(annFile)
extract_all_features_npy(coco, img_dir, save_dir)


loading annotations into memory...
Done (t=0.38s)
creating index...
index created!
Extracting features for 5000 images...


100%|██████████| 5000/5000 [03:39<00:00, 22.77it/s]


Saved image features to /home/BTECH_7TH_SEM/Desktop/MML-RL-and-NLP/MML/coco_features/images.npy
Saved caption features to /home/BTECH_7TH_SEM/Desktop/MML-RL-and-NLP/MML/coco_features/captions.npy


In [9]:
# Load saved features
images = np.load("coco_features/images.npy")
captions = np.load("coco_features/captions.npy", allow_pickle=True)

print("Images shape:", images.shape)
print("Number of caption sets:", len(captions))
print("Captions for first image shape:", captions[0].shape if len(captions) > 0 else None)

Images shape: (5000, 2048)
Number of caption sets: 5000
Captions for first image shape: (5, 768)
