In [1]:
import torch
import open_clip
from open_clip import tokenizer
from deep_translator import GoogleTranslator
from collections import defaultdict
import os
import shutil
import random
from PIL import Image,ImageOps
import matplotlib.pyplot as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# シードの固定
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
# RN50x64
high_model, _, high_preprocess = open_clip.create_model_and_transforms('RN50x64', pretrained='openai')
high_model.eval()
high_model = high_model.to(device)



In [5]:
# DeeplTranslatorのインスタンスを作成
translator = GoogleTranslator(source='ja', target='en')

# 辞書の定義
descriptions = {"YES_bird":"鳥の死骸が落ちている",
                "NO_bird":"鳥の死骸が落ちていない",
                "NO_bird2":"鳥の死骸ではなく物体が落ちている",
                "NO_bird3":"何も落ちていない"
                }

descriptions = {translator.translate(key): translator.translate(value) for key, value in descriptions.items()}

# 翻訳した辞書を表示
print(descriptions)

{'YES_bird': 'Dead birds are lying around', 'NO_bird': 'There are no dead birds lying around', 'NO_bird2': 'An object has fallen, not a dead bird', 'NO_bird3': 'Nothing has fallen'}


In [6]:
# テキストのエンコーディング
text_inputs = torch.cat([open_clip.tokenize(descriptions[key]) for key in descriptions]).to(device)

# 画像フォルダのパス
image_folder = 'imgs/test_objects_bird'

# 画像の分類
yes_bird_count = 0

for image_name in os.listdir(image_folder):
    image_path = os.path.join(image_folder, image_name)
    image = Image.open(image_path).convert("RGB")
    image_input = high_preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():
        image_features = high_model.encode_image(image_input)
        text_features = high_model.encode_text(text_inputs)

    # 類似度の計算
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (image_features @ text_features.T).squeeze(0)

    # 最も類似度が高いテキストを選択
    predicted_label = similarity.argmax().item()
    predicted_description = list(descriptions.keys())[predicted_label]

    # YES_bird に含まれているかどうかの判定
    if predicted_description == "YES_bird":
        yes_bird_count += 1
        if image_name == 'bird.jpg':
            print(f"{image_name} は YES_bird に含まれています。")

print(f"合計で {yes_bird_count} 枚が YES_bird に含まれていました。")

bird.jpg は YES_bird に含まれています。
合計で 11 枚が YES_bird に含まれていました。
