In [1]:
import os
import json
import numpy as np
import shutil
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from transformers import ViTFeatureExtractor, ViTModel

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

2025-04-27 18:46:22.742389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-27 18:46:22.750729: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-27 18:46:22.753174: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-27 18:46:22.759390: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [37]:
NO_WORD_IMAGE_ONLY = True
REAL_IMAGE_ONLY = True
REAL_IMAGE_THRESHOLD = 0.75
NO_UNIQUE_NOUN_SENTENCE_ONLY = True

VOCAB_SIZE = 8192
MAX_SENTENCE_LENGTH = 31
MIN_SENTENCE_LENGTH = 4
MIN_FREEQENCY = 16

LERNING_RATE = 0.001
BATCH_SIZE = 32
EPOCHS = 25
HIDDEN_DIM = 512

result_dir = f"../results/GUMI_T/{NO_WORD_IMAGE_ONLY}_{REAL_IMAGE_ONLY}_{REAL_IMAGE_THRESHOLD}_{NO_UNIQUE_NOUN_SENTENCE_ONLY}_{VOCAB_SIZE}_{MAX_SENTENCE_LENGTH}_{MIN_SENTENCE_LENGTH}_{MIN_FREEQENCY}_{LERNING_RATE}_{BATCH_SIZE}_{EPOCHS}_{HIDDEN_DIM}/"
if os.path.exists(result_dir):
    shutil.rmtree(result_dir)
os.makedirs(result_dir)
with open(f"{result_dir}training_config.json", "w") as f:
    json.dump({
        "NO_WORD_IMAGE_ONLY": NO_WORD_IMAGE_ONLY,
        "REAL_IMAGE_ONLY": REAL_IMAGE_ONLY,
        "REAL_IMAGE_THRESHOLD": REAL_IMAGE_THRESHOLD,
        "NO_UNIQUE_NOUN_SENTENCE_ONLY": NO_UNIQUE_NOUN_SENTENCE_ONLY,

        "VOCAB_SIZE": VOCAB_SIZE,
        "MAX_SENTENCE_LENGTH": MAX_SENTENCE_LENGTH,
        "MIN_SENTENCE_LENGTH": MIN_SENTENCE_LENGTH,
        "MIN_FREEQENCY": MIN_FREEQENCY,

        "LERNING_RATE": LERNING_RATE,
        "BATCH_SIZE": BATCH_SIZE,
        "EPOCHS": EPOCHS,
        "HIDDEN_DIM": HIDDEN_DIM
    }, f)

IMAGE_FEATURE_DIR = "../data/image_features/ViT/"
os.makedirs(IMAGE_FEATURE_DIR, exist_ok = True)

In [3]:
IMAGE_DIR = "../datas/Bokete_Dataset/boke_image/"
data_dir = f"../datas/{NO_WORD_IMAGE_ONLY}_{REAL_IMAGE_ONLY}_{REAL_IMAGE_THRESHOLD}_{NO_UNIQUE_NOUN_SENTENCE_ONLY}_{VOCAB_SIZE}_{MAX_SENTENCE_LENGTH}_{MIN_SENTENCE_LENGTH}_{MIN_FREEQENCY}/"

train_inputs_1 = np.load(f"{data_dir}train_inputs_1.npy")
train_inputs_2 = np.load(f"{data_dir}train_inputs_2.npy")
train_teacher_signals = np.load(f"{data_dir}train_teacher_signals.npy")
test_inputs_1 = np.load(f"{data_dir}test_inputs_1.npy")
test_inputs_2 = np.load(f"{data_dir}test_inputs_2.npy")
test_teacher_signals = np.load(f"{data_dir}test_teacher_signals.npy")

#
train_inputs_1.shape, train_inputs_2.shape, train_teacher_signals.shape, test_inputs_1.shape, test_inputs_2.shape, test_teacher_signals.shape

((1388020,), (1388020, 32), (1388020, 32), (13791,), (13791, 32), (13791, 32))

In [None]:
image_numbers = set(train_inputs_1.tolist() + test_inputs_1.tolist())
tmp_image_numbers = list()
for IN in image_numbers:
    if os.path.exists(f"{IMAGE_FEATURE_DIR}{IN}.npy"):
        continue
    tmp_image_numbers.append(IN)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/vit-base-patch16-224-in21k"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
vit = ViTModel.from_pretrained(model_name)
vit = vit.to(device)
vit.eval()

bs = 512
for idx in tqdm(range(0, len(tmp_image_numbers), bs)):
    images = [Image.open(f"{IMAGE_DIR}{IN}.jpg").convert("RGB") for IN in tmp_image_numbers[idx:idx + bs]]
    images = feature_extractor(images, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = vit(**images)
    image_features = outputs.last_hidden_state[:, 1:, :].cpu().numpy()

    for i, IN in enumerate(tmp_image_numbers[idx:idx + bs]):
        np.save(f"{IMAGE_FEATURE_DIR}{IN}.npy", image_features[i])

  0%|          | 0/432 [00:04<?, ?it/s]


In [39]:
image_features.shape

(512, 196, 768)

In [17]:
def make_image_dataloader(inputs_1, inputs_2, test_teacher_signals):
    class LoadImageDataset(Dataset):
        def __init__(self, inputs_1, inputs_2, test_teacher_signals):
            """
                inputs_1: 画像の番号からなるリスト
                inputs_2: 入力文からなるリスト
                test_teacher_signals: 教師信号からなるリスト
            """
            if len(inputs_1) != len(inputs_2):
                raise ValueError("inputs_1 and inputs_2 must have the same length.")
            if len(inputs_1) != len(test_teacher_signals):
                raise ValueError("inputs_1 and test_teacher_signals must have the same length.")

            self.inputs_1 = inputs_1
            self.inputs_2 = inputs_2
            self.test_teacher_signals = test_teacher_signals

        def __len__(self):
            return len(self.inputs_1)

        def __getitem__(self, idx):
            image_feature = torch.Tensor( np.load(f"{IMAGE_FEATURE_DIR}{self.inputs_1[idx]}.npy") ).to(torch.float32)
            input_sentence = torch.Tensor( self.inputs_2[idx] ).to(torch.int32)
            teacher_signal = torch.Tensor( self.test_teacher_signals[idx] ).to(torch.int32)
            
            return image_feature, input_sentence, teacher_signal

    dataset = LoadImageDataset(inputs_1, inputs_2, test_teacher_signals)
    dataloader = DataLoader(
        dataset, 
        batch_size = BATCH_SIZE, 
        num_workers = int(os.cpu_count() * 0.6), 
        shuffle = True
    )

    return dataloader

train_dataloader = make_image_dataloader(train_inputs_1, train_inputs_2, train_teacher_signals)
test_dataloader = make_image_dataloader(test_inputs_1, test_inputs_2, test_teacher_signals)

#
i1, i2, t = next(iter(train_dataloader))
i1.shape, i2.shape, t.shape

(torch.Size([32, 768]), torch.Size([32, 32]), torch.Size([32, 32]))

In [31]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, vocab_size, hidden_dim, input_length):
        super(SelfAttentionBlock, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.input_length = input_length

        self.self_attention = nn.MultiheadAttention(
            embed_dim = hidden_dim, 
            num_heads = 8, 
            batch_first = True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):

        batch_size, _, _ = x.size()
        attn_mask = torch.triu(torch.ones(self.input_length, self.input_length), diagonal = 1).bool()
        attn_mask = attn_mask.to(x.device)
        # attn_mask = attn_mask.unsqueeze(0).expand(batch_size, -1, -1)

        attn_output, attn_output_weights  = self.self_attention(x, x, x, attn_mask = attn_mask)
        x = self.layer_norm( x + attn_output )
        x = nn.LeakyReLU()( self.fc(x) )
        return x, attn_output_weights

class CrossAttentionBlock(nn.Module):
    def __init__(self, vocab_size, hidden_dim, input_length):
        super(CrossAttentionBlock, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.input_length = input_length

        self.cross_attention = nn.MultiheadAttention(
            embed_dim = hidden_dim, 
            num_heads = 8, 
            batch_first = True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x, y):



        attn_output, attn_output_weights  = self.cross_attention(x, y, y)
        x = self.layer_norm( x + attn_output )
        x = nn.LeakyReLU()( self.fc(x) )
        return x, attn_output_weights

In [34]:
class SentenceDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_dim, input_length):
        super(SentenceDecoder, self).__init__()

        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.sa_1 = SelfAttentionBlock(vocab_size, hidden_dim, input_length)
        self.ca_1 = CrossAttentionBlock(vocab_size, hidden_dim, input_length)
        self.sa_2 = SelfAttentionBlock(vocab_size, hidden_dim, input_length)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, y):
        """
            x: 文章
            y: 画像の特徴量
        """
        x = self.embedding(x)

        x, sa_1_weights = self.sa_1(x)
        print(x.shape, y.shape)
        x, ca_1_weights = self.ca_1(x, y)
        x, sa_2_weights = self.sa_2(x)

        x = self.fc(x)
        return x, sa_1_weights, ca_1_weights, sa_2_weights

In [35]:
model = SentenceDecoder(VOCAB_SIZE, HIDDEN_DIM, MAX_SENTENCE_LENGTH + 1)
model = model.to(device)

model(i2.to(device), i1.to(device))

torch.Size([32, 32, 8192]) torch.Size([32, 768])


AssertionError: For batched (3-D) `query`, expected `key` and `value` to be 3-D but found 2-D and 2-D tensors respectively

In [21]:
i1

tensor([[ 0.1283, -0.0771, -0.1870,  ...,  0.0586, -0.0634, -0.0775],
        [-0.0104, -0.2622,  0.3650,  ..., -0.3836, -0.2408, -0.0771],
        [-0.0657, -0.1895, -0.0658,  ..., -0.1393,  0.0847, -0.0212],
        ...,
        [-0.1642, -0.1845, -0.0161,  ...,  0.1604, -0.0831,  0.1903],
        [ 0.0532, -0.1612, -0.1201,  ...,  0.1068, -0.0672,  0.1242],
        [ 0.1493, -0.1318, -0.1077,  ...,  0.0862, -0.4769, -0.1203]])

In [20]:
i2

tensor([[   1, 3353, 1886,  ...,    0,    0,    0],
        [   1,  116,  111,  ...,    0,    0,    0],
        [   1, 2783, 2772,  ...,    0,    0,    0],
        ...,
        [   1, 4264, 6637,  ...,    0,    0,    0],
        [   1, 2562,   49,  ...,    0,    0,    0],
        [   1, 5498, 2531,  ...,    0,    0,    0]], dtype=torch.int32)