In [1]:
import numpy as np

# Load the features
images = np.load("coco_features/images.npy")
captions = np.load("coco_features/captions.npy", allow_pickle=True)

print("Image features:", images.shape)   # (N, 2048)
print("Captions:", len(captions))        # N arrays of shape (M, 768)

Image features: (5000, 2048)
Captions: 5000


In [2]:
# Align text features (average across captions for each image)
text_features = []
for caps in captions:
    if len(caps) > 0:
        text_features.append(np.mean(caps, axis=0))
    else:
        text_features.append(np.zeros(768))  # if no captions
text_features = np.array(text_features)

print("Text features:", text_features.shape)  # (N, 768)

Text features: (5000, 768)


In [3]:
fused_features = np.concatenate([images, text_features], axis=1)
print("Fused features:", fused_features.shape)  # (N, 2048+768=2816)

Fused features: (5000, 2816)


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Convert to torch tensors
X = torch.tensor(fused_features, dtype=torch.float32)

# Example: dummy labels (replace with your own task labels, e.g. categories, sentiments, etc.)
y = torch.randint(0, 10, (X.shape[0],))  # 10-class classification

dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Simple classifier
class FusionClassifier(nn.Module):
    def __init__(self, input_dim=2816, hidden_dim=512, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)

model = FusionClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop (simple)
for epoch in range(5):
    for batch_X, batch_y in loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


Epoch 1, Loss: 2.3222
Epoch 2, Loss: 2.3590
Epoch 3, Loss: 2.3018
Epoch 4, Loss: 2.3114
Epoch 5, Loss: 2.2720
