In [1]:
# WASP 최종코드
# 하단에 평가 코드(clip과의 비교)도 포함

In [3]:
# 환경 설정 및 라이브러리 로딩
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel
import json
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [4]:
from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights


class VisionEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # [B, 1+P, D]



In [5]:
# 텍스트 인코더: BERT 기반 Transformer
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, captions):
        tokenized = self.tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(self.bert.device)
        outputs = self.bert(**tokenized)
        return outputs.last_hidden_state, tokenized
    

In [6]:
# Cross-Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=768, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, text_emb, image_patches):
        attn_output, attn_weights = self.attn(text_emb, image_patches, image_patches)
        return attn_output, attn_weights

In [90]:
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.cross_attn = CrossAttentionBlock()
        self.proj_image = nn.Linear(768, 512)
        self.proj_text = nn.Linear(768, 512)

    def forward(self, images, captions):
        img_feat_all = self.vision_encoder(images)  # [B, 1+P, D]
        cls_feat = img_feat_all[:, 0]               # [B, D] CLS
        patch_feat = img_feat_all[:, 1:]            # [B, P, D] patches only

        text_emb, tokens = self.text_encoder(captions)  # [B, T, D]
        cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)  # [B, T, D], [B, T, P]

        img_proj = self.proj_image(cls_feat)        # [B, 512]
        text_proj = self.proj_text(text_emb[:, 0])  # [B, 512]

        return img_proj, text_proj, attn_weights, tokens

    def encode_for_inference(self, images, captions): #랜덤 3개 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)
            cls_feat = img_feat_all[:, 0]
            text_emb, _ = self.text_encoder(captions)
            img_proj = self.proj_image(cls_feat)
            text_proj = self.proj_text(text_emb[:, 0])
        return img_proj, text_proj
    
    def encode_tokenized_input(self, images, input_ids, attention_mask): #평균 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)        # (B, 197, 768)
            cls_feat = img_feat_all[:, 0, :]                  # CLS token만 추출 (B, 768)
            img_proj = self.proj_image(cls_feat)              # (B, 512)

            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            txt_cls = bert_out.last_hidden_state[:, 0, :]     # (B, 768)
            txt_proj = self.proj_text(txt_cls)                # (B, 512)

            return img_proj, txt_proj
        
     # 텍스트만 인코딩
    def encode_tokenized_input_text_only(self, input_ids, attention_mask):
        with torch.no_grad():
            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            txt_cls = bert_out.last_hidden_state[:, 0, :]
            txt_proj = self.proj_text(txt_cls)
            return txt_proj


In [8]:
# Consistency Loss 계산 함수
def compute_consistency_loss(attn_weights, masks, eps=1e-6):
    B, T, H, W = masks.shape
    masks_flat = masks.view(B, T, -1)
    scores = (attn_weights * masks_flat).sum(dim=-1)
    scores = torch.clamp(scores, min=eps, max=1.0)
    return -torch.log(scores).mean()

In [9]:
# CLIP Contrastive Loss

def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    
    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

In [10]:
from collections import Counter
from sklearn.model_selection import StratifiedShuffleSplit
import numpy as np
import random

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

json_path = "coco_token_bbox_matched.json"
image_root = "/raid/kyscap251/team2/train2017/train2017"

# 1. Load JSON once
with open(json_path, "r") as f:
    all_data = json.load(f)

# 2. 유효 이미지와 대표 label 추출
valid_indices = []
labels = []

def get_dominant_label(matches):
    if not matches:
        return "none"
    return Counter([m["label"] for m in matches]).most_common(1)[0][0]

for i, entry in enumerate(all_data):
    image_id = entry["image_id"]
    image_path = os.path.join(image_root, f"{image_id:012d}.jpg")
    if os.path.exists(image_path):
        valid_indices.append(i)
        labels.append(get_dominant_label(entry["matches"]))

# 3. stratified split
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=SEED)
subset_rel_idx, _ = next(splitter.split(np.zeros(len(labels)), labels))

# 4. 전체 JSON 기준으로 실제 subset 인덱스로 변환
subset_json_indices = [valid_indices[i] for i in subset_rel_idx]



In [11]:
from PIL import Image, UnidentifiedImageError
from collections import Counter
from torch.utils.data import Dataset


class CocoVLMDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, patch_size=16, max_tokens=10, subset_indices=None):
        with open(json_path, 'r') as f:
            all_data = json.load(f)
        self.image_root = image_root
        self.transform = transform
        self.patch_size = patch_size
        self.max_tokens = max_tokens
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        self.data = []
        selected = subset_indices if subset_indices is not None else range(len(all_data))

        for i in selected:
            entry = all_data[i]
            image_id = entry["image_id"]
            image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
            if os.path.exists(image_path):
                self.data.append(entry)

        print(f"유효 이미지 수: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image_id = entry["image_id"]
        image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")

        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, UnidentifiedImageError, OSError):
            print(f"[WARN] 이미지 불러오기 실패: {image_path}")
            return self.__getitem__((idx + 1) % len(self))

        if self.transform:
            image_tensor = self.transform(image)
        else:
            image_tensor = image

        captions = entry["captions"]
        matches = entry["matches"][:self.max_tokens]
        caption = captions[0]

        # Tokenize caption
        caption_tokens = self.tokenizer(caption, return_tensors="pt", padding="max_length",
                                        truncation=True, max_length=30)["input_ids"].squeeze(0)

        # Binary mask 생성
        H, W = 224 // self.patch_size, 224 // self.patch_size
        mask_tensor = torch.zeros((self.max_tokens, H, W))

        for i, match in enumerate(matches):
            x, y, w, h = match["bbox"]
            x1 = int(x // self.patch_size)
            y1 = int(y // self.patch_size)
            x2 = int((x + w) // self.patch_size)
            y2 = int((y + h) // self.patch_size)
            mask_tensor[i, y1:y2+1, x1:x2+1] = 1.0

        return {
            'image': image_tensor,
            'caption': caption,
            'mask': mask_tensor,
            'image_id': image_id
        }




In [12]:
def coco_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    captions = [item['caption'] for item in batch] 
    masks = torch.stack([item['mask'] for item in batch])
    image_ids = [item['image_id'] for item in batch]
    return images, captions, masks, image_ids


In [14]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CocoVLMDataset(
    json_path=json_path,
    image_root=image_root,
    transform=transform,
    subset_indices=subset_json_indices
)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=coco_collate_fn
)


유효 이미지 수: 54966


In [15]:
class LambdaScheduler: #람다 점진적 높이기
    def __init__(self, alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5):
        self.ema = None
        self.alpha = alpha
        self.threshold = threshold
        self.anneal_rate = anneal_rate
        self.max_lambda = max_lambda
        self.apply = False
        self.step = 0
        self.current_lambda = 0.0

    def update(self, attn_weights, return_diff=False):
        """
        attn_weights: Tensor of shape [B, T, P] (cross-attn weights)
        """
        attn_mean = attn_weights.mean(dim=1).mean(dim=0)  # [P]
        if self.ema is None:
            self.ema = attn_mean
        else:
            self.ema = self.alpha * attn_mean + (1 - self.alpha) * self.ema

        diff = torch.abs(self.ema - attn_mean).mean().item()

        if not self.apply and diff < self.threshold:
            self.apply = True
            print(f"[LambdaScheduler] Consistency loss ON (diff={diff:.6f})")

        if self.apply:
            self.current_lambda = min(self.max_lambda, self.anneal_rate * self.step)
            self.step += 1
        else:
            self.current_lambda = 0.0

        if return_diff:
            return self.current_lambda, diff
        else:
            return self.current_lambda

In [21]:
def kl_divergence_attention(prev_attn, curr_attn, eps=1e-6):
    """
    prev_attn, curr_attn: [B, T, P] - softmax된 attention map
    """
    prev = torch.clamp(prev_attn, min=eps)
    curr = torch.clamp(curr_attn, min=eps)
    kl = (prev * (prev.log() - curr.log())).sum(dim=-1).mean()  # 평균 over T, B
    return kl

In [23]:
import shutil

def train_warmup(model, dataloader, optimizer, device,
                 lambda_cons=0.5, warmup_epochs=10, start_epoch=0):
    model.train()
    bbox_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.005, max_lambda=lambda_cons)

    acc_history = []
    prev_acc = None
    drop_count = 0
    high_point_epoch = None

    for epoch in range(start_epoch, start_epoch + warmup_epochs):
        total_loss, total_acc = 0.0, 0.0
        num_batches = 0
        curr_attn_dict = {}

        progress = tqdm(dataloader, desc=f"Warm-up Epoch {epoch}", leave=False)
        for batch in progress:
            images, captions, masks, image_ids = batch
            images, masks = images.to(device), masks.to(device)

            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            T_attn = attn_weights.size(1)
            T_mask = masks.size(1)
            T_common = min(T_attn, T_mask)

            attn_weights_slice = attn_weights[:, :T_common, :]
            masks_slice = masks[:, :T_common]

            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            lambda_bbox, _ = bbox_scheduler.update(attn_weights_slice, return_diff=True)
            loss_bbox_cons = compute_consistency_loss(attn_weights_slice, masks_slice)
            loss_bbox_cons = torch.clamp(loss_bbox_cons, max=1.0)
            loss = loss_contrastive + lambda_bbox * loss_bbox_cons

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            for i, img_id in enumerate(image_ids):
                if isinstance(img_id, torch.Tensor):
                    img_id = img_id.item()
                curr_attn_dict[str(img_id)] = attn_weights[i].detach().cpu()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            num_batches += 1

            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})

        avg_loss = total_loss / num_batches
        avg_acc = total_acc / num_batches
        acc_history.append(avg_acc)
        print(f"[WARM-UP] Epoch {epoch} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")

        torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'attn_dict': curr_attn_dict},
                   f"temp_checkpoint_epoch{epoch}.pth")

        if prev_acc is not None and prev_acc > avg_acc:
            drop_count += 1
        else:
            drop_count = 0

        prev_acc = avg_acc

        if drop_count >= 2:
            high_point_epoch = epoch - 2
            print(f"[WARM-UP 종료 감지] Accuracy 하락 → 고점 epoch: {high_point_epoch}")
            shutil.copyfile(f"temp_checkpoint_epoch{high_point_epoch}.pth",
                            f"checkpoint_epoch{high_point_epoch}_stable.pth")
            break


In [24]:
# 1. 기기 설정 및 초기화
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionLanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# 2. Warm-up 학습 함수 실행
train_warmup(
    model=model,
    dataloader=dataloader,  # 사용 중인 dataloader 그대로 입력
    optimizer=optimizer,
    device=device,
    lambda_cons=0.5,         # bbox consistency weight 
    warmup_epochs=10,       # 충분한 최대 epoch 설정
    start_epoch=0
)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Warm-up Epoch 0:   0%| | 1/13742 [00:00<2:18:37,  1.65it/s, loss=1.46, acc=0.250

[LambdaScheduler] Consistency loss ON (diff=0.000000)


Warm-up Epoch 0:   1%| | 129/13742 [00:14<28:35,  7.93it/s, loss=0.795, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Warm-up Epoch 0:   9%| | 1172/13742 [02:04<21:41,  9.66it/s, loss=1.01, acc=0.75

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Warm-up Epoch 0:  21%|▏| 2923/13742 [05:06<18:24,  9.79it/s, loss=0.636, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Warm-up Epoch 0:  30%|▎| 4165/13742 [07:14<16:37,  9.60it/s, loss=0.797, acc=0.7

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Warm-up Epoch 0:  39%|▍| 5396/13742 [09:22<14:27,  9.62it/s, loss=0.8, acc=0.750

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Warm-up Epoch 0:  42%|▍| 5773/13742 [10:01<13:37,  9.75it/s, loss=0.544, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Warm-up Epoch 0:  46%|▍| 6364/13742 [11:01<12:30,  9.84it/s, loss=0.502, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Warm-up Epoch 0:  60%|▌| 8264/13742 [14:18<09:24,  9.71it/s, loss=0.548, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Warm-up Epoch 0:  74%|▋| 10137/13742 [17:31<06:12,  9.68it/s, loss=0.64, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Warm-up Epoch 0:  81%|▊| 11071/13742 [19:07<04:41,  9.49it/s, loss=0.544, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


                                                                                

[WARM-UP] Epoch 0 - Avg Loss: 0.6769, Avg Accuracy: 0.9340


Warm-up Epoch 1:   7%| | 985/13742 [01:42<22:12,  9.58it/s, loss=0.528, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Warm-up Epoch 1:  11%| | 1497/13742 [02:34<21:18,  9.58it/s, loss=0.504, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Warm-up Epoch 1:  22%|▏| 3048/13742 [05:15<18:26,  9.67it/s, loss=0.836, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Warm-up Epoch 1:  40%|▍| 5441/13742 [09:21<14:07,  9.80it/s, loss=0.524, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Warm-up Epoch 1:  65%|▋| 8936/13742 [15:23<08:18,  9.64it/s, loss=0.802, acc=0.7

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Warm-up Epoch 1:  77%|▊| 10569/13742 [18:13<05:24,  9.77it/s, loss=0.501, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Warm-up Epoch 1:  79%|▊| 10828/13742 [18:40<05:03,  9.61it/s, loss=0.734, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Warm-up Epoch 1:  93%|▉| 12811/13742 [22:05<01:36,  9.67it/s, loss=0.537, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Warm-up Epoch 1:  96%|▉| 13257/13742 [22:51<00:50,  9.62it/s, loss=0.653, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Warm-up Epoch 1:  99%|▉| 13548/13742 [23:21<00:19,  9.73it/s, loss=0.502, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


                                                                                

[WARM-UP] Epoch 1 - Avg Loss: 0.6301, Avg Accuracy: 0.9497


Warm-up Epoch 2:   9%| | 1275/13742 [02:12<21:24,  9.70it/s, loss=0.519, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Warm-up Epoch 2:  20%|▏| 2680/13742 [04:37<18:49,  9.79it/s, loss=0.518, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Warm-up Epoch 2:  22%|▏| 3004/13742 [05:11<18:26,  9.70it/s, loss=0.523, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Warm-up Epoch 2:  31%|▎| 4253/13742 [07:19<16:15,  9.73it/s, loss=0.531, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Warm-up Epoch 2:  54%|▌| 7355/13742 [12:41<11:02,  9.64it/s, loss=0.5, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Warm-up Epoch 2:  55%|▌| 7603/13742 [13:06<10:29,  9.75it/s, loss=0.58, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Warm-up Epoch 2:  63%|▋| 8629/13742 [14:53<08:46,  9.71it/s, loss=0.771, acc=0.7

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Warm-up Epoch 2:  81%|▊| 11103/13742 [19:09<04:33,  9.66it/s, loss=0.603, acc=0.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Warm-up Epoch 2:  94%|▉| 12885/13742 [22:15<01:29,  9.60it/s, loss=1.04, acc=0.7

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Warm-up Epoch 2:  98%|▉| 13464/13742 [23:14<00:28,  9.82it/s, loss=0.524, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


                                                                                

[WARM-UP] Epoch 2 - Avg Loss: 0.6374, Avg Accuracy: 0.9493


Warm-up Epoch 3:   8%| | 1152/13742 [02:00<22:18,  9.40it/s, loss=0.615, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Warm-up Epoch 3:  14%|▏| 1975/13742 [03:25<20:32,  9.55it/s, loss=0.556, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Warm-up Epoch 3:  22%|▏| 2986/13742 [05:10<18:17,  9.80it/s, loss=0.617, acc=0.7

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Warm-up Epoch 3:  43%|▍| 5877/13742 [10:09<13:30,  9.71it/s, loss=0.635, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Warm-up Epoch 3:  54%|▌| 7402/13742 [12:48<11:14,  9.40it/s, loss=0.528, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Warm-up Epoch 3:  54%|▌| 7456/13742 [12:53<11:07,  9.42it/s, loss=0.509, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Warm-up Epoch 3:  80%|▊| 11037/13742 [19:04<04:40,  9.65it/s, loss=0.542, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Warm-up Epoch 3:  83%|▊| 11406/13742 [19:42<04:04,  9.57it/s, loss=0.591, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Warm-up Epoch 3:  89%|▉| 12162/13742 [21:00<02:43,  9.67it/s, loss=0.928, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Warm-up Epoch 3:  91%|▉| 12460/13742 [21:31<02:12,  9.71it/s, loss=0.502, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


                                                                                

[WARM-UP] Epoch 3 - Avg Loss: 0.6520, Avg Accuracy: 0.9457
[WARM-UP 종료 감지] Accuracy 하락 → 고점 epoch: 1


In [49]:
shutil.copyfile("checkpoint_epoch1_stable.pth", "checkpoint_self_start.pth")

'checkpoint_self_start.pth'

In [50]:
def load_self_supervised_start(path="checkpoint_self_start.pth", lr=5e-5):
    checkpoint = torch.load(path)
    model = VisionLanguageModel().to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    model.load_state_dict(checkpoint['model'])
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(checkpoint['optimizer'])
    prev_attn_dict = checkpoint['attn_dict']
    return model, optimizer, prev_attn_dict

In [51]:
def train_self_supervised(model, dataloader, optimizer, device,
                          prev_attn_dict, lambda_self=0.1, num_epochs=10, start_epoch=0):
    model.train()
    acc_history = []
    lambda_self_ㅏscheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=lambda_self)

    for epoch in range(start_epoch, start_epoch + num_epochs):
        total_loss, total_acc, num_batches = 0.0, 0.0, 0
        curr_attn_dict = {}

        progress = tqdm(dataloader, desc=f"Self Epoch {epoch}", leave=False)
        for batch_idx, batch in enumerate(progress):
            images, captions, _, image_ids = batch
            images = images.to(device)

            img_proj, txt_proj, attn_weights, _ = model(images, captions)
            T_common = attn_weights.size(1)
            attn_weights_slice = attn_weights[:, :T_common, :]

            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)

            loss_self, count = 0.0, 0
            for i, img_id in enumerate(image_ids):
                if isinstance(img_id, torch.Tensor):
                    img_id = img_id.item()
                img_id_str = str(img_id)

                if img_id_str in prev_attn_dict:
                    prev_attn = prev_attn_dict[img_id_str].to(device)
                    curr_attn = attn_weights[i]

                    T_common = min(prev_attn.size(0), curr_attn.size(0))
                    prev_attn_crop = prev_attn[:T_common, :]
                    curr_attn_crop = curr_attn[:T_common, :]

                    prev_soft = F.softmax(prev_attn_crop, dim=-1)
                    curr_soft = F.softmax(curr_attn_crop, dim=-1)
                    loss_self += kl_divergence_attention(prev_soft, curr_soft)

                    diff = F.mse_loss(curr_attn_crop, prev_attn_crop, reduction='mean').item()
                    alpha = max(0.1, min(0.9, 1.0 - diff))  # 차이 클수록 alpha 작게

                    updated_attn = alpha * curr_attn_crop + (1 - alpha) * prev_attn_crop
                    curr_attn_dict[img_id_str] = updated_attn.detach().cpu()

                    count += 1

            loss_self = loss_self / count if count > 0 else torch.tensor(0.0, device=device)
            lambda_self_val, _ = lambda_self_scheduler.update(attn_weights_slice, return_diff=True)
            loss = loss_contrastive + lambda_self_val * loss_self

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            num_batches += 1
            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})

        avg_loss = total_loss / num_batches
        avg_acc = total_acc / num_batches
        acc_history.append(avg_acc)

        print(f"[SELF] Epoch {epoch} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")

        torch.save({
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'attn_dict': curr_attn_dict
        }, f"temp_checkpoint_self_epoch{epoch}.pth")


In [42]:
#  고점 체크포인트 불러오기
model, optimizer, prev_attn_dict = load_self_supervised_start()

#  self-consistency 학습용 train_model 함수 실행
train_self_supervised(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    device=device,
    lambda_self=0.1,
    num_epochs=5, # 원하는 만큼 self-consistency 학습 횟수
    start_epoch=2,
    prev_attn_dict=prev_attn_dict
)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Self Epoch 2:   0%|    | 2/13742 [00:00<27:49,  8.23it/s, loss=0.385, acc=0.750]

[LambdaScheduler] Consistency loss ON (diff=0.000000)


Self Epoch 2:   2%| | 212/13742 [00:22<23:36,  9.55it/s, loss=0.0194, acc=1.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg
[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Self Epoch 2:  20%|▏| 2700/13742 [04:43<19:29,  9.44it/s, loss=0.00392, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Self Epoch 2:  21%|▍ | 2929/13742 [05:07<19:11,  9.39it/s, loss=0.34, acc=0.750]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Self Epoch 2:  22%|▏| 3041/13742 [05:19<18:45,  9.51it/s, loss=0.0704, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Self Epoch 2:  26%|▎| 3613/13742 [06:19<17:33,  9.61it/s, loss=0.0157, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Self Epoch 2:  31%|▎| 4313/13742 [07:32<16:15,  9.67it/s, loss=0.0014, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Self Epoch 2:  61%|▌| 8427/13742 [14:43<09:12,  9.62it/s, loss=0.559, acc=0.500]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Self Epoch 2:  86%|▊| 11869/13742 [20:45<03:16,  9.55it/s, loss=0.000907, acc=1.

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Self Epoch 2:  86%|▊| 11886/13742 [20:46<03:14,  9.56it/s, loss=0.0317, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


                                                                                

[SELF] Epoch 2 - Avg Loss: 0.1123, Avg Accuracy: 0.9567


Self Epoch 3:   9%| | 1206/13742 [02:07<21:57,  9.51it/s, loss=0.00921, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Self Epoch 3:  32%|▎| 4383/13742 [07:41<16:40,  9.36it/s, loss=0.0332, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Self Epoch 3:  33%|▎| 4480/13742 [07:51<16:30,  9.35it/s, loss=0.014, acc=1.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Self Epoch 3:  35%|▎| 4812/13742 [08:27<15:52,  9.38it/s, loss=0.00064, acc=1.00

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Self Epoch 3:  39%|▍| 5357/13742 [09:25<15:07,  9.24it/s, loss=0.0107, acc=1.000

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Self Epoch 3:  59%|▌| 8061/13742 [14:10<10:04,  9.40it/s, loss=0.609, acc=0.500]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Self Epoch 3:  67%|▋| 9167/13742 [16:06<07:57,  9.59it/s, loss=0.000969, acc=1.0

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Self Epoch 3:  90%|▉| 12308/13742 [21:36<02:31,  9.44it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Self Epoch 3:  92%|▉| 12628/13742 [22:10<01:58,  9.37it/s, loss=1.39, acc=0.500]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Self Epoch 3:  93%|▉| 12728/13742 [22:21<01:48,  9.39it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


                                                                                

[SELF] Epoch 3 - Avg Loss: 0.3203, Avg Accuracy: 0.8391


Self Epoch 4:  11%|▏ | 1550/13742 [02:44<21:33,  9.42it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Self Epoch 4:  11%|▏ | 1575/13742 [02:47<21:23,  9.48it/s, loss=1.38, acc=0.500]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Self Epoch 4:  21%|▍ | 2906/13742 [05:08<19:04,  9.47it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Self Epoch 4:  29%|▌ | 3974/13742 [07:02<17:12,  9.46it/s, loss=1.38, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Self Epoch 4:  32%|▋ | 4351/13742 [07:42<16:36,  9.42it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Self Epoch 4:  35%|▋ | 4823/13742 [08:31<15:45,  9.43it/s, loss=1.38, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Self Epoch 4:  46%|▉ | 6337/13742 [11:12<12:59,  9.50it/s, loss=1.39, acc=0.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Self Epoch 4:  76%|▊| 10378/13742 [18:16<05:47,  9.69it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Self Epoch 4:  76%|▊| 10422/13742 [18:21<05:44,  9.63it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Self Epoch 4:  81%|▊| 11176/13742 [19:40<04:26,  9.63it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


                                                                                

[SELF] Epoch 4 - Avg Loss: 1.3864, Avg Accuracy: 0.2486


Self Epoch 5:   4%|   | 559/13742 [00:58<22:56,  9.58it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Self Epoch 5:  19%|▍ | 2665/13742 [04:39<19:08,  9.64it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Self Epoch 5:  34%|▋ | 4672/13742 [08:08<15:44,  9.61it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Self Epoch 5:  41%|▊ | 5664/13742 [09:52<13:56,  9.65it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Self Epoch 5:  71%|█▍| 9733/13742 [17:05<07:11,  9.29it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Self Epoch 5:  83%|▊| 11454/13742 [20:09<04:05,  9.34it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


Self Epoch 5:  87%|▊| 12024/13742 [21:10<03:05,  9.25it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Self Epoch 5:  88%|▉| 12101/13742 [21:19<02:55,  9.35it/s, loss=1.39, acc=0.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Self Epoch 5:  88%|▉| 12156/13742 [21:25<02:49,  9.37it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Self Epoch 5:  97%|▉| 13394/13742 [23:37<00:38,  9.15it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


                                                                                

[SELF] Epoch 5 - Avg Loss: 1.3863, Avg Accuracy: 0.2491


Self Epoch 6:   1%|    | 82/13742 [00:08<24:08,  9.43it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000478812.jpg


Self Epoch 6:   4%|   | 554/13742 [00:59<23:35,  9.32it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000453622.jpg


Self Epoch 6:  12%|▏ | 1669/13742 [02:57<21:50,  9.21it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000455533.jpg


Self Epoch 6:  43%|▊ | 5963/13742 [10:34<14:12,  9.13it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000017867.jpg


Self Epoch 6:  48%|▉ | 6541/13742 [11:36<13:01,  9.21it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000325021.jpg


Self Epoch 6:  79%|▊| 10831/13742 [19:18<05:09,  9.40it/s, loss=1.39, acc=0.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000322691.jpg


Self Epoch 6:  79%|▊| 10894/13742 [19:25<05:02,  9.41it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000329261.jpg


Self Epoch 6:  85%|▊| 11630/13742 [20:44<03:43,  9.43it/s, loss=1.39, acc=0.250]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000401218.jpg


Self Epoch 6:  87%|▊| 11932/13742 [21:16<03:13,  9.35it/s, loss=1.39, acc=0.000]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000429834.jpg


Self Epoch 6:  92%|▉| 12652/13742 [22:34<01:56,  9.33it/s, loss=1.39, acc=0.500]

[WARN] 이미지 불러오기 실패: /raid/kyscap251/team2/train2017/train2017/000000389624.jpg


                                                                                

[SELF] Epoch 6 - Avg Loss: 1.3863, Avg Accuracy: 0.2509


In [54]:
shutil.copyfile("temp_checkpoint_self_epoch2.pth", "end_checkpoint_self.pth")

'end_checkpoint_self.pth'

In [55]:
print("평가")

평가


In [64]:
from torch.utils.data import Dataset
from PIL import Image
from PIL import UnidentifiedImageError
from torch.utils.data import Dataset


class SimpleCocoCaptionDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform, tokenizer):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)

        self.imgid2caption = {}
        for entry in data:
            img_id = entry['image_id']
            captions = entry['captions']
            if captions:
                self.imgid2caption[img_id] = captions[0]  # 첫 번째 캡션만 사용

        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")

        try:
            image = self.transform(Image.open(img_path).convert("RGB"))
        except (FileNotFoundError, UnidentifiedImageError):
            print(f"[WARN] 이미지 불러오기 실패: {img_path}")
            # 다음 샘플로 넘어가기
            return self.__getitem__((idx + 1) % len(self))

        caption = self.imgid2caption[img_id]
        encoding = self.tokenizer(
            caption,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=32
        )
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return image, input_ids, attention_mask



In [65]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

train_dataset = SimpleCocoCaptionDataset(
    caption_json_path="coco_token_bbox_matched.json", 
    image_root="/shared/home/kyscap251/Team2/val2017",
    transform=transform,
    tokenizer=tokenizer
)


eval_loader = DataLoader(eval_subset, batch_size=32, shuffle=False,)

In [None]:
### clip모델의 코사인 유사도 평균

In [109]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import os, json
from tqdm import tqdm
import open_clip
from collections import defaultdict

# 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# OpenCLIP 모델 및 전처리
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name='ViT-B-32',
    pretrained='laion2b_s34b_b79k'
)
model = model.to(device)
model.eval()

tokenizer = open_clip.get_tokenizer('ViT-B-32')

# 데이터셋 준비
caption_json_path = "annotations/captions_val2017.json"
image_root = "val2017"

with open(caption_json_path, 'r') as f:
    coco_data = json.load(f)

###

imgid2captions = defaultdict(list)
for ann in coco_data['annotations']:
    img_id = ann['image_id']
    imgid2captions[img_id].append(ann['caption'])

# 나중에 사용할 때
caption = imgid2captions[img_id][0]  # 첫 번째 캡션만 사용

##3333

# COCO 이미지 ID 정렬 및 존재 확인
image_ids = sorted(list(imgid2caption.keys()))
image_paths = {img_id: os.path.join(image_root, f"{img_id:012d}.jpg") for img_id in image_ids}
valid_ids = [img_id for img_id in image_ids if os.path.exists(image_paths[img_id])]

# 평가 개수 제한
max_samples = 5000
valid_ids = valid_ids[:max_samples]


all_similarities = []  # 모든 정답쌍 유사도 저장

for i, img_id in enumerate(tqdm(valid_ids, desc="Checking cosine similarity")):
    img_path = image_paths[img_id]
    caption = imgid2caption[img_id]

    try:
        image = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception as e:
        continue

    text = tokenizer([caption]).to(device)

    with torch.no_grad():
        image_feat = model.encode_image(image)
        text_feat = model.encode_text(text)

    image_feat = F.normalize(image_feat, dim=-1)
    text_feat = F.normalize(text_feat, dim=-1)
    sim = F.cosine_similarity(image_feat, text_feat, dim=-1).item()
    all_similarities.append(sim)

# 평균 유사도 출력
avg_sim = sum(all_similarities) / len(all_similarities)
print(f"\n[OpenCLIP] 전체 정답쌍 평균 cosine similarity: {avg_sim:.4f}")


Checking cosine similarity: 100%|███████████| 1000/1000 [01:36<00:00, 10.38it/s]


[OpenCLIP] 전체 정답쌍 평균 cosine similarity: 0.3112





In [None]:
### WASP의 정답 쌍 코사인 유사도 평균

In [111]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import BertTokenizer
from PIL import Image
from collections import defaultdict
from tqdm import tqdm
import os, json
import random

# ----------------------------
# 모델 및 환경 설정
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = VisionLanguageModel().to(device)
ckpt = torch.load("end_checkpoint_self.pth", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# ----------------------------
# tokenizer 및 transform 정의
# ----------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ----------------------------
# 데이터 로딩
# ----------------------------
COCO_IMAGE_DIR = "/shared/home/kyscap251/Team2/val2017"
COCO_ANN_PATH = "val_coco_token_bbox_matched.json"

with open(COCO_ANN_PATH, 'r') as f:
    coco_data = json.load(f)

id_to_captions = defaultdict(list)
id_to_filename = {}
for ann in coco_data:
    img_id = ann['image_id']
    if ann['captions']:
        id_to_captions[img_id].extend(ann['captions'])
        id_to_filename[img_id] = f"{img_id:012d}.jpg"

# ----------------------------
# 유효 이미지 추리기
# ----------------------------
valid_ids = [img_id for img_id in id_to_filename if os.path.exists(os.path.join(COCO_IMAGE_DIR, id_to_filename[img_id]))]
selected_ids = valid_ids[:5000]  # 평가 개수 제한

# ----------------------------
# 정답쌍 유사도 측정
# ----------------------------
all_similarities = []

for i, img_id in enumerate(tqdm(selected_ids, desc="Checking cosine similarity (VLM)")):
    img_path = os.path.join(COCO_IMAGE_DIR, id_to_filename[img_id])
    caption = id_to_captions[img_id][0]  # 첫 번째 정답 캡션만 사용

    try:
        image = transform(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception as e:
        print(f"[SKIP] 이미지 오류: {img_path} | {e}")
        continue

    tokens = tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)

    with torch.no_grad():
        img_feat, txt_feat = model.encode_tokenized_input(image, input_ids, attention_mask)

    img_feat = F.normalize(img_feat, dim=-1)
    txt_feat = F.normalize(txt_feat, dim=-1)
    sim = F.cosine_similarity(img_feat, txt_feat, dim=-1).item()
    all_similarities.append(sim)

# ----------------------------
# 평균 유사도 출력
# ----------------------------
avg_sim = sum(all_similarities) / len(all_similarities)
print(f"\n[VisionLanguageModel] 전체 정답쌍 평균 cosine similarity: {avg_sim:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Checking cosine similarity (VLM): 100%|███████| 500/500 [00:48<00:00, 10.33it/s]


[VisionLanguageModel] 전체 정답쌍 평균 cosine similarity: 0.5326





In [None]:
### WASP의 cosine smiliarity 오답 쌍

In [120]:
batch_size = 32
negative_similarities = []

for i in range(0, len(selected_ids), batch_size):
    batch_ids = selected_ids[i:i+batch_size]
    images = []
    neg_captions = []

    for img_id in batch_ids:
        img_path = os.path.join(COCO_IMAGE_DIR, id_to_filename[img_id])
        if not os.path.exists(img_path):
            continue
        try:
            image = transform(Image.open(img_path).convert("RGB"))
        except:
            continue
        images.append(image)

        # 랜덤한 다른 이미지의 캡션
        while True:
            neg_id = random.choice(selected_ids)
            if neg_id != img_id and id_to_captions[neg_id]:
                break
        neg_captions.append(id_to_captions[neg_id][0])

    if not images or not neg_captions:
        continue

    images = torch.stack(images).to(device)
    tokens = tokenizer(neg_captions, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
    input_ids = tokens['input_ids'].to(device)
    attention_mask = tokens['attention_mask'].to(device)

    with torch.no_grad():
        img_feats, txt_feats = model.encode_tokenized_input(images, input_ids, attention_mask)

    img_feats = F.normalize(img_feats, dim=-1)
    txt_feats = F.normalize(txt_feats, dim=-1)
    sims = F.cosine_similarity(img_feats, txt_feats, dim=-1).tolist()
    negative_similarities.extend(sims)

avg_neg_sim = sum(negative_similarities) / len(negative_similarities)
print(f"\n[VisionLanguageModel] 정답이 아닌 쌍 평균 cosine similarity (5000개 배치 처리): {avg_neg_sim:.4f}")



[VisionLanguageModel] 정답이 아닌 쌍 평균 cosine similarity (5000개 배치 처리): 0.0958


In [None]:
### clip의 image retrieval

In [121]:
import torch
import open_clip
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm


class CocoClipEvaluationDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)
        self.imgid2caption = {}
        for ann in data['annotations']:
            img_id = ann['image_id']
            self.imgid2caption[img_id] = ann['caption']
        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        caption = self.imgid2caption[img_id]
        return image, caption  # 2개 값만 반환



def openclip_retrieval_evaluate(model, val_loader, tokenizer, device, max_samples=5000):
    model.eval()
    image_embeds_list = []
    text_embeds_list = []
    n_samples = 0

    with torch.no_grad():
        # 2개 값만 언패킹 (images, texts)
        for images, texts in tqdm(val_loader, desc="Encoding image-text pairs"):
            images = images.to(device)
            # 텍스트 토크나이징
            text_tokens = tokenizer(texts).to(device)
            
            # 임베딩 추출
            image_embeds = model.encode_image(images)
            text_embeds = model.encode_text(text_tokens)
            
            image_embeds_list.append(image_embeds)
            text_embeds_list.append(text_embeds)
            n_samples += images.size(0)
            if n_samples >= max_samples:
                break

    image_embeds_all = torch.cat(image_embeds_list, dim=0)[:max_samples]
    text_embeds_all = torch.cat(text_embeds_list, dim=0)[:max_samples]

    image_embeds_all = F.normalize(image_embeds_all, dim=-1)
    text_embeds_all = F.normalize(text_embeds_all, dim=-1)

    sim_matrix = torch.matmul(image_embeds_all, text_embeds_all.T)

    def compute_recall(sim_matrix, k):
        correct = 0
        for i in range(sim_matrix.size(0)):
            topk = sim_matrix[i].topk(k).indices
            if i in topk:
                correct += 1
        return correct / sim_matrix.size(0)

    recall1 = compute_recall(sim_matrix, 1)
    recall5 = compute_recall(sim_matrix, 5)
    recall10 = compute_recall(sim_matrix, 10)

    print("\n[Image-Text Retrieval Results (Closed-domain)]")
    print(f"Recall@1: {recall1:.4f}")
    print(f"Recall@5: {recall5:.4f}")
    print(f"Recall@10: {recall10:.4f}")

    return recall1, recall5, recall10


In [113]:
# CLIP 전용 transform
_, _, preprocess_clip = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

# 데이터셋 및 데이터로더 생성
clip_val_dataset = CocoClipEvaluationDataset(
    caption_json_path="annotations/captions_val2017.json",
    image_root="val2017",
    transform=preprocess_clip
)

clip_val_loader = DataLoader(
    clip_val_dataset,
    batch_size=32,
    shuffle=False
)


In [114]:
# 모델 및 토크나이저 초기화
model_clip, _, _ = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer_clip = open_clip.get_tokenizer('ViT-B-32')
device_clip = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_clip = model_clip.to(device_clip)

# 평가 실행
recall1, recall5, recall10 = openclip_retrieval_evaluate(
    model_clip,
    clip_val_loader,
    tokenizer_clip,
    device_clip,
    max_samples=5000
)

Encoding image-text pairs:  99%|█████████████▉| 156/157 [08:08<00:03,  3.13s/it]



[Image-Text Retrieval Results (Closed-domain)]
Recall@1: 0.4108
Recall@5: 0.6770
Recall@10: 0.7716


In [None]:
### 내 모델의 image retrival 측정

In [116]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import BertTokenizer
from PIL import Image
import os, json, random
from tqdm import tqdm
from collections import defaultdict

# ----------------------------
# 모델 및 환경 설정
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionLanguageModel().to(device)
ckpt = torch.load("end_checkpoint_self.pth", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# ----------------------------
# tokenizer 및 transform 정의
# ----------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ----------------------------
# 데이터셋 정의
# ----------------------------
class CustomCocoDataset(Dataset):
    def __init__(self, image_dir, ann_path, transform, tokenizer, max_samples=500):
        with open(ann_path, 'r') as f:
            data = json.load(f)

        self.imgid2caption = {}
        self.imgid2filename = {}
        for entry in data:
            img_id = entry['image_id']
            captions = entry['captions']
            if captions:
                self.imgid2caption[img_id] = captions[0]
                self.imgid2filename[img_id] = f"{img_id:012d}.jpg"

        valid_ids = [img_id for img_id in self.imgid2filename if os.path.exists(os.path.join(image_dir, self.imgid2filename[img_id]))]
        self.selected_ids = valid_ids[:max_samples]

        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.selected_ids)

    def __getitem__(self, idx):
        img_id = self.selected_ids[idx]
        img_path = os.path.join(self.image_dir, self.imgid2filename[img_id])
        image = self.transform(Image.open(img_path).convert("RGB"))

        caption = self.imgid2caption[img_id]
        tokens = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        input_ids = tokens['input_ids'].squeeze(0)
        attention_mask = tokens['attention_mask'].squeeze(0)

        return image, input_ids, attention_mask

# ----------------------------
# 평가 함수 정의
# ----------------------------
def evaluate_recall(model, dataloader, device):
    image_feats, text_feats = [], []

    model.eval()
    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(dataloader, desc="Encoding"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            img_feat, txt_feat = model.encode_tokenized_input(images, input_ids, attention_mask)
            image_feats.append(img_feat)
            text_feats.append(txt_feat)

    image_feats = F.normalize(torch.cat(image_feats, dim=0), dim=-1)
    text_feats = F.normalize(torch.cat(text_feats, dim=0), dim=-1)
    sim_matrix = torch.matmul(image_feats, text_feats.T)

    def compute_recall(sim_matrix, k):
        correct = 0
        for i in range(sim_matrix.size(0)):
            topk = sim_matrix[i].topk(k).indices
            if i in topk:
                correct += 1
        return correct / sim_matrix.size(0)

    recall1 = compute_recall(sim_matrix, 1)
    recall5 = compute_recall(sim_matrix, 5)
    recall10 = compute_recall(sim_matrix, 10)

    print("\n[Image-Text Retrieval Results (Your Model)]")
    print(f"Recall@1: {recall1:.4f}")
    print(f"Recall@5: {recall5:.4f}")
    print(f"Recall@10: {recall10:.4f}")

    return recall1, recall5, recall10

# ----------------------------
# 실행
# ----------------------------
dataset = CustomCocoDataset(
    image_dir="/shared/home/kyscap251/Team2/val2017",
    ann_path="val_coco_token_bbox_matched.json",
    transform=transform,
    tokenizer=tokenizer,
    max_samples=5000
)

loader = DataLoader(dataset, batch_size=32, shuffle=False)

recall1, recall5, recall10 = evaluate_recall(model, loader, device)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Encoding: 100%|█████████████████████████████████| 16/16 [00:48<00:00,  3.02s/it]


[Image-Text Retrieval Results (Your Model)]
Recall@1: 0.0980
Recall@5: 0.2800
Recall@10: 0.4040





In [None]:
### image captioning open clip

In [123]:
import torch
import torch.nn.functional as F
import open_clip
from PIL import Image
from torchvision import transforms
import os, json, random
from tqdm import tqdm
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer


# ----------------------------
# 환경 설정
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CAPTION_JSON_PATH = "annotations/captions_val2017.json"
IMAGE_DIR = "val2017"
max_images = 500  # 평가할 이미지 수
caption_pool_size = 5000  # caption pool 크기

# ----------------------------
# OpenCLIP 모델 및 전처리
# ----------------------------
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model = model.to(device)
model.eval()
tokenizer = open_clip.get_tokenizer('ViT-B-32')

# ----------------------------
# 캡션 및 이미지 정보 로딩
# ----------------------------
with open(CAPTION_JSON_PATH, 'r') as f:
    coco_data = json.load(f)

imgid2caption = {}
caption_pool = []

for ann in coco_data['annotations']:
    img_id = ann['image_id']
    caption = ann['caption']
    if img_id not in imgid2caption:
        imgid2caption[img_id] = caption
    caption_pool.append(caption)

# caption pool 샘플링
caption_pool = list(set(caption_pool))
random.shuffle(caption_pool)
caption_pool = caption_pool[:caption_pool_size]

# ----------------------------
# caption pool 임베딩
# ----------------------------
with torch.no_grad():
    tokenized_caps = tokenizer(caption_pool).to(device)
    caption_feats = model.encode_text(tokenized_caps)
    caption_feats = F.normalize(caption_feats, dim=-1)

# ----------------------------
# 이미지 인코딩 및 가장 유사한 캡션 추출
# ----------------------------
image_ids = list(imgid2caption.keys())
random.shuffle(image_ids)
image_ids = image_ids[:max_images]

generated = {}  # {img_id: [retrieved_caption]}
references = {}  # {img_id: [gt_caption]}

for img_id in tqdm(image_ids, desc="Generating captions"):
    img_path = os.path.join(IMAGE_DIR, f"{img_id:012d}.jpg")
    try:
        image = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception as e:
        continue

    with torch.no_grad():
        image_feat = model.encode_image(image)
        image_feat = F.normalize(image_feat, dim=-1)

    sims = torch.matmul(image_feat, caption_feats.T).squeeze(0)  # (pool_size,)
    best_idx = sims.argmax().item()
    retrieved_caption = caption_pool[best_idx]

    generated[str(img_id)] = [retrieved_caption]
    references[str(img_id)] = [imgid2caption[img_id]]

# ----------------------------
# CIDEr 점수 계산
# ----------------------------
def wrap_captions(d: dict) -> dict:
    return {
        str(k): [{"caption": c} for c in v] for k, v in d.items()
    }

# Tokenizer용 포맷으로 변환
gts = wrap_captions(references)
res = wrap_captions(generated)

# 평가
tokenizer = PTBTokenizer()
gts_tok = tokenizer.tokenize(gts)
res_tok = tokenizer.tokenize(res)

cider_scorer = Cider()
score, _ = cider_scorer.compute_score(gts_tok, res_tok)

print(f"\n[Indirect Captioning Evaluation - CIDEr Metric (Open-domain)]")
print(f"CIDEr: {score:.4f}")



Generating captions: 100%|████████████████████| 500/500 [00:46<00:00, 10.78it/s]
PTBTokenizer tokenized 6175 tokens at 28956.54 tokens per second.



[Indirect Captioning Evaluation - CIDEr Metric (Open-domain)]
CIDEr: 1.1378


PTBTokenizer tokenized 6187 tokens at 63048.88 tokens per second.


In [None]:
## WASP image captioning

In [131]:
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import BertTokenizer
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
from pycocoevalcap.cider.cider import Cider
import os, json, random
from collections import defaultdict
from tqdm import tqdm

# ----------------------------
# 환경 설정
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
COCO_IMAGE_DIR = "/shared/home/kyscap251/Team2/val2017"
COCO_ANN_PATH = "val_coco_token_bbox_matched.json"
max_images = 500
caption_pool_size = 5000

# ----------------------------
# 모델 로딩
# ----------------------------
model = VisionLanguageModel().to(device)
ckpt = torch.load("end_checkpoint_self.pth", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# ----------------------------
# tokenizer 및 transform 정의
# ----------------------------
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# ----------------------------
# COCO 캡션 데이터 로딩
# ----------------------------
with open(COCO_ANN_PATH, 'r') as f:
    coco_data = json.load(f)

id_to_captions = defaultdict(list)
id_to_filename = {}

for ann in coco_data:
    img_id = ann['image_id']
    if ann['captions']:
        id_to_captions[img_id].extend(ann['captions'])
        id_to_filename[img_id] = f"{img_id:012d}.jpg"

# caption pool 구성
caption_pool = list(set([cap for caps in id_to_captions.values() for cap in caps]))
random.shuffle(caption_pool)
caption_pool = caption_pool[:caption_pool_size]

# caption pool 임베딩
with torch.no_grad():
    all_caption_feats = []
    for i in range(0, len(caption_pool), 32):
        caps = caption_pool[i:i+32]
        tokenized = tokenizer(caps, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        input_ids = tokenized["input_ids"].to(device)
        attention_mask = tokenized["attention_mask"].to(device)
        text_feat = model.encode_tokenized_input_text_only(input_ids, attention_mask)

        all_caption_feats.append(text_feat)

    caption_feats = torch.cat(all_caption_feats, dim=0)

# ----------------------------
# 이미지 인코딩 + caption 매칭
# ----------------------------
valid_ids = [img_id for img_id in id_to_filename if os.path.exists(os.path.join(COCO_IMAGE_DIR, id_to_filename[img_id]))]
random.shuffle(valid_ids)
selected_ids = valid_ids[:max_images]

generated = {}   # {img_id: [retrieved_caption]}
references = {}  # {img_id: [gt_caption]}

for img_id in tqdm(selected_ids, desc="Generating captions"):
    img_path = os.path.join(COCO_IMAGE_DIR, id_to_filename[img_id])
    try:
        image = transform(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception:
        continue

    tokens = tokenizer(id_to_captions[img_id][0], return_tensors="pt", padding="max_length", truncation=True, max_length=32)
    input_ids = tokens["input_ids"].to(device)
    attention_mask = tokens["attention_mask"].to(device)

    with torch.no_grad():
        img_feat, _ = model.encode_tokenized_input(image, input_ids, attention_mask)
        img_feat = F.normalize(img_feat, dim=-1)

    sims = torch.matmul(img_feat, caption_feats.T).squeeze(0)
    best_idx = sims.argmax().item()
    retrieved_caption = caption_pool[best_idx]

    generated[str(img_id)] = [retrieved_caption]
    references[str(img_id)] = [id_to_captions[img_id][0]]

# ----------------------------
# CIDEr 계산
# ----------------------------
def wrap_captions(d):
    return {str(k): [{"caption": c} for c in v] for k, v in d.items()}

gts = wrap_captions(references)
res = wrap_captions(generated)

ptb_tokenizer = PTBTokenizer()
gts_tok = ptb_tokenizer.tokenize(gts)
res_tok = ptb_tokenizer.tokenize(res)

cider_scorer = Cider()
score, _ = cider_scorer.compute_score(gts_tok, res_tok)

print(f"\n[Indirect Captioning Evaluation - CIDEr Metric (Open-domain)]")
print(f"CIDEr: {score:.4f}")


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Generating captions: 100%|████████████████████| 500/500 [00:50<00:00,  9.86it/s]
PTBTokenizer tokenized 6137 tokens at 69041.10 tokens per second.



[Indirect Captioning Evaluation - CIDEr Metric (Open-domain)]
CIDEr: 0.4171


PTBTokenizer tokenized 6096 tokens at 57335.21 tokens per second.


## 