In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import DeiTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from google.colab import drive
from transformers import BeitFeatureExtractor, BeitForImageClassification

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 2
NUM_WORKERS = 2
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
drive.mount('/content/drive')
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/val/',   transform=val_transforms)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

model = BeitForImageClassification.from_pretrained(
    "microsoft/beit-base-patch16-224-pt22k-ft22k",
    num_labels=9,
    ignore_mismatched_sizes=True
)

model.beit.requires_grad_(False)

model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, 9)
)


Batch shape: torch.Size([2, 3, 224, 224])
Labels shape: torch.Size([2])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-base-patch16-224-pt22k-ft22k and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([21841, 768]) in the checkpoint and torch.Size([9, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([21841]) in the checkpoint and torch.Size([9]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        logits = outputs.logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

    return running_loss / total, correct / total

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            logits = outputs.logits

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    return val_loss / total, correct / total

In [None]:
num_epochs = 15
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "BEiT-15epoch.pth")

Epoch 1/15: Train loss 0.3556, acc 0.8833 | Val   loss 0.1505, acc 0.9515
Epoch 2/15: Train loss 0.1890, acc 0.9345 | Val   loss 0.2175, acc 0.9323
Epoch 3/15: Train loss 0.1717, acc 0.9423 | Val   loss 0.1411, acc 0.9556
Epoch 4/15: Train loss 0.1727, acc 0.9399 | Val   loss 0.1098, acc 0.9687
Epoch 5/15: Train loss 0.1766, acc 0.9438 | Val   loss 0.1291, acc 0.9616
Epoch 6/15: Train loss 0.1587, acc 0.9479 | Val   loss 0.0976, acc 0.9707
Epoch 7/15: Train loss 0.1379, acc 0.9535 | Val   loss 0.1095, acc 0.9707
Epoch 8/15: Train loss 0.1518, acc 0.9498 | Val   loss 0.1699, acc 0.9505
Epoch 9/15: Train loss 0.1701, acc 0.9449 | Val   loss 0.1455, acc 0.9606
Epoch 10/15: Train loss 0.1705, acc 0.9492 | Val   loss 0.1655, acc 0.9596
Epoch 11/15: Train loss 0.1893, acc 0.9466 | Val   loss 0.1072, acc 0.9768
Epoch 12/15: Train loss 0.1714, acc 0.9498 | Val   loss 0.1408, acc 0.9758
Epoch 13/15: Train loss 0.1487, acc 0.9531 | Val   loss 0.1101, acc 0.9697
Epoch 14/15: Train loss 0.1569, ac

In [None]:
save_path = '/content/drive/MyDrive/BEiT-15epoch.pth'

torch.save(model.state_dict(), save_path)
print("Model saved successfully to Google Drive!")

Model saved successfully to Google Drive!


In [None]:
import os

save_dir = '/content/drive/MyDrive/models'
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, 'BEiT-15epoch.pth')

torch.save(model.state_dict(), save_path)
print(f"Model saved at: {save_path}")

Model saved at: /content/drive/MyDrive/models/BEiT-15epoch.pth
