# Inference Example
This notebook demonstrates how to load a trained ViT model and perform inference on a single image.

In [None]:
from PIL import Image
import torch
from src.models.vit import LitViT
from src.dataload.dataset import feature_extractor, ids2label

# Path to your trained checkpoint
path_to_checkpoint = '/path/to/checkpoint.ckpt'  # Update this path

# Load the trained model
model = LitViT.load_from_checkpoint(path_to_checkpoint,
    num_labels=len(ids2label),
    id2label=ids2label,
    label2id={v: k for k, v in ids2label.items()},
    class_weights=torch.ones(len(ids2label))
)
model.eval()

# Load and preprocess the image
image = Image.open('test_casual.jpg').convert('RGB')
inputs = feature_extractor(images=image, return_tensors='pt')

# Run inference
with torch.no_grad():
    outputs = model(inputs['pixel_values'])
    predicted_class_idx = outputs.argmax(-1).item()

predicted_label = ids2label[predicted_class_idx]
print(f'Predicted style: {predicted_label}')