# Install Dependencies

In [4]:
import torch
import timm
import torchvision
import torchvision.transforms as T
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import f1_score
import time
import numpy as np
import random



# Personalised config

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(42)
config = {
    "batch_size": 64,
    "epochs": 100,
    "learning_rate": 5e-4,
    "weight_decay": 0.05,
    "warmup_epochs": 3,
    "drop_path_rate": 0.1,
    "img_size": 224,
    "num_classes": 10,
    "num_workers": 4
}


# Load Data and Model

In [6]:
transform_train = T.Compose([
    T.RandomResizedCrop(config["img_size"]),
    T.RandomHorizontalFlip(),
    T.AutoAugment(policy=T.AutoAugmentPolicy.IMAGENET),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

transform_val = T.Compose([
    T.Resize(256),
    T.CenterCrop(config["img_size"]),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

train_dataset = ImageFolder(root='/kaggle/input/imgwoofdeit/imagewoof2/train', transform=transform_train)
val_dataset = ImageFolder(root='/kaggle/input/imgwoofdeit/imagewoof2/val', transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=config["num_workers"])
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=config["num_workers"])



In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = timm.create_model("deit3_small_patch16_224", pretrained=False, num_classes=10)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

def cosine_schedule_with_warmup(optimizer, warmup_epochs, total_epochs):
    def lr_lambda(current_epoch):
        if current_epoch < warmup_epochs:
            return float(current_epoch) / float(max(1, warmup_epochs))
        return 0.5 * (1. + np.cos(np.pi * (current_epoch - warmup_epochs) / (total_epochs - warmup_epochs)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scheduler = cosine_schedule_with_warmup(optimizer, config["warmup_epochs"], config["epochs"])



# Training Loop

In [8]:
from tqdm import tqdm

def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    return acc

best_acc = 0.0
patience = 5  # Early stopping patience
counter = 0

for epoch in range(config["epochs"]):
    model.train()
    total_loss = 0
    correct, total = 0, 0

    train_loader_tqdm = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config["epochs"]}', leave=False)

    for images, labels in train_loader_tqdm:
        images, labels = images.cuda(), labels.cuda()

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()

    val_acc = evaluate(model, val_loader)

    print(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {correct/total:.4f}, Val Acc: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        counter = 0  #  Reset patience counter on improvement
        print(" Model improved and saved.")
    else:
        counter += 1
        print(f" No improvement. EarlyStopping counter: {counter}/{patience}")
        if counter >= patience:
            print(" Early stopping triggered.")
            break


                                                             

KeyboardInterrupt: 

# Evaluation

In [11]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image
import os

# Transformation for test images
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load labelled test set
test_dataset = datasets.ImageFolder(
    '/kaggle/input/imgwoofdeit/imagewoof2/test',
    transform=test_transform
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


idx_to_class = {v: k for k, v in test_dataset.class_to_idx.items()}


In [12]:
from sklearn.metrics import classification_report, confusion_matrix
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# Load best model
model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
model.eval()

all_preds = []
all_targets = []
correct = total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.cuda(), labels.cuda()
        outputs = model(inputs)
        probs = F.softmax(outputs, dim=1)
        preds = torch.argmax(probs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(labels.cpu().numpy())

        correct += (preds == labels).sum().item()
        total += labels.size(0)

# Top-1 Accuracy
top1_acc = correct / total

print("\n" + "="*60)
print(f"Top-1 Accuracy on Test Set: {top1_acc:.4f}")
print("="*60)

# Classification Report
print("\nClassification Report:")
print(classification_report(all_targets, all_preds, target_names=test_dataset.classes))

# Confusion Matrix
print("\nConfusion Matrix:")
conf_mat = confusion_matrix(all_targets, all_preds)
print(conf_mat)


  model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
Evaluating:  18%|█▊        | 11/62 [00:05<00:23,  2.20it/s]


KeyboardInterrupt: 