# Configurable Training Pipeline with Checkpointing


In [12]:
import torch
print(torch.cuda.is_available())

True


**## Install Required Libraries**


In [13]:
!pip install pyyaml



**## Create Required Directories**


In [14]:
import os

os.makedirs("checkpoints", exist_ok=True)
os.makedirs("logs", exist_ok=True)

**## Create YAML Configuration File**


In [15]:
%%writefile config.yaml
project:
  name: cifar10_training

paths:
  checkpoint_dir: checkpoints
  log_dir: logs
  resume_checkpoint: null

training:
  epochs: 10
  batch_size: 64
  learning_rate: 0.001

model:
  num_classes: 10

device: cuda

Overwriting config.yaml


**## Import Required Modules**


In [16]:
import yaml
import csv
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

**## Load Configuration and Set Device**


In [17]:
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


**## CNN Model Definition**


In [18]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = SimpleCNN(config["model"]["num_classes"]).to(device)

**## Dataset Preparation (CIFAR-10)**


In [19]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_data,
    batch_size=config["training"]["batch_size"],
    shuffle=True
)

**## Optimizer, Loss Function, and Metrics Logging Setup**


In [20]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config["training"]["learning_rate"])

log_file = os.path.join(config["paths"]["log_dir"], "metrics.csv")

if not os.path.exists(log_file):
    with open(log_file, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch", "loss", "accuracy"])


**## Load Checkpoint to Resume Training**


In [21]:
start_epoch = 0

if config["paths"]["resume_checkpoint"]:
    checkpoint = torch.load(config["paths"]["resume_checkpoint"])
    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    start_epoch = checkpoint["epoch"] + 1
    print(f"Resumed from epoch {start_epoch}")

**## Training Loop, Metrics Logging, and Checkpoint Saving**


In [22]:
for epoch in range(start_epoch, config["training"]["epochs"]):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    accuracy = 100 * correct / total
    avg_loss = total_loss / len(train_loader)

    with open(log_file, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([epoch, avg_loss, accuracy])

    checkpoint_path = f"checkpoints/model_epoch_{epoch}.pth"
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict()
    }, checkpoint_path)

    print(f"Epoch {epoch}: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%")

Epoch 0: Loss=1.2811, Accuracy=54.01%
Epoch 1: Loss=0.9156, Accuracy=67.53%
Epoch 2: Loss=0.7447, Accuracy=73.77%
Epoch 3: Loss=0.6077, Accuracy=78.65%
Epoch 4: Loss=0.4789, Accuracy=83.21%
Epoch 5: Loss=0.3614, Accuracy=87.40%
Epoch 6: Loss=0.2553, Accuracy=91.08%
Epoch 7: Loss=0.1703, Accuracy=94.17%
Epoch 8: Loss=0.1301, Accuracy=95.59%
Epoch 9: Loss=0.0984, Accuracy=96.65%
