In [None]:
# 정확도 상승 기반 warmup 종료시점 판단 적용(EMA랑 다름)
# warm up, self-supervised분리
# diff가 크면 alpha를 커지는 식의 EMA 적용

In [2]:
# 환경 설정 및 라이브러리 로딩
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel
import json
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import random
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile
from collections import Counter
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights

# 이미지 인코더: Vision Transformer 기반

class VisionEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # [B, 1+P, D]



In [4]:
# 텍스트 인코더: BERT 기반 Transformer
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, captions):
        tokenized = self.tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(self.bert.device)
        outputs = self.bert(**tokenized)
        return outputs.last_hidden_state, tokenized
    

In [5]:
# Cross-Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=768, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, text_emb, image_patches):
        attn_output, attn_weights = self.attn(text_emb, image_patches, image_patches)
        return attn_output, attn_weights

In [6]:
# Vision-Language Model
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.cross_attn = CrossAttentionBlock()
        self.proj_image = nn.Linear(768, 512)
        self.proj_text = nn.Linear(768, 512)

    def forward(self, images, captions):
        img_feat_all = self.vision_encoder(images)  # [B, 1+P, D]
        cls_feat = img_feat_all[:, 0]               # [B, D] CLS
        patch_feat = img_feat_all[:, 1:]            # [B, P, D] patches only

        text_emb, tokens = self.text_encoder(captions)  # [B, T, D]
        cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)  # [B, T, D], [B, T, P]

        img_proj = self.proj_image(cls_feat)        # [B, 512]
        text_proj = self.proj_text(text_emb[:, 0])  # [B, 512]

        return img_proj, text_proj, attn_weights, tokens

    def encode_for_inference(self, images, captions): #랜덤 3개 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)
            cls_feat = img_feat_all[:, 0]
            text_emb, _ = self.text_encoder(captions)
            img_proj = self.proj_image(cls_feat)
            text_proj = self.proj_text(text_emb[:, 0])
        return img_proj, text_proj
    
    def encode_tokenized_input(self, images, input_ids, attention_mask): #평균 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)        # (B, 197, 768)
            cls_feat = img_feat_all[:, 0, :]                  # CLS token만 추출 (B, 768)
            img_proj = self.proj_image(cls_feat)              # (B, 512)

            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            txt_cls = bert_out.last_hidden_state[:, 0, :]     # (B, 768)
            txt_proj = self.proj_text(txt_cls)                # (B, 512)

            return img_proj, txt_proj

In [7]:
# Consistency Loss 계산 함수
def compute_consistency_loss(attn_weights, masks, eps=1e-6):
    B, T, H, W = masks.shape
    masks_flat = masks.view(B, T, -1)
    scores = (attn_weights * masks_flat).sum(dim=-1)
    scores = torch.clamp(scores, min=eps, max=1.0)
    return -torch.log(scores).mean()

In [8]:
# CLIP Contrastive Loss

def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

In [9]:
import os
import json
from collections import Counter
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import random

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

json_path = "coco_token_bbox_matched.json"
image_root = "/raid/kyscap251/team2/train2017/train2017"

# 1. Load JSON once
with open(json_path, "r") as f:
    all_data = json.load(f)

# 2. 유효 이미지와 대표 label 추출
valid_indices = []
labels = []

def get_dominant_label(matches):
    if not matches:
        return "none"
    return Counter([m["label"] for m in matches]).most_common(1)[0][0]

for i, entry in enumerate(all_data):
    image_id = entry["image_id"]
    image_path = os.path.join(image_root, f"{image_id:012d}.jpg")
    if os.path.exists(image_path):
        valid_indices.append(i)
        labels.append(get_dominant_label(entry["matches"]))

# 3. stratified split
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=SEED)
subset_rel_idx, _ = next(splitter.split(np.zeros(len(labels)), labels))

# 4. 전체 JSON 기준으로 실제 subset 인덱스로 변환
subset_json_indices = [valid_indices[i] for i in subset_rel_idx]


In [10]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

json_path = "coco_token_bbox_matched.json"
image_root = "/raid/kyscap251/team2/train2017/train2017"

# 1. Load JSON once
with open(json_path, "r") as f:
    all_data = json.load(f)

# 2. 유효 이미지와 대표 label 추출
valid_indices = []
labels = []

def get_dominant_label(matches):
    if not matches:
        return "none"
    return Counter([m["label"] for m in matches]).most_common(1)[0][0]

for i, entry in enumerate(all_data):
    image_id = entry["image_id"]
    image_path = os.path.join(image_root, f"{image_id:012d}.jpg")
    if os.path.exists(image_path):
        valid_indices.append(i)
        labels.append(get_dominant_label(entry["matches"]))

# 3. stratified split
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=SEED)
subset_rel_idx, _ = next(splitter.split(np.zeros(len(labels)), labels))

# 4. 전체 JSON 기준으로 실제 subset 인덱스로 변환
subset_json_indices = [valid_indices[i] for i in subset_rel_idx]

In [11]:
class CocoVLMDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, patch_size=16, max_tokens=10, subset_indices=None):
        with open(json_path, 'r') as f:
            all_data = json.load(f)
        self.image_root = image_root
        self.transform = transform
        self.patch_size = patch_size
        self.max_tokens = max_tokens
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        self.data = []
        selected = subset_indices if subset_indices is not None else range(len(all_data))

        for i in selected:
            entry = all_data[i]
            image_id = entry["image_id"]
            image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
            if os.path.exists(image_path):
                self.data.append(entry)

        print(f"유효 이미지 수: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image_id = entry["image_id"]
        image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")

        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, UnidentifiedImageError, OSError):
            # print(f"[WARN] 이미지 불러오기 실패: {image_path}")
            return self.__getitem__((idx + 1) % len(self))  # 다음 인덱스로 재시도

        if self.transform:
            image_tensor = self.transform(image)
        else:
            image_tensor = image

        captions = entry["captions"]
        matches = entry["matches"][:self.max_tokens]
        caption = captions[0]

        # Tokenize caption
        caption_tokens = self.tokenizer(caption, return_tensors="pt", padding="max_length",
                                        truncation=True, max_length=30)["input_ids"].squeeze(0)

        # Binary mask 생성
        H, W = 224 // self.patch_size, 224 // self.patch_size
        mask_tensor = torch.zeros((self.max_tokens, H, W))

        for i, match in enumerate(matches):
            x, y, w, h = match["bbox"]
            x1 = int(x // self.patch_size)
            y1 = int(y // self.patch_size)
            x2 = int((x + w) // self.patch_size)
            y2 = int((y + h) // self.patch_size)
            mask_tensor[i, y1:y2+1, x1:x2+1] = 1.0

        return {
            'image': image_tensor,
            'caption': caption,
            'mask': mask_tensor,
            'image_id': image_id
        }

In [12]:
# Collate Function

def coco_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    captions = [item['caption'] for item in batch]  # ← 텐서 아님, 그냥 리스트로
    masks = torch.stack([item['mask'] for item in batch])
    image_ids = [item['image_id'] for item in batch]
    return images, captions, masks, image_ids


In [13]:
from torch.utils.data import Subset

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CocoVLMDataset(
    json_path=json_path,
    image_root=image_root,
    transform=transform,
    subset_indices=subset_json_indices
)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=coco_collate_fn
)

유효 이미지 수: 54966


In [14]:
class LambdaScheduler: #람다 점진적 높이기
    def __init__(self, alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5):
        self.ema = None
        self.alpha = alpha
        self.threshold = threshold
        self.anneal_rate = anneal_rate
        self.max_lambda = max_lambda
        self.apply = False
        self.step = 0
        self.current_lambda = 0.0

    def update(self, attn_weights, return_diff=False):
        """
        attn_weights: Tensor of shape [B, T, P] (cross-attn weights)
        """
        attn_mean = attn_weights.mean(dim=1).mean(dim=0)  # [P]
        if self.ema is None:
            self.ema = attn_mean
        else:
            self.ema = self.alpha * attn_mean + (1 - self.alpha) * self.ema

        diff = torch.abs(self.ema - attn_mean).mean().item()

        if not self.apply and diff < self.threshold:
            self.apply = True
            print(f"[LambdaScheduler] Consistency loss ON (diff={diff:.6f})")

        if self.apply:
            self.current_lambda = min(self.max_lambda, self.anneal_rate * self.step)
            self.step += 1
        else:
            self.current_lambda = 0.0

        if return_diff:
            return self.current_lambda, diff
        else:
            return self.current_lambda

In [15]:
def kl_divergence_attention(prev_attn, curr_attn, eps=1e-6):
    """
    prev_attn, curr_attn: [B, T, P] - softmax된 attention map
    """
    prev = torch.clamp(prev_attn, min=eps)
    curr = torch.clamp(curr_attn, min=eps)
    kl = (prev * (prev.log() - curr.log())).sum(dim=-1).mean()  # 평균 over T, B
    return kl

In [16]:
class EMA:
    def __init__(self, alpha=0.2):
        self.alpha = alpha
        self.ema = None

    def update(self, value):
        with torch.no_grad():
            if self.ema is None:
                self.ema = value.detach()
            else:
                self.ema = self.alpha * value.detach() + (1 - self.alpha) * self.ema
        return self.ema

    def reset(self):
        self.ema = None


In [17]:
def update_ema(ema_model, model, alpha):
    with torch.no_grad():
        for ema_param, model_param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(1 - alpha).add_(model_param.data, alpha=alpha)

In [18]:
import shutil

def train_warmup(model, dataloader, optimizer, device,
                 lambda_cons=0.5, warmup_epochs=10, start_epoch=0):
    model.train()
    bbox_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.005, max_lambda=lambda_cons)

    acc_history = []
    prev_acc = None
    drop_count = 0
    high_point_epoch = None

    for epoch in range(start_epoch, start_epoch + warmup_epochs):
        total_loss, total_acc = 0.0, 0.0
        num_batches = 0
        curr_attn_dict = {}

        progress = tqdm(dataloader, desc=f"Warm-up Epoch {epoch}", leave=False)
        for batch in progress:
            images, captions, masks, image_ids = batch
            images, masks = images.to(device), masks.to(device)

            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            T_attn = attn_weights.size(1)
            T_mask = masks.size(1)
            T_common = min(T_attn, T_mask)

            attn_weights_slice = attn_weights[:, :T_common, :]
            masks_slice = masks[:, :T_common]

            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            lambda_bbox, _ = bbox_scheduler.update(attn_weights_slice, return_diff=True)
            loss_bbox_cons = compute_consistency_loss(attn_weights_slice, masks_slice)
            loss_bbox_cons = torch.clamp(loss_bbox_cons, max=1.0)
            loss = loss_contrastive + lambda_bbox * loss_bbox_cons

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # attention map 저장
            for i, img_id in enumerate(image_ids):
                if isinstance(img_id, torch.Tensor):
                    img_id = img_id.item()
                curr_attn_dict[str(img_id)] = attn_weights[i].detach().cpu()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            num_batches += 1

            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})

        avg_loss = total_loss / num_batches
        avg_acc = total_acc / num_batches
        acc_history.append(avg_acc)
        print(f"[WARM-UP] Epoch {epoch} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")

        torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'attn_dict': curr_attn_dict},
                   f"temp_checkpoint_epoch{epoch}_cjy01.pth")

        if prev_acc is not None and prev_acc > avg_acc:
            drop_count += 1
        else:
            drop_count = 0

        prev_acc = avg_acc

        if drop_count >= 1:
            high_point_epoch = epoch - 1
            print(f"[WARM-UP 종료 감지] Accuracy 하락 → 고점 epoch: {high_point_epoch}")
            shutil.copyfile(f"temp_checkpoint_epoch{high_point_epoch}_cjy01.pth",
                            f"checkpoint_epoch{high_point_epoch}_stable_cjy01.pth")
            break


In [19]:
# 기기 설정 및 초기화
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Warm-up 학습 함수 실행
train_warmup(
    model=model,
    dataloader=dataloader,  # 사용 중인 dataloader 그대로 입력
    optimizer=optimizer,
    device=device,
    lambda_cons=0.5,         # bbox consistency weight 
    warmup_epochs=10,       # 충분한 최대 epoch 설정
    start_epoch=0
)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Warm-up Epoch 0:   0%|                                | 0/13742 [00:00<?, ?it/s]

[LambdaScheduler] Consistency loss ON (diff=0.000000)


                                                                                

[WARM-UP] Epoch 0 - Avg Loss: 0.6770, Avg Accuracy: 0.9318


                                                                                

[WARM-UP] Epoch 1 - Avg Loss: 0.6277, Avg Accuracy: 0.9515


                                                                                

[WARM-UP] Epoch 2 - Avg Loss: 0.6066, Avg Accuracy: 0.9579


                                                                                

[WARM-UP] Epoch 3 - Avg Loss: 0.6115, Avg Accuracy: 0.9562
[WARM-UP 종료 감지] Accuracy 하락 → 고점 epoch: 2


In [20]:
shutil.copyfile("checkpoint_epoch2_stable_cjy01.pth", "checkpoint_self_start_cjy03.pth")

'checkpoint_self_start_cjy03.pth'

In [32]:
def load_self_supervised_start(path="checkpoint_self_start_cjy03.pth", lr=5e-5):
    checkpoint = torch.load(path)
    model = VisionLanguageModel().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model.load_state_dict(checkpoint['model'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(checkpoint['optimizer'])
    prev_attn_dict = checkpoint['attn_dict']
    return model, optimizer, prev_attn_dict

In [33]:
def train_self_supervised(model, dataloader, optimizer, device,
                          prev_attn_dict, lambda_self=0.1, num_epochs=10, start_epoch=0):
    model.train()
    acc_history = []
    # LambdaScheduler 추가 (self-consistency용)
    lambda_self_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=lambda_self)

    for epoch in range(start_epoch, start_epoch + num_epochs):
        total_loss, total_acc = 0.0, 0.0
        num_batches = 0
        curr_attn_dict = {}

        progress = tqdm(dataloader, desc=f"Self Epoch {epoch}", leave=False)
        for batch_idx, batch in enumerate(progress):
            images, captions, _, image_ids = batch
            images = images.to(device)

            img_proj, txt_proj, attn_weights, _ = model(images, captions)
            T_common = attn_weights.size(1)
            attn_weights_slice = attn_weights[:, :T_common, :]

            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)

            loss_self = 0.0
            count = 0
            for i, img_id in enumerate(image_ids):
                # image_id 문자열로 변환
                if isinstance(img_id, torch.Tensor):
                    img_id = img_id.item()
                img_id_str = str(img_id)

                if img_id_str in prev_attn_dict:
                    prev_attn = prev_attn_dict[img_id_str].to(device)
                    curr_attn = attn_weights[i]

                    T_curr = curr_attn.size(0)
                    T_prev = prev_attn.size(0)
                    T_common = min(T_curr, T_prev)

                    prev_attn_crop = prev_attn[:T_common, :]
                    curr_attn_crop = curr_attn[:T_common, :]

                    # EMA alpha 결정
                    diff = F.mse_loss(curr_attn_crop, prev_attn_crop, reduction='mean').item()
#                     print(f"Current diff: {diff}")
                    alpha = 0.1 + 0.8 * min(diff, 1.0)  # diff↑ → alpha↑  # 차이 클수록 alpha 크게

                    # EMA 적용
                    updated_attn = alpha * curr_attn_crop + (1 - alpha) * prev_attn_crop

                    # Self loss 계산은 EMA와 prev 기준
                    loss_self += F.mse_loss(updated_attn, prev_attn_crop)

                    # EMA된 attention 저장
                    curr_attn_dict[img_id_str] = updated_attn.detach().cpu()
                    count += 1

            loss_self = loss_self / count if count > 0 else torch.tensor(0.0, device=device)
            # LambdaScheduler로부터 가변 lambda_self 적용
            lambda_self_val, diff = lambda_self_scheduler.update(attn_weights_slice, return_diff=True)
            loss = loss_contrastive + lambda_self_val * loss_self

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            num_batches += 1

            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})

        avg_loss = total_loss / num_batches
        avg_acc = total_acc / num_batches
        acc_history.append(avg_acc)

        print(f"[SELF] Epoch {epoch} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")
        print(f"[SELF] Epoch {epoch} - Final Self Loss: {loss_self.item():.4f}, Final Match Count: {count}")

        torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'attn_dict': curr_attn_dict},
                   f"temp_checkpoint_self_epoch{epoch}.pth")


In [34]:
import copy

#  고점 체크포인트 불러오기
model, optimizer, prev_attn_dict = load_self_supervised_start()

ema_model = copy.deepcopy(model)

#  self-consistency 학습용 train_model 함수 실행
train_self_supervised(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    device=device,
    lambda_self=0.5,
    num_epochs=5, # 원하는 만큼 self-consistency 학습 횟수
    start_epoch=1,
    prev_attn_dict=prev_attn_dict
)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Self Epoch 1:   0%|  | 2/13742 [00:00<24:17,  9.43it/s, loss=7.94e-5, acc=1.000]

[LambdaScheduler] Consistency loss ON (diff=0.000000)


                                                                                

[SELF] Epoch 1 - Avg Loss: 1.0983, Avg Accuracy: 0.4095
[SELF] Epoch 1 - Final Self Loss: 0.0000, Final Match Count: 2


                                                                                

[SELF] Epoch 2 - Avg Loss: 1.3863, Avg Accuracy: 0.2510
[SELF] Epoch 2 - Final Self Loss: 0.0000, Final Match Count: 2


                                                                                

[SELF] Epoch 3 - Avg Loss: 1.3863, Avg Accuracy: 0.2492
[SELF] Epoch 3 - Final Self Loss: 0.0000, Final Match Count: 2


                                                                                

[SELF] Epoch 4 - Avg Loss: 1.3866, Avg Accuracy: 0.2487
[SELF] Epoch 4 - Final Self Loss: 0.0000, Final Match Count: 2


                                                                                

[SELF] Epoch 5 - Avg Loss: 1.3864, Avg Accuracy: 0.2507
[SELF] Epoch 5 - Final Self Loss: 0.0000, Final Match Count: 2


In [33]:
from torch.utils.data import Dataset
from PIL import Image
import os
import json

class SimpleCocoCaptionDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform, tokenizer):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)

        self.imgid2caption = {}
        for ann in data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.imgid2caption:
                self.imgid2caption[img_id] = ann['caption']

        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")
        image = self.transform(Image.open(img_path).convert("RGB"))
        caption = self.imgid2caption[img_id]

        encoding = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return image, input_ids, attention_mask


In [34]:
from torch.utils.data import DataLoader

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

val_dataset = SimpleCocoCaptionDataset(
    caption_json_path="annotations/captions_val2017.json",
    image_root="val2017",
    transform=transform,
    tokenizer=tokenizer
)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [35]:
# 정답,오답쌍 평균 유사도 계산 추론
def evaluate_mean_similarity(model, val_loader, device):
    model.eval()
    total_correct_sim = 0.0
    total_incorrect_sim = 0.0
    correct_count = 0
    incorrect_count = 0

    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            image_feat, text_feat = model.encode_tokenized_input(images, input_ids, attention_mask)  # (B, 512)

            sim_matrix = F.cosine_similarity(
                image_feat.unsqueeze(1),  # (B, 1, D)
                text_feat.unsqueeze(0),  # (1, B, D)
                dim=-1
            )  # (B, B)

            B = sim_matrix.size(0)
            correct_sims = sim_matrix.diag()
            total_correct_sim += correct_sims.sum().item()
            correct_count += B

            mask = ~torch.eye(B, dtype=torch.bool, device=device)
            incorrect_sims = sim_matrix[mask]
            total_incorrect_sim += incorrect_sims.sum().item()
            incorrect_count += incorrect_sims.numel()

    mean_correct = total_correct_sim / correct_count
    mean_incorrect = total_incorrect_sim / incorrect_count

    print(f"\nEvaluation Results:")
    print(f" - Mean Correct Sim    : {mean_correct:.4f}")
    print(f" - Mean Incorrect Sim  : {mean_incorrect:.4f}")
    return mean_correct, mean_incorrect


In [2]:
# [LambdaScheduler] Consistency loss ON (diff=0.000000)
                                                                                
# [SELF] Epoch 1 - Avg Loss: 0.1247, Avg Accuracy: 0.9518

# [SELF] Epoch 2 - Avg Loss: 0.1134, Avg Accuracy: 0.9560

# [SELF] Epoch 3 - Avg Loss: 0.1044, Avg Accuracy: 0.9589

# [SELF] Epoch 4 - Avg Loss: 0.8424, Avg Accuracy: 0.5520

# 이렇게 학습됐던 때에 돌린 결과. 에포크 3에서의 pth가 최종 모델이 되는 것.

In [37]:
shutil.copyfile("temp_checkpoint_self_epoch3.pth", "end_checkpoint_self_cjy01.pth")

'end_checkpoint_self_cjy01.pth'

In [38]:
def load_model_for_evaluation(checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model = VisionLanguageModel().to(device)
    model.load_state_dict(checkpoint['model'])  # student 모델 로드
    # 또는 teacher 모델 사용 시: model.load_state_dict(checkpoint['ema_model'])
    return model

# 모델 로드
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model_for_evaluation("end_checkpoint_self_cjy01.pth", device)

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [39]:
mean_correct, mean_incorrect = evaluate_mean_similarity(model, val_loader, device)

Evaluating: 100%|█████████████████████████████| 157/157 [09:44<00:00,  3.72s/it]


Evaluation Results:
 - Mean Correct Sim    : 0.6022
 - Mean Incorrect Sim  : 0.0790





In [64]:
import torch
import torch.nn.functional as F
from tqdm import tqdm

def retrieval_evaluate(model, val_loader, device, max_samples=5000):
    model.eval()
    image_embeds_list = []
    text_embeds_list = []
    text_list = []
    img_id_list = []

    n_samples = 0

    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(val_loader, desc="Encoding image-text pairs"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            # (B, 512)
            image_feat, text_feat = model.encode_tokenized_input(images, input_ids, attention_mask)

            image_embeds_list.append(image_feat)
            text_embeds_list.append(text_feat)
            n_samples += images.size(0)
            if n_samples >= max_samples:
                break

    # (N, 512)
    image_embeds_all = torch.cat(image_embeds_list, dim=0)[:max_samples]
    text_embeds_all = torch.cat(text_embeds_list, dim=0)[:max_samples]

    # 정규화
    image_embeds_all = F.normalize(image_embeds_all, dim=-1)
    text_embeds_all = F.normalize(text_embeds_all, dim=-1)

    # 유사도 행렬 (N, N)
    sim_matrix = torch.matmul(image_embeds_all, text_embeds_all.T)

    def compute_recall(sim_matrix, k):
        correct = 0
        for i in range(sim_matrix.size(0)):
            topk = sim_matrix[i].topk(k).indices
            if i in topk:
                correct += 1
        return correct / sim_matrix.size(0)

    recall1 = compute_recall(sim_matrix, 1)
    recall5 = compute_recall(sim_matrix, 5)
    recall10 = compute_recall(sim_matrix, 10)

    print("\n[Image-Text Retrieval Results (Closed-domain)]")
    print(f"Recall@1: {recall1:.4f}")
    print(f"Recall@5: {recall5:.4f}")
    print(f"Recall@10: {recall10:.4f}")

    return recall1, recall5, recall10

In [65]:
recall1, recall5, recall10 = retrieval_evaluate(model, val_loader, device, max_samples=5000)

Encoding image-text pairs:  99%|█████████████▉| 156/157 [13:56<00:05,  5.36s/it]



[Image-Text Retrieval Results (Closed-domain)]
Recall@1: 0.0454
Recall@5: 0.1540
Recall@10: 0.2468


In [None]:
# CLIP에서의 Retrieval Results 보기 위한 코드

In [60]:
import torch
import open_clip
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm


class CocoClipEvaluationDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)
        self.imgid2caption = {}
        for ann in data['annotations']:
            img_id = ann['image_id']
            self.imgid2caption[img_id] = ann['caption']
        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        caption = self.imgid2caption[img_id]
        return image, caption  # 2개 값만 반환



def openclip_retrieval_evaluate(model, val_loader, tokenizer, device, max_samples=5000):
    model.eval()
    image_embeds_list = []
    text_embeds_list = []
    n_samples = 0

    with torch.no_grad():
        # 2개 값만 언패킹 (images, texts)
        for images, texts in tqdm(val_loader, desc="Encoding image-text pairs"):
            images = images.to(device)
            # 텍스트 토크나이징
            text_tokens = tokenizer(texts).to(device)
            
            # 임베딩 추출
            image_embeds = model.encode_image(images)
            text_embeds = model.encode_text(text_tokens)
            
            image_embeds_list.append(image_embeds)
            text_embeds_list.append(text_embeds)
            n_samples += images.size(0)
            if n_samples >= max_samples:
                break

    image_embeds_all = torch.cat(image_embeds_list, dim=0)[:max_samples]
    text_embeds_all = torch.cat(text_embeds_list, dim=0)[:max_samples]

    image_embeds_all = F.normalize(image_embeds_all, dim=-1)
    text_embeds_all = F.normalize(text_embeds_all, dim=-1)

    sim_matrix = torch.matmul(image_embeds_all, text_embeds_all.T)

    def compute_recall(sim_matrix, k):
        correct = 0
        for i in range(sim_matrix.size(0)):
            topk = sim_matrix[i].topk(k).indices
            if i in topk:
                correct += 1
        return correct / sim_matrix.size(0)

    recall1 = compute_recall(sim_matrix, 1)
    recall5 = compute_recall(sim_matrix, 5)
    recall10 = compute_recall(sim_matrix, 10)

    print("\n[Image-Text Retrieval Results (Closed-domain)]")
    print(f"Recall@1: {recall1:.4f}")
    print(f"Recall@5: {recall5:.4f}")
    print(f"Recall@10: {recall10:.4f}")

    return recall1, recall5, recall10


In [62]:
# CLIP 전용 transform
_, _, preprocess_clip = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

# 데이터셋 및 데이터로더 생성
clip_val_dataset = CocoClipEvaluationDataset(
    caption_json_path="annotations/captions_val2017.json",
    image_root="val2017",
    transform=preprocess_clip
)

clip_val_loader = DataLoader(
    clip_val_dataset,
    batch_size=32,
    shuffle=False
)


In [63]:
# 모델 및 토크나이저 초기화
model_clip, _, _ = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer_clip = open_clip.get_tokenizer('ViT-B-32')
device_clip = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_clip = model_clip.to(device_clip)

# 평가 실행
recall1, recall5, recall10 = openclip_retrieval_evaluate(
    model_clip,
    clip_val_loader,  # 2개 값 반환하는 데이터로더
    tokenizer_clip,
    device_clip,
    max_samples=5000
)

Encoding image-text pairs:  99%|█████████████▉| 156/157 [13:27<00:05,  5.17s/it]



[Image-Text Retrieval Results (Closed-domain)]
Recall@1: 0.4108
Recall@5: 0.6770
Recall@10: 0.7716
