In [1]:
import torch
from torchvision import transforms
from PIL import Image
from transformers import ViTForImageClassification, ViTFeatureExtractor

In [2]:
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')



In [3]:
device = 'mps'
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [4]:
def classify_scene(image_path):
    image = Image.open(image_path).convert('RGB')
    inputs = feature_extractor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_idx = logits.argmax(-1).item()
    predicted_label = model.config.id2label[predicted_class_idx]
    return predicted_label

In [5]:
KEYFRAMES_DIR = '/Users/VoThinhPhat/Desktop/data/batch1/keyframes/keyframes_L01/L01_V001/'

# Danh sách tên file keyframes
keyframes = ['00725.jpg', '00751.jpg', '00776.jpg', '00800.jpg', '00828.jpg']

In [6]:
for frame in keyframes:
    image_path = f"{KEYFRAMES_DIR}/{frame}"
    label = classify_scene(image_path)
    print(f"Keyframe: {frame} --> Cảnh: {label}")

Keyframe: 00725.jpg --> Cảnh: crane
Keyframe: 00751.jpg --> Cảnh: web site, website, internet site, site
Keyframe: 00776.jpg --> Cảnh: web site, website, internet site, site
Keyframe: 00800.jpg --> Cảnh: sports car, sport car
Keyframe: 00828.jpg --> Cảnh: web site, website, internet site, site
