In [1]:
import os
import json
import numpy as np
import shutil
import argparse
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from transformers import ViTImageProcessor, ViTModel
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms

In [4]:
RESULT_DIR = "/home/user/workspace/Master_Thesis/results/Neural_Joking_Machine/False_False_False_0_32_4_31_25_64_0.0001_1024/"

with open(f"{RESULT_DIR}index_to_word.json", "r") as f:
    index_to_word = json.load(f)
index_to_word = {
    int(K): V for K, V in index_to_word.items()
}

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, feature_dim):
        """
            max_len: 
            feature_dim: 
        """
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, feature_dim)
        position = torch.arange(0, max_len, dtype = torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, feature_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / feature_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe

class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(self, feature_dim, num_heads):
        super(MultiHeadSelfAttentionBlock, self).__init__()

        self.self_attention = nn.MultiheadAttention(embed_dim = feature_dim, 
                                                    num_heads = num_heads, batch_first = True)
        self.layer_norm1 = nn.LayerNorm(feature_dim)

        self.fc = nn.Linear(feature_dim, feature_dim)
        self.layer_norm2 = nn.LayerNorm(feature_dim)
    
    def forward(self, x, attn_mask = None):
        # self attention
        x_dash, _ = self.self_attention(x, x, x, attn_mask = attn_mask)
        x = self.layer_norm1(x_dash) + x

        # feed forward
        x_dash = F.leaky_relu( self.fc(x) )
        return self.layer_norm2(x_dash) + x

class MultiHeadCrossAttentionBlock(nn.Module):
    def __init__(self, feature_dim, num_heads):
        super(MultiHeadCrossAttentionBlock, self).__init__()

        self.self_attention = nn.MultiheadAttention(embed_dim = feature_dim, 
                                                    num_heads = num_heads, batch_first = True)
        self.layer_norm1 = nn.LayerNorm(feature_dim)

        self.cross_attention = nn.MultiheadAttention(embed_dim = feature_dim, 
                                                     num_heads = num_heads, batch_first = True)
        self.layer_norm2 = nn.LayerNorm(feature_dim)

        self.fc = nn.Linear(feature_dim, feature_dim)
        self.layer_norm3 = nn.LayerNorm(feature_dim)
    
    def forward(self, src, tgt, attn_mask = None):
        # self attention
        tgt_dash, _ = self.self_attention(tgt, tgt, tgt, attn_mask = attn_mask)
        tgt = self.layer_norm1(tgt_dash) + tgt
        
        # cross attention
        src_dash, _ = self.cross_attention(tgt, src, src)
        tgt = self.layer_norm2(src_dash) + tgt

        # feed forward
        tgt_dash = F.leaky_relu( self.fc(tgt) )
        return self.layer_norm3(tgt_dash) + tgt

# 大喜利生成モデルのクラス
class TransformerBokeGeneratorModel(nn.Module):
    def __init__(self, num_image_patch, image_feature_dim, num_word, sentence_length, feature_dim = 1024, num_heads = 4):
        """
            num_image_patch
            num_word: 学習に用いる単語の総数
            image_feature_dim: 画像の特徴量の次元数
            sentence_length: 入力する文章の単語数
            feature_dim: 単語の埋め込み次元数
            num_heads: 
        """
        super(TransformerBokeGeneratorModel, self).__init__()
        self.num_word = num_word
        self.image_feature_dim = image_feature_dim
        self.sentence_length = sentence_length
        self.feature_dim = feature_dim
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # image encoder
        self.fc1 = nn.Linear(image_feature_dim, feature_dim)
        self.pe1 = PositionalEncoding(num_image_patch, feature_dim)
        self.self_attention = MultiHeadSelfAttentionBlock(feature_dim = feature_dim, num_heads = num_heads)
        
        # sentence decoder
        self.embedding = nn.Embedding(num_word, feature_dim, padding_idx = 0)
        self.pe2 = PositionalEncoding(sentence_length, feature_dim)

        self.cross_attention1 = MultiHeadCrossAttentionBlock(feature_dim = feature_dim, num_heads = num_heads)
        self.cross_attention2 = MultiHeadCrossAttentionBlock(feature_dim = feature_dim, num_heads = num_heads)
        self.cross_attention3 = MultiHeadCrossAttentionBlock(feature_dim = feature_dim, num_heads = num_heads)
        self.fc2 = nn.Linear(feature_dim, num_word)

    def forward(self, image_features, sentences):
        """
            image_features: 画像の特徴量(batch_size, num_patch, image_feature_dim)
            sentences: 入力する文章(batch_size, sentence_length)
        """
        # encode image 
        src = F.leaky_relu( self.fc1( image_features ) )
        src = self.pe1( src )
        src = self.self_attention( src )

        # decode sentence
        tgt = self.embedding( sentences )
        tgt = self.pe2( tgt )

        attn_mask = torch.triu(torch.ones(self.sentence_length, self.sentence_length), diagonal = 1)
        attn_mask = attn_mask.masked_fill(attn_mask == 1, float('-inf')).to(self.device)

        tgt = self.cross_attention1( tgt = tgt, src = src, attn_mask = attn_mask )
        tgt = self.cross_attention2( tgt = tgt, src = src, attn_mask = attn_mask )
        tgt = self.cross_attention3( tgt = tgt, src = src, attn_mask = attn_mask  )
        return self.fc2( tgt )

class GUMI_T:
    def __init__(self, weight_path, index_to_word, sentence_length, feature_dim = 1024, num_heads = 8):
        """
            weight_path: 大喜利適合判定モデルの学習済みの重みのパス
            index_to_word: 単語のID: 単語の辞書(0:<PAD>, 1:<START>, 2:<END>)
            sentence_length: 入力する文章の単語数
            feature_dim: 特徴量次元数
        """
        self.index_to_word = index_to_word
        self.sentence_length = sentence_length

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.boke_generate_model = TransformerBokeGeneratorModel(num_image_patch = 197,
                                      image_feature_dim = 1024,
                                      num_word = len(index_to_word),
                                      sentence_length = sentence_length,
                                      feature_dim = feature_dim,
                                      num_heads = num_heads)
        self.boke_generate_model.load_state_dict(torch.load(weight_path))
        self.boke_generate_model.to(self.device)
        self.boke_generate_model.eval()

        self.vit = ViTModel.from_pretrained('google/vit-large-patch16-224-in21k')
        self.vit = self.vit.to(self.device)
        self.vit.eval()

        # 画像の前処理
        self.image_preprocesser = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

    def __call__(self, image_path, argmax = False, top_k = 5):
        """
            image_path: 大喜利を生成したい画像のパス
            argmax: Trueなら最大確率の単語を選ぶ, FalseならTop-Kサンプリングを行う
            top_k: Top-Kサンプリング時に考慮する単語の数
        """
        image = Image.open(image_path)
        preprocessed_image = self.image_preprocesser(np.array([image]), return_tensors="pt").to(self.device)
        outputs = self.vit( **preprocessed_image )
        image_feature = outputs.last_hidden_state # (1, 197, 1024)
        print(image_feature.shape)

        generated_text = [1] # <START>トークン
        for i in range(1, self.sentence_length):
            tmp = generated_text + [0] * (self.sentence_length - i) # Padding
            tmp = torch.Tensor(np.array(tmp)).unsqueeze(0).to(self.device).to(dtype=torch.int32) # (1, sentence_length)
            pred = self.boke_generate_model(image_feature, tmp) # (1, sentence_length, num_word)
            target_pred = pred[0][i - 1]

            if argmax:
                # 最大確率の単語を選ぶ
                chosen_id = torch.argmax(target_pred).item()
            else:
                # Top-Kサンプリング
                top_k_probs, top_k_indices = torch.topk(target_pred, top_k)
                top_k_probs = torch.nn.functional.softmax(top_k_probs, dim = -1)
                chosen_id = np.random.choice(top_k_indices.detach().cpu().numpy(),
                                             p = top_k_probs.detach().cpu().numpy())

            generated_text.append(chosen_id)
            if chosen_id == 2:
                break

        generated_sentence = ""
        for I in generated_text[1:-1]:
            generated_sentence += self.index_to_word[I]
        return generated_sentence

In [8]:
gumi_t = GUMI_T(weight_path = f"{RESULT_DIR}best_model.pth",
                          index_to_word = index_to_word,
                          sentence_length = 32)

  self.boke_generate_model.load_state_dict(torch.load(weight_path))


In [15]:
image_path = "/home/user/workspace/Master_Thesis/Master_Thesis_programs/image_1.jpg"

fig = plt.figure()
ax = fig.add_subplot()
ax.imshow(Image.open(image_path))
ax.axis("off")

gumi_t(image_path)

'この階段、上がって来いよ！空港で、酔って帰ってきたら。'