In [1]:
# train_model.py

import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from torchvision import models

# Config
data_dir = 'data/train'
num_epochs = 15
batch_size = 16
learning_rate = 0.001
model_path = 'models/best_model.pth'

# Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Dataset and Dataloader
dataset = ImageFolder(root=data_dir, transform=transform)
class_names = dataset.classes
num_classes = len(class_names)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size)

# Model
model = models.resnet18(weights='IMAGENET1K_V1')
model.fc = nn.Linear(model.fc.in_features, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

best_val_loss = float('inf')
for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    val_loss = 0
    model.eval()
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            val_loss += criterion(outputs, labels).item()

    print(f"Epoch {epoch+1}, Loss: {running_loss:.4f}, Val Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_path)
        print("✅ Model saved!")



Epoch 1, Loss: 34.0803, Val Loss: 22.3365
✅ Model saved!
Epoch 2, Loss: 23.1253, Val Loss: 6.6572
✅ Model saved!
Epoch 3, Loss: 19.3665, Val Loss: 15.2950
Epoch 4, Loss: 15.8863, Val Loss: 7.3244
Epoch 5, Loss: 10.0718, Val Loss: 16.6015
Epoch 6, Loss: 13.7379, Val Loss: 11.8735
Epoch 7, Loss: 8.4297, Val Loss: 9.8502
Epoch 8, Loss: 9.6926, Val Loss: 10.2036
Epoch 9, Loss: 9.3837, Val Loss: 18.7033
Epoch 10, Loss: 11.0052, Val Loss: 8.8330
Epoch 11, Loss: 8.0257, Val Loss: 9.9656
Epoch 12, Loss: 3.6620, Val Loss: 19.9183
Epoch 13, Loss: 6.0231, Val Loss: 15.9902
Epoch 14, Loss: 8.3206, Val Loss: 9.3554
Epoch 15, Loss: 6.3142, Val Loss: 22.2448


In [2]:
pip install torch


Note: you may need to restart the kernel to use updated packages.
