In [1]:
import os
import json
import numpy as np
import argparse
import matplotlib.pyplot as plt
import japanize_matplotlib
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

import japanese_clip as ja_clip
from transformers import MLukeTokenizer, LukeModel

In [2]:
class BokeJudgeModel(nn.Module):
    def __init__(self, cif_dim = 512, csf_dim = 512, lsf_dim = 768, feature_dim = 1024):
        """
            cif_dim: CLIPの画像の特徴量の次元数
            csf_dim: CLIPの文章の特徴量の次元数
            lsf_dim: Sentene-LUKEの文章の特徴量の次元数
        """
        super(BokeJudgeModel, self).__init__()
        self.cif_dim = cif_dim
        self.csf_dim = csf_dim
        self.lsf_dim = lsf_dim
        
        self.fc1 = nn.Linear(cif_dim + csf_dim + lsf_dim, feature_dim)
        self.fc2 = nn.Linear(feature_dim, feature_dim)
        self.fc3 = nn.Linear(feature_dim, feature_dim)
        self.output_layer = nn.Linear(feature_dim, 1)
        
    def forward(self, cif, csf, lsf):
        """
            cif: CLIPの画像の特徴量
            csf: CLIPの文章の特徴量
            lsf: Sentence-LUKEの文章の特徴量
        """
        x = torch.cat([cif, csf, lsf], dim = 1)

        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = F.leaky_relu(self.fc3(x))

        output = torch.sigmoid(self.output_layer(x))
        return output

class SentenceLukeJapanese:
    def __init__(self, device = None):
        self.tokenizer = MLukeTokenizer.from_pretrained("sonoisa/sentence-luke-japanese-base-lite")
        self.model = LukeModel.from_pretrained("sonoisa/sentence-luke-japanese-base-lite",
                                               torch_dtype = torch.float16)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    @torch.no_grad()
    def encode(self, sentences, batch_size = 256):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
                                           truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        return torch.stack(all_embeddings)

class BokeJugeAI:
    def __init__(self, weight_path, feature_dim):
        """
            weight_path: 大喜利適合判定モデルの学習済みの重みのパス
        """
        # 大喜利適合判定AIの読み込み
        self.boke_judge_model = BokeJudgeModel(feature_dim = feature_dim)
        self.boke_judge_model.load_state_dict(torch.load(weight_path))
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.boke_judge_model.to(self.device)
        self.boke_judge_model.eval()

        # CLIP
        self.clip_model, self.clip_preprocesser = ja_clip.load("rinna/japanese-clip-vit-b-16",
                                             cache_dir="/tmp/japanese_clip",
                                             torch_dtype = torch.float16,
                                             device = self.device)
        self.clip_tokenizer = ja_clip.load_tokenizer()

        # Sentence-LUKE
        self.luke_model = SentenceLukeJapanese()

    def __call__(self, image_path, sentence):
        """
            image_path: 判定したい大喜利のお題画像
            sentence: 判定したい大喜利
        """
        # CLIPによる特徴量への変換
        tokenized_sentences = ja_clip.tokenize(
            texts = [sentence],
            max_seq_len = 77,
            device = self.device,
            tokenizer = self.clip_tokenizer,
            )
        image = Image.open(image_path)
        preprcessed_image = self.clip_preprocesser(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            clip_image_features = self.clip_model.get_image_features(preprcessed_image)
            clip_sentence_features = self.clip_model.get_text_features(**tokenized_sentences)

        # Sentence-LUKEによる特徴量への変換
        luke_sentence_feature = self.luke_model.encode([sentence])

        # 大喜利適合判定AIの推論
        with torch.no_grad():
            outputs = self.boke_judge_model(clip_image_features,
                                        clip_sentence_features,
                                        luke_sentence_feature.to(self.device))

        return outputs.cpu().numpy()

In [3]:
boke_judge_AI = BokeJugeAI(weight_path = "", 
                           feature_dim = 1024)

  self.boke_judge_model.load_state_dict(torch.load(weight_path))


FileNotFoundError: [Errno 2] No such file or directory: ''