In [1]:
import kagglehub

# Download latest version
dataset_path = kagglehub.dataset_download("techsash/waste-classification-data")

print("Path to dataset files:", dataset_path)

Using Colab cache for faster access to the 'waste-classification-data' dataset.
Path to dataset files: /kaggle/input/waste-classification-data


In [2]:
import pathlib

TEST_DATA = pathlib.Path(dataset_path) / "DATASET" / "TEST"
TRAIN_DATA = pathlib.Path(dataset_path) / "DATASET" / "TRAIN"

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=TRAIN_DATA, transform=train_transform)
val_dataset = datasets.ImageFolder(root=TEST_DATA, transform=val_transform)

batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

print(len(train_loader), len(val_loader))

706 79


In [4]:
from tqdm import tqdm


def train_one_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    model.to(device)

    total_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        targets = targets.to(device).float().unsqueeze(1)

        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Metrics
        total_loss += loss.item()
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        correct += (predicted == targets).sum().item()
        total += targets.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total

    return {'loss': avg_loss, 'accuracy': accuracy}

In [5]:
def validate_one_epoch(model, dataloader, loss_fn, device):
    model.eval()
    model.to(device)

    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device)
            targets = targets.to(device).float().unsqueeze(1)

            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # Metrics
            total_loss += loss.item()
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            correct += (predicted == targets).sum().item()
            total += targets.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = 100 * correct / total

    return {'loss': avg_loss, 'accuracy': accuracy}

In [6]:
def train_model(model, train_loader, val_loader, optimizer, loss_fn, device, epochs):
    history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }

    best_val_accuracy = 0.0

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")
        print("-" * 50)

        train_metrics = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
        print(f"Train Loss: {train_metrics['loss']:.4f}, Train Accuracy: {train_metrics['accuracy']:.2f}%")

        val_metrics = validate_one_epoch(model, val_loader, loss_fn, device)
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Accuracy: {val_metrics['accuracy']:.2f}%")

        history['train_loss'].append(train_metrics['loss'])
        history['train_accuracy'].append(train_metrics['accuracy'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_accuracy'].append(val_metrics['accuracy'])

        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"✓ Saved best model with validation accuracy: {best_val_accuracy:.2f}%")

    return history

In [7]:
import torch
import torch.nn as nn

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

loss_fn = nn.BCEWithLogitsLoss()

Using device: cuda


In [8]:
from torchvision import models
from torchvision.models import ResNet152_Weights

model = models.resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)

for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, 1)

optimizer_resnet = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Frozen parameters: {total_params - trainable_params:,}")

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /root/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth


100%|██████████| 230M/230M [00:01<00:00, 200MB/s]


  Total parameters: 58,145,857
  Trainable parameters: 2,049
  Frozen parameters: 58,143,808


In [9]:
from torchvision.models import swin_b, Swin_B_Weights

model_swin = swin_b(weights=Swin_B_Weights.IMAGENET1K_V1)

for param in model_swin.parameters():
    param.requires_grad = False

model_swin.head = nn.Linear(model_swin.head.in_features, 1)

optimizer_swin = torch.optim.Adam(model_swin.head.parameters(), lr=1e-3)

trainable_params = sum(p.numel() for p in model_swin.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_swin.parameters())
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Frozen parameters: {total_params - trainable_params:,}")

Downloading: "https://download.pytorch.org/models/swin_b-68c6b09e.pth" to /root/.cache/torch/hub/checkpoints/swin_b-68c6b09e.pth


100%|██████████| 335M/335M [00:02<00:00, 154MB/s]


  Total parameters: 86,744,249
  Trainable parameters: 1,025
  Frozen parameters: 86,743,224


In [10]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

model_vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

for param in model_vit.parameters():
    param.requires_grad = False

model_vit.heads.head = nn.Linear(model_vit.heads.head.in_features, 1)

optimizer_vit = torch.optim.Adam(model_vit.heads.head.parameters(), lr=1e-3)

trainable_params = sum(p.numel() for p in model_vit.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_vit.parameters())
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Frozen parameters: {total_params - trainable_params:,}")

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:01<00:00, 188MB/s]


  Total parameters: 85,799,425
  Trainable parameters: 769
  Frozen parameters: 85,798,656


In [11]:
model_dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

for param in model_dinov2.parameters():
    param.requires_grad = False

# Add classification head
model_dinov2.head = nn.Linear(768, 1)  # DINOv2 ViT-B/14 has 768 dimensional embeddings

optimizer_dinov2 = torch.optim.Adam(model_dinov2.head.parameters(), lr=1e-3)

trainable_params = sum(p.numel() for p in model_dinov2.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model_dinov2.parameters())
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Frozen parameters: {total_params - trainable_params:,}")

Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip




Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vitb14_pretrain.pth


100%|██████████| 330M/330M [00:01<00:00, 262MB/s]


  Total parameters: 86,581,249
  Trainable parameters: 769
  Frozen parameters: 86,580,480


In [12]:
print("=" * 60)
print("Training ResNet152")
print("=" * 60)
history_resnet = train_model(model, train_loader, val_loader, optimizer_resnet, loss_fn, device, epochs=10)

Training ResNet152

Epoch 1/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.22it/s]


Train Loss: 0.2239, Train Accuracy: 91.54%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.34it/s]


Val Loss: 0.1907, Val Accuracy: 92.68%
✓ Saved best model with validation accuracy: 92.68%

Epoch 2/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.22it/s]


Train Loss: 0.1751, Train Accuracy: 93.58%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.37it/s]


Val Loss: 0.1748, Val Accuracy: 93.00%
✓ Saved best model with validation accuracy: 93.00%

Epoch 3/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.21it/s]


Train Loss: 0.1709, Train Accuracy: 93.68%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.39it/s]


Val Loss: 0.1373, Val Accuracy: 94.87%
✓ Saved best model with validation accuracy: 94.87%

Epoch 4/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.20it/s]


Train Loss: 0.1624, Train Accuracy: 94.14%


Validation: 100%|██████████| 79/79 [00:17<00:00,  4.39it/s]


Val Loss: 0.1791, Val Accuracy: 92.88%

Epoch 5/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.20it/s]


Train Loss: 0.1634, Train Accuracy: 93.94%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.32it/s]


Val Loss: 0.2162, Val Accuracy: 91.21%

Epoch 6/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:48<00:00,  4.20it/s]


Train Loss: 0.1574, Train Accuracy: 94.25%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.36it/s]


Val Loss: 0.1558, Val Accuracy: 93.91%

Epoch 7/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:48<00:00,  4.19it/s]


Train Loss: 0.1589, Train Accuracy: 94.11%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.38it/s]


Val Loss: 0.1579, Val Accuracy: 93.91%

Epoch 8/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.20it/s]


Train Loss: 0.1550, Train Accuracy: 94.23%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.33it/s]


Val Loss: 0.1330, Val Accuracy: 94.95%
✓ Saved best model with validation accuracy: 94.95%

Epoch 9/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:48<00:00,  4.20it/s]


Train Loss: 0.1539, Train Accuracy: 94.29%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.29it/s]


Val Loss: 0.1742, Val Accuracy: 93.04%

Epoch 10/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [02:47<00:00,  4.20it/s]


Train Loss: 0.1477, Train Accuracy: 94.59%


Validation: 100%|██████████| 79/79 [00:18<00:00,  4.33it/s]

Val Loss: 0.1829, Val Accuracy: 92.76%





In [13]:
print("\n" + "=" * 60)
print("Training Vit-B/16")
print("=" * 60)
history_vit = train_model(model_vit, train_loader, val_loader, optimizer_vit, loss_fn, device, epochs=10)


Training Vit-B/16

Epoch 1/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:48<00:00,  2.45it/s]


Train Loss: 0.1952, Train Accuracy: 93.28%


Validation: 100%|██████████| 79/79 [00:27<00:00,  2.91it/s]


Val Loss: 0.2478, Val Accuracy: 89.14%
✓ Saved best model with validation accuracy: 89.14%

Epoch 2/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1347, Train Accuracy: 95.26%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]


Val Loss: 0.1972, Val Accuracy: 91.48%
✓ Saved best model with validation accuracy: 91.48%

Epoch 3/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1254, Train Accuracy: 95.59%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.95it/s]


Val Loss: 0.1896, Val Accuracy: 91.88%
✓ Saved best model with validation accuracy: 91.88%

Epoch 4/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1223, Train Accuracy: 95.71%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]


Val Loss: 0.1963, Val Accuracy: 91.21%

Epoch 5/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1181, Train Accuracy: 95.79%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]


Val Loss: 0.1927, Val Accuracy: 91.52%

Epoch 6/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:46<00:00,  2.46it/s]


Train Loss: 0.1165, Train Accuracy: 95.93%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.93it/s]


Val Loss: 0.1902, Val Accuracy: 91.52%

Epoch 7/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1149, Train Accuracy: 95.98%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]


Val Loss: 0.1864, Val Accuracy: 91.88%

Epoch 8/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1134, Train Accuracy: 95.98%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.93it/s]


Val Loss: 0.1851, Val Accuracy: 91.96%
✓ Saved best model with validation accuracy: 91.96%

Epoch 9/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:46<00:00,  2.46it/s]


Train Loss: 0.1112, Train Accuracy: 96.02%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]


Val Loss: 0.1877, Val Accuracy: 91.76%

Epoch 10/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:47<00:00,  2.46it/s]


Train Loss: 0.1112, Train Accuracy: 96.02%


Validation: 100%|██████████| 79/79 [00:26<00:00,  2.94it/s]

Val Loss: 0.1833, Val Accuracy: 91.88%





In [14]:
print("\n" + "=" * 60)
print("Training Swin-B")
print("=" * 60)
history_swin = train_model(model_swin, train_loader, val_loader, optimizer_swin, loss_fn, device, epochs=10)


Training Swin-B

Epoch 1/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.2076, Train Accuracy: 93.21%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]


Val Loss: 0.2054, Val Accuracy: 90.69%
✓ Saved best model with validation accuracy: 90.69%

Epoch 2/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1475, Train Accuracy: 95.05%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.68it/s]


Val Loss: 0.2016, Val Accuracy: 90.85%
✓ Saved best model with validation accuracy: 90.85%

Epoch 3/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1396, Train Accuracy: 95.12%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Val Loss: 0.1833, Val Accuracy: 92.56%
✓ Saved best model with validation accuracy: 92.56%

Epoch 4/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1357, Train Accuracy: 95.24%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]


Val Loss: 0.2053, Val Accuracy: 91.01%

Epoch 5/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1351, Train Accuracy: 95.23%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]


Val Loss: 0.2035, Val Accuracy: 90.85%

Epoch 6/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1349, Train Accuracy: 95.11%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Val Loss: 0.1966, Val Accuracy: 91.36%

Epoch 7/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1299, Train Accuracy: 95.54%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Val Loss: 0.1847, Val Accuracy: 92.04%

Epoch 8/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1306, Train Accuracy: 95.32%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.70it/s]


Val Loss: 0.1743, Val Accuracy: 92.52%

Epoch 9/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:23<00:00,  2.68it/s]


Train Loss: 0.1302, Train Accuracy: 95.31%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.69it/s]


Val Loss: 0.1822, Val Accuracy: 92.04%

Epoch 10/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [04:22<00:00,  2.68it/s]


Train Loss: 0.1269, Train Accuracy: 95.45%


Validation: 100%|██████████| 79/79 [00:29<00:00,  2.71it/s]

Val Loss: 0.1915, Val Accuracy: 91.76%





In [None]:
print("\n" + "=" * 60)
print("Training DINOv2")
print("=" * 60)
history_dinov2 = train_model(model_dinov2, train_loader, val_loader, optimizer_dinov2, loss_fn, device, epochs=10)


Training DINOv2

Epoch 1/10
--------------------------------------------------


Training: 100%|██████████| 706/706 [05:19<00:00,  2.21it/s]


Train Loss: 0.1320, Train Accuracy: 95.26%


Validation: 100%|██████████| 79/79 [00:35<00:00,  2.20it/s]


Val Loss: 0.2118, Val Accuracy: 91.05%
✓ Saved best model with validation accuracy: 91.05%

Epoch 2/10
--------------------------------------------------


Training:   3%|▎         | 18/706 [00:08<05:08,  2.23it/s]

In [None]:
import pandas as pd

results = pd.DataFrame({
    'Model': ['ResNet152', 'ViT-B/16', 'Swin-B', 'DINOv2'],
    'Final Train Acc': [
        history_resnet['train_accuracy'][-1],
        history_vit['train_accuracy'][-1],
        history_swin['train_accuracy'][-1],
        history_dinov2['train_accuracy'][-1]
    ],
    'Final Val Acc': [
        history_resnet['val_accuracy'][-1],
        history_vit['val_accuracy'][-1],
        history_swin['val_accuracy'][-1],
        history_dinov2['val_accuracy'][-1]
    ],
    'Best Val Acc': [
        max(history_resnet['val_accuracy']),
        max(history_vit['val_accuracy']),
        max(history_swin['val_accuracy']),
        max(history_dinov2['val_accuracy'])
    ]
})

print("\n" + "=" * 60)
print("Results Comparison")
print("=" * 60)
print(results.to_string(index=False))
print("\nBest model:", results.loc[results['Best Val Acc'].idxmax(), 'Model'])

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

models_data = {
    'ResNet152': history_resnet,
    'ViT-B/16': history_vit,
    'Swin-B': history_swin,
    'DINOv2': history_dinov2
}

colors = {'ResNet152': 'blue', 'ViT-B/16': 'green', 'Swin-B': 'red', 'DINOv2': 'orange'}

ax1 = axes[0, 0]
for model_name, history in models_data.items():
    epochs_range = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs_range, history['train_loss'],
             marker='o', label=model_name, color=colors[model_name], linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training Loss', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2 = axes[0, 1]
for model_name, history in models_data.items():
    epochs_range = range(1, len(history['val_loss']) + 1)
    ax2.plot(epochs_range, history['val_loss'],
             marker='o', label=model_name, color=colors[model_name], linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Loss', fontsize=12)
ax2.set_title('Validation Loss', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)

ax3 = axes[1, 0]
for model_name, history in models_data.items():
    epochs_range = range(1, len(history['train_accuracy']) + 1)
    ax3.plot(epochs_range, history['train_accuracy'],
             marker='o', label=model_name, color=colors[model_name], linewidth=2)
ax3.set_xlabel('Epoch', fontsize=12)
ax3.set_ylabel('Accuracy (%)', fontsize=12)
ax3.set_title('Training Accuracy', fontsize=14, fontweight='bold')
ax3.legend()
ax3.grid(True, alpha=0.3)

ax4 = axes[1, 1]
for model_name, history in models_data.items():
    epochs_range = range(1, len(history['val_accuracy']) + 1)
    ax4.plot(epochs_range, history['val_accuracy'],
             marker='o', label=model_name, color=colors[model_name], linewidth=2)
ax4.set_xlabel('Epoch', fontsize=12)
ax4.set_ylabel('Accuracy (%)', fontsize=12)
ax4.set_title('Validation Accuracy', fontsize=14, fontweight='bold')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.suptitle('Training Progress Comparison', y=1.01, fontsize=16, fontweight='bold')
plt.show()