In [7]:
import re
import random
import time
from statistics import mode
from collections import Counter

from PIL import Image
import numpy as np
import pandas as pd
from sklearn.utils import resample

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchvision
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from tqdm import tqdm

from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig

model_path = "/workspace/models/InternVL-Chat-V1-5"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# 4bit量子化の設定
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# 学習済みモデルの読み込み
intern_model = AutoModel.from_pretrained(model_path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.bfloat16).eval()

Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]

In [4]:
df = pd.read_json("/workspace/data/train.json")

In [8]:
import json
import matplotlib.pyplot as plt
with open("/workspace/data/train.json") as f:
    training_json = json.load(f)

question_length = []
for question in training_json["answer"].values():
    print(question)
    break

# plt.hist(question_length, bins=100)
# print(max(question_length))

KeyError: 'answer'

In [81]:
df = pd.read_json("/workspace/data/train.json")

In [5]:
df

Unnamed: 0,image,question,answers
0,train_00000.jpg,What is this?,"[{'answer_confidence': 'yes', 'answer': 'beef ..."
1,train_00001.jpg,maybe it's because you're pushing it down instead,"[{'answer_confidence': 'yes', 'answer': 'unans..."
2,train_00002.jpg,What color is this item?,"[{'answer_confidence': 'yes', 'answer': 'grey'..."
3,train_00003.jpg,Can you tell me if this is like body wash or l...,"[{'answer_confidence': 'maybe', 'answer': 'lot..."
4,train_00004.jpg,Is it a paper?,"[{'answer_confidence': 'yes', 'answer': 'no'},..."
...,...,...,...
19868,train_19868.jpg,What's on this card please?,"[{'answer_confidence': 'yes', 'answer': 'unans..."
19869,train_19869.jpg,I can't tell what it is that I'm holding.,"[{'answer_confidence': 'yes', 'answer': 'finge..."
19870,train_19870.jpg,What does it say on this shirt?,"[{'answer_confidence': 'yes', 'answer': 'hands..."
19871,train_19871.jpg,I'm looking for the model number of this print...,"[{'answer_confidence': 'yes', 'answer': 'unans..."


In [85]:
mode_answer_dict = {}
for _, row in df.iterrows():
    answers = [process_text(ans['answer']) for ans in row['answers']]
    mode_answer = Counter(answers).most_common(1)[0][0]
    if mode_answer not in mode_answer_dict:
        mode_answer_dict[mode_answer] = 1
    else:
        mode_answer_dict[mode_answer] += 1

for word, num in mode_answer_dict.items():
    if num < 20:
        mode_answer_dict[word] = 0

mode_answer_dict = {k: v for k, v in mode_answer_dict.items() if v != 0}

mode_answer_dict = dict(sorted(mode_answer_dict.items(), key=lambda x: x[1], reverse=True))
print(mode_answer_dict)

{'unanswerable': 7559, 'no': 481, 'yes': 476, 'white': 300, 'grey': 266, 'black': 227, 'blue': 195, 'red': 115, 'brown': 99, 'pink': 91, 'keyboard': 89, 'green': 73, 'laptop': 68, 'purple': 64, 'dog': 63, 'soup': 57, 'ph1': 53, 'yellow': 48, 'coca cola': 41, 'lotion': 40, 'cell ph1': 40, 'wine': 40, 'remote': 37, 'nothing': 36, 'tv': 35, 'corn': 35, 'orange': 34, 'computer screen': 33, 'pepsi': 33, 'coffee': 31, 'chair': 30, 'chicken': 29, 'computer': 28, 'green beans': 28, 'beer': 28, 'tan': 27, 'shampoo': 26, 'pen': 26, 'water bottle': 26, 'cup': 26, 'hand sanitizer': 25, 'remote control': 25, 'cereal': 24, 'black white': 23, '20': 22, 'm1y': 22, 'cat': 21, 'beans': 21, 'dr pepper': 20, 'flowers': 20, 'door': 20}


In [87]:
# unanswerable の sumを計算
unanswerable_sum = 0
for k, v in mode_answer_dict.items():
    if k != "unanswerable":
        unanswerable_sum += v
print(unanswerable_sum)

3717


In [75]:
mode_answer

['unanswerable',
 'no',
 'yes',
 'white',
 'grey',
 'black',
 'blue',
 'red',
 'brown',
 'pink',
 'keyboard',
 'green',
 'laptop',
 'purple',
 'dog',
 'soup',
 'ph1',
 'yellow',
 'coca cola',
 'lotion',
 'cell ph1',
 'wine',
 'remote',
 'nothing',
 'tv',
 'corn',
 'orange',
 'computer screen',
 'pepsi',
 'coffee',
 'chair',
 'chicken',
 'computer',
 'green beans',
 'beer',
 'tan',
 'shampoo',
 'pen',
 'water bottle',
 'cup',
 'hand sanitizer',
 'remote control',
 'cereal',
 'black white',
 '20',
 'm1y',
 'cat',
 'beans',
 'dr pepper',
 'flowers',
 'door']

In [44]:
from sklearn.utils import resample
import torch
from torch.utils.data import Dataset
from collections import Counter
import random

In [93]:
def process_text(text):
    text = text.lower()
    num_word_to_digit = {
        'zero': '0', 'one': '1', 'two': '2', 'three': '3', 'four': '4',
        'five': '5', 'six': '6', 'seven': '7', 'eight': '8', 'nine': '9',
        'ten': '10'
    }
    for word, digit in num_word_to_digit.items():
        text = text.replace(word, digit)
    text = re.sub(r'(?<!\d)\.(?!\d)', '', text)
    text = re.sub(r'\b(a|an|the)\b', '', text)
    contractions = {
        "dont": "don't", "isnt": "isn't", "arent": "aren't", "wont": "won't",
        "cant": "can't", "wouldnt": "wouldn't", "couldnt": "couldn't"
    }
    for contraction, correct in contractions.items():
        text = text.replace(contraction, correct)
    text = re.sub(r"[^\w\s':]", ' ', text)
    text = re.sub(r'\s+,', ',', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class VQADataset(torch.utils.data.Dataset):
    def __init__(self, df_path, image_dir, model, tokenizer, answer=True):
        self.image_dir = image_dir
        self.df = pd.read_json(df_path)
        self.answer = answer

        self.answer2idx = {}
        self.idx2answer = {}

        if self.answer:

            # Training_dataに含まれるAnswerを全て取得
            for answers in self.df["answers"]:
                for answer in answers:
                    word = answer["answer"]
                    word = process_text(word)
                    if word not in self.answer2idx:
                        self.answer2idx[word] = len(self.answer2idx)
            # 追加でClass_mappingに含まれるAnswerを取得
            class_mapping = pd.read_csv("/workspace/class_mapping.csv")
            self.idx2answer = {}
            for word, idx in zip(class_mapping["answer"], class_mapping["class_id"]):
                word = process_text(word)
                self.answer2idx[word] = idx

            self.idx2answer = {v: k for k, v in self.answer2idx.items()}

        self.model = model
        self.tokenizer = tokenizer

    def update_dict(self, dataset):
        self.answer2idx = dataset.answer2idx
        self.idx2answer = dataset.idx2answer

    def extract_text_features(self, text):
        inputs = self.tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=303)
        input_ids = inputs.input_ids.to(self.model.device)
        with torch.no_grad():
            text_features = self.model.language_model.model.tok_embeddings(input_ids)
        return text_features

    def extract_image_features(self, pixel_values):
        with torch.no_grad():
            image_features = self.model.vision_model.embeddings(pixel_values)
        return image_features

    def __getitem__(self, idx):
        image = Image.open(f"{self.image_dir}/{self.df['image'][idx]}")
        input_size = 224
        max_num = 1
        transform = build_transform(input_size=input_size)
        images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        pixel_values = pixel_values.to(torch.bfloat16)
        pixel_values = self.extract_image_features(pixel_values)

        question = self.df["question"][idx]
        # 質問文の前処理
        question = process_text(question)
        question = self.extract_text_features(question)

        if self.answer:
            answers = [self.answer2idx[process_text(answer["answer"])] for answer in self.df["answers"][idx]]
            mode_answer_idx = mode(answers)
            return pixel_values, question, torch.Tensor(answers), int(mode_answer_idx)
        else:
            return pixel_values, question

    def __len__(self):
        return len(self.df)

def VQA_criterion(batch_pred: torch.Tensor, batch_answers: torch.Tensor):
    total_acc = 0.
    for pred, answers in zip(batch_pred, batch_answers):
        acc = 0.
        for i in range(len(answers)):
            num_match = 0
            for j in range(len(answers)):
                if i == j:
                    continue
                if pred == answers[j]:
                    num_match += 1
            acc += min(num_match / 3, 1)
        total_acc += acc / 10
    return total_acc / len(batch_pred)

class VQAModel(nn.Module):
    def __init__(self, n_answer: int):
        super().__init__()

        # vision branch
        self.vision_conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 3), stride=(3, 3), padding=1)
        self.vision_bn1 = nn.BatchNorm2d(16)
        self.vision_conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 3), stride=(3, 3), padding=1)
        self.vision_bn2 = nn.BatchNorm2d(32)
        self.vision_conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), stride=(3, 3), padding=1)
        self.vision_bn3 = nn.BatchNorm2d(64)
        self.vision_pool = nn.MaxPool2d(kernel_size=(3, 3), stride=(3, 3))

        # text branch
        self.text_conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(2, 2), stride=(2, 2), padding=1)
        self.text_bn1 = nn.BatchNorm2d(4)
        self.text_conv2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=(2, 2), stride=(2, 2), padding=1)
        self.text_bn2 = nn.BatchNorm2d(8)
        self.text_conv3 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(2, 2), stride=(2, 2), padding=1)
        self.text_bn3 = nn.BatchNorm2d(16)
        self.text_conv4 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2, 2), stride=(2, 2), padding=1)
        self.text_bn4 = nn.BatchNorm2d(32)
        self.text_pool= nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

        # combined
        self.fc1 = nn.Linear(1600, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, n_answer)

        # 重みの初期化
        self._initialize_weights()


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, vision_input, text_input):
        # vision branch
        #print("vision branch")
        v = F.relu(self.vision_bn1(self.vision_bn1(self.vision_conv1(vision_input))))
        v = F.relu(self.vision_bn2(self.vision_bn2(self.vision_conv2(v))))
        v = self.vision_pool(v)
        v = F.relu(self.vision_bn3(self.vision_conv3(v)))
        v = self.vision_pool(v)
        v = v.view(v.size(0), -1)
        #print(v.shape)

        #print("text branch")
        t = F.relu(self.text_bn1(self.text_conv1(text_input)))
        t = self.text_pool(t)
        t = F.relu(self.text_bn2(self.text_conv2(t)))
        t = self.text_pool(t)
        t = F.relu(self.text_bn3(self.text_conv3(t)))
        t = self.text_pool(t)
        t = F.relu(self.text_bn4(self.text_conv4(t)))
        t = self.text_pool(t)
        t = t.view(t.size(0), -1)

        # combined
        combined = torch.cat((v, t), dim=1)
        x = F.relu(self.bn1(self.fc1(combined)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.fc3(x)

        return x
    
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_acc = 0
    simple_acc = 0
    start = time.time()
    for image, question, answers, mode_answer in tqdm(dataloader, desc="Batch", leave=False):
        image, question, answers, mode_answer = \
            image.to(device, dtype=torch.bfloat16), question.to(device, dtype=torch.bfloat16), answers.to(device), mode_answer.to(device)
        #print(image.shape, question.shape, answers.shape, mode_answer.shape)
        pred = model(image, question)
        #print(pred.shape, mode_answer.shape)
        loss = criterion(pred, mode_answer.squeeze().to(torch.long))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += VQA_criterion(pred.argmax(1), answers) 
        simple_acc += (pred.argmax(1) == mode_answer).float().mean().item() 

    return total_loss / len(dataloader), total_acc / len(dataloader), simple_acc / len(dataloader), time.time() - start

def eval(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_acc = 0
    simple_acc = 0
    start = time.time()
    for image, question, answers, mode_answer in dataloader:
        image, question, answers, mode_answer = \
            image.to(device, dtype=torch.bfloat16), question.to(device, dtype=torch.bfloat16), answers.to(device), mode_answer.to(device)
        pred = model(image, question)
        loss = criterion(pred, mode_answer.squeeze().to(torch.long))

        total_loss += loss.item()
        total_acc += VQA_criterion(pred.argmax(1), answers)  
        simple_acc += (pred.argmax(1) == mode_answer).float().mean().item()  

    return total_loss / len(dataloader), total_acc / len(dataloader), simple_acc / len(dataloader), time.time() - start

In [94]:
set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"

train_dataset = VQADataset(df_path="/workspace/data/train.json", image_dir="/workspace/data/train", model=intern_model, tokenizer=tokenizer, answer=True)
test_dataset = VQADataset(df_path="/workspace/data/valid.json", image_dir="/workspace/data/valid", model=intern_model, tokenizer=tokenizer, answer=False)
test_dataset.update_dict(train_dataset)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

model = VQAModel(n_answer=len(train_dataset.answer2idx)).to(device)
model = model.to(torch.bfloat16)  # モデル全体をbfloat16に変換


num_epoch = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

for epoch in tqdm(range(num_epoch), desc="Epoch"):
    train_loss, train_acc, train_simple_acc, train_time = train(model, train_loader, optimizer, criterion, device)
    print(f"【{epoch + 1}/{num_epoch}】\n"
            f"train time: {train_time:.2f} [s]\n"
            f"train loss: {train_loss:.4f}\n"
            f"train acc: {train_acc:.4f}\n"
            f"train simple acc: {train_simple_acc:.4f}")
    
    if True:
        model.eval()
        submission = []
        for image, question in test_loader:
            image, question = image.to(device, dtype=torch.bfloat16), question.to(device, dtype=torch.bfloat16)
            pred = model(image, question)
            pred = pred.argmax(1).cpu().item()
            submission.append(pred)

        submission = [train_dataset.idx2answer[id] for id in submission]
        submission = np.array(submission)
        torch.save(model.state_dict(), f"/workspace/submissions/my_main_3_2/model/model_{epoch}.pth")
        np.save(f"/workspace/submissions/my_main_3_2/npy/submission_{epoch}.npy", submission)

Epoch:   0%|          | 0/100 [00:00<?, ?it/s]

【1/100】
train time: 124.10 [s]
train loss: 5.5000
train acc: 0.3044
train simple acc: 0.2398


Epoch:   1%|          | 1/100 [04:14<6:59:31, 254.26s/it]

【2/100】
train time: 124.03 [s]
train loss: 2.5051
train acc: 0.4448
train simple acc: 0.3582


Epoch:   2%|▏         | 2/100 [08:27<6:54:24, 253.72s/it]

【3/100】
train time: 123.30 [s]
train loss: 2.0915
train acc: 0.5036
train simple acc: 0.4201


Epoch:   3%|▎         | 3/100 [12:38<6:47:56, 252.33s/it]

【4/100】
train time: 122.24 [s]
train loss: 1.7551
train acc: 0.5907
train simple acc: 0.5322


Epoch:   4%|▍         | 4/100 [16:49<6:43:02, 251.90s/it]

【5/100】
train time: 123.80 [s]
train loss: 1.3991
train acc: 0.6641
train simple acc: 0.6275


Epoch:   5%|▌         | 5/100 [21:02<6:39:15, 252.16s/it]

【6/100】
train time: 124.28 [s]
train loss: 1.0712
train acc: 0.7629
train simple acc: 0.7422


Epoch:   6%|▌         | 6/100 [25:16<6:36:09, 252.87s/it]

【7/100】
train time: 123.16 [s]
train loss: 0.7754
train acc: 0.8441
train simple acc: 0.8441


Epoch:   7%|▋         | 7/100 [29:28<6:31:37, 252.67s/it]

【8/100】
train time: 123.81 [s]
train loss: 0.5306
train acc: 0.9139
train simple acc: 0.9301


Epoch:   8%|▊         | 8/100 [33:41<6:27:40, 252.83s/it]

【9/100】
train time: 123.02 [s]
train loss: 0.3344
train acc: 0.9507
train simple acc: 0.9722


Epoch:   9%|▉         | 9/100 [37:53<6:22:50, 252.42s/it]

【10/100】
train time: 124.56 [s]
train loss: 0.2271
train acc: 0.9566
train simple acc: 0.9802


Epoch:  10%|█         | 10/100 [42:07<6:19:39, 253.11s/it]

【11/100】
train time: 124.08 [s]
train loss: 0.1584
train acc: 0.9643
train simple acc: 0.9899


Epoch:  11%|█         | 11/100 [46:21<6:15:41, 253.28s/it]

【12/100】
train time: 123.81 [s]
train loss: 0.1245
train acc: 0.9661
train simple acc: 0.9931


Epoch:  12%|█▏        | 12/100 [50:35<6:11:35, 253.36s/it]

【13/100】
train time: 122.96 [s]
train loss: 0.0945
train acc: 0.9674
train simple acc: 0.9952


Epoch:  13%|█▎        | 13/100 [54:47<6:07:01, 253.12s/it]

【14/100】
train time: 124.16 [s]
train loss: 0.0818
train acc: 0.9680
train simple acc: 0.9959


Epoch:  14%|█▍        | 14/100 [59:01<6:02:57, 253.23s/it]

【15/100】
train time: 124.04 [s]
train loss: 0.0776
train acc: 0.9676
train simple acc: 0.9947


Epoch:  15%|█▌        | 15/100 [1:03:14<5:58:52, 253.33s/it]

【16/100】
train time: 124.21 [s]
train loss: 0.0652
train acc: 0.9684
train simple acc: 0.9968


Epoch:  16%|█▌        | 16/100 [1:07:28<5:54:58, 253.55s/it]

【17/100】
train time: 124.12 [s]
train loss: 0.0583
train acc: 0.9691
train simple acc: 0.9977


Epoch:  17%|█▋        | 17/100 [1:11:43<5:51:01, 253.75s/it]

【18/100】
train time: 123.75 [s]
train loss: 0.0592
train acc: 0.9693
train simple acc: 0.9970


Epoch:  18%|█▊        | 18/100 [1:15:56<5:46:38, 253.64s/it]

【19/100】
train time: 123.31 [s]
train loss: 0.0503
train acc: 0.9694
train simple acc: 0.9986


Epoch:  19%|█▉        | 19/100 [1:20:09<5:42:08, 253.43s/it]

【20/100】
train time: 124.43 [s]
train loss: 0.0478
train acc: 0.9687
train simple acc: 0.9970


Epoch:  20%|██        | 20/100 [1:24:23<5:38:06, 253.58s/it]

【21/100】
train time: 122.62 [s]
train loss: 0.0464
train acc: 0.9695
train simple acc: 0.9977


Epoch:  21%|██        | 21/100 [1:28:35<5:33:18, 253.15s/it]

【22/100】
train time: 122.54 [s]
train loss: 0.0464
train acc: 0.9688
train simple acc: 0.9963


Epoch:  22%|██▏       | 22/100 [1:32:47<5:28:30, 252.70s/it]

【23/100】
train time: 124.04 [s]
train loss: 0.0709
train acc: 0.9668
train simple acc: 0.9954


Epoch:  23%|██▎       | 23/100 [1:37:01<5:24:45, 253.06s/it]

【24/100】
train time: 124.11 [s]
train loss: 0.0547
train acc: 0.9677
train simple acc: 0.9966


Epoch:  24%|██▍       | 24/100 [1:41:14<5:20:49, 253.28s/it]

【25/100】
train time: 124.00 [s]
train loss: 0.0424
train acc: 0.9686
train simple acc: 0.9970


Epoch:  25%|██▌       | 25/100 [1:45:28<5:16:40, 253.34s/it]

【26/100】
train time: 124.21 [s]
train loss: 0.0400
train acc: 0.9682
train simple acc: 0.9970


Epoch:  26%|██▌       | 26/100 [1:49:42<5:12:40, 253.52s/it]

【27/100】
train time: 123.88 [s]
train loss: 0.0502
train acc: 0.9674
train simple acc: 0.9943


Epoch:  27%|██▋       | 27/100 [1:53:55<5:08:26, 253.51s/it]

【28/100】
train time: 123.86 [s]
train loss: 0.0492
train acc: 0.9672
train simple acc: 0.9956


Epoch:  28%|██▊       | 28/100 [1:58:08<5:04:07, 253.44s/it]

【29/100】
train time: 124.50 [s]
train loss: 0.0352
train acc: 0.9690
train simple acc: 0.9979


Epoch:  29%|██▉       | 29/100 [2:02:23<5:00:07, 253.63s/it]

【30/100】
train time: 122.17 [s]
train loss: 0.0358
train acc: 0.9678
train simple acc: 0.9977


Epoch:  30%|███       | 30/100 [2:06:34<4:55:13, 253.05s/it]

【31/100】
train time: 122.67 [s]
train loss: 0.0323
train acc: 0.9695
train simple acc: 0.9977


Epoch:  31%|███       | 31/100 [2:10:46<4:50:42, 252.79s/it]

【32/100】
train time: 123.97 [s]
train loss: 0.0312
train acc: 0.9696
train simple acc: 0.9972


Epoch:  32%|███▏      | 32/100 [2:15:00<4:46:46, 253.04s/it]

【33/100】
train time: 123.69 [s]
train loss: 0.0286
train acc: 0.9702
train simple acc: 0.9984


Epoch:  33%|███▎      | 33/100 [2:19:13<4:42:39, 253.13s/it]

【34/100】
train time: 124.13 [s]
train loss: 0.0275
train acc: 0.9687
train simple acc: 0.9977


Epoch:  34%|███▍      | 34/100 [2:23:27<4:38:29, 253.18s/it]

【35/100】
train time: 122.07 [s]
train loss: 0.0274
train acc: 0.9693
train simple acc: 0.9984


Epoch:  35%|███▌      | 35/100 [2:27:36<4:33:05, 252.09s/it]

【36/100】
train time: 124.04 [s]
train loss: 0.0263
train acc: 0.9698
train simple acc: 0.9982


Epoch:  36%|███▌      | 36/100 [2:31:50<4:29:32, 252.69s/it]

【37/100】
train time: 123.00 [s]
train loss: 0.0253
train acc: 0.9702
train simple acc: 0.9986


Epoch:  37%|███▋      | 37/100 [2:36:03<4:25:12, 252.58s/it]

【38/100】
train time: 123.35 [s]
train loss: 0.0258
train acc: 0.9693
train simple acc: 0.9982


Epoch:  38%|███▊      | 38/100 [2:40:16<4:21:04, 252.66s/it]

【39/100】
train time: 123.60 [s]
train loss: 0.0260
train acc: 0.9692
train simple acc: 0.9982


Epoch:  39%|███▉      | 39/100 [2:44:28<4:16:45, 252.55s/it]

【40/100】
train time: 124.51 [s]
train loss: 0.0243
train acc: 0.9699
train simple acc: 0.9986


Epoch:  40%|████      | 40/100 [2:48:41<4:12:37, 252.63s/it]

【41/100】
train time: 124.04 [s]
train loss: 0.0236
train acc: 0.9688
train simple acc: 0.9982


Epoch:  41%|████      | 41/100 [2:52:55<4:08:48, 253.02s/it]

【42/100】
train time: 124.44 [s]
train loss: 0.0244
train acc: 0.9694
train simple acc: 0.9984


Epoch:  42%|████▏     | 42/100 [2:57:09<4:05:02, 253.50s/it]

【43/100】
train time: 124.30 [s]
train loss: 0.0235
train acc: 0.9697
train simple acc: 0.9986


Epoch:  43%|████▎     | 43/100 [3:01:23<4:00:54, 253.58s/it]

【44/100】
train time: 124.17 [s]
train loss: 0.0249
train acc: 0.9684
train simple acc: 0.9982


Epoch:  44%|████▍     | 44/100 [3:05:37<3:56:48, 253.73s/it]

【45/100】
train time: 124.25 [s]
train loss: 0.0230
train acc: 0.9693
train simple acc: 0.9979


Epoch:  45%|████▌     | 45/100 [3:09:51<3:52:45, 253.93s/it]

【46/100】
train time: 124.12 [s]
train loss: 0.0227
train acc: 0.9690
train simple acc: 0.9984


Epoch:  46%|████▌     | 46/100 [3:14:05<3:48:31, 253.92s/it]

【47/100】
train time: 123.52 [s]
train loss: 0.0234
train acc: 0.9696
train simple acc: 0.9989


Epoch:  47%|████▋     | 47/100 [3:18:18<3:43:58, 253.55s/it]

【48/100】
train time: 123.98 [s]
train loss: 0.0262
train acc: 0.9689
train simple acc: 0.9977


Epoch:  48%|████▊     | 48/100 [3:22:32<3:39:46, 253.58s/it]

【49/100】
train time: 123.89 [s]
train loss: 0.0294
train acc: 0.9693
train simple acc: 0.9975


Epoch:  49%|████▉     | 49/100 [3:26:45<3:35:34, 253.61s/it]

【50/100】
train time: 124.04 [s]
train loss: 0.0239
train acc: 0.9687
train simple acc: 0.9982


Epoch:  50%|█████     | 50/100 [3:30:59<3:31:21, 253.64s/it]

【51/100】
train time: 124.31 [s]
train loss: 0.0222
train acc: 0.9689
train simple acc: 0.9986


Epoch:  51%|█████     | 51/100 [3:35:13<3:27:18, 253.84s/it]

【52/100】
train time: 123.79 [s]
train loss: 0.0220
train acc: 0.9692
train simple acc: 0.9979


Epoch:  52%|█████▏    | 52/100 [3:39:27<3:22:57, 253.70s/it]

【53/100】
train time: 124.20 [s]
train loss: 0.0215
train acc: 0.9694
train simple acc: 0.9984


Epoch:  53%|█████▎    | 53/100 [3:43:41<3:18:45, 253.74s/it]

【54/100】
train time: 124.08 [s]
train loss: 0.0204
train acc: 0.9692
train simple acc: 0.9984


Epoch:  54%|█████▍    | 54/100 [3:47:53<3:14:07, 253.22s/it]

【55/100】
train time: 122.05 [s]
train loss: 0.0197
train acc: 0.9695
train simple acc: 0.9989


Epoch:  55%|█████▌    | 55/100 [3:52:04<3:09:26, 252.58s/it]

【56/100】
train time: 123.20 [s]
train loss: 0.0192
train acc: 0.9696
train simple acc: 0.9986


Epoch:  56%|█████▌    | 56/100 [3:56:15<3:04:53, 252.12s/it]

【57/100】
train time: 122.46 [s]
train loss: 0.0190
train acc: 0.9692
train simple acc: 0.9986


Epoch:  57%|█████▋    | 57/100 [4:00:26<3:00:36, 252.01s/it]

【58/100】
train time: 123.84 [s]
train loss: 0.0189
train acc: 0.9699
train simple acc: 0.9989


Epoch:  58%|█████▊    | 58/100 [4:04:40<2:56:41, 252.42s/it]

【59/100】
train time: 123.97 [s]
train loss: 0.0209
train acc: 0.9699
train simple acc: 0.9982


Epoch:  59%|█████▉    | 59/100 [4:08:53<2:52:41, 252.71s/it]

【60/100】
train time: 123.80 [s]
train loss: 0.0220
train acc: 0.9687
train simple acc: 0.9979


Epoch:  60%|██████    | 60/100 [4:13:06<2:48:33, 252.83s/it]

【61/100】
train time: 123.85 [s]
train loss: 0.0260
train acc: 0.9684
train simple acc: 0.9972


Epoch:  61%|██████    | 61/100 [4:17:20<2:44:30, 253.09s/it]

【62/100】
train time: 124.24 [s]
train loss: 0.0220
train acc: 0.9695
train simple acc: 0.9986


Epoch:  62%|██████▏   | 62/100 [4:21:35<2:40:34, 253.53s/it]

【63/100】
train time: 124.05 [s]
train loss: 0.0202
train acc: 0.9693
train simple acc: 0.9986


Epoch:  63%|██████▎   | 63/100 [4:25:48<2:36:22, 253.57s/it]

【64/100】
train time: 123.80 [s]
train loss: 0.0198
train acc: 0.9690
train simple acc: 0.9989


Epoch:  64%|██████▍   | 64/100 [4:30:02<2:32:07, 253.54s/it]

【65/100】
train time: 124.36 [s]
train loss: 0.0242
train acc: 0.9682
train simple acc: 0.9977


Epoch:  65%|██████▌   | 65/100 [4:34:16<2:28:00, 253.74s/it]

【66/100】
train time: 124.02 [s]
train loss: 0.0195
train acc: 0.9701
train simple acc: 0.9984


Epoch:  66%|██████▌   | 66/100 [4:38:30<2:23:49, 253.82s/it]

【67/100】
train time: 123.84 [s]
train loss: 0.0179
train acc: 0.9697
train simple acc: 0.9986


Epoch:  67%|██████▋   | 67/100 [4:42:43<2:19:28, 253.60s/it]

【68/100】
train time: 123.53 [s]
train loss: 0.0169
train acc: 0.9695
train simple acc: 0.9986


Epoch:  68%|██████▊   | 68/100 [4:47:44<2:15:24, 253.88s/it]


KeyboardInterrupt: 

In [77]:
model.eval()
submission = []
for image, question in test_loader:
    image, question = image.to(device, dtype=torch.bfloat16), question.to(device, dtype=torch.bfloat16)
    pred = model(image, question)
    pred = pred.argmax(1).cpu().item()
    submission.append(pred)

submission = [train_dataset.idx2answer[id] for id in submission]
submission = np.array(submission)
# torch.save(model.state_dict(), f"/workspace/submissions/my_main_3_2/model/model_test.pth")
# np.save(f"/workspace/submissions/my_main_3_2/npy/submission_test.npy", submission)

In [None]:
for image, question, answers, mode_answer in train_loader:
        image, question, answers, mode_answer = \
            image.to(device, dtype=torch.bfloat16), question.to(device, dtype=torch.bfloat16), answers.to(device), mode_answer.to(device)
        #print(image.shape, question.shape, answers.shape, mode_answer.shape)
        pred = model(image, question)

In [80]:
ans_dict = {}
for ans in submission:
    if ans not in ans_dict:
        ans_dict[ans] = 1
    else:
        ans_dict[ans] += 1
ans_dict = dict(sorted(ans_dict.items(), key=lambda x: x[1], reverse=True))
print(ans_dict)

{'unanswerable': 3268, 'no': 1682, 'yes': 11, 'white': 8}


In [49]:
# 各画像のモードアンサーを計算
mode_answers = []

for idx, row in df.iterrows():
    answers = [ans['answer'] for ans in row['answers']]
    mode_answer = Counter(answers).most_common(1)[0][0]
    mode_answers.append(mode_answer)

# モードアンサーの分布を取得
mode_answer_counts = Counter(mode_answers)

In [53]:
mode_answer_counts

Counter({'unanswerable': 7565,
         'no': 481,
         'yes': 476,
         'white': 300,
         'grey': 266,
         'black': 227,
         'blue': 195,
         'red': 115,
         'brown': 99,
         'pink': 91,
         'keyboard': 89,
         'green': 73,
         'laptop': 68,
         'purple': 64,
         'dog': 63,
         'soup': 57,
         'phone': 53,
         'yellow': 48,
         'coca cola': 41,
         'lotion': 40,
         'cell phone': 40,
         'wine': 40,
         'remote': 37,
         'nothing': 36,
         'tv': 35,
         'corn': 35,
         'orange': 34,
         'computer screen': 33,
         'pepsi': 33,
         'coffee': 31,
         'chair': 30,
         'chicken': 29,
         'computer': 28,
         'green beans': 28,
         'beer': 28,
         'tan': 27,
         'shampoo': 26,
         'pen': 26,
         'water bottle': 26,
         'cup': 26,
         'hand sanitizer': 25,
         'remote control': 25,
         'cereal

In [54]:
sum(mode_answer_counts.values())

19873