In [None]:
import re
import random
import time
from statistics import mode
from collections import Counter
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torchvision.models import resnet152, ResNet152_Weights
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel
import gc
from torch.cuda.amp import GradScaler, autocast
import os
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
import copy

In [1]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def process_text(text):
    text = text.lower()
    num_word_to_digit = {
        'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4',
        'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9',
        'ten': '10'
    }
    for word, digit in num_word_to_digit.items():
        text = text.replace(word, digit)
    text = re.sub(r'(?<!\d)\.(?!\d)', '', text)
    text = re.sub(r'\b(a|an|the)\b', '', text)
    contractions = {
        "dont": "don't", "isnt": "isn't", "arent": "aren't", "wont": "won't",
        "cant": "can't", "wouldnt": "wouldn't", "couldnt": "couldn't"
    }
    for contraction, correct in contractions.items():
        text = text.replace(contraction, correct)
    text = re.sub(r"[^\w\s':]", ' ', text)
    text = re.sub(r'\s+,', ',', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))  # Noneをフィルタリング
    images, input_ids, attention_masks, answers, mode_answers = zip(*batch)

    images = torch.stack(images)
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    max_len = max(len(a) for a in answers)
    padded_answers = torch.zeros((len(answers), max_len), dtype=torch.long)
    for i, a in enumerate(answers):
        padded_answers[i, :len(a)] = a.clone().detach()

    mode_answers = torch.stack(mode_answers)

    return images, input_ids, attention_masks, padded_answers, mode_answers

class VQADataset(torch.utils.data.Dataset):
    def __init__(self, df_path, image_dir, class_mapping, transform=None, answer=True):
        self.transform = transform
        self.image_dir = image_dir
        self.df = pd.read_json(df_path)
        self.answer = answer

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")  # CLIPプロセッサを追加

        self.class_mapping = class_mapping
        self.answer2idx = {row['answer']: row['class_id'] for _, row in class_mapping.iterrows()}
        self.idx2answer = {row['class_id']: row['answer'] for _, row in class_mapping.iterrows()}

        self.question2idx = {}
        self.idx2question = {}

        for question in self.df["w"]:
            question = process_text(question)
            words = question.split(" ")
            for word in words:
                if word not in self.question2idx:
                    self.question2idx[word] = len(self.question2idx)
        self.idx2question = {v: k for k, v in self.question2idx.items()}

    def __getitem__(self, idx):
        try:
            image = Image.open(f"{self.image_dir}/{self.df['image'][idx]}")
            if self.transform:
                image = self.transform(image)

            question_text = self.df["w"][idx]
            question_enc = self.tokenizer(question_text, padding='max_length', truncation=True, max_length=64, return_tensors="pt")
            
            input_ids = question_enc['input_ids'].squeeze()
            attention_mask = question_enc['attention_mask'].squeeze()

            if self.answer:
                answers = [self.answer2idx.get(process_text(answer["answer"]), -1) for answer in self.df["answers"][idx]]
                answers = [answer for answer in answers if answer != -1]
                if not answers:
                    return None  # マッピングにない回答がある場合はスキップ

                mode_answer_idx = Counter(answers).most_common(1)[0][0]

                return image, input_ids, attention_mask, torch.tensor(answers, dtype=torch.long), torch.tensor(mode_answer_idx, dtype=torch.long)
            else:
                return image, input_ids, attention_mask
        except Exception as e:
            return self.__getitem__((idx + 1) % len(self.df))

    def __len__(self):
        return len(self.df)

# def VQA_criterion(batch_pred: torch.Tensor, batch_answers: torch.Tensor):
#     total_acc = 0.

#     for pred, answers in zip(batch_pred, batch_answers):
#         acc = 0.
#         pred = pred.item()
#         num_match = 0
#         for answer in answers:
#             if pred == answer.item():
#                 num_match += 1
#         acc += min(num_match / 3, 1)
#         total_acc += acc / 10

#     return total_acc / len(batch_pred)
def VQA_criterion(batch_pred: torch.Tensor, batch_answers: torch.Tensor):
    total_acc = 0.

    for pred, answers in zip(batch_pred, batch_answers):
        acc = 0.
        for i in range(len(answers)):
            num_match = 0
            for j in range(len(answers)):
                if i == j:
                    continue
                if pred == answers[j]:
                    num_match += 1
            acc += min(num_match / 3, 1)
        total_acc += acc / 10

    return total_acc / len(batch_pred)

class MultiModalAttention(nn.Module):
    def __init__(self, hidden_size):
        super(MultiModalAttention, self).__init__()
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=8)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, combined_features):
        combined_features = combined_features.unsqueeze(1)  # (batch_size, 1, hidden_size)
        attn_output, _ = self.attention(combined_features, combined_features, combined_features)
        return self.fc(attn_output.squeeze(1))

class VQAModel(nn.Module):
    def __init__(self, vocab_size: int, n_answer: int):
        super(VQAModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.resnet = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, 512)

        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

        # 各特徴量の次元を計算
        resnet_feat_dim = 512
        bert_feat_dim = 768
        clip_vision_dim = 512
        clip_text_dim = 512

        hidden_size = resnet_feat_dim + bert_feat_dim + clip_vision_dim + clip_text_dim  # 2304
        self.multi_modal_attention = MultiModalAttention(hidden_size=hidden_size)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(p=0.5),
            nn.Linear(512, n_answer)
        )

        self.residual = nn.Linear(hidden_size, n_answer)  # 残差接続用

    def forward(self, image, input_ids, attention_mask):
        # ResNetで画像特徴を抽出
        img_feat = self.resnet(image)  # [batch_size, 512]
        
        # BERTでテキスト特徴を抽出
        question_feat = self.bert(input_ids=input_ids, attention_mask=attention_mask).pooler_output  # [batch_size, 768]

        # CLIPで画像特徴を抽出
        clip_img_feat = self.clip_model.get_image_features(pixel_values=image)  # [batch_size, 512]

        # CLIPでテキスト特徴を抽出
        text_inputs = self.clip_model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)  # [batch_size, 512]

        # 全ての特徴を結合
        combined_feat = torch.cat([img_feat, question_feat, clip_img_feat, text_inputs], dim=1)  # [batch_size, 2304]

        # 次元を確認するためのprint文を追加
        # print(f"img_feat: {img_feat.shape}")
        # print(f"question_feat: {question_feat.shape}")
        # print(f"clip_img_feat: {clip_img_feat.shape}")
        # print(f"text_inputs: {text_inputs.shape}")
        # print(f"combined_feat: {combined_feat.shape}")

        attention_output = self.multi_modal_attention(combined_feat)
        output_fc = self.fc(attention_output)
        output_residual = self.residual(combined_feat)
        
        return output_fc + output_residual  # 残差接続による出力の強化


def soft_target(answers, n_classes, device):
    target = torch.zeros(n_classes, device=device)
    for answer in answers:
        target[answer] += 1
    target = target / target.sum()
    return target

class KLDivergenceLoss(nn.Module):
    def __init__(self):
        super(KLDivergenceLoss, self).__init__()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, pred, answers, n_classes):
        device = pred.device
        soft_targets = torch.stack([soft_target(ans, n_classes, device) for ans in answers])
        log_pred = F.log_softmax(pred, dim=1)
        return self.kl_loss(log_pred, soft_targets)

def train_eval(model, dataloader, optimizer, criterion, device, phase='train'):
    if phase == 'train':
        model.train()
    else:
        model.eval()
    
    total_loss = 0
    total_acc = 0
    simple_acc = 0
    scaler = GradScaler()

    with torch.set_grad_enabled(phase == 'train'):
        for batch in tqdm(dataloader, desc=f"{phase.capitalize()}ing"):
            if batch is None:
                continue
            image, input_ids, attention_mask, answers, mode_answer = batch
            image, input_ids, attention_mask, answers, mode_answer = \
                image.to(device), input_ids.to(device), attention_mask.to(device), answers.to(device), mode_answer.to(device)

            optimizer.zero_grad()
            with autocast():
                pred = model(image, input_ids, attention_mask)
                loss = criterion(pred, answers, pred.size(1))

            if phase == 'train':
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

            total_loss += loss.item()
            total_acc += VQA_criterion(pred.argmax(1), answers)
            simple_acc += (pred.argmax(1) == mode_answer).float().mean().item()

            del image, input_ids, attention_mask, answers, mode_answer, pred, loss
            torch.cuda.empty_cache()
            gc.collect()

    return total_loss / len(dataloader), total_acc / len(dataloader), simple_acc / len(dataloader)

def train_and_visualize(model, n_epochs, optimizer, scheduler, dataloader_train, dataloader_valid, device):
    criterion = KLDivergenceLoss()
    model.train()

    train_losses = []
    valid_losses = []
    valid_accuracies = []

    best_valid_acc = 0.0
    best_valid_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in trange(n_epochs, desc="Epochs"):
        train_loss, train_acc, train_simple_acc = train_eval(model, dataloader_train, optimizer, criterion, device, phase='train')
        valid_loss, valid_acc, valid_simple_acc = train_eval(model, dataloader_valid, optimizer, criterion, device, phase='eval')

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        valid_accuracies.append(valid_acc)

        print(f'Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Valid Loss: {valid_loss:.4f}, Valid Acc: {valid_acc:.4f}')

        # 検証データの損失が最小のモデルを保存
        # if valid_loss < best_valid_loss:
        #     best_valid_loss = valid_loss
        #     best_model_wts = copy.deepcopy(model.state_dict())
        if valid_acc > best_valid_acc or (valid_acc == best_valid_acc and valid_loss < best_valid_loss):
            best_valid_acc = valid_acc
            best_valid_loss = valid_loss
            best_model_wts = copy.deepcopy(model.state_dict())

        scheduler.step()

    # 最良のモデルの重みをロード
    model.load_state_dict(best_model_wts)

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(valid_losses, label='Valid Loss')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')

    plt.subplot(1, 2, 2)
    plt.plot(valid_accuracies, label='Valid Accuracy')
    plt.legend()
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy Curve')

    plt.tight_layout()
    plt.show()


##### 交差検証

In [8]:
from sklearn.model_selection import KFold
import warnings

# データセットを訓練用と検証用に分割する関数
def get_train_valid_loader(vqa_dataset, train_indices, valid_indices, batch_size=256):
    train_dataset = Subset(vqa_dataset, train_indices)
    valid_dataset = Subset(vqa_dataset, valid_indices)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), pin_memory=True, collate_fn=collate_fn)
    
    return train_loader, valid_loader

# 交差検証の実行
def cross_validation_training(vqa_dataset, class_mapping_df, n_splits=5, num_epoch=15):
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    fold_results = []
    models = []

    for fold, (train_indices, valid_indices) in enumerate(kfold.split(vqa_dataset)):
        print(f"Fold {fold + 1}/{n_splits}")

        train_loader, valid_loader = get_train_valid_loader(vqa_dataset, train_indices, valid_indices)

        model = VQAModel(vocab_size=len(vqa_dataset.question2idx) + 1, n_answer=len(class_mapping_df)).to(device)

        optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.0001)
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epoch)

        train_and_visualize(model, num_epoch, optimizer, scheduler, train_loader, valid_loader, device)

        # 各フォールドの検証結果を保存
        _, valid_acc, _ = train_eval(model, valid_loader, optimizer, KLDivergenceLoss(), device, phase='eval')
        fold_results.append(valid_acc)

        # モデルを保存
        models.append(copy.deepcopy(model))

        # メモリクリア
        del model, optimizer, scheduler, train_loader, valid_loader
        torch.cuda.empty_cache()
        gc.collect()

    # 各フォールドの結果を表示
    for fold, acc in enumerate(fold_results):
        print(f"Fold {fold + 1}: Valid Acc: {acc:.4f}")

    print(f"Average Valid Acc: {np.mean(fold_results):.4f}")

    return models


In [None]:
# 設定とデータセットの準備
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomRotation(15),
    transforms.ToTensor()
])

class_mapping_df = pd.read_csv("/mnt/newdisk/data_annotations_class_mapping.csv")
vqa_dataset = VQADataset(df_path="/mnt/newdisk/train.json", image_dir="/mnt/newdisk/train", class_mapping=class_mapping_df, transform=transform)

# 交差検証の実行
models = cross_validation_training(vqa_dataset, class_mapping_df, n_splits=5, num_epoch=12)

In [None]:
class TestVQADataset(torch.utils.data.Dataset):
    def __init__(self, df_path, image_dir, class_mapping, transform=None):
        self.df = pd.read_json(df_path)
        self.image_dir = image_dir
        self.transform = transform

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

        self.class_mapping = class_mapping
        self.answer2idx = {row['answer']: row['class_id'] for _, row in class_mapping.iterrows()}
        self.idx2answer = {row['class_id']: row['answer'] for _, row in class_mapping.iterrows()}

        self.question2idx = {}
        self.idx2question = {}

        for question in self.df["question"]:
            question = process_text(question)
            words = question.split(" ")
            for word in words:
                if word not in self.question2idx:
                    self.question2idx[word] = len(self.question2idx)
        self.idx2question = {v: k for k, v in self.question2idx.items()}

    def update_dict(self, dataset):
        self.question2idx = dataset.question2idx
        self.answer2idx = dataset.answer2idx
        self.idx2question = dataset.idx2question
        self.idx2answer = dataset.idx2answer

    def __getitem__(self, idx):
        image_path = f"{self.image_dir}/{self.df.iloc[idx]['image']}"
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        question_text = self.df["question"][idx]
        question_enc = self.tokenizer(question_text, padding='max_length', truncation=True, max_length=64, return_tensors="pt")
        
        input_ids = question_enc['input_ids'].squeeze()
        attention_mask = question_enc['attention_mask'].squeeze()

        return image, input_ids, attention_mask

    def __len__(self):
        return len(self.df)

In [None]:
def ensemble_predictions(models, test_loader, device, weights):
    model_outputs = []

    for model in models:
        model.eval()
        fold_submission = []
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Predicting", unit="batch"):
                image, input_ids, attention_mask = batch
                image, input_ids, attention_mask = image.to(device), input_ids.to(device), attention_mask.to(device)
                with torch.cuda.amp.autocast():
                    pred = model(image, input_ids, attention_mask)
                
                fold_submission.append(pred.cpu().numpy())
        
        model_outputs.append(np.vstack(fold_submission))

    weighted_preds = np.zeros_like(model_outputs[0])
    for model_output, weight in zip(model_outputs, weights):
        weighted_preds += model_output * weight

    weighted_preds /= np.sum(weights)
    return weighted_preds

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# テスト用データセットの作成
test_dataset = TestVQADataset(df_path="/mnt/newdisk/valid.json", image_dir="/mnt/newdisk/valid", class_mapping=class_mapping_df, transform=test_transform)

# test_datasetの辞書を更新
test_dataset.update_dict(vqa_dataset)

# データローダーを作成
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# アンサンブルモデルの重み
weights = [0.2, 0.2, 0.2, 0.2, 0.2] 

# アンサンブル予測
ensemble_preds = ensemble_predictions(models, test_loader, device, weights)

# 予測結果の最も高いスコアのインデックスを取得
pred_idx = np.argmax(ensemble_preds, axis=1)

# 予測インデックスを回答に変換
submission = [test_dataset.idx2answer[idx] for idx in pred_idx]

# 予測結果を保存
np.save("submission_ensemble_clip.npy", np.array(submission))

In [10]:
# モデルを保存するディレクトリ
save_dir = "/mnt/newdisk/clip_cross_valid_models"
os.makedirs(save_dir, exist_ok=True)

# それぞれのモデルを保存
for idx, model in enumerate(models):
    model_path = os.path.join(save_dir, f"model_fold_{idx}.pt")
    torch.save(model.state_dict(), model_path)