In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import ViTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from google.colab import drive

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 8
NUM_WORKERS = 2
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
drive.mount('/content/drive')
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/val/',   transform=val_transforms)

Mounted at /content/drive


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=9,
    ignore_mismatched_sizes=True
)
model.vit.requires_grad_(False)

model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, 9)
)

Batch shape: torch.Size([8, 3, 224, 224])
Labels shape: torch.Size([8])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([9]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([9, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3,
)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        logits = outputs.logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

    return running_loss / total, correct / total

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            logits = outputs.logits

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    return val_loss / total, correct / total

In [None]:
num_epochs = 15
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "ViT-15epoch.pth")

Epoch 1/15: Train loss 0.5317, acc 0.8266 | Val   loss 0.3019, acc 0.9061
Epoch 2/15: Train loss 0.2654, acc 0.9122 | Val   loss 0.2595, acc 0.9071
Epoch 3/15: Train loss 0.2303, acc 0.9207 | Val   loss 0.2075, acc 0.9313
Epoch 4/15: Train loss 0.2061, acc 0.9336 | Val   loss 0.1888, acc 0.9364
Epoch 5/15: Train loss 0.1970, acc 0.9330 | Val   loss 0.2004, acc 0.9303
Epoch 6/15: Train loss 0.1905, acc 0.9332 | Val   loss 0.1761, acc 0.9374
Epoch 7/15: Train loss 0.1914, acc 0.9295 | Val   loss 0.1753, acc 0.9414
Epoch 8/15: Train loss 0.1787, acc 0.9367 | Val   loss 0.1553, acc 0.9384
Epoch 9/15: Train loss 0.1801, acc 0.9369 | Val   loss 0.1978, acc 0.9263
Epoch 10/15: Train loss 0.1733, acc 0.9382 | Val   loss 0.1935, acc 0.9293
Epoch 11/15: Train loss 0.1690, acc 0.9390 | Val   loss 0.1827, acc 0.9313
Epoch 12/15: Train loss 0.1540, acc 0.9470 | Val   loss 0.1454, acc 0.9485
Epoch 13/15: Train loss 0.1692, acc 0.9406 | Val   loss 0.1633, acc 0.9404
Epoch 14/15: Train loss 0.1750, ac

In [None]:
save_path = '/content/drive/MyDrive/ViT-15epoch.pth'

torch.save(model.state_dict(), save_path)
print("Model saved successfully to Google Drive!")

Model saved successfully to Google Drive!


In [None]:
import os

save_dir = '/content/drive/MyDrive/models'
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, 'ViT-15epoch.pth')

torch.save(model.state_dict(), save_path)
print(f"Model saved at: {save_path}")

Model saved at: /content/drive/MyDrive/models/ViT-15epoch.pth
