In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from google.colab import drive

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 8
NUM_WORKERS = 2
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
drive.mount('/content/drive')
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/val/',   transform=val_transforms)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=9,
    ignore_mismatched_sizes=True
)

model.vit.requires_grad_(False)


model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, 9)
)

Batch shape: torch.Size([8, 3, 224, 224])
Labels shape: torch.Size([8])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([9]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([9, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3,
)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        logits = outputs.logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

    return running_loss / total, correct / total

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            logits = outputs.logits

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    return val_loss / total, correct / total

In [None]:
num_epochs = 30
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "ViT.pth")

Epoch 1/30: Train loss 0.5262, acc 0.8281 | Val   loss 0.3018, acc 0.8939
Epoch 2/30: Train loss 0.2613, acc 0.9133 | Val   loss 0.2469, acc 0.9101
Epoch 3/30: Train loss 0.2324, acc 0.9196 | Val   loss 0.2005, acc 0.9364
Epoch 4/30: Train loss 0.2082, acc 0.9261 | Val   loss 0.1869, acc 0.9354
Epoch 5/30: Train loss 0.1880, acc 0.9330 | Val   loss 0.2202, acc 0.9263
Epoch 6/30: Train loss 0.2048, acc 0.9293 | Val   loss 0.1760, acc 0.9354
Epoch 7/30: Train loss 0.1745, acc 0.9328 | Val   loss 0.1576, acc 0.9424
Epoch 8/30: Train loss 0.1692, acc 0.9377 | Val   loss 0.1533, acc 0.9475
Epoch 9/30: Train loss 0.1712, acc 0.9351 | Val   loss 0.1789, acc 0.9444
Epoch 10/30: Train loss 0.1827, acc 0.9347 | Val   loss 0.1821, acc 0.9404
Epoch 11/30: Train loss 0.1589, acc 0.9401 | Val   loss 0.1723, acc 0.9424
Epoch 12/30: Train loss 0.1649, acc 0.9401 | Val   loss 0.1611, acc 0.9434
Epoch 13/30: Train loss 0.1722, acc 0.9397 | Val   loss 0.1500, acc 0.9404
Epoch 14/30: Train loss 0.1725, ac

In [None]:
save_path = '/content/drive/MyDrive/ViT.pth'

torch.save(model.state_dict(), save_path)
print("Model saved successfully to Google Drive!")

Model saved successfully to Google Drive!


In [None]:
import os

save_dir = '/content/drive/MyDrive/models'
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, 'ViT.pth')

torch.save(model.state_dict(), save_path)
print(f"Model saved at: {save_path}")

Model saved at: /content/drive/MyDrive/models/ViT.pth
