In [1]:
!pip install wandb



In [2]:
import wandb

In [3]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnipkha21[0m ([33mnipkha21-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import wandb
from tqdm import tqdm
import time
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [7]:
from data_utils import load_split_data
from models import SimpleCNN, get_model
from training_utils import ModelTrainer, get_optimizer, get_scheduler
from evaluation import ModelEvaluator

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
class CNNFacialExpressionDataset(Dataset):
    def __init__(self, data, transform=None, augment=False, is_test=False):
        self.data = data
        self.transform = transform
        self.augment = augment
        self.is_test = is_test

        if transform is None:
            if augment and not is_test:
                self.transform = self._get_augmentation_transform()
            else:
                self.transform = self._get_basic_transform()

    def _get_basic_transform(self):
        return A.Compose([
            A.Normalize(mean=[0.485], std=[0.229]),
            ToTensorV2()
        ])

    def _get_augmentation_transform(self):
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=15, p=0.3, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.Blur(blur_limit=3, p=0.1),
            A.CoarseDropout(max_holes=1, max_height=8, max_width=8,
                           min_holes=1, min_height=4, min_width=4, p=0.2),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=10, p=0.3),
            A.Normalize(mean=[0.485], std=[0.229]),
            ToTensorV2()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        pixels = self.data.iloc[idx]['pixels']
        image = np.array([int(pixel) for pixel in pixels.split()]).reshape(48, 48)

        image = np.stack([image] * 3, axis=-1).astype(np.uint8)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        else:
            image = torch.FloatTensor(image).unsqueeze(0) / 255.0

        if self.is_test:
            return image
        else:
            emotion = int(self.data.iloc[idx]['emotion'])
            return image, emotion

In [10]:
class ImprovedSimpleCNN(nn.Module):
    def __init__(self, num_classes=7, dropout_rate=0.3, use_batch_norm=True):
        super(ImprovedSimpleCNN, self).__init__()

        self.use_batch_norm = use_batch_norm

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) if use_batch_norm else nn.Identity()

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64) if use_batch_norm else nn.Identity()

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128) if use_batch_norm else nn.Identity()

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256) if use_batch_norm else nn.Identity()

        self.pool = nn.MaxPool2d(2, 2)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((3, 3))

        self.fc1 = nn.Linear(256 * 3 * 3, 512)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 128)
        self.dropout2 = nn.Dropout(dropout_rate / 2)
        self.fc3 = nn.Linear(128, num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if x.shape[1] == 3:
            x = x[:, 0:1, :, :]

        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        x = self.adaptive_pool(x)

        x = x.view(-1, 256 * 3 * 3)
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)

        return x

In [11]:
class CNNTrainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer,
                 scheduler=None, device='cuda', experiment_name='experiment', run_name='run'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device

        wandb.init(
            project="facial-expression-recognition",
            group=experiment_name,
            name=run_name,
            config={
                "data_split_method": "predefined_stratified",
                "train_samples": len(train_loader.dataset),
                "val_samples": len(val_loader.dataset),
                "split_random_state": 42
            },
            reinit=True
        )
        wandb.watch(self.model, log='all', log_freq=100)

        self.history = {
            'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [],
            'learning_rates': [], 'epoch_times': []
        }
        self.best_val_acc = 0.0
        self.best_model_state = None

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="Training")
        for batch_idx, (data, targets) in enumerate(pbar):
            data, targets = data.to(self.device), targets.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(data)
            loss = self.criterion(outputs, targets)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })

        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total

        return epoch_loss, epoch_acc

    def validate_epoch(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            for data, targets in tqdm(self.val_loader, desc="Validation"):
                data, targets = data.to(self.device), targets.to(self.device)

                outputs = self.model(data)
                loss = self.criterion(outputs, targets)

                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

                all_predictions.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = 100. * correct / total

        return epoch_loss, epoch_acc, np.array(all_predictions), np.array(all_targets)

    def train(self, epochs, early_stopping_patience=10):
        print(f"Starting CNN training for {epochs} epochs...")

        patience_counter = 0

        for epoch in range(epochs):
            epoch_start = time.time()

            train_loss, train_acc = self.train_epoch()

            val_loss, val_acc, val_preds, val_targets = self.validate_epoch()

            current_lr = self.optimizer.param_groups[0]['lr']
            if self.scheduler:
                if isinstance(self.scheduler, ReduceLROnPlateau):
                    self.scheduler.step(val_loss)
                else:
                    self.scheduler.step()

            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.best_model_state = self.model.state_dict().copy()
                patience_counter = 0
            else:
                patience_counter += 1

            epoch_time = time.time() - epoch_start

            self.history['train_loss'].append(train_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_loss'].append(val_loss)
            self.history['val_acc'].append(val_acc)
            self.history['learning_rates'].append(current_lr)
            self.history['epoch_times'].append(epoch_time)

            log_dict = {
                'epoch': epoch + 1,
                'train_loss': train_loss,
                'train_accuracy': train_acc,
                'val_loss': val_loss,
                'val_accuracy': val_acc,
                'learning_rate': current_lr,
                'epoch_time': epoch_time,
                'best_val_accuracy': self.best_val_acc
            }

            if (epoch + 1) % 15 == 0:
                emotion_map = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}
                log_dict['confusion_matrix'] = wandb.plot.confusion_matrix(
                    probs=None, y_true=val_targets, preds=val_preds,
                    class_names=list(emotion_map.values())
                )

            wandb.log(log_dict)

            print(f"Epoch {epoch+1}/{epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            print(f"  LR: {current_lr:.6f}, Time: {epoch_time:.2f}s")
            print("-" * 50)

            if patience_counter >= early_stopping_patience:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        if self.best_model_state:
            self.model.load_state_dict(self.best_model_state)
            print(f"Loaded best model with validation accuracy: {self.best_val_acc:.2f}%")

        return self.history

In [12]:
wandb.login()

print("=== LOADING PRE-SPLIT DATA ===")

train_df, val_df, test_df = load_split_data('drive/MyDrive/data')

emotion_map = {0: 'Angry', 1: 'Disgust', 2: 'Fear', 3: 'Happy', 4: 'Sad', 5: 'Surprise', 6: 'Neutral'}

=== LOADING PRE-SPLIT DATA ===


In [None]:
print("=== DATA AUGMENTATION EXPERIMENTS ===")

augmentation_strategies = {
    'no_augmentation': {
        'augment': False,
        'description': 'No data augmentation'
    },
}

In [16]:
augmentation_results = {}

for aug_name, aug_config in augmentation_strategies.items():
    print(f"\nTesting augmentation strategy: {aug_name}")
    print(f"Description: {aug_config['description']}")

    if 'transform' in aug_config:
        train_dataset = CNNFacialExpressionDataset(
            train_df, transform=aug_config['transform'], augment=False
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, transform=A.Compose([
                A.Normalize(mean=[0.485], std=[0.229]),
                ToTensorV2()
            ]), augment=False
        )
    else:
        train_dataset = CNNFacialExpressionDataset(
            train_df, augment=aug_config['augment']
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, augment=False
        )

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    model = SimpleCNN(num_classes=7, dropout_rate=0.5)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    trainer = CNNTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        experiment_name="Simple_CNN_Training",
        run_name=f"Augmentation_{aug_name}"
    )

    history = trainer.train(epochs=10, early_stopping_patience=5)

    augmentation_results[aug_name] = {
        'best_val_acc': trainer.best_val_acc,
        'final_train_acc': history['train_acc'][-1],
        'final_val_acc': history['val_acc'][-1],
        'overfitting_score': history['train_acc'][-1] - history['val_acc'][-1]
    }

    wandb.finish()

    print(f"Best validation accuracy: {trainer.best_val_acc:.2f}%")

print("\n=== AUGMENTATION ANALYSIS ===")
best_augmentation = max(augmentation_results.keys(),
                       key=lambda x: augmentation_results[x]['best_val_acc'])
print(f"Best augmentation strategy: {best_augmentation}")

for aug_name, results in augmentation_results.items():
    print(f"{aug_name}: {results['best_val_acc']:.2f}% "
          f"(overfitting: {results['overfitting_score']:.2f}%)")


Testing augmentation strategy: no_augmentation
Description: No data augmentation


Starting CNN training for 10 epochs...


Training: 100%|██████████| 314/314 [02:12<00:00,  2.36it/s, Loss=1.5715, Acc=37.74%]
Validation: 100%|██████████| 68/68 [00:12<00:00,  5.34it/s]


Epoch 1/10:
  Train Loss: 1.5900, Train Acc: 37.74%
  Val Loss: 1.4204, Val Acc: 44.79%
  LR: 0.001000, Time: 145.54s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:09<00:00,  2.42it/s, Loss=1.2098, Acc=49.32%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.08it/s]


Epoch 2/10:
  Train Loss: 1.3352, Train Acc: 49.32%
  Val Loss: 1.2719, Val Acc: 51.36%
  LR: 0.001000, Time: 143.25s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:18<00:00,  2.27it/s, Loss=1.3596, Acc=54.37%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.09it/s]


Epoch 3/10:
  Train Loss: 1.2092, Train Acc: 54.37%
  Val Loss: 1.2086, Val Acc: 53.82%
  LR: 0.001000, Time: 151.95s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:17<00:00,  2.28it/s, Loss=1.0403, Acc=58.55%]
Validation: 100%|██████████| 68/68 [00:11<00:00,  5.82it/s]


Epoch 4/10:
  Train Loss: 1.1012, Train Acc: 58.55%
  Val Loss: 1.1663, Val Acc: 55.91%
  LR: 0.001000, Time: 149.52s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:17<00:00,  2.28it/s, Loss=0.9871, Acc=63.09%]
Validation: 100%|██████████| 68/68 [00:12<00:00,  5.28it/s]


Epoch 5/10:
  Train Loss: 0.9873, Train Acc: 63.09%
  Val Loss: 1.1854, Val Acc: 56.37%
  LR: 0.001000, Time: 150.34s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:17<00:00,  2.29it/s, Loss=0.9181, Acc=67.60%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.18it/s]


Epoch 6/10:
  Train Loss: 0.8794, Train Acc: 67.60%
  Val Loss: 1.1838, Val Acc: 56.74%
  LR: 0.001000, Time: 150.23s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:17<00:00,  2.28it/s, Loss=0.6850, Acc=71.87%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.12it/s]


Epoch 7/10:
  Train Loss: 0.7586, Train Acc: 71.87%
  Val Loss: 1.2657, Val Acc: 56.68%
  LR: 0.001000, Time: 150.87s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:16<00:00,  2.30it/s, Loss=0.8289, Acc=77.17%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.16it/s]


Epoch 8/10:
  Train Loss: 0.6334, Train Acc: 77.17%
  Val Loss: 1.3005, Val Acc: 56.40%
  LR: 0.001000, Time: 149.76s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:16<00:00,  2.29it/s, Loss=0.4448, Acc=80.74%]
Validation: 100%|██████████| 68/68 [00:11<00:00,  5.85it/s]


Epoch 9/10:
  Train Loss: 0.5293, Train Acc: 80.74%
  Val Loss: 1.3803, Val Acc: 57.21%
  LR: 0.001000, Time: 148.49s
--------------------------------------------------


Training: 100%|██████████| 314/314 [02:16<00:00,  2.30it/s, Loss=0.2924, Acc=84.11%]
Validation: 100%|██████████| 68/68 [00:13<00:00,  5.16it/s]

Epoch 10/10:
  Train Loss: 0.4374, Train Acc: 84.11%
  Val Loss: 1.4438, Val Acc: 57.05%
  LR: 0.001000, Time: 149.70s
--------------------------------------------------
Loaded best model with validation accuracy: 57.21%





0,1
best_val_accuracy,▁▅▆▇██████
epoch,▁▂▃▃▄▅▆▆▇█
epoch_time,▃▁█▆▇▇▇▆▅▆
learning_rate,▁▁▁▁▁▁▁▁▁▁
train_accuracy,▁▃▄▄▅▆▆▇▇█
train_loss,█▆▆▅▄▄▃▂▂▁
val_accuracy,▁▅▆▇██████
val_loss,▇▄▂▁▁▁▄▄▆█

0,1
best_val_accuracy,57.20919
epoch,10.0
epoch_time,149.69551
learning_rate,0.001
train_accuracy,84.11048
train_loss,0.43737
val_accuracy,57.04667
val_loss,1.44376


Best validation accuracy: 57.21%

=== AUGMENTATION ANALYSIS ===
Best augmentation strategy: no_augmentation
no_augmentation: 57.21% (overfitting: 27.06%)


In [18]:
augmentation_results = {}

for aug_name, aug_config in augmentation_strategies.items():
    print(f"\nTesting augmentation strategy: {aug_name}")
    print(f"Description: {aug_config['description']}")

    if 'transform' in aug_config:
        train_dataset = CNNFacialExpressionDataset(
            train_df, transform=aug_config['transform'], augment=False
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, transform=A.Compose([
                A.Normalize(mean=[0.485], std=[0.229]),
                ToTensorV2()
            ]), augment=False
        )
    else:
        train_dataset = CNNFacialExpressionDataset(
            train_df, augment=aug_config['augment']
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, augment=False
        )

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    model = SimpleCNN(num_classes=7, dropout_rate=0.5)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    trainer = CNNTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        experiment_name="Simple_CNN_Training",
        run_name=f"Augmentation_{aug_name}"
    )

    history = trainer.train(epochs=10, early_stopping_patience=5)

    augmentation_results[aug_name] = {
        'best_val_acc': trainer.best_val_acc,
        'final_train_acc': history['train_acc'][-1],
        'final_val_acc': history['val_acc'][-1],
        'overfitting_score': history['train_acc'][-1] - history['val_acc'][-1]
    }

    wandb.finish()

    print(f"Best validation accuracy: {trainer.best_val_acc:.2f}%")

print("\n=== AUGMENTATION ANALYSIS ===")
best_augmentation = max(augmentation_results.keys(),
                       key=lambda x: augmentation_results[x]['best_val_acc'])
print(f"Best augmentation strategy: {best_augmentation}")

for aug_name, results in augmentation_results.items():
    print(f"{aug_name}: {results['best_val_acc']:.2f}% "
          f"(overfitting: {results['overfitting_score']:.2f}%)")


Testing augmentation strategy: advanced_augmentation
Description: Advanced: Flip + Rotation + Brightness + Noise


Starting CNN training for 10 epochs...


Training: 100%|██████████| 314/314 [02:11<00:00,  2.38it/s, Loss=1.6199, Acc=31.11%]
Validation: 100%|██████████| 68/68 [00:11<00:00,  5.78it/s]


Epoch 1/10:
  Train Loss: 1.7157, Train Acc: 31.11%
  Val Loss: 1.5453, Val Acc: 40.86%
  LR: 0.001000, Time: 143.68s
--------------------------------------------------


Training:  84%|████████▍ | 263/314 [01:53<00:22,  2.31it/s, Loss=1.4774, Acc=39.37%]


KeyboardInterrupt: 

In [19]:
best_augmentation = 'no_augmentation'  # es iyo itogshi sauketeso agar davrune bolomde :))

In [20]:
print("=== CNN ARCHITECTURE COMPARISON ===")

architectures = {
    'simple_cnn_low_dropout': {
        'model': SimpleCNN(num_classes=7, dropout_rate=0.3),
        'description': 'Basic CNN with lower dropout'
    },
    'improved_cnn': {
        'model': ImprovedSimpleCNN(num_classes=7, dropout_rate=0.3, use_batch_norm=True),
        'description': 'Improved CNN with BatchNorm'
    },
}

architecture_results = {}

best_aug_config = augmentation_strategies[best_augmentation]

for arch_name, arch_config in architectures.items():
    print(f"\nTesting architecture: {arch_name}")
    print(f"Description: {arch_config['description']}")

    if 'transform' in best_aug_config:
        train_dataset = CNNFacialExpressionDataset(
            train_df, transform=best_aug_config['transform'], augment=False
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, transform=A.Compose([
                A.Normalize(mean=[0.485], std=[0.229]),
                ToTensorV2()
            ]), augment=False
        )
    else:
        train_dataset = CNNFacialExpressionDataset(
            train_df, augment=best_aug_config['augment']
        )
        val_dataset = CNNFacialExpressionDataset(
            val_df, augment=False
        )

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    model = arch_config['model']

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    trainer = CNNTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        experiment_name="Simple_CNN_Training",
        run_name=f"Architecture_{arch_name}"
    )

    history = trainer.train(epochs=10, early_stopping_patience=7)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    architecture_results[arch_name] = {
        'best_val_acc': trainer.best_val_acc,
        'final_train_acc': history['train_acc'][-1],
        'final_val_acc': history['val_acc'][-1],
        'total_params': total_params,
        'trainable_params': trainable_params,
        'avg_epoch_time': np.mean(history['epoch_times'])
    }

    wandb.finish()

    print(f"Best validation accuracy: {trainer.best_val_acc:.2f}%")
    print(f"Total parameters: {total_params:,}")

print("\n=== ARCHITECTURE ANALYSIS ===")
best_architecture = max(architecture_results.keys(),
                       key=lambda x: architecture_results[x]['best_val_acc'])
print(f"Best architecture: {best_architecture}")

for arch_name, results in architecture_results.items():
    print(f"{arch_name}: {results['best_val_acc']:.2f}% "
          f"({results['total_params']:,} params, "
          f"{results['avg_epoch_time']:.2f}s/epoch)")

=== CNN ARCHITECTURE COMPARISON ===


KeyError: 'no_augmentation'