In [1]:
import os
import json
import torch
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
from tqdm import tqdm
import numpy as np

In [2]:
with open('../data/flickr_train_images.json', 'r') as f:
    train_images = json.load(f)

image_dir = '../data/Flickr8k_Dataset/Flicker8k_Dataset/'  

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [4]:
from torchvision.models import resnet50, ResNet50_Weights

weights = ResNet50_Weights.DEFAULT
resnet = resnet50(weights=weights)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])  # Remove final FC layer
resnet.eval()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [5]:
feature_dir = '../features/'
os.makedirs(feature_dir, exist_ok=True)

for img_name in tqdm(train_images):
    img_path = os.path.join(image_dir, img_name)
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        feature = resnet(image).squeeze().cpu().numpy()  
    
    feature_path = os.path.join(feature_dir, img_name.replace('.jpg', '.npy'))
    np.save(feature_path, feature)


100%|██████████████████████████████████████████████████████████████████████████████| 6000/6000 [13:16<00:00,  7.53it/s]


In [6]:
feature_dir = '../features_spatial/'
os.makedirs(feature_dir, exist_ok=True)

for img_name in tqdm(train_images):
    img_path = os.path.join(image_dir, img_name)
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        spatial_features = resnet(image)
        batch_size, channels, height, width = spatial_features.shape
        spatial_features = spatial_features.view(channels, -1).transpose(0, 1)  
        spatial_features = spatial_features.cpu().numpy()
    
    feature_path = os.path.join(feature_dir, img_name.replace('.jpg', '_spatial.npy'))
    np.save(feature_path, spatial_features)

100%|██████████████████████████████████████████████████████████████████████████████| 6000/6000 [14:02<00:00,  7.12it/s]
