In [1]:
# Import necessary libraries
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchinfo import summary
from itertools import product
import os

from helper_functions import set_seeds

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seeds for reproducibility
def set_seeds(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# Load pretrained ViT model
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# Freeze base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Define dataset paths
train_dir = r'C:\Users\satya\Desktop\ransom\cnn2\MORE\RGBprocessed_images\train'
test_dir = r'C:\Users\satya\Desktop\ransom\cnn2\MORE\RGBprocessed_images\test'
val_dir = r'C:\Users\satya\Desktop\ransom\cnn2\MORE\RGBprocessed_images\val'



# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()

# Function to create dataloaders
def create_dataloaders(train_dir, val_dir, test_dir, transform, batch_size, num_workers=os.cpu_count()):
    # Create datasets
    train_data = datasets.ImageFolder(train_dir, transform=transform)
    val_data = datasets.ImageFolder(val_dir, transform=transform)  # Validation dataset
    test_data = datasets.ImageFolder(test_dir, transform=transform)

    # Get class names
    class_names = train_data.classes

    # Create dataloaders
    train_dataloader = DataLoader(
        train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    val_dataloader = DataLoader(
        val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )
    test_dataloader = DataLoader(
        test_data, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_dataloader, val_dataloader, test_dataloader, class_names

# Define hyperparameter grid
hyperparam_grid = {
    'learning_rate': [1e-3, 1e-4],  # Learning rates to test
    'batch_size': [32, 64],         # Batch sizes to test
    'optimizer': ['Adam', 'SGD'],   # Optimizers to test
    'weight_decay': [0, 1e-4]       # Weight decay (regularization)
}

# Generate all combinations of hyperparameters
grid_combinations = [dict(zip(hyperparam_grid.keys(), values)) 
                     for values in product(*hyperparam_grid.values())]

# Track best model and metrics
best_metrics = {
    'accuracy': 0,
    'hyperparams': None,
    'model_state': None
}

# Grid search loop
for params in grid_combinations:
    print(f"\n\033[1mTesting hyperparameters: {params}\033[0m")
    
    # Set seeds for reproducibility
    set_seeds()
    
    # Create dataloaders with current batch size
    train_dataloader, val_dataloader, test_dataloader, class_names = create_dataloaders(
        train_dir=train_dir,
        val_dir=val_dir,
        test_dir=test_dir,
        transform=pretrained_vit_transforms,
        batch_size=params['batch_size']
    )
    
    # Initialize fresh model for each combination
    model = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
    
    # Configure optimizer
    optimizer_config = {
        'params': model.parameters(),
        'lr': params['learning_rate'],
        'weight_decay': params['weight_decay']
    }
    if params['optimizer'] == 'Adam':
        optimizer = torch.optim.Adam(**optimizer_config)
    elif params['optimizer'] == 'SGD':
        optimizer = torch.optim.SGD(**optimizer_config, momentum=0.9)
    
    # Train model with validation
    from going_modular.going_modular import engine
    print(f"Training with {params['optimizer']} optimizer...")
    results = engine.train(
        model=model,
        train_dataloader=train_dataloader,
        test_dataloader=val_dataloader,  # Validate on validation set
        optimizer=optimizer,
        loss_fn=nn.CrossEntropyLoss(),
        epochs=10,
        device=device
    )
    
    # Track best model using validation accuracy
    current_val_acc = max(results['test_acc'])
    if current_val_acc > best_metrics['accuracy']:
        best_metrics['accuracy'] = current_val_acc
        best_metrics['hyperparams'] = params
        best_metrics['model_state'] = model.state_dict()
        print(f"New best validation accuracy: {current_val_acc:.4f}")

# Final evaluation with best model
print("\n\033[1mBest hyperparameters:", best_metrics['hyperparams'])
print(f"Best validation accuracy: {best_metrics['accuracy']:.4f}\033[0m")

# Load best model for test evaluation
best_model = torchvision.models.vit_b_16().to(device)
best_model.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
best_model.load_state_dict(best_metrics['model_state'])

# Evaluate on test set
test_results = engine.train(
    model=best_model,
    train_dataloader=train_dataloader,  # Not used but required by function
    test_dataloader=test_dataloader,
    optimizer=optimizer,  # Not used in evaluation
    loss_fn=nn.CrossEntropyLoss(),
    epochs=1,  # Only evaluation
    device=device
)
print(f"\n\033[1mFinal test accuracy: {max(test_results['test_acc']):.4f}\033[0m")

Using device: cuda

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 32, 'optimizer': 'Adam', 'weight_decay': 0}[0m
Training with Adam optimizer...


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.2649 | train_acc: 0.9102 | test_loss: 0.1271 | test_acc: 0.9630
Epoch: 2 | train_loss: 0.1277 | train_acc: 0.9631 | test_loss: 0.0878 | test_acc: 0.9803
Epoch: 3 | train_loss: 0.0957 | train_acc: 0.9760 | test_loss: 0.0708 | test_acc: 0.9873
Epoch: 4 | train_loss: 0.0793 | train_acc: 0.9791 | test_loss: 0.0556 | test_acc: 0.9942
Epoch: 5 | train_loss: 0.0705 | train_acc: 0.9817 | test_loss: 0.0485 | test_acc: 0.9954
Epoch: 6 | train_loss: 0.0622 | train_acc: 0.9844 | test_loss: 0.0393 | test_acc: 0.9942
Epoch: 7 | train_loss: 0.0567 | train_acc: 0.9864 | test_loss: 0.0382 | test_acc: 0.9954
Epoch: 8 | train_loss: 0.0539 | train_acc: 0.9864 | test_loss: 0.0332 | test_acc: 0.9954
Epoch: 9 | train_loss: 0.0525 | train_acc: 0.9867 | test_loss: 0.0338 | test_acc: 0.9954
Epoch: 10 | train_loss: 0.0464 | train_acc: 0.9887 | test_loss: 0.0347 | test_acc: 0.9942
New best validation accuracy: 0.9954

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 32,

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.2649 | train_acc: 0.9102 | test_loss: 0.1271 | test_acc: 0.9630
Epoch: 2 | train_loss: 0.1277 | train_acc: 0.9631 | test_loss: 0.0878 | test_acc: 0.9803
Epoch: 3 | train_loss: 0.0957 | train_acc: 0.9760 | test_loss: 0.0708 | test_acc: 0.9873
Epoch: 4 | train_loss: 0.0793 | train_acc: 0.9791 | test_loss: 0.0557 | test_acc: 0.9942
Epoch: 5 | train_loss: 0.0706 | train_acc: 0.9817 | test_loss: 0.0486 | test_acc: 0.9954
Epoch: 6 | train_loss: 0.0623 | train_acc: 0.9844 | test_loss: 0.0393 | test_acc: 0.9942
Epoch: 7 | train_loss: 0.0568 | train_acc: 0.9864 | test_loss: 0.0383 | test_acc: 0.9954
Epoch: 8 | train_loss: 0.0539 | train_acc: 0.9864 | test_loss: 0.0333 | test_acc: 0.9954
Epoch: 9 | train_loss: 0.0526 | train_acc: 0.9867 | test_loss: 0.0340 | test_acc: 0.9954
Epoch: 10 | train_loss: 0.0465 | train_acc: 0.9887 | test_loss: 0.0348 | test_acc: 0.9942

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 32, 'optimizer': 'SGD', 'weight_decay': 

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.3571 | train_acc: 0.8590 | test_loss: 0.1883 | test_acc: 0.9398
Epoch: 2 | train_loss: 0.1884 | train_acc: 0.9378 | test_loss: 0.1428 | test_acc: 0.9630
Epoch: 3 | train_loss: 0.1535 | train_acc: 0.9555 | test_loss: 0.1213 | test_acc: 0.9769
Epoch: 4 | train_loss: 0.1338 | train_acc: 0.9631 | test_loss: 0.1058 | test_acc: 0.9803
Epoch: 5 | train_loss: 0.1216 | train_acc: 0.9693 | test_loss: 0.1000 | test_acc: 0.9780
Epoch: 6 | train_loss: 0.1113 | train_acc: 0.9731 | test_loss: 0.0896 | test_acc: 0.9838
Epoch: 7 | train_loss: 0.1028 | train_acc: 0.9761 | test_loss: 0.0842 | test_acc: 0.9803
Epoch: 8 | train_loss: 0.1003 | train_acc: 0.9735 | test_loss: 0.0789 | test_acc: 0.9850
Epoch: 9 | train_loss: 0.0955 | train_acc: 0.9753 | test_loss: 0.0732 | test_acc: 0.9850
Epoch: 10 | train_loss: 0.0895 | train_acc: 0.9791 | test_loss: 0.0748 | test_acc: 0.9884

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 32, 'optimizer': 'SGD', 'weight_decay': 

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.3571 | train_acc: 0.8590 | test_loss: 0.1883 | test_acc: 0.9398
Epoch: 2 | train_loss: 0.1884 | train_acc: 0.9378 | test_loss: 0.1428 | test_acc: 0.9630
Epoch: 3 | train_loss: 0.1535 | train_acc: 0.9555 | test_loss: 0.1214 | test_acc: 0.9769
Epoch: 4 | train_loss: 0.1338 | train_acc: 0.9631 | test_loss: 0.1058 | test_acc: 0.9803
Epoch: 5 | train_loss: 0.1216 | train_acc: 0.9693 | test_loss: 0.1001 | test_acc: 0.9780
Epoch: 6 | train_loss: 0.1113 | train_acc: 0.9731 | test_loss: 0.0896 | test_acc: 0.9838
Epoch: 7 | train_loss: 0.1029 | train_acc: 0.9761 | test_loss: 0.0843 | test_acc: 0.9803
Epoch: 8 | train_loss: 0.1003 | train_acc: 0.9735 | test_loss: 0.0789 | test_acc: 0.9850
Epoch: 9 | train_loss: 0.0955 | train_acc: 0.9753 | test_loss: 0.0733 | test_acc: 0.9850
Epoch: 10 | train_loss: 0.0895 | train_acc: 0.9791 | test_loss: 0.0748 | test_acc: 0.9884

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 64, 'optimizer': 'Adam', 'weight_decay':

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.3414 | train_acc: 0.8758 | test_loss: 0.1697 | test_acc: 0.9386
Epoch: 2 | train_loss: 0.1634 | train_acc: 0.9496 | test_loss: 0.1158 | test_acc: 0.9688
Epoch: 3 | train_loss: 0.1229 | train_acc: 0.9735 | test_loss: 0.0907 | test_acc: 0.9844
Epoch: 4 | train_loss: 0.1024 | train_acc: 0.9743 | test_loss: 0.0768 | test_acc: 0.9855
Epoch: 5 | train_loss: 0.0894 | train_acc: 0.9772 | test_loss: 0.0671 | test_acc: 0.9900
Epoch: 6 | train_loss: 0.0837 | train_acc: 0.9791 | test_loss: 0.0564 | test_acc: 0.9911
Epoch: 7 | train_loss: 0.0735 | train_acc: 0.9837 | test_loss: 0.0502 | test_acc: 0.9922
Epoch: 8 | train_loss: 0.0669 | train_acc: 0.9854 | test_loss: 0.0470 | test_acc: 0.9911
Epoch: 9 | train_loss: 0.0633 | train_acc: 0.9854 | test_loss: 0.0416 | test_acc: 0.9955
Epoch: 10 | train_loss: 0.0609 | train_acc: 0.9850 | test_loss: 0.0456 | test_acc: 0.9944
New best validation accuracy: 0.9955

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 64,

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.3414 | train_acc: 0.8758 | test_loss: 0.1697 | test_acc: 0.9386
Epoch: 2 | train_loss: 0.1634 | train_acc: 0.9496 | test_loss: 0.1158 | test_acc: 0.9688
Epoch: 3 | train_loss: 0.1229 | train_acc: 0.9735 | test_loss: 0.0907 | test_acc: 0.9844
Epoch: 4 | train_loss: 0.1024 | train_acc: 0.9743 | test_loss: 0.0768 | test_acc: 0.9855
Epoch: 5 | train_loss: 0.0894 | train_acc: 0.9772 | test_loss: 0.0671 | test_acc: 0.9900
Epoch: 6 | train_loss: 0.0838 | train_acc: 0.9791 | test_loss: 0.0564 | test_acc: 0.9911
Epoch: 7 | train_loss: 0.0735 | train_acc: 0.9837 | test_loss: 0.0502 | test_acc: 0.9922
Epoch: 8 | train_loss: 0.0670 | train_acc: 0.9854 | test_loss: 0.0471 | test_acc: 0.9911
Epoch: 9 | train_loss: 0.0634 | train_acc: 0.9854 | test_loss: 0.0417 | test_acc: 0.9955
Epoch: 10 | train_loss: 0.0610 | train_acc: 0.9850 | test_loss: 0.0457 | test_acc: 0.9944

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 64, 'optimizer': 'SGD', 'weight_decay': 

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.4694 | train_acc: 0.8052 | test_loss: 0.2602 | test_acc: 0.9252
Epoch: 2 | train_loss: 0.2419 | train_acc: 0.9190 | test_loss: 0.1859 | test_acc: 0.9431
Epoch: 3 | train_loss: 0.1971 | train_acc: 0.9372 | test_loss: 0.1576 | test_acc: 0.9487
Epoch: 4 | train_loss: 0.1748 | train_acc: 0.9449 | test_loss: 0.1391 | test_acc: 0.9643
Epoch: 5 | train_loss: 0.1592 | train_acc: 0.9512 | test_loss: 0.1268 | test_acc: 0.9654
Epoch: 6 | train_loss: 0.1480 | train_acc: 0.9564 | test_loss: 0.1192 | test_acc: 0.9688
Epoch: 7 | train_loss: 0.1367 | train_acc: 0.9631 | test_loss: 0.1102 | test_acc: 0.9810
Epoch: 8 | train_loss: 0.1298 | train_acc: 0.9716 | test_loss: 0.1039 | test_acc: 0.9833
Epoch: 9 | train_loss: 0.1230 | train_acc: 0.9738 | test_loss: 0.0994 | test_acc: 0.9844
Epoch: 10 | train_loss: 0.1179 | train_acc: 0.9742 | test_loss: 0.0955 | test_acc: 0.9799

[1mTesting hyperparameters: {'learning_rate': 0.001, 'batch_size': 64, 'optimizer': 'SGD', 'weight_decay': 

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.4694 | train_acc: 0.8052 | test_loss: 0.2602 | test_acc: 0.9252
Epoch: 2 | train_loss: 0.2419 | train_acc: 0.9190 | test_loss: 0.1859 | test_acc: 0.9431
Epoch: 3 | train_loss: 0.1972 | train_acc: 0.9372 | test_loss: 0.1576 | test_acc: 0.9487
Epoch: 4 | train_loss: 0.1748 | train_acc: 0.9449 | test_loss: 0.1391 | test_acc: 0.9643
Epoch: 5 | train_loss: 0.1592 | train_acc: 0.9512 | test_loss: 0.1268 | test_acc: 0.9654
Epoch: 6 | train_loss: 0.1480 | train_acc: 0.9564 | test_loss: 0.1192 | test_acc: 0.9688
Epoch: 7 | train_loss: 0.1367 | train_acc: 0.9631 | test_loss: 0.1102 | test_acc: 0.9810
Epoch: 8 | train_loss: 0.1298 | train_acc: 0.9716 | test_loss: 0.1039 | test_acc: 0.9833
Epoch: 9 | train_loss: 0.1230 | train_acc: 0.9738 | test_loss: 0.0994 | test_acc: 0.9844
Epoch: 10 | train_loss: 0.1179 | train_acc: 0.9742 | test_loss: 0.0955 | test_acc: 0.9799

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 32, 'optimizer': 'Adam', 'weight_decay'

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6001 | train_acc: 0.7683 | test_loss: 0.4659 | test_acc: 0.9039
Epoch: 2 | train_loss: 0.4086 | train_acc: 0.9082 | test_loss: 0.3378 | test_acc: 0.9236
Epoch: 3 | train_loss: 0.3181 | train_acc: 0.9172 | test_loss: 0.2695 | test_acc: 0.9282
Epoch: 4 | train_loss: 0.2667 | train_acc: 0.9234 | test_loss: 0.2278 | test_acc: 0.9352
Epoch: 5 | train_loss: 0.2333 | train_acc: 0.9287 | test_loss: 0.1995 | test_acc: 0.9433
Epoch: 6 | train_loss: 0.2086 | train_acc: 0.9385 | test_loss: 0.1799 | test_acc: 0.9630
Epoch: 7 | train_loss: 0.1895 | train_acc: 0.9432 | test_loss: 0.1631 | test_acc: 0.9641
Epoch: 8 | train_loss: 0.1788 | train_acc: 0.9492 | test_loss: 0.1493 | test_acc: 0.9641
Epoch: 9 | train_loss: 0.1648 | train_acc: 0.9553 | test_loss: 0.1390 | test_acc: 0.9757
Epoch: 10 | train_loss: 0.1535 | train_acc: 0.9618 | test_loss: 0.1311 | test_acc: 0.9699

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 32, 'optimizer': 'Adam', 'weight_decay'

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6001 | train_acc: 0.7683 | test_loss: 0.4659 | test_acc: 0.9039
Epoch: 2 | train_loss: 0.4086 | train_acc: 0.9082 | test_loss: 0.3378 | test_acc: 0.9236
Epoch: 3 | train_loss: 0.3181 | train_acc: 0.9172 | test_loss: 0.2695 | test_acc: 0.9282
Epoch: 4 | train_loss: 0.2667 | train_acc: 0.9234 | test_loss: 0.2278 | test_acc: 0.9352
Epoch: 5 | train_loss: 0.2334 | train_acc: 0.9287 | test_loss: 0.1995 | test_acc: 0.9433
Epoch: 6 | train_loss: 0.2086 | train_acc: 0.9385 | test_loss: 0.1799 | test_acc: 0.9630
Epoch: 7 | train_loss: 0.1895 | train_acc: 0.9432 | test_loss: 0.1631 | test_acc: 0.9641
Epoch: 8 | train_loss: 0.1788 | train_acc: 0.9492 | test_loss: 0.1493 | test_acc: 0.9641
Epoch: 9 | train_loss: 0.1648 | train_acc: 0.9553 | test_loss: 0.1390 | test_acc: 0.9757
Epoch: 10 | train_loss: 0.1535 | train_acc: 0.9618 | test_loss: 0.1312 | test_acc: 0.9699

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 32, 'optimizer': 'SGD', 'weight_decay':

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6437 | train_acc: 0.6905 | test_loss: 0.5201 | test_acc: 0.8912
Epoch: 2 | train_loss: 0.4635 | train_acc: 0.9003 | test_loss: 0.3974 | test_acc: 0.9213
Epoch: 3 | train_loss: 0.3770 | train_acc: 0.9086 | test_loss: 0.3321 | test_acc: 0.9236
Epoch: 4 | train_loss: 0.3278 | train_acc: 0.9125 | test_loss: 0.2916 | test_acc: 0.9259
Epoch: 5 | train_loss: 0.2959 | train_acc: 0.9138 | test_loss: 0.2639 | test_acc: 0.9248
Epoch: 6 | train_loss: 0.2722 | train_acc: 0.9202 | test_loss: 0.2445 | test_acc: 0.9363
Epoch: 7 | train_loss: 0.2543 | train_acc: 0.9225 | test_loss: 0.2282 | test_acc: 0.9387
Epoch: 8 | train_loss: 0.2445 | train_acc: 0.9220 | test_loss: 0.2145 | test_acc: 0.9387
Epoch: 9 | train_loss: 0.2304 | train_acc: 0.9248 | test_loss: 0.2041 | test_acc: 0.9375
Epoch: 10 | train_loss: 0.2197 | train_acc: 0.9302 | test_loss: 0.1955 | test_acc: 0.9398

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 32, 'optimizer': 'SGD', 'weight_decay':

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6437 | train_acc: 0.6905 | test_loss: 0.5201 | test_acc: 0.8912
Epoch: 2 | train_loss: 0.4635 | train_acc: 0.9003 | test_loss: 0.3974 | test_acc: 0.9213
Epoch: 3 | train_loss: 0.3770 | train_acc: 0.9086 | test_loss: 0.3321 | test_acc: 0.9236
Epoch: 4 | train_loss: 0.3278 | train_acc: 0.9125 | test_loss: 0.2916 | test_acc: 0.9259
Epoch: 5 | train_loss: 0.2959 | train_acc: 0.9138 | test_loss: 0.2639 | test_acc: 0.9248
Epoch: 6 | train_loss: 0.2722 | train_acc: 0.9202 | test_loss: 0.2445 | test_acc: 0.9363
Epoch: 7 | train_loss: 0.2543 | train_acc: 0.9225 | test_loss: 0.2282 | test_acc: 0.9387
Epoch: 8 | train_loss: 0.2445 | train_acc: 0.9220 | test_loss: 0.2145 | test_acc: 0.9387
Epoch: 9 | train_loss: 0.2304 | train_acc: 0.9248 | test_loss: 0.2041 | test_acc: 0.9375
Epoch: 10 | train_loss: 0.2197 | train_acc: 0.9302 | test_loss: 0.1955 | test_acc: 0.9398

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 64, 'optimizer': 'Adam', 'weight_decay'

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6626 | train_acc: 0.6630 | test_loss: 0.5632 | test_acc: 0.8873
Epoch: 2 | train_loss: 0.5122 | train_acc: 0.8907 | test_loss: 0.4433 | test_acc: 0.9230
Epoch: 3 | train_loss: 0.4191 | train_acc: 0.9082 | test_loss: 0.3653 | test_acc: 0.9263
Epoch: 4 | train_loss: 0.3578 | train_acc: 0.9109 | test_loss: 0.3129 | test_acc: 0.9286
Epoch: 5 | train_loss: 0.3152 | train_acc: 0.9172 | test_loss: 0.2761 | test_acc: 0.9308
Epoch: 6 | train_loss: 0.2838 | train_acc: 0.9214 | test_loss: 0.2478 | test_acc: 0.9408
Epoch: 7 | train_loss: 0.2585 | train_acc: 0.9250 | test_loss: 0.2267 | test_acc: 0.9531
Epoch: 8 | train_loss: 0.2401 | train_acc: 0.9294 | test_loss: 0.2077 | test_acc: 0.9453
Epoch: 9 | train_loss: 0.2231 | train_acc: 0.9368 | test_loss: 0.1939 | test_acc: 0.9453
Epoch: 10 | train_loss: 0.2100 | train_acc: 0.9352 | test_loss: 0.1812 | test_acc: 0.9621

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 64, 'optimizer': 'Adam', 'weight_decay'

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.6626 | train_acc: 0.6630 | test_loss: 0.5632 | test_acc: 0.8873
Epoch: 2 | train_loss: 0.5122 | train_acc: 0.8907 | test_loss: 0.4433 | test_acc: 0.9230
Epoch: 3 | train_loss: 0.4191 | train_acc: 0.9082 | test_loss: 0.3653 | test_acc: 0.9263
Epoch: 4 | train_loss: 0.3578 | train_acc: 0.9109 | test_loss: 0.3129 | test_acc: 0.9286
Epoch: 5 | train_loss: 0.3152 | train_acc: 0.9172 | test_loss: 0.2761 | test_acc: 0.9308
Epoch: 6 | train_loss: 0.2838 | train_acc: 0.9214 | test_loss: 0.2478 | test_acc: 0.9408
Epoch: 7 | train_loss: 0.2585 | train_acc: 0.9250 | test_loss: 0.2267 | test_acc: 0.9531
Epoch: 8 | train_loss: 0.2401 | train_acc: 0.9294 | test_loss: 0.2077 | test_acc: 0.9453
Epoch: 9 | train_loss: 0.2231 | train_acc: 0.9368 | test_loss: 0.1939 | test_acc: 0.9453
Epoch: 10 | train_loss: 0.2100 | train_acc: 0.9352 | test_loss: 0.1812 | test_acc: 0.9621

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 64, 'optimizer': 'SGD', 'weight_decay':

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.7081 | train_acc: 0.4707 | test_loss: 0.6266 | test_acc: 0.6004
Epoch: 2 | train_loss: 0.5784 | train_acc: 0.8524 | test_loss: 0.5167 | test_acc: 0.8996
Epoch: 3 | train_loss: 0.4920 | train_acc: 0.8930 | test_loss: 0.4436 | test_acc: 0.9029
Epoch: 4 | train_loss: 0.4340 | train_acc: 0.9022 | test_loss: 0.3935 | test_acc: 0.9241
Epoch: 5 | train_loss: 0.3928 | train_acc: 0.9095 | test_loss: 0.3567 | test_acc: 0.9252
Epoch: 6 | train_loss: 0.3619 | train_acc: 0.9104 | test_loss: 0.3279 | test_acc: 0.9275
Epoch: 7 | train_loss: 0.3369 | train_acc: 0.9133 | test_loss: 0.3061 | test_acc: 0.9286
Epoch: 8 | train_loss: 0.3188 | train_acc: 0.9116 | test_loss: 0.2874 | test_acc: 0.9286
Epoch: 9 | train_loss: 0.3020 | train_acc: 0.9142 | test_loss: 0.2732 | test_acc: 0.9286
Epoch: 10 | train_loss: 0.2888 | train_acc: 0.9173 | test_loss: 0.2598 | test_acc: 0.9342

[1mTesting hyperparameters: {'learning_rate': 0.0001, 'batch_size': 64, 'optimizer': 'SGD', 'weight_decay':

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.7081 | train_acc: 0.4707 | test_loss: 0.6266 | test_acc: 0.6004
Epoch: 2 | train_loss: 0.5784 | train_acc: 0.8524 | test_loss: 0.5167 | test_acc: 0.8996
Epoch: 3 | train_loss: 0.4920 | train_acc: 0.8930 | test_loss: 0.4436 | test_acc: 0.9029
Epoch: 4 | train_loss: 0.4340 | train_acc: 0.9022 | test_loss: 0.3935 | test_acc: 0.9241
Epoch: 5 | train_loss: 0.3928 | train_acc: 0.9095 | test_loss: 0.3567 | test_acc: 0.9252
Epoch: 6 | train_loss: 0.3619 | train_acc: 0.9104 | test_loss: 0.3279 | test_acc: 0.9275
Epoch: 7 | train_loss: 0.3369 | train_acc: 0.9133 | test_loss: 0.3061 | test_acc: 0.9286
Epoch: 8 | train_loss: 0.3188 | train_acc: 0.9116 | test_loss: 0.2874 | test_acc: 0.9286
Epoch: 9 | train_loss: 0.3021 | train_acc: 0.9142 | test_loss: 0.2732 | test_acc: 0.9286
Epoch: 10 | train_loss: 0.2888 | train_acc: 0.9173 | test_loss: 0.2598 | test_acc: 0.9342

[1mBest hyperparameters: {'learning_rate': 0.001, 'batch_size': 64, 'optimizer': 'Adam', 'weight_decay': 0}

  0%|          | 0/1 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 148.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 4.30 GiB is allocated by PyTorch, and 178.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [3]:

# Evaluate on test set
test_results = engine.train(
    model=best_model,
    train_dataloader=train_dataloader,  # Not used but required by function
    test_dataloader=test_dataloader,
    optimizer=optimizer,  # Not used in evaluation
    loss_fn=nn.CrossEntropyLoss(),
    epochs=1,  # Only evaluation
    device=device
)
print(f"\n\033[1mFinal test accuracy: {max(test_results['test_acc']):.4f}\033[0m")

  0%|          | 0/1 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 38.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 4.33 GiB is allocated by PyTorch, and 141.54 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)