In [16]:
import numpy as np
import json
import random
import torch
import torchvision.transforms as transforms
from transformers import BertTokenizer, BertTokenizerFast, BertModel
import timm
from torch.utils.data import DataLoader
import torch.nn as nn
from PIL import Image
import requests
from io import BytesIO
import re
import matplotlib.pyplot as plt
import os
from google.colab import drive

In [17]:
# Mount Google Drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [18]:
# Load the tokenizer and models
bert_tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
image_encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
image_encoder.eval()  # Set model to evaluation mode
text_encoder = BertModel.from_pretrained('bert-base-uncased')



In [19]:
# Load COCO-AB dataset from Google Drive path
coco_ab_json_path = '/content/drive/My Drive/COCO-AB/coco_ab_v1_0.json'

In [20]:
# Check if the file exists
if not os.path.exists(coco_ab_json_path):
    raise FileNotFoundError(f"The file {coco_ab_json_path} does not exist. Please check the path.")

In [21]:
# JSON 파일 읽기
with open(coco_ab_json_path, 'r') as f:
    coco_ab_data = json.load(f)

# 데이터 구조의 일부만 확인 (처음 5개의 항목만 출력)
print(coco_ab_data[:5])

[{'imageWidth': 450, 'originalImageWidth': 480, 'mouseTracking': [], 'originalImageHeight': 360, 'actionHistories': [{'actionType': 'add', 'pointTo': {'x': 0.7883333333333333, 'y': 0.8377777777777777}, 'timeAt': 155755, 'iconType': 'cake'}, {'pointFrom': {'x': 0.7883333333333333, 'y': 0.8377777777777777}, 'actionType': 'move', 'pointTo': {'x': 0.7016666666666667, 'y': 0.5444444444444444}, 'timeAt': 158789, 'iconType': 'cake'}], 'categoryHistories': [{'timeAt': 144081, 'categoryIndex': 1, 'categoryName': 'Animal', 'usingKeyboard': False}, {'timeAt': 144161, 'categoryIndex': 2, 'categoryName': 'Vehicle', 'usingKeyboard': False}, {'timeAt': 144372, 'categoryIndex': 3, 'categoryName': 'OutdoorObj', 'usingKeyboard': False}, {'timeAt': 144495, 'categoryIndex': 4, 'categoryName': 'Sports', 'usingKeyboard': False}, {'timeAt': 144693, 'categoryIndex': 5, 'categoryName': 'Kitchenware', 'usingKeyboard': False}, {'timeAt': 146670, 'categoryIndex': 6, 'categoryName': 'Food', 'usingKeyboard': False}

In [22]:
# Define contrastive loss for multimodal alignment
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, image_features, text_features):
        # Normalize features
        image_features = nn.functional.normalize(image_features, p=2, dim=1)
        text_features = nn.functional.normalize(text_features, p=2, dim=1)

        # Calculate similarity scores
        similarity_matrix = torch.matmul(image_features, text_features.T)
        positive_pairs = torch.diag(similarity_matrix)
        negative_pairs = similarity_matrix - torch.eye(similarity_matrix.size(0)).to(similarity_matrix.device) * 1e12

        # Contrastive loss computation
        positive_loss = 1 - positive_pairs
        negative_loss = torch.clamp(self.margin - negative_pairs, min=0).mean()

        loss = positive_loss.mean() + negative_loss
        return loss

In [23]:
# Dataset class for COCO-AB with annotation byproducts
class COCOABDataset(torch.utils.data.Dataset):
    def __init__(self, coco_data, transform=None):
        # Assuming 'actionHistories' or 'categoryHistories' is a key in the data
        self.data = [item for item in coco_data if 'actionHistories' in item or 'categoryHistories' in item]
        self.transform = transform
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_url = item.get('url', None)

        # 기본 이미지 초기화
        img = Image.new('RGB', (224, 224), color='black')

        if image_url:
            if image_url.startswith('/') or image_url.startswith('./'):  # 로컬 경로일 경우
                try:
                    img = Image.open(image_url).convert('RGB')
                except Exception as e:
                    print(f"Error loading local image: {image_url}, Error: {e}")
            elif re.match(r'(http|https)://', image_url):  # 웹 URL일 경우
                try:
                    response = requests.get(image_url)
                    img = Image.open(BytesIO(response.content)).convert('RGB')
                except requests.exceptions.RequestException as e:
                    print(f"Error loading image from URL: {image_url}, Error: {e}")

        if self.transform:
            img = self.transform(img)

        # Convert annotation byproducts to text
        if 'actionHistories' in item and len(item['actionHistories']) > 0:
            action = item['actionHistories'][0]
            click_info_text = f"Action taken at point ({action['pointTo']['x']}, {action['pointTo']['y']}) of type {action['iconType']}."
        else:
            click_info_text = "No specific action recorded."

        # Tokenize the text with padding
        tokenized_text = self.tokenizer(click_info_text, padding='max_length', truncation=True, max_length=64, return_tensors='pt')

        return img, tokenized_text

In [24]:
# Set up dataset and data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [25]:
# Since the loaded JSON data is a list, we use it directly
dataset = COCOABDataset(coco_ab_data, transform=transform)


In [26]:
# Check if dataset is empty
if len(dataset) == 0:
    raise ValueError("The dataset is empty. Please check the data loading process.")

train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)  # Reduced num_workers to avoid warning


In [27]:
# Initialize contrastive loss
contrastive_loss = ContrastiveLoss()

In [28]:
# Training loop (simplified)
optimizer = torch.optim.Adam(list(image_encoder.parameters()) + list(text_encoder.parameters()), lr=1e-4)


In [29]:
for epoch in range(1):  # Example: 10 epochs
    for batch in train_loader:
        images, tokenized_texts = batch

        # Extract features
        image_features = image_encoder(images).view(images.size(0), -1)
        text_features = text_encoder(**{key: val.squeeze(1) for key, val in tokenized_texts.items()}).last_hidden_state[:, 0, :]

        # Ensure dimensions match for matrix multiplication
        image_features = image_features[:, :text_features.size(1)]

        # Calculate loss
        loss = contrastive_loss(image_features, text_features)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

KeyboardInterrupt: 