<a href="https://colab.research.google.com/github/yagiyuki/clip-study-playground/blob/main/CLIP_acceleration_batch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CLIPで大量に画像分類するときのバッチ処理による高速化実装


In [6]:
%%bash
# 画像を管理するパスを作成
mkdir -p data

# dataディレクトリにsample.jpegと言う名前で画像ファイルを上げる

In [14]:
%%bash
# 画像を999枚複製（合計1000枚）
for i in {1..1000}; do
  cp data/sample.jpeg data/sample_${i}.jpeg
done

In [15]:
import glob
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
HF_MODEL_PATH = 'line-corporation/clip-japanese-base'

tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(HF_MODEL_PATH, trust_remote_code=True).to(device)

## ベースライン

In [16]:
%%time
# 単純な推論ループ
with torch.no_grad():
    text = tokenizer(["ベンチプレス", "スクワット", "デッドリフト"]).to(device)
    text_feats = model.get_text_features(**text)
    for path in glob.glob('data/*'):
        img = Image.open(path).convert("RGB")
        inputs = processor(images=[img], return_tensors="pt").to(device)

        # テキスト特徴量を計算
        img_feats  = model.get_image_features(inputs.pixel_values)
        probs = (img_feats @ text_feats.T).softmax(dim=-1)
        # …結果処理…
        #print(path, probs)

CPU times: user 30.3 s, sys: 128 ms, total: 30.4 s
Wall time: 30.5 s


## 実装 1 – 手動スライス

In [17]:
%%time

batch_size = 50

with torch.no_grad():
    text = tokenizer(["ベンチプレス", "スクワット", "デッドリフト"]).to(device)
    text_feats = model.get_text_features(**text)

    image_paths = glob.glob('data/*')
    for i in range(0, len(image_paths), batch_size):
        # ここだけが追加：複数画像を一度に処理
        batch = image_paths[i:i+batch_size]
        imgs = [Image.open(p).convert("RGB") for p in batch]
        inputs = processor(images=imgs, return_tensors="pt").to(device)

        img_feats = model.get_image_features(inputs.pixel_values)

        probs = (img_feats @ text_feats.T).softmax(dim=-1)
        # …結果処理…
        #for path, p in zip(batch, probs):
        #    print(path, p)

CPU times: user 16.9 s, sys: 271 ms, total: 17.2 s
Wall time: 17.2 s


## 実装 2 – DataLoader にパスを直渡

In [18]:
%%time
from torch.utils.data import DataLoader

# 画像パス一覧
image_paths = glob.glob('data/*')

# DataLoader で自動的に path のリストをバッチ化
dataloader = DataLoader(
    image_paths,
    batch_size=50,
    shuffle=False
)

with torch.no_grad():
    for batch_paths in dataloader:
        # batch_paths は文字列のリスト
        imgs = [Image.open(p).convert("RGB") for p in batch_paths]
        inputs = processor(images=imgs, return_tensors="pt").to(device)
        img_feats = model.get_image_features(inputs.pixel_values)
        probs = (img_feats @ text_feats.T).softmax(dim=-1)
        # …結果処理…
        #for path, p in zip(batch_paths, probs):
        #    print(path, p)

CPU times: user 17.5 s, sys: 69.6 ms, total: 17.5 s
Wall time: 17.6 s


In [19]:
%%bash
free -h



               total        used        free      shared  buff/cache   available
Mem:            12Gi       2.9Gi       2.2Gi        16Mi       7.6Gi       9.5Gi
Swap:             0B          0B          0B
