## Vision Transformers on custom dataset
- This notebook shows how to train a lightweight Vision Transformer using your own dataset (e.g. chicken disease detection)
- You will Train ViT-Tiny (DeiT-Tiny)

In [None]:
# 1. Install dependencies
!pip install transformers datasets torchvision accelerate --quiet

In [None]:
# 2. Mount Google Drive (optional if your dataset is there)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 3. Load image folder dataset
from datasets import load_dataset

# Change path below to your dataset path
data_path = "/content/drive/MyDrive/poultry_dataset"
dataset = load_dataset("imagefolder", data_dir=data_path)
dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)

In [None]:
# 4. Load ViT processor and model
from transformers import AutoImageProcessor, ViTForImageClassification

checkpoint = "facebook/deit-tiny-patch16-224"
processor = AutoImageProcessor.from_pretrained(checkpoint)

labels = dataset["train"].features["label"].names
id2label = {i: l for i, l in enumerate(labels)}
label2id = {l: i for i, l in enumerate(labels)}

model = ViTForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

In [None]:
# 5. Preprocess dataset
from PIL import Image

def preprocess(example):
    image = Image.open(example["image"]).convert("RGB")
    inputs = processor(image, return_tensors="pt")
    example["pixel_values"] = inputs["pixel_values"].squeeze().numpy()
    example["labels"] = example["label"]
    return example

dataset = dataset.map(preprocess)
dataset.set_format(type="torch", columns=["pixel_values", "labels"])

In [None]:
# 6. Wrap into PyTorch Dataset + DataLoader
from torch.utils.data import Dataset, DataLoader

class TorchImageDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        return {
            "pixel_values": item["pixel_values"],
            "labels": item["labels"]
        }

train_loader = DataLoader(TorchImageDataset(dataset["train"]), batch_size=16, shuffle=True)
test_loader = DataLoader(TorchImageDataset(dataset["test"]), batch_size=16)

In [None]:
# 7. Train
import torch
from torch.optim import AdamW
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    loop = tqdm(train_loader, leave=True)
    total_loss = 0
    for batch in loop:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        loop.set_description(f"Epoch {epoch+1}")
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1} avg loss: {total_loss / len(train_loader):.4f}")

In [None]:
# 8. Evaluate accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        outputs = model(pixel_values=pixel_values)
        preds = outputs.logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test accuracy: {correct / total:.4f}")

In [None]:
model.save_pretrained("vit-tiny-poultry")
processor.save_pretrained("vit-tiny-poultry")

### Prototype on Gradio

In [None]:
!pip install gradio

In [None]:
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch

In [None]:
# Load saved model
model_path = "./vit-tiny-poultry"
model = ViTForImageClassification.from_pretrained(model_path)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_path)


In [None]:
# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
def predict(image):
    # Preprocess image
    inputs = feature_extractor(images=image, return_tensors="pt").to(device)

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Get probabilities
    probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]

    # Return dictionary of {class: probability}
    return {model.config.id2label[i]: float(prob) for i, prob in enumerate(probabilities)}


In [None]:
import gradio as gr

In [None]:
# Get class names (same as folder names)
class_names = list(model.config.id2label.values())

In [None]:
# Define Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Poultry Image"),
    outputs=gr.Label(num_top_classes=3, label="Prediction"),
    examples=[
        ["chicken_healthy.jpg"],  # Replace with actual sample paths
        ["chicken_cocci.jpg"],
        ["chicken_ncd.jpg"]
    ],
    title="🐔 Poultry Classifier (Vision Transformer)",
    description="Upload an image of a chicken dropping to classify it.",
    allow_flagging="never"
)

In [None]:
# Run in Colab (creates a shareable link)
demo.launch(share=True)

### TO DO
1. Consider different ViT variants:
   - `google/vit-base-patch16-224-in21k` 
   - `google/vit-small-patch16-224-in21k` (smaller, faster)

2. Model evaluation using Confusion Matrix

### Key Comparison features of Transformers to FastAI:

1. Directory Structure:
   - Identical folder-per-class structure as FastAI
   - Automatically infers labels from folder names

2. Data Loading:
   - `load_dataset("imagefolder")` replaces FastAI's `DataBlock`
   - Automatically handles train/validation splits if `train`/`test` folders exist

3. Transforms:
   - `ViTFeatureExtractor` handles normalization/resizing like FastAI's `item_tfms`
   

4. Training:
   - `Trainer` class provides similar high-level interface to FastAI's `Learner`
  

5. Model Saving:
   - Saves model and preprocessing in one directory (like FastAI's `.export()`)
