
# ViTによるドライバー眠気検知（Drowsy vs Non Drowsy）
このNotebookは、既存のEAR（Eye Aspect Ratio）ベースのCNN/NNパイプラインを**Vision Transformer (ViT)** に置き換え、**画像から直接**「Drowsy / Non Drowsy」を二値分類します。  
Hugging Face Transformers の `ViTForImageClassification` を使用し、画像フォルダから学習します。

## 構成
- ライブラリ準備（必要に応じてインストール）
- データ読み込み（フォルダ構成: `DATA_DIR/Drowsy`, `DATA_DIR/Non Drowsy`）
- ViTモデル構築・学習
- 評価
- モデル保存・読み込み
- 画像単体推論
- **Webカメラ（または動画）からのリアルタイム推論（OpenCV）**


In [15]:

# 必要に応じてインストール（オフライン環境ではスキップしてください）
# %pip install torch torchvision transformers pillow opencv-python scikit-learn


In [16]:

import os
import math
import time
from pathlib import Path
from dataclasses import dataclass

import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from transformers import AutoImageProcessor, ViTForImageClassification

from PIL import Image
import numpy as np

# 評価用
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# 推論・表示
import cv2

# ===== ユーザー設定（必要に応じて変更） =====
DATA_DIR = "./Driver Drowsiness Dataset (DDD)"  # 例: /path/to/dataset, 階層: DATA_DIR/Drowsy, DATA_DIR/Non Drowsy
OUTPUT_DIR = "./vit_drowsiness_model"
MODEL_NAME = "google/vit-base-patch16-224-in21k"
NUM_EPOCHS = 5
BATCH_SIZE = 16
VAL_RATIO = 0.2
RANDOM_SEED = 42
NUM_WORKERS = 2
IMAGE_SIZE = 224
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-4
NUM_WORKERS = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
os.makedirs(OUTPUT_DIR, exist_ok=True)


Device: cpu


In [17]:

# 画像前処理（ViTの事前学習時の正規化値を使用）
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
mean = processor.image_mean
std = processor.image_std

train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# ImageFolder前提: DATA_DIR/クラス名/画像ファイル
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=train_transform)

# クラス名が 'Drowsy' と 'Non Drowsy' などスペースを含む可能性に対応
class_names = full_dataset.classes
print("Detected classes:", class_names)

# 学習/検証分割
val_len = int(len(full_dataset) * VAL_RATIO)
train_len = len(full_dataset) - val_len
train_dataset, val_dataset = random_split(full_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(RANDOM_SEED))
max_train_samples = 2000
if len(train_dataset) > max_train_samples:
    indices = torch.randperm(len(train_dataset))[:max_train_samples]
    train_dataset = torch.utils.data.Subset(train_dataset, indices)

# 検証データはval_transformに切替
val_dataset.dataset.transform = val_transform  # type: ignore

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

id2label = {i: name for i, name in enumerate(class_names)}
label2id = {name: i for i, name in id2label.items()}
print("id2label:", id2label)


Detected classes: ['Drowsy', 'Non Drowsy']
id2label: {0: 'Drowsy', 1: 'Non Drowsy'}


In [18]:

model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(class_names),
    id2label=id2label,
    label2id=label2id,
)
model.to(device)

# Optimizerなど
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()


Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:

def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    from tqdm import tqdm
    for images, targets in tqdm(loader, desc='Training', leave=False):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += targets.size(0)

    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc = correct / total if total > 0 else 0.0
    return epoch_loss, epoch_acc


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_targets = []
    all_preds = []

    from tqdm import tqdm
    for images, targets in tqdm(loader, desc='Training', leave=False):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(images).logits
        loss = criterion(outputs, targets)

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += targets.size(0)

        all_targets.extend(targets.detach().cpu().numpy().tolist())
        all_preds.extend(preds.detach().cpu().numpy().tolist())

    epoch_loss = running_loss / total if total > 0 else 0.0
    epoch_acc = correct / total if total > 0 else 0.0

    report = classification_report(all_targets, all_preds, target_names=[id2label[i] for i in range(len(id2label))], digits=4)
    cm = confusion_matrix(all_targets, all_preds)
    return epoch_loss, epoch_acc, report, cm


best_val_acc = 0.0
for epoch in range(1, NUM_EPOCHS + 1):
    start = time.time()
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, report, cm = evaluate(model, val_loader, criterion, device)
    elapsed = time.time() - start

    print(f"[Epoch {epoch}/{NUM_EPOCHS}] "
          f"train_loss: {train_loss:.4f}  train_acc: {train_acc:.4f} | "
          f"val_loss: {val_loss:.4f}  val_acc: {val_acc:.4f} | {elapsed:.1f}s")
    print("Classification Report:\n", report)
    print("Confusion Matrix:\n", cm)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # ベストモデル保存
        model.save_pretrained(OUTPUT_DIR)
        processor.save_pretrained(OUTPUT_DIR)
        print(f"Saved best model to {OUTPUT_DIR} (val_acc={best_val_acc:.4f})")


                                                           

[Epoch 1/5] train_loss: 0.1783  train_acc: 0.9450 | val_loss: 0.0206  val_acc: 0.9988 | 2215.8s
Classification Report:
               precision    recall  f1-score   support

      Drowsy     0.9984    0.9993    0.9989      4469
  Non Drowsy     0.9992    0.9982    0.9987      3889

    accuracy                         0.9988      8358
   macro avg     0.9988    0.9988    0.9988      8358
weighted avg     0.9988    0.9988    0.9988      8358

Confusion Matrix:
 [[4466    3]
 [   7 3882]]
Saved best model to ./vit_drowsiness_model (val_acc=0.9988)


                                                           

[Epoch 2/5] train_loss: 0.0123  train_acc: 0.9995 | val_loss: 0.0085  val_acc: 0.9996 | 2045.6s
Classification Report:
               precision    recall  f1-score   support

      Drowsy     1.0000    0.9993    0.9997      4469
  Non Drowsy     0.9992    1.0000    0.9996      3889

    accuracy                         0.9996      8358
   macro avg     0.9996    0.9997    0.9996      8358
weighted avg     0.9996    0.9996    0.9996      8358

Confusion Matrix:
 [[4466    3]
 [   0 3889]]
Saved best model to ./vit_drowsiness_model (val_acc=0.9996)


                                                         

KeyboardInterrupt: 

In [20]:

# 保存済みモデルの読み込み（学習後や別セッションで使用）
loaded_processor = AutoImageProcessor.from_pretrained(OUTPUT_DIR if os.path.exists(OUTPUT_DIR) else MODEL_NAME)
loaded_model = ViTForImageClassification.from_pretrained(OUTPUT_DIR if os.path.exists(OUTPUT_DIR) else MODEL_NAME)
loaded_model.to(device)
loaded_model.eval()
print("Loaded model from:", OUTPUT_DIR if os.path.exists(OUTPUT_DIR) else MODEL_NAME)


  return torch.load(checkpoint_file, map_location="cpu")


Loaded model from: ./vit_drowsiness_model


In [21]:

@torch.no_grad()
def predict_image(img_path, model, processor, device):
    img = Image.open(img_path).convert("RGB")
    img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
    pixel_values = transforms.ToTensor()(img)
    pixel_values = transforms.Normalize(mean=mean, std=std)(pixel_values)
    pixel_values = pixel_values.unsqueeze(0).to(device)

    logits = model(pixel_values).logits
    probs = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
    pred_id = int(np.argmax(probs))
    pred_label = model.config.id2label[pred_id] if hasattr(model.config, "id2label") else str(pred_id)
    return pred_label, probs

# 使用例
# test_img = "/path/to/test.jpg"
# label, prob = predict_image(test_img, loaded_model, loaded_processor, device)
# print(label, prob)


In [23]:

def put_text_with_bg(frame, text, org=(10,30)):
    font = cv2.FONT_HERSHEY_SIMPLEX
    scale = 0.8
    thickness = 2
    (w, h), baseline = cv2.getTextSize(text, font, scale, thickness)
    x, y = org
    cv2.rectangle(frame, (x-5, y-h-5), (x+w+5, y+baseline+5), (0,0,0), -1)
    cv2.putText(frame, text, (x, y), font, scale, (255,255,255), thickness, cv2.LINE_AA)

@torch.no_grad()
def infer_from_capture(model, processor, device, source=0, window_title="ViT Drowsiness Detection", conf_display=True):
    cap = cv2.VideoCapture(source)
    if not cap.isOpened():
        print("Failed to open video source:", source)
        return

    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # BGR -> RGB
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil = Image.fromarray(rgb).resize((IMAGE_SIZE, IMAGE_SIZE))

            # 前処理
            tensor = transforms.ToTensor()(pil)
            tensor = transforms.Normalize(mean=mean, std=std)(tensor)
            tensor = tensor.unsqueeze(0).to(device)

            logits = model(tensor).logits
            probs_t = torch.softmax(logits, dim=-1).squeeze(0).detach().cpu().numpy()
            pred_id = int(np.argmax(probs_t))
            pred_label = model.config.id2label[pred_id] if hasattr(model.config, "id2label") else str(pred_id)
            conf = float(probs_t[pred_id])

            # 画面に表示
            text = f"{pred_label} ({conf:.2f})" if conf_display else pred_label
            put_text_with_bg(frame, text, (10, 30))

            cv2.imshow(window_title, frame)
            key = cv2.waitKey(1) & 0xFF
            if key == ord('q'):
                break
    finally:
        cap.release()
        cv2.destroyAllWindows()

# 使用例:
# 学習が終わり、出力フォルダにベストモデルが保存された後に実行してください。
infer_from_capture(loaded_model, loaded_processor, device, source=0)





## メモ・ヒント
- データ構成は `DATA_DIR/<class_name>/*.jpg` 形式を前提としています。クラス名はフォルダ名が使われます（例: `Drowsy`, `Non Drowsy`）。
- 画像数が少ない場合は**データ拡張**（RandomResizedCrop, ColorJitter等）の追加をご検討ください。
- 転移学習を高速化したい場合は、最終層以外のパラメータを一時的に凍結して学習し、後半でUnfreezeする戦略も有効です。
- OpenCVのカメラ入力は環境によって異なります。`source=0` をWebカメラ、動画ファイルパスを指定すれば動画でも動作します。
- 精度が十分でない場合、`MODEL_NAME` を `vit-large` 系に変える、エポック数や学習率の調整、画像中心切り出し（顔検出＋トリミング）などを検討してください。
- EARベースとの**ハイブリッド**（顔検出→目領域拡大→ViTなど）で頑健性が上がる場合もあります。
