In [2]:
import os

val_path = "val2017"
train_path = "/raid/kyscap251/team2/train2017/train2017"
test = "train2017"

val_items = os.listdir(val_path)
train_items = os.listdir(train_path)

print("val", val_items[:5])
print("train", train_items[:5])

# 파일만 필터링
files = [f for f in os.listdir(train_path)
         if os.path.isfile(os.path.join(train_path, f))]

print(f"파일 개수: {len(files)}")

filess = [f for f in os.listdir(test)
         if os.path.isfile(os.path.join(test, f))]

print(f"파일 개수: {len(filess)}")

val ['000000433103.jpg', '000000129113.jpg', '000000196843.jpg', '000000252507.jpg', '000000258541.jpg']
train ['000000254879.jpg', '000000316649.jpg', '000000430989.jpg', '000000286349.jpg', '000000458365.jpg']
파일 개수: 109932
파일 개수: 109932


In [1]:
# 환경 설정 및 라이브러리 로딩
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
import json
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from torch.optim.lr_scheduler import StepLR

In [2]:
# 이미지 인코더: ViT
class VisionEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # [B, 1+P, D]

In [3]:
# 텍스트 인코더: BERT 기반 Transformer
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, captions):
        tokenized = self.tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(self.bert.device)
        outputs = self.bert(**tokenized)
        return outputs.last_hidden_state, tokenized

In [4]:
# Cross-Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=768, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, text_emb, image_patches):
        attn_output, attn_weights = self.attn(text_emb, image_patches, image_patches)
        return attn_output, attn_weights

In [15]:
# 모델 
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.cross_attn = CrossAttentionBlock()
        self.proj_image = nn.Linear(768, 512)
        self.proj_text = nn.Linear(768, 512)

    def forward(self, images, captions):
        img_feat_all = self.vision_encoder(images)  # [B, 1+P, D]
        cls_feat = img_feat_all[:, 0]               # [B, D] CLS
        patch_feat = img_feat_all[:, 1:]            # [B, P, D] patches only

        text_emb, tokens = self.text_encoder(captions)  # [B, T, D]
        cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)  # [B, T, D], [B, T, P]

        img_proj = self.proj_image(cls_feat)        # [B, 512]
        text_proj = self.proj_text(text_emb[:, 0])  # [B, 512]

        return img_proj, text_proj, attn_weights, tokens

    def encode_for_inference(self, images, captions):
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)
            cls_feat = img_feat_all[:, 0]
            text_emb, _ = self.text_encoder(captions)
            img_proj = self.proj_image(cls_feat)
            text_proj = self.proj_text(text_emb[:, 0])
        return img_proj, text_proj

    def encode_tokenized_input(self, images, input_ids, attention_mask):
        with torch.no_grad():
            img_feat = self.vision_encoder(images)
            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            cls = bert_out.last_hidden_state[:, 0, :]        # CLS token
            return self.proj_image(img_feat), self.proj_text(cls)


In [6]:
# Consistency Loss 계산 함수
def compute_consistency_loss(attn_weights, masks, eps=1e-6):
    B, T, H, W = masks.shape
    masks_flat = masks.view(B, T, -1)
    scores = (attn_weights * masks_flat).sum(dim=-1)
    scores = torch.clamp(scores, min=eps, max=1.0)
    return -torch.log(scores).mean()

In [7]:
# CLIP Contrastive Loss

def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

In [8]:
# 전처리된 JSON 로딩 및 binary mask 생성
class CocoVLMDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, patch_size=16, max_tokens=10):
        with open(json_path, 'r') as f:
            all_data = json.load(f)
        self.image_root = image_root
        self.transform = transform
        self.patch_size = patch_size
        self.max_tokens = max_tokens
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        self.data = []
        for entry in all_data:
            image_id = entry["image_id"]
            image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
            if os.path.exists(image_path):
                self.data.append(entry)

        print(f"유효 이미지 수: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image_id = entry["image_id"]
        image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, UnidentifiedImageError, OSError):
            # print(f"[WARN] 이미지 불러오기 실패: {image_path}")
            return self.__getitem__((idx + 1) % len(self))  # 다음 인덱스로 재시도
        if self.transform:
            image = self.transform(image)
        captions = entry["captions"]
        matches = entry["matches"][:self.max_tokens]
        caption = captions[0]
        H, W = 224 // self.patch_size, 224 // self.patch_size
        masks = torch.zeros((self.max_tokens, H, W))
        for i, match in enumerate(matches):
            x, y, w, h = match["bbox"]
            x1 = int(x // self.patch_size)
            y1 = int(y // self.patch_size)
            x2 = int((x + w) // self.patch_size)
            y2 = int((y + h) // self.patch_size)
            masks[i, y1:y2+1, x1:x2+1] = 1.0
        return image, caption, masks


In [9]:
# Collate Function

def coco_collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    captions = [item[1] for item in batch]
    masks = torch.stack([item[2] for item in batch])
    return images, captions, masks


In [10]:
# DataLoader 생성
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CocoVLMDataset(
    json_path="coco_token_bbox_matched.json",
    image_root="/raid/kyscap251/team2/train2017/train2017",
    # image_root = "/shared/home/kyscap251/Team2/val2017",
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=coco_collate_fn)


유효 이미지 수: 109932


In [11]:
def train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=5, scheduler=None):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        total_acc = 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for images, captions, masks in progress:
            images, masks = images.to(device), masks.to(device)
            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            T_mask = masks.shape[1]
            attn_weights_matched = attn_weights[:, :T_mask, :]

            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            loss_consistency = compute_consistency_loss(attn_weights_matched, masks)
            loss = loss_contrastive + lambda_cons * loss_consistency

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                sims = torch.diag(sim_matrix)
                sim_mean = sims.mean().item()
                sim_std = sims.std().item()

                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            progress.set_postfix({"loss": loss.item(), "cos_sim": f"{sim_mean:.3f}±{sim_std:.3f}", "acc": f"{acc:.3f}"})

        # scheduler 업데이트
        if scheduler is not None:
            scheduler.step()
            print(f"[Epoch {epoch+1}] LR: {scheduler.get_last_lr()[0]:.2e}")

        avg_loss = total_loss / len(dataloader)
        avg_acc = total_acc / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")


In [12]:
# 모델 학습 실행

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=3)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                                                           

Epoch 1/3 - Avg Loss: 0.7939, Avg Accuracy: 0.9361


                                                                                                           

Epoch 2/3 - Avg Loss: 0.7532, Avg Accuracy: 0.9510


                                                                                                           

Epoch 3/3 - Avg Loss: 1.7996, Avg Accuracy: 0.3678




In [13]:
# 추론
def run_batch_inference(model, val_image_dir, caption_json_path, transform, device, sample_size=3):
    with open(caption_json_path, 'r') as f:
        coco_captions = json.load(f)

    # image_id → caption 매핑
    imgid2caption = {}
    for ann in coco_captions['annotations']:
        imgid = ann['image_id']
        if imgid not in imgid2caption:
            imgid2caption[imgid] = []
        imgid2caption[imgid].append(ann['caption'])

    # 랜덤 샘플링 (image_id 3개)
    img_ids = random.sample(list(imgid2caption.keys()), sample_size)
    captions = [imgid2caption[i][0] for i in img_ids]
    image_paths = [os.path.join(val_image_dir, f"{i:012d}.jpg") for i in img_ids]
    images_tensor = torch.stack([
        transform(Image.open(p).convert("RGB")) for p in image_paths
    ]).to(device)

    # 인코딩
    model.eval()
    image_embeds, text_embeds = model.encode_for_inference(images_tensor, captions)

    # 코사인 유사도
    sim_matrix = F.cosine_similarity(image_embeds.unsqueeze(1), text_embeds.unsqueeze(0), dim=-1)

    # 출력
    print("image_embeds shape:", image_embeds.shape)
    print("text_embeds shape :", text_embeds.shape)
    print("\n\U0001F4CA Cosine Similarity Matrix:\n")
    for i, img_id in enumerate(img_ids):
        print(f"\U0001F5BC️ {img_id:012d}.jpg")
        for j, cap in enumerate(captions):
            print(f"  \"{cap}\" → similarity: {sim_matrix[i, j]:.4f}")
        print()

In [14]:
run_batch_inference(
    model,
    val_image_dir="val2017",
    caption_json_path="annotations/captions_val2017.json",
    transform=transform,
    device=device,
    sample_size=3 #랜덤으로 이미지 세장
)

image_embeds shape: torch.Size([3, 512])
text_embeds shape : torch.Size([3, 512])

📊 Cosine Similarity Matrix:

🖼️ 000000402433.jpg
  "A pizza sitting on top of a pan in the oven." → similarity: 0.2596
  "Several plates of food are set on a table." → similarity: 0.2596
  "A bathroom with two urinals and a sink." → similarity: 0.2596

🖼️ 000000062554.jpg
  "A pizza sitting on top of a pan in the oven." → similarity: 0.2597
  "Several plates of food are set on a table." → similarity: 0.2597
  "A bathroom with two urinals and a sink." → similarity: 0.2597

🖼️ 000000201775.jpg
  "A pizza sitting on top of a pan in the oven." → similarity: 0.2596
  "Several plates of food are set on a table." → similarity: 0.2596
  "A bathroom with two urinals and a sink." → similarity: 0.2596

