In [None]:
import os
import random
import numpy as np
import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as T
from transformers import DistilBertTokenizer
from models import CLIPModel  # models.py 내의 CLIPModel (forward에 return_logits 옵션 추가)
from dataset import Flickr8kDataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt


In [3]:
#========================
# 1. Seed 설정 (재현성을 위해)
# ========================
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# ========================
# 2. Device설정
# ========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# 
# ========================
# 3. 토크나이저 및 이미지 전처리 정의
# ========================
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
transform = T.Compose([
    T.Resize((224,224)),
    T.ToTensor(),
    T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711])
])


In [4]:

# ========================
# 4. Dataset 및 DataLoader 생성
# ------------------------
# train split 파일: Flickr_8k.trainImages.txt
train_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.trainImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# validation split 파일: Flickr_8k.devImages.txt
val_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.devImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# test split 파일: Flickr_8k.testImages.txt
test_dataset = Flickr8kDataset(
    img_folder='images',
    caption_file='captions.txt',
    split_file='Flickr_8k.testImages.txt',  # 분할 파일 경로 지정
    transform=transform,
    tokenizer=tokenizer,
    max_length=40
)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [None]:

# ========================
# 5. Model 및 옵티마이저 설정
# ========================
model = CLIPModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Optional: 한 배치 데이터 모양 확인
for batch in test_loader:
    print("Image shape:", batch["image"].shape)       # 예: [32, 3, 224, 224]
    print("Input IDs shape:", batch["input_ids"].shape) # 예: [32, 40]
    print("Mask shape:", batch["mask"].shape)           # 예: [32, 40]
    break

# 손실 기록용 리스트
train_loss_history = []
val_loss_history = []

# ========================
# 6. Training, Validation, Test 루프
# ========================
num_epochs = 30
best_val_loss = float('inf')

for epoch in range(num_epochs):
    # --- Training phase ---
    model.train()
    train_loss = 0.0
    total_train_batches = 0
    # tqdm progress bar로 학습 배치 진행 상황 표시
    pbar = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    step=0
    for batch in pbar:
        step+=1
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["mask"].to(device)
        inputs = {"image": images, "input_ids": input_ids, "mask": mask}
        
        loss = model(inputs)  # training 모드에서는 loss만 반환
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        total_train_batches += 1

        # 10 스텝마다 현재 Loss 출력
        if step % 10 == 0:
            pbar.set_postfix({"Loss": loss.item()})
    train_loss /= total_train_batches
    train_loss_history.append(train_loss)


    
    # --- Validation phase ---
    model.eval()
    val_loss = 0.0
    total_val_batches = 0
    correct_top1 = 0
    correct_top5 = 0
    total_samples = 0
    with torch.no_grad():
        for batch in val_loader:
            images = batch["image"].to(device)
            input_ids = batch["input_ids"].to(device)
            mask = batch["mask"].to(device)
            inputs = {"image": images, "input_ids": input_ids, "mask": mask}
            
            # 여기서 return_logits=True로 logits도 함께 반환받음
            loss, logits = model(inputs, return_logits=True)
            val_loss += loss.item()
            total_val_batches += 1
            
            batch_size = logits.size(0)
            total_samples += batch_size
            labels = torch.arange(batch_size).to(device)
            # Top-1 accuracy: 각 이미지가 가장 높은 점수를 가진 텍스트 예측
            preds_top1 = logits.argmax(dim=1)
            correct_top1 += (preds_top1 == labels).sum().item()
            # Top-5 accuracy: 각 이미지 i에 대해 top-5 예측 중 정답이 있는지 확인
            top5_preds = logits.topk(5, dim=1)[1]  # shape: [batch_size, 5], [1]을 통해 인덱스를 담은 길이 5의 리스트 얻을 수 있음
            for i in range(batch_size): #각 이미지 i에 대해
                if labels[i] in top5_preds[i]:  
                    correct_top5 += 1
                    
    val_loss /= total_val_batches
    top1_acc = 100.0 * correct_top1 / total_samples
    top5_acc = 100.0 * correct_top5 / total_samples
    val_loss_history.append(val_loss)
    
    print(f"[Epoch {epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Top1: {top1_acc:.2f}% | Val Top5: {top5_acc:.2f}%")
    
    # 모델 체크포인트 저장 (발리데이션 손실이 개선되면)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved.")
        patience_counter = 0  # 개선되었으므로 카운터 리셋
    else:
        patience_counter += 1
        print(f"No improvement for {patience_counter} epoch(s).")
    
    # 만약 5번  이상 개선이 없으면 학습 조기 종료
    if patience_counter >= 5:
        print(f"Validation loss did not improve for {5} consecutive epochs. Early stopping triggered.")
        break



Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/anaconda3/envs/py310/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/anaconda3/envs/py310/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/yang/Desktop/CLIP_implementation/dataset.py", line 9, in <module>
    class Flickr8kDataset(Dataset):
  File "/Users/yang/Desktop/CLIP_implementation/dataset.py", line 21, in Flickr8kDataset
    tokenizer=DistilBertTokenizer.from_pretrained('distilbert-base-uncased'),
  File "/opt/anaconda3/envs/py310/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 1971, in from_pretrained
    resolved_config_file = cached_file(
  File "/opt/anaconda3/envs/py310/lib/python3.10/site-packages/transformers/utils/hub.py", line 342, in cached_file
    resolved_file = hf_hub_download(
  File "/opt/anaconda3/envs/py3

KeyboardInterrupt: 

In [None]:

# ========================
# 7. Test 루프
# ========================

# --- Test phase ---
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
test_loss = 0.0
correct_top1 = 0
correct_top5 = 0
total_samples = 0
with torch.no_grad():
    for batch in test_loader:
        images = batch["image"].to(device)
        input_ids = batch["input_ids"].to(device)
        mask = batch["mask"].to(device)
        inputs = {"image": images, "input_ids": input_ids, "mask": mask}
        loss, logits = model(inputs, return_logits=True)
        test_loss += loss.item()
        
        batch_size = logits.size(0)
        total_samples += batch_size
        labels = torch.arange(batch_size).to(device)
        preds_top1 = logits.argmax(dim=1)
        correct_top1 += (preds_top1 == labels).sum().item()
        top5_preds = logits.topk(5, dim=1)[1]
        for i in range(batch_size):
            if labels[i] in top5_preds[i]:
                correct_top5 += 1
                
test_loss /= len(test_loader)
test_top1_acc = 100.0 * correct_top1 / total_samples
test_top5_acc = 100.0 * correct_top5 / total_samples

print(f"Test Loss: {test_loss:.4f} | Test Top1 Accuracy: {test_top1_acc:.2f}% | Test Top5 Accuracy: {test_top5_acc:.2f}%")



In [None]:
# 학습/검증 손실 변화를 그래프로 출력
plt.figure(figsize=(8, 6))
plt.plot(range(1, num_epochs+1), train_loss_history, label="Train Loss")
plt.plot(range(1, num_epochs+1), val_loss_history, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()