In [1]:
# import shutil
# #
# shutil.rmtree("../datas/encoded/")

In [2]:
import os
import json
import torch
import numpy as np
import japanese_clip as ja_clip
from PIL import Image
from tqdm import tqdm
from transformers import MLukeTokenizer, LukeModel

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocesser = ja_clip.load("rinna/japanese-clip-vit-b-16", 
                                             cache_dir="/tmp/japanese_clip", 
                                             torch_dtype = torch.float16,
                                             device = device)
clip_tokenizer = ja_clip.load_tokenizer()

device

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


'cuda'

In [4]:
class SentenceLukeJapanese:
    def __init__(self, device = None):
        self.tokenizer = MLukeTokenizer.from_pretrained("sonoisa/sentence-luke-japanese-base-lite")
        self.model = LukeModel.from_pretrained("sonoisa/sentence-luke-japanese-base-lite",
                                               torch_dtype = torch.float16)
        self.model.eval()

        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(device)
        self.model.to(device)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0] #First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    @torch.no_grad()
    def encode(self, sentences, batch_size = 256):
        all_embeddings = []
        iterator = range(0, len(sentences), batch_size)
        for batch_idx in iterator:
            batch = sentences[batch_idx:batch_idx + batch_size]

            encoded_input = self.tokenizer.batch_encode_plus(batch, padding="longest",
                                           truncation=True, return_tensors="pt").to(self.device)
            model_output = self.model(**encoded_input)
            sentence_embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"]).to('cpu')

            all_embeddings.extend(sentence_embeddings)

        return torch.stack(all_embeddings)

luke_model = SentenceLukeJapanese()

In [5]:
DATA_DIR = "../datas/boke_data_assemble/"
IMAGE_DIR = "../datas/boke_image/"

TARGET_DIR = "../datas/encoded/"
TARGET_CLIP_IMAGE_FEATURE_DIR = f"{TARGET_DIR}clip_image_feature/"
TARGET_CLIP_SENTENCE_FEATURE_DIR = f"{TARGET_DIR}clip_sentence_feature/"
TARGET_CLIP_BOKE_FEATURE_DIR = f"{TARGET_DIR}clip_sentence_feature/boke/"
TARGET_CLIP_CAPTION_FEATURE_DIR = f"{TARGET_DIR}clip_sentence_feature/caption/"
TARGET_LUKE_SENTENCE_FEATURE_DIR = f"{TARGET_DIR}luke_sentence_feature/"
TARGET_LUKE_BOKE_FEATURE_DIR = f"{TARGET_DIR}luke_sentence_feature/boke/"
TARGET_LUKE_CAPTION_FEATURE_DIR = f"{TARGET_DIR}luke_sentence_feature/caption/"

tmp = [TARGET_DIR, 
       TARGET_CLIP_IMAGE_FEATURE_DIR, 
       TARGET_CLIP_SENTENCE_FEATURE_DIR, TARGET_CLIP_BOKE_FEATURE_DIR, TARGET_CLIP_CAPTION_FEATURE_DIR,
       TARGET_LUKE_SENTENCE_FEATURE_DIR, TARGET_LUKE_BOKE_FEATURE_DIR, TARGET_LUKE_CAPTION_FEATURE_DIR]
for D in tmp:
    if not os.path.exists(D): os.mkdir(D)

len(os.listdir(DATA_DIR)), len(os.listdir(IMAGE_DIR))

(668970, 668982)

画像のパス

../datas/encoded/clip_image_feature/871086.npy

文章のパス

../datas/encoded/clip_sentence_feature/caption/871086.npy

../datas/encoded/clip_sentence_feature/boke/871086/5526.npy



In [6]:
for JP in tqdm(os.listdir(DATA_DIR)):
    N = int(JP.split(".")[0])
    if os.path.exists(f"{TARGET_LUKE_SENTENCE_FEATURE_DIR}{N}.npy"):
        continue

    image_path = f"{IMAGE_DIR}{N}.jpg"
    if not os.path.exists(image_path): continue

    with open(f"{DATA_DIR}{JP}", "r") as f:
        a = json.load(f)

    bokes = [A["boke"] for A in a["bokes"]]
    caption = a["image_information"]["ja_caption"]
    sentences = bokes + [caption]

    encoded_sentences = ja_clip.tokenize(
        texts = sentences,
        max_seq_len = 77,
        device = device,
        tokenizer = clip_tokenizer,
    )
    image = Image.open(image_path)
    preprcessed_image = clip_preprocesser(image).unsqueeze(0).to(device)
    with torch.no_grad():
        clip_image_feature = clip_model.get_image_features(preprcessed_image)
        clip_image_feature = clip_image_feature.cpu().numpy()[0]
        clip_sentence_features = clip_model.get_text_features(**encoded_sentences)
        clip_sentence_features = clip_sentence_features.cpu().numpy()
    clip_boke_features = clip_sentence_features[:-1]
    clip_caption_feature = clip_sentence_features[-1]

    luke_sentence_features = luke_model.encode(sentences).cpu().numpy()
    luke_boke_features = luke_sentence_features[:-1]
    luke_caption_feature = luke_sentence_features[-1]

    TMP_CLIP_BOKE_FEATURE_DIR = f"{TARGET_CLIP_BOKE_FEATURE_DIR}{N}/"
    if not os.path.exists(TMP_CLIP_BOKE_FEATURE_DIR):
        os.mkdir(TMP_CLIP_BOKE_FEATURE_DIR)
    TMP_LUKE_BOKE_FEATURE_DIR = f"{TARGET_LUKE_BOKE_FEATURE_DIR}{N}/"
    if not os.path.exists(TMP_LUKE_BOKE_FEATURE_DIR):
        os.mkdir(TMP_LUKE_BOKE_FEATURE_DIR)

    for i in range(len(bokes)):
        
        np.save(f"{TMP_CLIP_BOKE_FEATURE_DIR}{i}", clip_boke_features[i])
        np.save(f"{TMP_LUKE_BOKE_FEATURE_DIR}{i}", luke_boke_features[i])

    np.save(f"{TARGET_CLIP_IMAGE_FEATURE_DIR}{N}", clip_image_feature)
    np.save(f"{TARGET_CLIP_CAPTION_FEATURE_DIR}{N}", clip_caption_feature)
    np.save(f"{TARGET_LUKE_CAPTION_FEATURE_DIR}{N}", luke_caption_feature)

100%|██████████| 668970/668970 [3:41:02<00:00, 50.44it/s]  
