In [None]:
import json
import os
import csv
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import BertModel, BertTokenizer
import torch.optim as optim
from tqdm import tqdm
from collections import defaultdict

#===========================================================
# Config 區
#===========================================================
TRAIN_JSON_PATH = './train.jsonl'
TEST_JSON_PATH = './test.jsonl'
TEST_IMAGES_JSON_PATH = './test_images.jsonl'
TRAIN_IMAGES_DIR = './train_images/train_images'
TEST_IMAGES_DIR = './test_images/test_images'
OUTPUT_CSV = 'submission.csv'
TOP_K = 30
EPOCHS = 5
LR = 1e-4
BATCH_SIZE = 16
TEMPERATURE = 0.07  # CLIP-style temperature

#===========================================================
# 通用函式
#===========================================================
def load_jsonl(file_path):
    """讀取 JSONL 檔案並返回資料列表"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def merge_dialogue_messages(dialogue):
    """將 dialogue 列表中所有 message 合併為單一字串"""
    messages = [turn['message'] for turn in dialogue if turn['message'].strip() != '']
    return " ".join(messages) if messages else ""

#===========================================================
# 圖片與文字預處理 (加入資料增強)
#===========================================================
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),  # 資料增強：水平翻轉
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def encode_text(text, tokenizer=bert_tokenizer, max_length=128):
    inputs = tokenizer(text, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
    return inputs

#===========================================================
# Dataset 定義
#===========================================================
class DialogueImageTrainDataset(Dataset):
    """訓練資料集：包含對話與圖片資訊"""
    def __init__(self, data, image_dir, tokenizer, transform):
        self.data = data
        self.image_dir = image_dir
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # 將 dialogue 合併為一個文字字串
        dialogue_text = merge_dialogue_messages(item['dialogue'])
        
        photo_path = item['photo_path']
        photo_id = item['photo_id']

        # 載入圖片
        image_path = os.path.join(self.image_dir, photo_path)
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        
        # 編碼文字
        text_inputs = encode_text(dialogue_text, self.tokenizer)
        
        return {
            'dialogue_id': item['dialogue_id'],
            'photo_id': photo_id,
            'image': image,
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0)
        }

class DialogueTestDataset(Dataset):
    """測試資料集：僅包含對話 (需用於檢索 test_images)"""
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # 將 dialogue 合併成文字字串
        dialogue_text = merge_dialogue_messages(item['dialogue'])
        text_inputs = encode_text(dialogue_text, self.tokenizer)
        
        return {
            'dialogue_id': item['dialogue_id'],
            'input_ids': text_inputs['input_ids'].squeeze(0),
            'attention_mask': text_inputs['attention_mask'].squeeze(0)
        }

class TestImageDataset(Dataset):
    """測試集圖片資料集：用於產生所有 test_images 的特徵"""
    def __init__(self, data, image_dir, transform):
        self.data = data
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        photo_id = item['photo_id']
        photo_path = item['photo_path']
        image_path = os.path.join(self.image_dir, photo_path)
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found: {image_path}")
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        return {
            'photo_id': photo_id,
            'image': image
        }

#===========================================================
# 模型定義 (Dual Encoder 範例)
#===========================================================
class DualEncoder(nn.Module):
    def __init__(self, text_model_name='bert-base-uncased'):
        super(DualEncoder, self).__init__()
        self.text_encoder = BertModel.from_pretrained(text_model_name)
        
        resnet = models.resnet50(pretrained=True)
        resnet.fc = nn.Identity()
        self.image_encoder = resnet
        
        self.hidden_dim = 2048
        self.text_proj = nn.Linear(self.text_encoder.config.hidden_size, self.hidden_dim)
        self.image_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
        
    def encode_text(self, input_ids, attention_mask):
        outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :] 
        return self.text_proj(cls_emb)
    
    def encode_image(self, image):
        img_emb = self.image_encoder(image)
        return self.image_proj(img_emb)
    
    def forward(self, input_ids, attention_mask, image):
        text_emb = self.encode_text(input_ids, attention_mask)
        image_emb = self.encode_image(image)
        return text_emb, image_emb

#===========================================================
# 載入資料
#===========================================================
train_data = load_jsonl(TRAIN_JSON_PATH)
test_data = load_jsonl(TEST_JSON_PATH)
test_images_data = load_jsonl(TEST_IMAGES_JSON_PATH)

train_dataset = DialogueImageTrainDataset(train_data, TRAIN_IMAGES_DIR, bert_tokenizer, image_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

test_dataset = DialogueTestDataset(test_data, bert_tokenizer)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

test_img_dataset = TestImageDataset(test_images_data, TEST_IMAGES_DIR, image_transform)
test_img_loader = DataLoader(test_img_dataset, batch_size=16, shuffle=False)

#===========================================================
# 訓練模型 (CLIP-style 對比學習，加入雙向對比)
#===========================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DualEncoder().to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

# 使用學習率排程 (CosineAnnealingLR)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

model.train()
for epoch in range(EPOCHS):
    epoch_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{EPOCHS}"):
        input_ids = batch['input_ids'].to(device)            # (batch, seq_len)
        attention_mask = batch['attention_mask'].to(device)  # (batch, seq_len)
        images = batch['image'].to(device)                   # (batch, C, H, W)

        optimizer.zero_grad()

        # 前向傳遞得到文字與圖片的嵌入
        text_emb, image_emb = model(input_ids, attention_mask, images) 
        # 對 embedding 做 L2 正規化
        text_emb_norm = text_emb / text_emb.norm(dim=1, keepdim=True)
        image_emb_norm = image_emb / image_emb.norm(dim=1, keepdim=True)

        # 計算相似度矩陣 (batch, batch) 並考慮溫度參數
        logits = torch.matmul(text_emb_norm, image_emb_norm.T) / TEMPERATURE
        labels = torch.arange(logits.size(0)).to(device)

        # 雙向對比學習損失計算 (text->image & image->text)
        logits_t2i = logits
        logits_i2t = logits.T
        loss_t2i = criterion(logits_t2i, labels)
        loss_i2t = criterion(logits_i2t, labels)

        # 將兩個方向的損失平均
        loss = (loss_t2i + loss_i2t) / 2
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    scheduler.step()  # 更新學習率

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss}")

# 訓練完成後將模型轉為評估模式並保存
model.eval()
torch.save(model.state_dict(), "dual_encoder_model.pth")
print("Model training completed and saved.")

#===========================================================
# 推論 (生成提交檔案)
#===========================================================
model.eval()

# 產生測試圖片嵌入向量 (保存在 device 上)
test_image_embeddings = {}
with torch.no_grad():
    for batch in test_img_loader:
        images = batch['image'].to(device)
        photo_ids = batch['photo_id']
        img_emb = model.encode_image(images)  # (batch, hidden_dim)
        # 推論階段可使用正規化以保持一致
        img_emb = img_emb / img_emb.norm(dim=1, keepdim=True)
        for pid, emb in zip(photo_ids, img_emb):
            test_image_embeddings[pid] = emb

predictions = []

all_img_ids = list(test_image_embeddings.keys())
if len(all_img_ids) > 0:
    all_img_embs = torch.stack([test_image_embeddings[i] for i in all_img_ids]).to(device)  # (num_imgs, hidden_dim)
    # 已正規化過
    K = min(TOP_K, len(all_img_ids))

    with torch.no_grad():
        for batch in test_loader:
            dialogue_id = batch['dialogue_id'].item()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            text_emb = model.encode_text(input_ids, attention_mask) # (1, hidden_dim)
            text_emb_norm = text_emb / text_emb.norm(dim=1, keepdim=True)
            
            similarity = torch.matmul(text_emb_norm, all_img_embs.T).squeeze(0) # (num_imgs,)
            topk_vals, topk_indices = torch.topk(similarity, K)
            top_photo_ids = [all_img_ids[i] for i in topk_indices.cpu().tolist()]
            
            # 將所有 top_photo_ids 加入 predictions 清單
            for pid in top_photo_ids:
                predictions.append((dialogue_id, pid))
else:
    print("Warning: No test images found. No predictions can be made.")

#===========================================================
# 合併同一 dialogue_id 的 photo_id 並輸出
#===========================================================
dialogue_to_photos = defaultdict(list)
for d_id, p_id in predictions:
    dialogue_to_photos[d_id].append(p_id)

def generate_submission(dialogue_to_photos, output_path):
    with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['dialogue_id', 'photo_id'])
        for d_id, p_ids in dialogue_to_photos.items():
            photo_str = " ".join(p_ids)
            writer.writerow([d_id, photo_str])

if dialogue_to_photos:
    generate_submission(dialogue_to_photos, OUTPUT_CSV)
    print(f"Submission file saved to {OUTPUT_CSV}")
else:
    print("No predictions were made.")

#===========================================================
# 計算 Recall@30 (若需要可執行)
#===========================================================
dialogue_to_gt = {}
for item in test_data:
    d_id = item['dialogue_id']
    gt_pid = item['photo_id']
    dialogue_to_gt[d_id] = gt_pid

correct_count = 0
total_count = 0

for d_id, gt_pid in dialogue_to_gt.items():
    total_count += 1
    predicted_pids = dialogue_to_photos.get(d_id, [])
    if gt_pid in predicted_pids:
        correct_count += 1

if total_count > 0:
    recall_at_30 = correct_count / total_count
    print(f"Recall@30: {recall_at_30}")
else:
    print("No ground truth available for recall calculation.")


Training Epoch 1/5: 100%|██████████| 313/313 [02:53<00:00,  1.81it/s]


Epoch 1/5, Loss: 1.862026016171367


Training Epoch 2/5: 100%|██████████| 313/313 [02:50<00:00,  1.83it/s]


Epoch 2/5, Loss: 1.1678033562513968


Training Epoch 3/5: 100%|██████████| 313/313 [02:50<00:00,  1.84it/s]


Epoch 3/5, Loss: 0.690227624255057


Training Epoch 4/5: 100%|██████████| 313/313 [02:49<00:00,  1.84it/s]


Epoch 4/5, Loss: 0.3440193438801331


Training Epoch 5/5:  14%|█▍        | 44/313 [00:23<02:27,  1.83it/s]

In [None]:
#===========================================================
# 計算 Recall@30
#===========================================================
# 假設 test_data 中每一筆都有一個正確的 photo_id
dialogue_to_gt = {}
for item in test_data:
    d_id = item['dialogue_id']
    gt_pid = item['photo_id']
    dialogue_to_gt[d_id] = gt_pid

correct_count = 0
total_count = 0

for d_id, gt_pid in dialogue_to_gt.items():
    total_count += 1
    predicted_pids = dialogue_to_photos.get(d_id, [])
    if gt_pid in predicted_pids:
        correct_count += 1

recall_at_30 = correct_count / total_count if t