In [None]:
import sys
import os
sys.path.append(os.path.abspath('../../'))

In [None]:
import numpy as np
import json
import os
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torch
from torchvision import transforms
from src.data import MultimodalDataset
from src.models.image_encoder import ImageEncoder


# Exploring the data

In [None]:
captions_path = '../../data/mscoco/annotations/captions_val2017.json'


with open(captions_path, 'r') as f:
    captions = json.load(f)

In [None]:
len(captions['annotations'])

In [None]:
captions['annotations'][0]

In [None]:
captions['images'][2]

In [None]:
captions['annotations'][2]['image_id']

In [None]:
idx = 2
img = mpimg.imread(os.path.join('../../data/mscoco/val2017', captions['images'][idx]['file_name']))
plt.imshow(img)
plt.axis('off')
plt.show()
print(captions['annotations'][idx]['caption'])

In [None]:
id2file = {img['id']: img['file_name'] for img in captions['images']}

image_filename_caption_pairs = []
for ann in captions['annotations']:
    caption = ann['caption']
    filename = id2file.get(ann['image_id'])
    if filename:
        image_filename_caption_pairs.append((filename, caption))

In [None]:
[caption for caption in image_filename_caption_pairs if caption[0] == image_filename_caption_pairs[idx][0]]

In [None]:
img = mpimg.imread(os.path.join('../../data/mscoco/val2017', image_filename_caption_pairs[idx][0]))
plt.imshow(img)
plt.axis('off')
plt.show()
print(image_filename_caption_pairs[idx][1])

In [None]:
from collections import Counter

# Count occurrences of each image filename in image_filename_caption_pairs
filename_counts = Counter([pair[0] for pair in image_filename_caption_pairs])

# Get list of filenames that appear more than once
duplicates = [filename for filename, count in filename_counts.items() if count > 1]
len(duplicates)


In [None]:
filename_counts.most_common()[-10:]

# Testing the implementation

In [None]:
img_path = '../../data/mscoco/val2017'
captions_path = '../../data/mscoco/annotations/captions_val2017.json'


transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop((224, 224)),
    transforms.ToTensor(),
])

dataset = MultimodalDataset(img_path, captions_path, transform=transforms, freq_threshold=1)

In [None]:
sample_dataset = [dataset[i] for i in range(100)]

image_encoder = ImageEncoder(in_channels=3, image_size=(224, 224), latent_dim=128)
image_encoder.eval()  # Set to evaluation mode

for idx, (sample_image, _) in enumerate(sample_dataset):
    sample_image = sample_image.unsqueeze(0)
    print(sample_image.shape)
    mu, logvar = image_encoder(sample_image)
    print(f"Image {idx+1}: Latent mean (mu): {mu}, Latent log-variance (logvar): {logvar}")