In [None]:
import os
import random
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
from transformers import DistilBertTokenizer
from models import CLIPModel  # models.py 내의 CLIPModel (forward에 return_logits 옵션 추가)
from dataset import Flickr8kDataset
from torch.utils.tensorboard import SummaryWriter


In [None]:
#========================
# 1. Seed 설정 (재현성을 위해)
# ========================
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# ========================
# 2. Device 및 TensorBoard 설정
# ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter(log_dir="runs/clip_experiment")


# 
# ========================
# 3. 토크나이저 및 이미지 전처리 정의
# ========================
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711])
])


In [None]:

# ========================
# 4. Dataset 및 DataLoader 생성
# ------------------------
# train split 파일: Flickr_8k.trainImages.txt
train_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.trainImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# validation split 파일: Flickr_8k.devImages.txt
val_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.devImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# test split 파일: Flickr_8k.testImages.txt
test_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.testImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [None]:

# ========================
# 5. Model 및 옵티마이저 설정
# ========================
model = CLIPModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optional: 한 배치 데이터 모양 확인
for batch in test_loader:
    print("Image shape:", batch["image"].shape)       # 예: [32, 3, 224, 224]
    print("Input IDs shape:", batch["input_ids"].shape) # 예: [32, 40]
    print("Mask shape:", batch["mask"].shape)           # 예: [32, 40]
    break

# ========================
# 6. Training, Validation, Test 루프
# ========================
num_epochs = 5
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # --- Training phase ---
    model.train()
    train_loss = 0.0
    total_train_batches = 0
    for batch in train_loader:
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["mask"].to(device)
        inputs = {"image": images, "input_ids": input_ids, "mask": mask}
        
        loss = model(inputs)  # training 모드에서는 loss만 반환
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        total_train_batches += 1
    train_loss /= total_train_batches
    writer.add_scalar("Loss/train", train_loss, epoch)
    
    # --- Validation phase ---
    model.eval()
    val_loss = 0.0
    total_val_batches = 0
    correct_top1 = 0
    correct_top5 = 0
    total_samples = 0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            mask = batch["mask"].to(device)
            inputs = {"image": images, "input_ids": input_ids, "mask": mask}
            
            # 여기서 return_logits=True로 logits도 함께 반환받음
            loss, logits = model(inputs, return_logits=True)
            val_loss += loss.item()
            total_val_batches += 1
            
            batch_size = logits.size(0)
            total_samples += batch_size
            labels = torch.arange(batch_size).to(device)
            # Top-1 accuracy: 각 이미지가 가장 높은 점수를 가진 텍스트 예측
            preds_top1 = logits.argmax(dim=1)
            correct_top1 += (preds_top1 == labels).sum().item()
            # Top-5 accuracy: 각 이미지에 대해 top-5 예측 중 정답이 있는지 확인
            top5_preds = logits.topk(5, dim=1)[1]  # shape: [batch_size, 5]
            for i in range(batch_size):
                if labels[i] in top5_preds[i]:
                    correct_top5 += 1
                    
    val_loss /= total_val_batches
    top1_acc = 100.0 * correct_top1 / total_samples
    top5_acc = 100.0 * correct_top5 / total_samples
    writer.add_scalar("Loss/val", val_loss, epoch)
    writer.add_scalar("Accuracy/val_top1", top1_acc, epoch)
    writer.add_scalar("Accuracy/val_top5", top5_acc, epoch)
    
    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Top1: {top1_acc:.2f}% | Val Top5: {top5_acc:.2f}%")
    
    # --- Model Checkpoint 저장 (Validation Loss가 줄어들면)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved.")



In [None]:

# ========================
# 7. Test 루프
# ========================

# --- Test phase ---
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_loss = 0.0
correct_top1 = 0
correct_top5 = 0
total_samples = 0
with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["mask"].to(device)
        inputs = {"image": images, "input_ids": input_ids, "mask": mask}
        loss, logits = model(inputs, return_logits=True)
        test_loss += loss.item()
        
        batch_size = logits.size(0)
        total_samples += batch_size
        labels = torch.arange(batch_size).to(device)
        preds_top1 = logits.argmax(dim=1)
        correct_top1 += (preds_top1 == labels).sum().item()
        top5_preds = logits.topk(5, dim=1)[1]
        for i in range(batch_size):
            if labels[i] in top5_preds[i]:
                correct_top5 += 1
                
test_loss /= len(test_loader)
test_top1_acc = 100.0 * correct_top1 / total_samples
test_top5_acc = 100.0 * correct_top5 / total_samples
writer.add_scalar("Loss/test", test_loss, num_epochs)
writer.add_scalar("Accuracy/test_top1", test_top1_acc, num_epochs)
writer.add_scalar("Accuracy/test_top5", test_top5_acc, num_epochs)
print(f"Test Loss: {test_loss:.4f} | Test Top1 Accuracy: {test_top1_acc:.2f}% | Test Top5 Accuracy: {test_top5_acc:.2f}%")

writer.close()