In [1]:
import os
import timm
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from torch import nn, optim
from sklearn.metrics import classification_report
from tqdm import tqdm

# Configs
DATA_DIR = '/kaggle/input/waste-classification/combined_dataset'
BATCH_SIZE = 16
IMAGE_SIZE = 224
EPOCHS = 15
MODEL_NAME = 'vit_base_patch16_224'  
NUM_WORKERS = os.cpu_count()

# Transforms
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Load Dataset
dataset = ImageFolder(DATA_DIR, transform=transform)
class_names = dataset.classes

# Train-val split
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Load Pretrained Model
model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=len(class_names))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training Setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# Training Loop
def train():
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(imgs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Loss: {total_loss / len(train_loader):.4f}")
        evaluate()

# Evaluation
def evaluate():
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    print(classification_report(y_true, y_pred, target_names=class_names))

# Run
train()


ModuleNotFoundError: No module named 'timm'