In [49]:
import os

val = "val2017"
train_path = "/raid/kyscap251/team2/train2017/train2017"
val_path = "/shared/home/kyscap251/Team2/val2017"
# folder_path = "/raid/kyscap251/team2/val2017/val2017"
test = "train2017"

val_items = os.listdir(val_path)
train_items = os.listdir(train_path)

print("val", val_items[:5])
print("train", train_items[:5])

# 파일만 필터링
files = [f for f in os.listdir(val_path)
         if os.path.isfile(os.path.join(val_path, f))]

print(f"val파일 개수: {len(files)}")

filess = [ff for ff in os.listdir(train_path)
         if os.path.isfile(os.path.join(train_path, ff))]
print(f"train파일 개수: {len(filess)}")

val ['000000433103.jpg', '000000129113.jpg', '000000196843.jpg', '000000252507.jpg', '000000258541.jpg']
train ['000000427548.jpg', '000000367442.jpg', '000000574946.jpg', '000000215255.jpg', '000000016119.jpg']
val파일 개수: 5000
train파일 개수: 109933


In [50]:
# 환경 설정 및 라이브러리 로딩
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from transformers import BertTokenizer, BertModel
import json
import numpy as np
from PIL import Image
import os
from torch.utils.data import Dataset, DataLoader
import random
from tqdm import tqdm
from PIL import Image, UnidentifiedImageError, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [51]:
# # 이미지 인코더: ViT
# class VisionEncoder(nn.Module):
#     def __init__(self, output_dim=768):
#         super().__init__()
#         self.vit = models.vit_b_16(pretrained=True)
#         self.vit.heads = nn.Identity()  # classification head 제거
#         self.output_dim = output_dim

#     def forward(self, images):
#         patch_feats = self.vit(images)  # [B, D]
#         return patch_feats
    

from transformers import BertTokenizer, BertModel, ViTModel, ViTFeatureExtractor
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights


class VisionEncoder(nn.Module):
    def __init__(self, model_name='google/vit-base-patch16-224'):
        super().__init__()
        self.vit = ViTModel.from_pretrained(model_name)

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # [B, 1+P, D]



In [52]:
# 텍스트 인코더: BERT 기반 Transformer
class TextEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, captions):
        tokenized = self.tokenizer(captions, return_tensors="pt", padding=True, truncation=True).to(self.bert.device)
        outputs = self.bert(**tokenized)
        return outputs.last_hidden_state, tokenized
    

In [53]:
# Cross-Attention Block
class CrossAttentionBlock(nn.Module):
    def __init__(self, dim=768, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)

    def forward(self, text_emb, image_patches):
        attn_output, attn_weights = self.attn(text_emb, image_patches, image_patches)
        return attn_output, attn_weights

In [54]:
# # 모델
# class VisionLanguageModel(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.vision_encoder = VisionEncoder()
#         self.text_encoder = TextEncoder()
#         self.cross_attn = CrossAttentionBlock()
#         self.proj_image = nn.Linear(768, 512)
#         self.proj_text = nn.Linear(768, 512)

#     def forward(self, images, captions):
#         img_feat = self.vision_encoder(images)
#         text_emb, tokens = self.text_encoder(captions)
#         patch_feat = img_feat.unsqueeze(1)  # dummy patch feature (B, 1, D)
#         cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)
#         img_proj = self.proj_image(img_feat)
#         text_proj = self.proj_text(text_emb[:, 0])  # CLS token 기준
#         return img_proj, text_proj, attn_weights, tokens

#     def encode_for_inference(self, images, captions):
#         with torch.no_grad():
#             img_feat = self.vision_encoder(images)
#             text_emb, _ = self.text_encoder(captions)
#             img_proj = self.proj_image(img_feat)
#             text_proj = self.proj_text(text_emb[:, 0])  # CLS token 기준
#         return img_proj, text_proj
    
# -------------------------------
# Vision-Language Model
# -------------------------------
class VisionLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.cross_attn = CrossAttentionBlock()
        self.proj_image = nn.Linear(768, 512)
        self.proj_text = nn.Linear(768, 512)

    def forward(self, images, captions):
        img_feat_all = self.vision_encoder(images)  # [B, 1+P, D]
        cls_feat = img_feat_all[:, 0]               # [B, D] CLS
        patch_feat = img_feat_all[:, 1:]            # [B, P, D] patches only

        text_emb, tokens = self.text_encoder(captions)  # [B, T, D]
        cross_out, attn_weights = self.cross_attn(text_emb, patch_feat)  # [B, T, D], [B, T, P]

        img_proj = self.proj_image(cls_feat)        # [B, 512]
        text_proj = self.proj_text(text_emb[:, 0])  # [B, 512]

        return img_proj, text_proj, attn_weights, tokens

    def encode_for_inference(self, images, captions): #랜덤 3개 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)
            cls_feat = img_feat_all[:, 0]
            text_emb, _ = self.text_encoder(captions)
            img_proj = self.proj_image(cls_feat)
            text_proj = self.proj_text(text_emb[:, 0])
        return img_proj, text_proj
    
    def encode_tokenized_input(self, images, input_ids, attention_mask): #평균 유사도 평가용
        with torch.no_grad():
            img_feat_all = self.vision_encoder(images)        # (B, 197, 768)
            cls_feat = img_feat_all[:, 0, :]                  # CLS token만 추출 (B, 768)
            img_proj = self.proj_image(cls_feat)              # (B, 512)

            bert_out = self.text_encoder.bert(input_ids=input_ids, attention_mask=attention_mask)
            txt_cls = bert_out.last_hidden_state[:, 0, :]     # (B, 768)
            txt_proj = self.proj_text(txt_cls)                # (B, 512)

            return img_proj, txt_proj


In [55]:
# Consistency Loss 계산 함수
def compute_consistency_loss(attn_weights, masks, eps=1e-6):
    B, T, H, W = masks.shape
    masks_flat = masks.view(B, T, -1)
    scores = (attn_weights * masks_flat).sum(dim=-1)
    scores = torch.clamp(scores, min=eps, max=1.0)
    return -torch.log(scores).mean()

In [56]:
# CLIP Contrastive Loss

def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    image_embeds = F.normalize(image_embeds, dim=-1)
    text_embeds = F.normalize(text_embeds, dim=-1)
    logits = image_embeds @ text_embeds.T / temperature
    labels = torch.arange(len(image_embeds)).to(image_embeds.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

In [57]:
# 전처리된 JSON 로딩 및 binary mask 생성
class CocoVLMDataset(Dataset):
    def __init__(self, json_path, image_root, transform=None, patch_size=16, max_tokens=10):
        with open(json_path, 'r') as f:
            all_data = json.load(f)
        self.image_root = image_root
        self.transform = transform
        self.patch_size = patch_size
        self.max_tokens = max_tokens
        self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        
        self.data = []
        for entry in all_data:
            image_id = entry["image_id"]
            image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
            if os.path.exists(image_path):
                self.data.append(entry)

        print(f"유효 이미지 수: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        image_id = entry["image_id"]
        image_path = os.path.join(self.image_root, f"{image_id:012d}.jpg")
        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, UnidentifiedImageError, OSError):
            print(f"[WARN] 이미지 불러오기 실패: {image_path}")
            return self.__getitem__((idx + 1) % len(self))  # 다음 인덱스로 재시도
        if self.transform:
            image = self.transform(image)
        captions = entry["captions"]
        matches = entry["matches"][:self.max_tokens]
        caption = captions[0]
        H, W = 224 // self.patch_size, 224 // self.patch_size
        masks = torch.zeros((self.max_tokens, H, W))
        for i, match in enumerate(matches):
            x, y, w, h = match["bbox"]
            x1 = int(x // self.patch_size)
            y1 = int(y // self.patch_size)
            x2 = int((x + w) // self.patch_size)
            y2 = int((y + h) // self.patch_size)
            masks[i, y1:y2+1, x1:x2+1] = 1.0
        return image, caption, masks


In [58]:
# Collate Function

def coco_collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    captions = [item[1] for item in batch]
    masks = torch.stack([item[2] for item in batch])
    return images, captions, masks


In [59]:
# DataLoader 생성
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = CocoVLMDataset(
    json_path="val_coco_token_bbox_matched.json",
#    image_root="/raid/kyscap251/team2/val2017/val2017",
    image_root = "/shared/home/kyscap251/Team2/val2017",
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=coco_collate_fn)


유효 이미지 수: 5000


In [60]:
class LambdaScheduler: #람다 점진적 높이기
    def __init__(self, alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5):
        self.ema = None
        self.alpha = alpha
        self.threshold = threshold
        self.anneal_rate = anneal_rate
        self.max_lambda = max_lambda
        self.apply = False
        self.step = 0
        self.current_lambda = 0.0

    def update(self, attn_weights, return_diff=False):
        """
        attn_weights: Tensor of shape [B, T, P] (cross-attn weights)
        """
        attn_mean = attn_weights.mean(dim=1).mean(dim=0)  # [P]
        if self.ema is None:
            self.ema = attn_mean
        else:
            self.ema = self.alpha * attn_mean + (1 - self.alpha) * self.ema

        diff = torch.abs(self.ema - attn_mean).mean().item()

        if not self.apply and diff < self.threshold:
            self.apply = True
            print(f"[LambdaScheduler] Consistency loss ON (diff={diff:.6f})")

        if self.apply:
            self.current_lambda = min(self.max_lambda, self.anneal_rate * self.step)
            self.step += 1
        else:
            self.current_lambda = 0.0

        if return_diff:
            return self.current_lambda, diff
        else:
            return self.current_lambda

In [61]:
# 전체 Epoch 학습 루프 (Matched Token만 사용하도록 수정)
def train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=5):
    model.train()
#     ////////////
    lambda_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5)
#     ////////////////////
#     lambda_log = [] #λ 변화 기록용 리스트
#     diff_log = [] #diff값 기록용 리스트
    
    for epoch in range(num_epochs):
        total_loss = 0
        total_acc = 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        
        for images, captions, masks in progress:
            images, masks = images.to(device), masks.to(device)
            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            # matched token 수만큼 attention weight 슬라이싱
            T_mask = masks.shape[1]
            attn_weights_matched = attn_weights[:, :T_mask, :] #bbox에 해당하는 token만 consistency loss 계산에 사용

            
#             //////////////
            # 현재 λ 값을 EMA 기반으로 계산 및 기록
            lambda_cons, diff = lambda_scheduler.update(attn_weights_matched, return_diff=True)
#             lambda_log.append(lambda_cons)
#             diff_log.append(diff)
# /////////////////////////////////////
            
            # 손실 계산
            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            loss_consistency = compute_consistency_loss(attn_weights_matched, masks)
            loss = loss_contrastive + lambda_cons * loss_consistency

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # cosine similarity + accuracy 계산
            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                sims = torch.diag(sim_matrix)
                sim_mean = sims.mean().item()
                sim_std = sims.std().item()

                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            progress.set_postfix({"loss": loss.item(), "cos_sim": f"{sim_mean:.3f}±{sim_std:.3f}", "acc": f"{acc:.3f}"})

#         # scheduler 업데이트
#         if scheduler is not None:
#             scheduler.step()
#             print(f"[Epoch {epoch+1}] LR: {scheduler.get_last_lr()[0]:.2e}")
        avg_loss = total_loss / len(dataloader)
        avg_acc = total_acc / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}, Avg Accuracy: {avg_acc:.4f}")


In [62]:
# # 모델 학습 실행

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = VisionLanguageModel().to(device)
# optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# train_model(model, dataloader, optimizer, device, lambda_cons=0.05, num_epochs=3)


In [63]:
# 랜덤 3개 추론
def run_batch_inference(model, val_image_dir, caption_json_path, transform, device, sample_size=3):
    with open(caption_json_path, 'r') as f:
        coco_captions = json.load(f)

    # image_id → caption 매핑
    imgid2caption = {}
    for ann in coco_captions['annotations']:
        imgid = ann['image_id']
        if imgid not in imgid2caption:
            imgid2caption[imgid] = []
        imgid2caption[imgid].append(ann['caption'])

    # 랜덤 샘플링 (image_id 3개)
    img_ids = random.sample(list(imgid2caption.keys()), sample_size)
    captions = [imgid2caption[i][0] for i in img_ids]
    image_paths = [os.path.join(val_image_dir, f"{i:012d}.jpg") for i in img_ids]
    images_tensor = torch.stack([
        transform(Image.open(p).convert("RGB")) for p in image_paths
    ]).to(device)

    # 인코딩
    model.eval()
    image_embeds, text_embeds = model.encode_for_inference(images_tensor, captions)

    # 코사인 유사도
    sim_matrix = F.cosine_similarity(image_embeds.unsqueeze(1), text_embeds.unsqueeze(0), dim=-1)

    # 출력
    print("image_embeds shape:", image_embeds.shape)
    print("text_embeds shape :", text_embeds.shape)
    print("\n\U0001F4CA Cosine Similarity Matrix:\n")
    for i, img_id in enumerate(img_ids):
        print(f"\U0001F5BC️ {img_id:012d}.jpg")
        for j, cap in enumerate(captions):
            print(f"  \"{cap}\" → similarity: {sim_matrix[i, j]:.4f}")
        print()

In [64]:
run_batch_inference(
    model,
    val_image_dir="val2017",
    caption_json_path="annotations/captions_val2017.json",
    transform=transform,
    device=device,
    sample_size=3 #랜덤으로 이미지 세장
)

image_embeds shape: torch.Size([3, 512])
text_embeds shape : torch.Size([3, 512])

📊 Cosine Similarity Matrix:

🖼️ 000000185472.jpg
  "The train is approaching and a man is getting off his bicycle." → similarity: 0.2968
  "A filtered image of a microwave available to use in a store. " → similarity: -0.1456
  "A fat orange cat on a couch beside a TV remote" → similarity: -0.1866

🖼️ 000000222455.jpg
  "The train is approaching and a man is getting off his bicycle." → similarity: -0.1226
  "A filtered image of a microwave available to use in a store. " → similarity: 0.7122
  "A fat orange cat on a couch beside a TV remote" → similarity: 0.0761

🖼️ 000000271728.jpg
  "The train is approaching and a man is getting off his bicycle." → similarity: -0.0907
  "A filtered image of a microwave available to use in a store. " → similarity: 0.1085
  "A fat orange cat on a couch beside a TV remote" → similarity: 0.6556



In [65]:
from torch.utils.data import Dataset
from PIL import Image
import os
import json

class SimpleCocoCaptionDataset(Dataset):
    def __init__(self, caption_json_path, image_root, transform, tokenizer):
        with open(caption_json_path, 'r') as f:
            data = json.load(f)

        self.imgid2caption = {}
        for ann in data['annotations']:
            img_id = ann['image_id']
            if img_id not in self.imgid2caption:
                self.imgid2caption[img_id] = ann['caption']

        self.image_ids = list(self.imgid2caption.keys())
        self.image_root = image_root
        self.transform = transform
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_root, f"{img_id:012d}.jpg")
        image = self.transform(Image.open(img_path).convert("RGB"))
        caption = self.imgid2caption[img_id]

        encoding = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=32)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)

        return image, input_ids, attention_mask


In [66]:
from torch.utils.data import DataLoader

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# val_dataset = SimpleCocoCaptionDataset(
#     caption_json_path="annotations/captions_val2017.json",
#     image_root="val2017",
#     transform=transform,
#     tokenizer=tokenizer
# )

# val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

val_dataset = SimpleCocoCaptionDataset(
    caption_json_path="annotations/captions_val2017.json",
    image_root="val2017",
    transform=transform,
    tokenizer=tokenizer
)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [67]:
# 정답,오답쌍 평균 유사도 계산 추론
def evaluate_mean_similarity(model, val_loader, device):
    model.eval()
    total_correct_sim = 0.0
    total_incorrect_sim = 0.0
    correct_count = 0
    incorrect_count = 0

    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            image_feat, text_feat = model.encode_tokenized_input(images, input_ids, attention_mask)  # (B, 512)

            sim_matrix = F.cosine_similarity(
                image_feat.unsqueeze(1),  # (B, 1, D)
                text_feat.unsqueeze(0),  # (1, B, D)
                dim=-1
            )  # (B, B)

            B = sim_matrix.size(0)
            correct_sims = sim_matrix.diag()
            total_correct_sim += correct_sims.sum().item()
            correct_count += B

            mask = ~torch.eye(B, dtype=torch.bool, device=device)
            incorrect_sims = sim_matrix[mask]
            total_incorrect_sim += incorrect_sims.sum().item()
            incorrect_count += incorrect_sims.numel()

    mean_correct = total_correct_sim / correct_count
    mean_incorrect = total_incorrect_sim / incorrect_count

    print(f"\nEvaluation Results:")
    print(f" - Mean Correct Sim    : {mean_correct:.4f}")
    print(f" - Mean Incorrect Sim  : {mean_incorrect:.4f}")
    return mean_correct, mean_incorrect


In [68]:
mean_correct, mean_incorrect = evaluate_mean_similarity(model, val_loader, device)

Evaluating: 100%|█████████████████████████████| 157/157 [08:47<00:00,  3.36s/it]


Evaluation Results:
 - Mean Correct Sim    : 0.6390
 - Mean Incorrect Sim  : 0.0512





In [69]:
print("후기단계 시작")

후기단계 시작


In [70]:
def kl_divergence_attention(prev_attn_scalar, curr_attn, eps=1e-6):
    """
    prev_attn_scalar: float
    curr_attn: [B, T, P]
    """
    curr_mean = torch.clamp(curr_attn, min=eps).mean().item()
    return abs(prev_attn_scalar - curr_mean)

In [71]:
def update_ema(prev_ema, new_val, decay=0.9):
    """
    new_val: [B, T, P] → scalar
    prev_ema: scalar
    """
    new_val_mean = new_val.mean().item()  # float
    if prev_ema is None:
        return new_val_mean
    return decay * prev_ema + (1 - decay) * new_val_mean


In [72]:
# # === Contrastive Loss ===
# def clip_contrastive_loss(image_embeds, text_embeds, temperature=0.07):
#     image_embeds = F.normalize(image_embeds, dim=-1)
#     text_embeds = F.normalize(text_embeds, dim=-1)
#     logits = image_embeds @ text_embeds.T / temperature
#     labels = torch.arange(len(image_embeds)).to(image_embeds.device)
#     loss_i2t = F.cross_entropy(logits, labels)
#     loss_t2i = F.cross_entropy(logits.T, labels)
#     return (loss_i2t + loss_t2i) / 2

# # === Consistency Loss ===
# def compute_consistency_loss(attn_weights, masks, eps=1e-6):
#     B, T, H, W = masks.shape
#     masks_flat = masks.view(B, T, -1)
#     scores = (attn_weights * masks_flat).sum(dim=-1)
#     scores = torch.clamp(scores, min=eps, max=1.0)
#     return -torch.log(scores).mean()

In [73]:
# import torch
# import torch.nn.functional as F
# from torch.optim import AdamW
# from tqdm import tqdm

# # === 1. EMA 기반 안정 시점 추정 함수 ===
# def train_model_with_ema(model, dataloader, optimizer, device, lambda_cons=0.05, ema_decay=0.9, num_epochs=10, std_threshold=0.01):
#     model.train()
#     ema_attn_scalar = None
#     ema_scores_history = []
#     stable_epoch = 0
    
#     for epoch in range(num_epochs):
#         total_loss, total_acc = 0, 0
#         progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
#         ema_epoch_scores = []
        
#         for images, captions, masks in progress:
#             images, masks = images.to(device), masks.to(device)
#             img_proj, txt_proj, attn_weights, _ = model(images, captions)
#             loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
#             T_mask = masks.shape[1]
            
#             loss_bbox_cons = compute_consistency_loss(attn_weights[:, :T_mask, :], masks)
#             curr_attn_mean = attn_weights.mean().item()
#             ema_epoch_scores.append(curr_attn_mean)
#             loss = loss_contrastive + lambda_cons * loss_bbox_cons
            
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             with torch.no_grad():
#                 sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
#                 pred = sim_matrix.argmax(dim=1)
#                 labels = torch.arange(sim_matrix.size(0)).to(device)
#                 acc = (pred == labels).float().mean().item()
                
#             total_loss += loss.item()
#             total_acc += acc
#             progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})
            
#         if ema_attn_scalar is None:
#             ema_attn_scalar = np.mean(ema_epoch_scores)
            
#         else:
#             ema_attn_scalar = ema_decay * ema_attn_scalar + (1 - ema_decay) * np.mean(ema_epoch_scores)
#         ema_scores_history.append(ema_attn_scalar)
        
#         if len(ema_scores_history) > 3:
#             ema_scores_history.pop(0)
            
#         if len(ema_scores_history) == 3:
#             std_ema = np.std(ema_scores_history)
            
#             if std_ema < std_threshold:
#                 stable_epoch = epoch
                
#                 print(f"[Warm-up 종료 기준 충족] Epoch: {epoch} (std_ema={std_ema:.6f})")
#         print(f"[Epoch {epoch+1}] Avg Loss: {total_loss / len(dataloader):.4f}, Avg Acc: {total_acc / len(dataloader):.4f}")
#     print(f"\nEMA 기반 최종 Warm-up 종료 기준 (3회 std_ema < {std_threshold}): {stable_epoch}")
#     return stable_epoch

# import numpy as np
# import torch
# import torch.nn.functional as F
# from tqdm import tqdm
# from torch.optim import AdamW

# def train_model_with_ema(model, dataloader, optimizer, device,
#                          lambda_cons=0.05, ema_decay=0.9,
#                          num_epochs=10, std_threshold=0.01):
#     model.train()
    
#     ema_attn_scalar = None
#     ema_scores_history = []
#     stable_epoch = 0
    
#     for epoch in range(num_epochs):
#         total_loss, total_acc = 0, 0
#         progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        
#         for images, captions, masks in progress:
#             images, masks = images.to(device), masks.to(device)
            
#             # 모델 forward
#             img_proj, txt_proj, attn_weights, _ = model(images, captions)
#             loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
#             T_mask = masks.shape[1]
#             loss_bbox_cons = compute_consistency_loss(attn_weights[:, :T_mask, :], masks)
#             loss = loss_contrastive + lambda_cons * loss_bbox_cons
            
#             # Optimizer update
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
            
#             # Cosine similarity accuracy 계산
#             with torch.no_grad():
#                 sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
#                 pred = sim_matrix.argmax(dim=1)
#                 labels = torch.arange(sim_matrix.size(0)).to(device)
#                 acc = (pred == labels).float().mean().item()
                
#             total_loss += loss.item()
#             total_acc += acc
#             progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})
            
#             # === EMA 업데이트 (배치 단위로!) ===
#             curr_attn_mean = attn_weights.mean().item()
#             if ema_attn_scalar is None:
#                 ema_attn_scalar = curr_attn_mean
#             else:
#                 ema_attn_scalar = ema_decay * ema_attn_scalar + (1 - ema_decay) * curr_attn_mean
        
#         # === Epoch마다 EMA 기록 및 안정화 판단 ===
#         ema_scores_history.append(ema_attn_scalar)
#         if len(ema_scores_history) > 3:
#             ema_scores_history.pop(0)
        
#         if len(ema_scores_history) == 3:
#             std_ema = np.std(ema_scores_history)
#             if std_ema < std_threshold:
#                 stable_epoch = epoch
#                 print(f"[Warm-up 종료 기준 충족] Epoch: {epoch} (std_ema={std_ema:.6f})")
        
#         # Epoch별 평균 로그
#         avg_loss = total_loss / len(dataloader)
#         avg_acc = total_acc / len(dataloader)
#         print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}, Avg Acc: {avg_acc:.4f}")
    
#     print(f"\nEMA 기반 최종 Warm-up 종료 기준 (3회 std_ema < {std_threshold}): {stable_epoch}")
#     return stable_epoch


In [74]:
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

def train_model_with_best_epoch(model, dataloader, optimizer, device,
                                lambda_cons=0.05, num_epochs=5):
    model.train()
    
    best_loss = float('inf')
    best_acc = 0.0
    best_epoch = 0
    
    for epoch in range(num_epochs):
        total_loss, total_acc = 0, 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        
        for images, captions, masks in progress:
            images, masks = images.to(device), masks.to(device)
            
            # 모델 forward
            img_proj, txt_proj, attn_weights, _ = model(images, captions)
            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)
            T_mask = masks.shape[1]
            loss_bbox_cons = compute_consistency_loss(attn_weights[:, :T_mask, :], masks)
            loss = loss_contrastive + lambda_cons * loss_bbox_cons
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Cosine similarity accuracy 계산
            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()
            
            total_loss += loss.item()
            total_acc += acc
            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})
        
        # Epoch별 평균
        avg_loss = total_loss / len(dataloader)
        avg_acc = total_acc / len(dataloader)
        print(f"[Epoch {epoch+1}] Avg Loss: {avg_loss:.4f}, Avg Acc: {avg_acc:.4f}")
        
        # 가장 성능 좋은 epoch를 업데이트
        if avg_loss < best_loss or (avg_loss == best_loss and avg_acc > best_acc):
            best_loss = avg_loss
            best_acc = avg_acc
            best_epoch = epoch
    
    print(f"\n Best Warm-up Epoch: {best_epoch} (Loss: {best_loss:.4f}, Acc: {best_acc:.4f})")
    return best_epoch


In [75]:
class EMA:
    def __init__(self, alpha=0.2):
        self.alpha = alpha
        self.ema = None

    def update(self, value):
        with torch.no_grad():
            v = value.mean().item()
            if self.ema is None:
                self.ema = v
            else:
                self.ema = self.alpha * v + (1 - self.alpha) * self.ema
        return self.ema

    def reset(self):
        self.ema = None

In [76]:
def train_model(model, dataloader, optimizer, device,
                lambda_self=0.1, warmup_epochs=3, num_epochs=5,
                start_epoch=0, ema_tracker=None):

    model.train()

    if ema_tracker is None:
        ema_tracker = EMA(alpha=0.2)
    use_ema = False
    lambda_scheduler = LambdaScheduler(alpha=0.2, threshold=0.01, anneal_rate=0.01, max_lambda=0.5)

    for epoch in range(start_epoch, start_epoch + num_epochs):
        total_loss, total_acc = 0, 0
        progress = tqdm(dataloader, desc=f"Epoch {epoch+1}/{start_epoch + num_epochs}", leave=False)

        if epoch == warmup_epochs:
            ema_tracker.reset()
            use_ema = True
            print(f"[Epoch {epoch}] EMA 기준 self-consistency 시작")

        for images, captions, masks in progress:
            images, masks = images.to(device), masks.to(device)
            img_proj, txt_proj, attn_weights, _ = model(images, captions)

            T_mask = masks.shape[1]
            attn_weights_matched = attn_weights[:, :T_mask, :]
            loss_contrastive = clip_contrastive_loss(img_proj, txt_proj)

            if epoch < warmup_epochs:
                lambda_cons, _ = lambda_scheduler.update(attn_weights_matched, return_diff=True)
                loss_consistency = compute_consistency_loss(attn_weights_matched, masks)
                loss = loss_contrastive + lambda_cons * loss_consistency
            else:
                ema_attn = ema_tracker.update(attn_weights_matched)
                loss_self_cons = kl_divergence_attention(ema_attn, attn_weights_matched.detach())
                loss = loss_contrastive + lambda_self * loss_self_cons

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                sim_matrix = F.cosine_similarity(img_proj.unsqueeze(1), txt_proj.unsqueeze(0), dim=-1)
                pred = sim_matrix.argmax(dim=1)
                labels = torch.arange(sim_matrix.size(0)).to(device)
                acc = (pred == labels).float().mean().item()

            total_loss += loss.item()
            total_acc += acc
            progress.set_postfix({"loss": loss.item(), "acc": f"{acc:.3f}"})

        print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(dataloader):.4f}, Avg Acc: {total_acc / len(dataloader):.4f}")

    return ema_tracker



In [77]:
# # 3. 전체 실행

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = VisionLanguageModel().to(device)
# optimizer = AdamW(model.parameters(), lr=5e-5)

# # Step 1: attention 안정 시점 추정
# stable_epoch = train_model_with_ema(
#     model, dataloader, optimizer, device,
#     lambda_cons=0.05,
#     ema_decay=0.9,
#     num_epochs=10
# )

# # Step 2: 이후 self-consistency 중심 학습
# ema_tracker = train_model(
#     model, dataloader, optimizer, device,
#     lambda_self=0.1,
#     warmup_epochs=stable_epoch,
#     num_epochs=10,
#     start_epoch=stable_epoch
# )

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionLanguageModel().to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Warm-up 단계
stable_epoch = train_model_with_best_epoch(
    model, dataloader, optimizer, device,
    lambda_cons=0.05,
    num_epochs=5
)

# 이후 단계 (Self-consistency)
ema_tracker = train_model(
    model, dataloader, optimizer, device,
    lambda_self=0.1,
    warmup_epochs=stable_epoch,
    num_epochs=10,
    start_epoch=stable_epoch
)


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                                                

[Epoch 1] Avg Loss: 0.9357, Avg Acc: 0.8904


                                                                                

[Epoch 2] Avg Loss: 0.7615, Avg Acc: 0.9484


                                                                                

[Epoch 3] Avg Loss: 0.7248, Avg Acc: 0.9634


                                                                                

[Epoch 4] Avg Loss: 0.7193, Avg Acc: 0.9656


                                                                                

[Epoch 5] Avg Loss: 0.7178, Avg Acc: 0.9670

✅ Best Warm-up Epoch: 4 (Loss: 0.7178, Acc: 0.9670)


Epoch 5/14:   0%|                                      | 0/1250 [00:00<?, ?it/s]

[Epoch 4] EMA 기준 self-consistency 시작


                                                                                

Epoch 5 - Avg Loss: 0.0916, Avg Acc: 0.9638


Epoch 6/14:  99%|███▉| 1240/1250 [24:14<00:09,  1.05it/s, loss=0.013, acc=1.000]