In [2]:
import os

val = "val2017"
train_path = "/raid/kyscap251/team2/train2017/train2017"
val_path = "/shared/home/kyscap251/Team2/val2017"
# folder_path = "/raid/kyscap251/team2/val2017/val2017"
test = "train2017"

val_items = os.listdir(val_path)
train_items = os.listdir(train_path)

print("val", val_items[:5])
print("train", train_items[:5])

# 파일만 필터링
files = [f for f in os.listdir(val_path)
         if os.path.isfile(os.path.join(val_path, f))]

print(f"val파일 개수: {len(files)}")

filess = [ff for ff in os.listdir(train_path)
         if os.path.isfile(os.path.join(train_path, ff))]
print(f"train파일 개수: {len(filess)}")

val ['000000433103.jpg', '000000129113.jpg', '000000196843.jpg', '000000252507.jpg', '000000258541.jpg']
train ['000000254879.jpg', '000000316649.jpg', '000000430989.jpg', '000000286349.jpg', '000000458365.jpg']
val파일 개수: 5000
train파일 개수: 109932


In [1]:
# 환경 설정 및 라이브러리 로딩
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel
import json
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# # 이미지 인코더: ViT
# class VisionEncoder(nn.Module):
#     def __init__(self, output_dim=768):
#         super().__init__()
#         self.vit = models.vit_b_16(pretrained=True)
#         self.vit.heads = nn.Identity()  # classification head 제거
#         self.output_dim = output_dim

#     def forward(self, images):
#         patch_feats = self.vit(images)  # [B, D]
#         return patch_feats
    

from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights


class VisionEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # [B, 1+P, D]



In [3]:
# 텍스트 인코더: BERT 기반 Transformer
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, captions):
        tokenized = self.tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(self.bert.device)
        outputs = self.bert(**tokenized)
        return outputs.last_hidden_state, tokenized
    

In [4]:
# Cross-Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=768, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, text_emb, image_patches):
        attn_output, attn_weights = self.attn(text_emb, image_patches, image_patches)
        return attn_output, attn_weights

In [5]:
# # 모델
# class VisionLanguageModel(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.vision_encoder = VisionEncoder()
#         self.text_encoder = TextEncoder()
#         self.cross_attn = CrossAttentionBlock()
#         self.proj_image = nn.Linear(768, 512)
#         self.proj_text = nn.Linear(768, 512)

#     def forward(self, images, captions):
#         img_feat = self.vision_encoder(images)
#         text_emb, tokens = self.text_encoder(captions)
#         patch_feat = img_feat.unsqueeze(1)  # dummy patch feature (B, 1, D)
#         cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)
#         img_proj = self.proj_image(img_feat)
#         text_proj = self.proj_text(text_emb[:, 0])  # CLS token 기준
#         return img_proj, text_proj, attn_weights, tokens

#     def encode_for_inference(self, images, captions):
#         with torch.no_grad():
#             img_feat = self.vision_encoder(images)
#             text_emb, _ = self.text_encoder(captions)
#             img_proj = self.proj_image(img_feat)
#             text_proj = self.proj_text(text_emb[:, 0])  # CLS token 기준
#         return img_proj, text_proj
    
# -------------------------------
# Vision-Language Model
# -------------------------------
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.cross_attn = CrossAttentionBlock()
        self.proj_image = nn.Linear(768, 512)
        self.proj_text = nn.Linear(768, 512)

    def forward(self, images, captions):
        img_feat_all = self.vision_encoder(images)  # [B, 1+P, D]
        cls_feat = img_feat_all[:, 0]               # [B, D] CLS
        patch_feat = img_feat_all[:, 1:]            # [B, P, D] patches only

        text_emb, tokens = self.text_encoder(captions)  # [B, T, D]
        cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)  # [B, T, D], [B, T, P]

        img_proj = self.proj_image(cls_feat)        # [B, 512]
        text_proj = self.proj_text(text_emb[:, 0])  # [B, 512]

        return img_proj, text_proj, attn_weights, tokens

    def encode_for_inference(self, images, captions): #랜덤 3개 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)
            cls_feat = img_feat_all[:, 0]
            text_emb, _ = self.text_encoder(captions)
            img_proj = self.proj_image(cls_feat)
            text_proj = self.proj_text(text_emb[:, 0])
        return img_proj, text_proj
    
    def encode_tokenized_input(self, images, input_ids, attention_mask): #평균 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)        # (B, 197, 768)
            cls_feat = img_feat_all[:, 0, :]                  # CLS token만 추출 (B, 768)
            img_proj = self.proj_image(cls_feat)              # (B, 512)

            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            txt_cls = bert_out.last_hidden_state[:, 0, :]     # (B, 768)
            txt_proj = self.proj_text(txt_cls)                # (B, 512)

            return img_proj, txt_proj


In [6]:
# Consistency Loss 계산 함수
def compute_consistency_loss(attn_weights, masks, eps=1e-6):
    B, T, H, W = masks.shape
    masks_flat = masks.view(B, T, -1)
    scores = (attn_weights * masks_flat).sum(dim=-1)
    scores = torch.clamp(scores, min=eps, max=1.0)
    return -torch.log(scores).mean()

In [7]:
# CLIP Contrastive Loss

def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

In [8]:
# 전처리된 JSON 로딩 및 binary mask 생성
class CocoVLMDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, patch_size=16, max_tokens=10):
        with open(json_path, 'r') as f:
            all_data = json.load(f)
        self.image_root = image_root
        self.transform = transform
        self.patch_size = patch_size
        self.max_tokens = max_tokens
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        self.data = []
        for entry in all_data:
            image_id = entry["image_id"]
            image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
            if os.path.exists(image_path):
                self.data.append(entry)

        print(f"유효 이미지 수: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image_id = entry["image_id"]
        image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, UnidentifiedImageError, OSError):
#             print(f"[WARN] 이미지 불러오기 실패: {image_path}")
            return self.__getitem__((idx + 1) % len(self))  # 다음 인덱스로 재시도
        if self.transform:
            image = self.transform(image)
        captions = entry["captions"]
        matches = entry["matches"][:self.max_tokens]
        caption = captions[0]
        H, W = 224 // self.patch_size, 224 // self.patch_size
        masks = torch.zeros((self.max_tokens, H, W))
        for i, match in enumerate(matches):
            x, y, w, h = match["bbox"]
            x1 = int(x // self.patch_size)
            y1 = int(y // self.patch_size)
            x2 = int((x + w) // self.patch_size)
            y2 = int((y + h) // self.patch_size)
            masks[i, y1:y2+1, x1:x2+1] = 1.0
        return image, caption, masks


In [9]:
# Collate Function

def coco_collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    captions = [item[1] for item in batch]
    masks = torch.stack([item[2] for item in batch])
    return images, captions, masks


In [10]:
# DataLoader 생성
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CocoVLMDataset(
    json_path="coco_token_bbox_matched.json",
    image_root="/raid/kyscap251/team2/train2017/train2017",
#     image_root = "/shared/home/kyscap251/Team2/val2017",
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=coco_collate_fn)


유효 이미지 수: 109932


In [11]:
class LambdaScheduler: #람다 점진적 높이기
    def __init__(self, alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5):
        self.ema = None
        self.alpha = alpha
        self.threshold = threshold
        self.anneal_rate = anneal_rate
        self.max_lambda = max_lambda
        self.apply = False
        self.step = 0
        self.current_lambda = 0.0

    def update(self, attn_weights, return_diff=False):
        """
        attn_weights: Tensor of shape [B, T, P] (cross-attn weights)
        """
        attn_mean = attn_weights.mean(dim=1).mean(dim=0)  # [P]
        if self.ema is None:
            self.ema = attn_mean
        else:
            self.ema = self.alpha * attn_mean + (1 - self.alpha) * self.ema

        diff = torch.abs(self.ema - attn_mean).mean().item()

        if not self.apply and diff < self.threshold:
            self.apply = True
            print(f"[LambdaScheduler] Consistency loss ON (diff={diff:.6f})")

        if self.apply:
            self.current_lambda = min(self.max_lambda, self.anneal_rate * self.step)
            self.step += 1
        else:
            self.current_lambda = 0.0

        if return_diff:
            return self.current_lambda, diff
        else:
            return self.current_lambda

In [12]:
def kl_divergence_attention(prev_attn, curr_attn, eps=1e-6):
    """
    prev_attn, curr_attn: [B, T, P] - softmax된 attention map
    """
    prev = torch.clamp(prev_attn, min=eps)
    curr = torch.clamp(curr_attn, min=eps)
    kl = (prev * (prev.log() - curr.log())).sum(dim=-1).mean()  # 평균 over T, B
    return kl

In [13]:
class EMA:
    def __init__(self, alpha=0.2):
        self.alpha = alpha
        self.ema = None

    def update(self, value):
        with torch.no_grad():
            if self.ema is None:
                self.ema = value.detach()
            else:
                self.ema = self.alpha * value.detach() + (1 - self.alpha) * self.ema
        return self.ema

    def reset(self):
        self.ema = None


In [14]:
def train_model(model, dataloader, optimizer, device,
                lambda_cons=0.1,
                lambda_self=0.1,
                warmup_epochs=3,
                num_epochs=5,
                start_epoch=0,
                prev_attn_dict=None,
                ema_tracker=None):

    model.train()

    # EMA 초기화
    if ema_tracker is None:
        ema_tracker = EMA(alpha=0.2)
    use_ema = False

    # Lambda scheduler는 warm-up 동안만 사용
    lambda_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5)

    for epoch in range(start_epoch, start_epoch + num_epochs):
        total_loss, total_acc = 0, 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{start_epoch + num_epochs}", leave=False)

        if epoch == warmup_epochs:
            ema_tracker.reset()
            use_ema = True
            print(f"[Epoch {epoch}] EMA 기준 self-consistency 시작")

        for step, (images, captions, masks) in enumerate(progress):
            images, masks = images.to(device), masks.to(device)
            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            T_mask = masks.shape[1]
            attn_weights_matched = attn_weights[:, :T_mask, :]

            # ▶ Contrastive Loss (공통)
            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)

            # ▶ Warm-up Phase
            if epoch < warmup_epochs:
                lambda_val, diff = lambda_scheduler.update(attn_weights_matched, return_diff=True)
                loss_consistency = compute_consistency_loss(attn_weights_matched, masks)
                loss = loss_contrastive + lambda_val * loss_consistency

            # ▶ Self-consistency Phase
            else:
                ema_attn = ema_tracker.update(attn_weights_matched)
                loss_self_cons = kl_divergence_attention(ema_attn, attn_weights_matched.detach())
                loss = loss_contrastive + lambda_self * loss_self_cons

            # ▶ Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # ▶ Accuracy
            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc

            if step % 50 == 0 or step == len(dataloader) - 1:
                progress.set_postfix({
                    "loss": f"{loss.item():.4f}",
                    "acc": f"{acc:.3f}",
                    "λ": f"{lambda_cons:.3f}" if epoch < warmup_epochs else f"{lambda_self:.3f}"
                })

        # ▶ Save checkpoint after warm-up phase
        if epoch + 1 == warmup_epochs:
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'attn_dict': attn_weights_matched.detach().clone()
            }, f'checkpoint_epoch{epoch+1}.pth')
            print(f"[Checkpoint] Saved at epoch {epoch+1}")

        print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(dataloader):.4f}, Avg Acc: {total_acc / len(dataloader):.4f}")

    return ema_tracker

In [15]:
# 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 모델과 옵티마이저 초기화
model = VisionLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Warm-up 전용 학습 (예: 2 epoch 동안)
train_model(
    model, dataloader, optimizer, device,
    lambda_cons=0.05,         # warm-up용 consistency loss 계수
    lambda_self=0.0,          # 후기용 self-consistency는 사용 안 함
    warmup_epochs=2,
    num_epochs=2              # warm-up만 돌리므로 일치시킴
)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/2:   0%| | 1/27483 [00:01<10:06:19,  1.32s/it, loss=1.3262, acc=0.500, λ

[LambdaScheduler] Consistency loss ON (diff=0.000000)


                                                                                

Epoch 1 - Avg Loss: 6.4150, Avg Acc: 0.9363


                                                                                

[Checkpoint] Saved at epoch 2
Epoch 2 - Avg Loss: 6.8522, Avg Acc: 0.6867


<__main__.EMA at 0x7f41086cd9a0>

In [19]:
checkpoint = torch.load('checkpoint_epoch2.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
prev_attn_map = checkpoint['attn_dict']

train_model(
    model, dataloader, optimizer, device,
    lambda_cons=0.0,               # warm-up 끝났으므로 사용 안 함
    lambda_self=0.1,
    warmup_epochs=0,
    num_epochs=5,                  # 후속 학습
    start_epoch=2,                 # 이어서
    prev_attn_dict=prev_attn_map   # 필요시 활용
)

  checkpoint = torch.load('checkpoint_epoch2.pth')
                                                                                

Epoch 3 - Avg Loss: 1.4403, Avg Acc: 0.2495


                                                                                

Epoch 4 - Avg Loss: 1.3967, Avg Acc: 0.2507


                                                                                

Epoch 5 - Avg Loss: 1.3962, Avg Acc: 0.2501


Epoch 6/7:  80%|▊| 21858/27483 [2:55:26<43:56,  2.13it/s, loss=1.3868, acc=0.250IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 7/7:  25%|▎| 6958/27483 [56:42<2:37:58,  2.17it/s, loss=1.3884, acc=0.250,IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

                                                                                

Epoch 7 - Avg Loss: 1.3881, Avg Acc: 0.2494




<__main__.EMA at 0x7f4111cd2120>

In [12]:
# 전체 Epoch 학습 루프 (Matched Token만 사용하도록 ㅏ수정)
def train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=5):
    model.train()
#     ////////////
    lambda_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5)
#     ////////////////////
#     lambda_log = [] #λ 변화 기록용 리스트
#     diff_log = [] #diff값 기록용 리스트
    
    for epoch in range(num_epochs):
        total_loss = 0
        total_acc = 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        
        for images, captions, masks in progress:
            images, masks = images.to(device), masks.to(device)
            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            # matched token 수만큼 attention weight 슬라이싱
            T_mask = masks.shape[1]
            attn_weights_matched = attn_weights[:, :T_mask, :] #bbox에 해당하는 token만 consistency loss 계산에 사용

            
#             //////////////
            # 현재 λ 값을 EMA 기반으로 계산 및 기록
            lambda_cons, diff = lambda_scheduler.update(attn_weights_matched, return_diff=True)
#             lambda_log.append(lambda_cons)
#             diff_log.append(diff)
# /////////////////////////////////////
            
            # 손실 계산
            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            loss_consistency = compute_consistency_loss(attn_weights_matched, masks)
            loss = loss_contrastive + lambda_cons * loss_consistency

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # cosine similarity + accuracy 계산
            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                sims = torch.diag(sim_matrix)
                sim_mean = sims.mean().item()
                sim_std = sims.std().item()

                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            progress.set_postfix({"loss": loss.item(), "cos_sim": f"{sim_mean:.3f}±{sim_std:.3f}", "acc": f"{acc:.3f}"})

#         # scheduler 업데이트
#         if scheduler is not None:
#             scheduler.step()
#             print(f"[Epoch {epoch+1}] LR: {scheduler.get_last_lr()[0]:.2e}")
        avg_loss = total_loss / len(dataloader)
        avg_acc = total_acc / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")


In [13]:
# # 모델 학습 실행

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = VisionLanguageModel().to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=3)


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                                                 

Epoch 1/3 - Avg Loss: 0.9415, Avg Accuracy: 0.8782


                                                                                                 

Epoch 2/3 - Avg Loss: 0.7531, Avg Accuracy: 0.9510


                                                                                                 

Epoch 3/3 - Avg Loss: 0.7454, Avg Accuracy: 0.9536




In [1]:
# 랜덤 3개 추론
def run_batch_inference(model, val_image_dir, caption_json_path, transform, device, sample_size=3):
    with open(caption_json_path, 'r') as f:
        coco_captions = json.load(f)

    # image_id → caption 매핑
    imgid2caption = {}
    for ann in coco_captions['annotations']:
        imgid = ann['image_id']
        if imgid not in imgid2caption:
            imgid2caption[imgid] = []
        imgid2caption[imgid].append(ann['caption'])

    # 랜덤 샘플링 (image_id 3개)
    img_ids = random.sample(list(imgid2caption.keys()), sample_size)
    captions = [imgid2caption[i][0] for i in img_ids]
    image_paths = [os.path.join(val_image_dir, f"{i:012d}.jpg") for i in img_ids]
    images_tensor = torch.stack([
        transform(Image.open(p).convert("RGB")) for p in image_paths
    ]).to(device)

    # 인코딩
    model.eval()
    image_embeds, text_embeds = model.encode_for_inference(images_tensor, captions)

    # 코사인 유사도
    sim_matrix = F.cosine_similarity(image_embeds.unsqueeze(1), text_embeds.unsqueeze(0), dim=-1)

    # 출력
    print("image_embeds shape:", image_embeds.shape)
    print("text_embeds shape :", text_embeds.shape)
    print("\n\U0001F4CA Cosine Similarity Matrix:\n")
    for i, img_id in enumerate(img_ids):
        print(f"\U0001F5BC️ {img_id:012d}.jpg")
        for j, cap in enumerate(captions):
            print(f"  \"{cap}\" → similarity: {sim_matrix[i, j]:.4f}")
        print()

In [None]:
run_batch_inference(
    model,
    val_image_dir="val2017",
    caption_json_path="annotations/captions_val2017.json",
    transform=transform,
    device=device,
    sample_size=3 #랜덤으로 이미지 세장
)

In [21]:
from torch.utils.data import Dataset
from PIL import Image
import os
import json

class SimpleCocoCaptionDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform, tokenizer):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)

        self.imgid2caption = {}
        for ann in data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.imgid2caption:
                self.imgid2caption[img_id] = ann['caption']

        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")
        image = self.transform(Image.open(img_path).convert("RGB"))
        caption = self.imgid2caption[img_id]

        encoding = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return image, input_ids, attention_mask


In [22]:
from torch.utils.data import DataLoader

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

val_dataset = SimpleCocoCaptionDataset(
    caption_json_path="annotations/captions_val2017.json",
    image_root="val2017",
    transform=transform,
    tokenizer=tokenizer
)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [23]:
# 정답,오답쌍 평균 유사도 계산 추론
def evaluate_mean_similarity(model, val_loader, device):
    model.eval()
    total_correct_sim = 0.0
    total_incorrect_sim = 0.0
    correct_count = 0
    incorrect_count = 0

    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            image_feat, text_feat = model.encode_tokenized_input(images, input_ids, attention_mask)  # (B, 512)

            sim_matrix = F.cosine_similarity(
                image_feat.unsqueeze(1),  # (B, 1, D)
                text_feat.unsqueeze(0),  # (1, B, D)
                dim=-1
            )  # (B, B)

            B = sim_matrix.size(0)
            correct_sims = sim_matrix.diag()
            total_correct_sim += correct_sims.sum().item()
            correct_count += B

            mask = ~torch.eye(B, dtype=torch.bool, device=device)
            incorrect_sims = sim_matrix[mask]
            total_incorrect_sim += incorrect_sims.sum().item()
            incorrect_count += incorrect_sims.numel()

    mean_correct = total_correct_sim / correct_count
    mean_incorrect = total_incorrect_sim / incorrect_count

    print(f"\nEvaluation Results:")
    print(f" - Mean Correct Sim    : {mean_correct:.4f}")
    print(f" - Mean Incorrect Sim  : {mean_incorrect:.4f}")
    return mean_correct, mean_incorrect


In [24]:
mean_correct, mean_incorrect = evaluate_mean_similarity(model, val_loader, device)

Evaluating: 100%|█████████████████████████████| 157/157 [01:19<00:00,  1.97it/s]


Evaluation Results:
 - Mean Correct Sim    : -0.0597
 - Mean Incorrect Sim  : -0.0597





In [13]:
print("후기단계 시작")

후기단계 시작


In [22]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)


train_model(
    model, dataloader, optimizer, device,
    lambda_cons=0.05,
    lambda_self=0.1,
    warmup_epochs=2,
    num_epochs=2
)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                                                           

Epoch 1 - Avg Loss: 0.9405, Avg Accuracy: 0.8812


                                                                                                           

Epoch 2 - Avg Loss: 0.7593, Avg Accuracy: 0.9466


                                                                                                           

Epoch 3 - Avg Loss: 0.7222, Avg Accuracy: 0.9638


In [24]:
checkpoint = torch.load('checkpoint_epoch3.pth')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
prev_attn_map = checkpoint['attn_map']

train_model(
    model, dataloader, optimizer, device,
    lambda_cons=0.0,
    lambda_self=0.1,
    warmup_epochs=0,
    num_epochs=7,
    initial_attn_map=prev_attn_map,
    start_epoch=3
)

                                                                                                           

Epoch 4 - Avg Loss: 0.1864, Avg Accuracy: 0.9576


                                                                                                           

Epoch 5 - Avg Loss: 0.1822, Avg Accuracy: 0.9608


                                                                                                           

Epoch 6 - Avg Loss: 0.1588, Avg Accuracy: 0.9644


                                                                                                           

Epoch 7 - Avg Loss: 0.1380, Avg Accuracy: 0.9682


                                                                                                           

Epoch 8 - Avg Loss: 0.1259, Avg Accuracy: 0.9700


                                                                                                           

Epoch 9 - Avg Loss: 0.1208, Avg Accuracy: 0.9704


                                                                                                           

Epoch 10 - Avg Loss: 0.1105, Avg Accuracy: 0.9768


