<a href="https://colab.research.google.com/github/stefanagheorghita/CARN/blob/main/Homework%203.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import json

sweep_config = {
    "method": "grid",
    "metric": {
        "name": "val_accuracy",
        "goal": "maximize"
    },
    "parameters": {
        "dataset": {"value": "CIFAR100"},
        "data_path": {"value": "./data"},
        "model_name": {"values": [ "resnet18_resize"]},
        "num_classes": {"value": 100},
        "batch_size": {"values": [64]},
        "num_epochs": {"values": [100]},
         # "learning_rate": {
        "optimizer_config": {
        "values": [
            {"optimizer": "adamw", "learning_rate": 0.0005},
             {"optimizer": "adamw", "learning_rate": 0.001},
             {"optimizer": "sgd", "learning_rate": 0.01}
        ]
    },
        "weight_decay": {"value": 0.0005},
        # "optimizer": {"values": ["adamw", "sgd"]},
        "momentum": {"value": 0.9},
        "nesterov": {"value": True},
        "patience": {"value": 3},
        "stop_mode": {"value": "max"},
        "min_delta": {"value": 0.0001},
        "scheduler": {"values": ["cosineannealinglr"]},
        "t_max": {"values": [100]},
        "eta_min": {"value": 0.00001},
        "augmentation_scheme": {"values": ["randaugment", "combined"]},
        "use_cutmix": {"values": [True]},
        "use_mixup": {"values": [True]},
        "alpha": {"value": 1.0},
        "t_0": {"value": 10},
        "t_mult": {"value": 2},
        "warmup": {"value": 5},
        "patience_early_stopping": {"value": 10},
        "pretrained": {"value": True}
    }
}


with open("sweep_config.json", "w") as f:
    json.dump(sweep_config, f)


In [None]:

import json

import yaml
from torch import nn, Tensor
import torch.nn.functional as F
import random

import torch.backends.cudnn

benchmark = True
class PreActBlock(nn.Module):
    """Pre-activation version of the BasicBlock."""

    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )

        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Conv2d(
                in_planes,
                self.expansion * planes,
                kernel_size=1,
                stride=stride,
                bias=False,
            )

    def forward(self, x: Tensor) -> Tensor:
        out = F.relu(self.bn1(x), inplace=True)
        shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    """Pre-activation version of the original Bottleneck module."""

    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )

        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Conv2d(
                in_planes,
                self.expansion * planes,
                kernel_size=1,
                stride=stride,
                bias=False,
            )

    def forward(self, x: Tensor) -> Tensor:
        out = F.relu(self.bn1(x), inplace=True)
        shortcut = self.shortcut(out) if hasattr(self, "shortcut") else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out), inplace=True))
        out = self.conv3(F.relu(self.bn3(out), inplace=True))
        out += shortcut
        return out


class PreActResNet_C10(nn.Module):
    """Pre-activation ResNet for CIFAR-10"""

    def __init__(self, block, num_blocks, num_classes):
        super(PreActResNet_C10, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18_C10(num_classes):
    return PreActResNet_C10(PreActBlock, [2, 2, 2, 2], num_classes)


import torch
from torch import nn


class MLP(torch.nn.Module):
    def __init__(self, num_classes):
        super(MLP, self).__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(784, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.layers(x)


class CustomMLP(nn.Module):
    def __init__(self, input_size, hidden_layers, output_size, activation_fn=nn.ReLU):
        super(CustomMLP, self).__init__()

        layers = []
        current_input_size = input_size

        for hidden_size in hidden_layers:
            layers.append(nn.Linear(current_input_size, hidden_size))
            layers.append(activation_fn())
            current_input_size = hidden_size

        layers.append(nn.Linear(current_input_size, output_size))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)


model = CustomMLP(input_size=784, hidden_layers=[256, 128], output_size=10, activation_fn=nn.Sigmoid)

from torch import nn



class LeNet(nn.Module):
    def __init__(self, num_classes=10, in_channels=1):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, 6, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


from typing import Literal, cast

import torch
import torchvision
from torch import nn
from torchvision.datasets import MNIST, CIFAR10, CIFAR100
from torchvision.transforms.v2 import CutMix, MixUp

from torchvision.transforms import v2
from timm import create_model
import os
import pickle
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import AutoAugment, AutoAugmentPolicy

import torch.optim.lr_scheduler as lr_scheduler





def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"


def cache_dataset(dataset_class, data_dir, cache_dir='./cache', train=True):
    os.makedirs(cache_dir, exist_ok=True)
    subset = 'train' if train else 'test'
    cache_path = os.path.join(cache_dir, f'{dataset_class.__name__}_{subset}.pkl')

    if os.path.exists(cache_path):
        with open(cache_path, 'rb') as f:
            data = pickle.load(f)
    else:
        data = dataset_class(root=data_dir, train=train, download=True)
        with open(cache_path, 'wb') as f:
            pickle.dump(data, f)

    return data


def get_data_augmentation(scheme="basic", dataset="CIFAR"):
    if dataset == "MNIST":
        if scheme == "basic":
            train_transform = v2.Compose(
                [v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize((0.5,), (0.5,))])
        elif scheme == "random_flip":
            train_transform = v2.Compose(
                [v2.RandomHorizontalFlip(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5,), (0.5,))])
        elif scheme == "random_crop_flip":
            train_transform = v2.Compose(
                [v2.RandomResizedCrop(28, scale=(0.8, 1.0)), v2.RandomHorizontalFlip(), v2.ToImage(),
                 v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5,), (0.5,))])
        elif scheme == "randaugment":
            train_transform = v2.Compose(
                [v2.RandAugment(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize((0.5,), (0.5,))])
        elif scheme == "combined":
            train_transform = v2.Compose(
                [ v2.RandomResizedCrop(28, scale=(0.8, 1.0)), v2.RandomRotation(15), v2.RandomHorizontalFlip(),
                 v2.RandAugment(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize((0.5,), (0.5,))])

        else:
            raise ValueError(f"Augmentation scheme '{scheme}' not supported for MNIST.")
        test_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize((0.5,), (0.5,))])

    elif dataset == "CIFAR":
        if scheme == "basic":
            train_transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                                          v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
        elif scheme == "random_flip":
            train_transform = v2.Compose(
                [v2.RandomHorizontalFlip(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
        elif scheme == "random_crop_flip":
            train_transform = v2.Compose([v2.RandomCrop(32, padding=4), v2.RandomHorizontalFlip(), v2.ToImage(),
                                          v2.ToDtype(torch.float32, scale=True),
                                          v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
        elif scheme == "randaugment":
            train_transform = v2.Compose(
                [v2.RandAugment(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
        elif scheme == "autoaugment":
            train_transform = v2.Compose(
                [AutoAugment(policy=AutoAugmentPolicy.CIFAR10), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])

        elif scheme == "combined":
            train_transform = v2.Compose(
                [v2.RandomResizedCrop(32, scale=(0.8, 1.0)), v2.RandomHorizontalFlip(), v2.RandomRotation(15),
                 v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), v2.RandAugment(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])
        elif scheme == "combined2":
            train_transform = v2.Compose(
                [AutoAugment(policy=AutoAugmentPolicy.CIFAR10), v2.RandomCrop(32, padding=4), v2.RandomHorizontalFlip(),
                 v2.ColorJitter(brightness=0.2, contrast=0.2),
                 v2.RandomRotation(15), v2.AutoAugment(), v2.ToImage(), v2.ToDtype(torch.float32, scale=True),
                 v2.Normalize((0.5,), (0.5,))])
        elif scheme == "combined_resize":
                train_transform = v2.Compose([
                    v2.Resize((64, 64)),
                    v2.RandomResizedCrop(64, scale=(0.8, 1.0)),
                    v2.RandomRotation(15),
                    v2.RandomHorizontalFlip(),
                    v2.RandAugment(),
                    v2.ToImage(),
                    v2.ToDtype(torch.float32, scale=True),
                    v2.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                ])
        elif scheme == "combined_resize2":
                    train_transform = v2.Compose([
                        v2.RandomRotation(10),
                        v2.RandomResizedCrop(32, scale=(0.9, 1.1)),
                        v2.RandomHorizontalFlip(),
                        v2.RandomAffine(degrees=0, shear=10),
                        v2.RandomCrop(32, padding=3),
                        v2.ToImage(),
                        v2.ToDtype(torch.float32, scale=True),
                        v2.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
                    ])
        else:
            raise ValueError(f"Augmentation scheme '{scheme}' not supported for CIFAR.")
        test_transform = v2.Compose(
            [v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])

    else:
        raise ValueError(f"Dataset '{dataset}' not supported.")

    return train_transform, test_transform


class CutMixMixUp:
    def __init__(self, num_classes, alpha_cutmix=1.0, alpha_mixup=1.0):
        self.cutmix = CutMix(num_classes=num_classes, alpha=alpha_cutmix)
        self.mixup = MixUp(num_classes=num_classes, alpha=alpha_mixup)

    def __call__(self, batch):
        batch = self.cutmix(batch)
        batch = self.mixup(batch)
        return batch


def load_data(dataset_class, data_dir, batch_size=64, cache_dir='./cache', scheme="basic", custom_transforms=None,
              shuffle=True, use_cutmix=False, use_mixup=False, alpha=1.0):
    if custom_transforms:
        train_transform, test_transform = custom_transforms
    else:
        if dataset_class in [MNIST]:
            train_transform, test_transform = get_data_augmentation(scheme=scheme, dataset="MNIST")
        elif dataset_class in [CIFAR10, CIFAR100]:
            train_transform, test_transform = get_data_augmentation(scheme=scheme, dataset="CIFAR")
        else:
            raise ValueError("Unknown dataset, please specify valid transforms.")

    try:
        train_data = cache_dataset(dataset_class, data_dir, cache_dir, train=True)
        test_data = cache_dataset(dataset_class, data_dir, cache_dir, train=False)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        raise

    train_data.transform = train_transform
    test_data.transform = test_transform

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=shuffle, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)

    return train_loader, test_loader


def get_model(dataset, model_name, num_classes, input_size=None, hidden_layers=None, pretrained = True):
    if dataset == 'CIFAR10' or dataset == 'CIFAR100':
        if model_name in ['resnet18', 'resnet18_resize'] :
            model = create_model("resnet18", pretrained=pretrained, num_classes=num_classes)
        elif model_name in ['resnet50', 'resnet50_resize']:
            model = create_model("resnet50", pretrained=pretrained, num_classes=num_classes)
        elif model_name == 'resnet18_cifar10':
            model = create_model("hf_hub:edadaltocg/resnet18_cifar10", pretrained=False, num_classes=num_classes)
            model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        elif model_name == 'preactresnet18':
            model = PreActResNet18_C10(num_classes=num_classes)

        else:
            raise ValueError(
                f"Model '{model_name}' is not supported for CIFAR.")
        if pretrained and model_name in ['resnet18_resize', 'resnet50_resize']:
            model = nn.Sequential(
                nn.Upsample(size=(224, 224), mode='bilinear', align_corners=False),
                model
            )


    elif dataset == 'MNIST':
        if model_name.upper() == "MLP":
            model = MLP(num_classes=num_classes)
        elif model_name.upper() == "LENET":
            model = LeNet(num_classes=num_classes, in_channels=1)
        elif model_name.upper() == "CUSTOMMLP":
            model = CustomMLP(input_size=input_size, hidden_layers=hidden_layers, output_size=num_classes)
        else:
            raise ValueError(f"Model '{model_name}' is not supported for MNIST. Choose 'MLP' or 'LeNet'.")

    else:
        raise ValueError(f"Dataset '{dataset}' is not supported. Choose 'CIFAR10', 'CIFAR100', or 'MNIST'.")

    return model


import torch.optim as optim


def get_optimizer(optimizer_name, model_parameters, lr=0.001, momentum=0.9, weight_decay=0.0, nesterov=True):
    optimizer_name = optimizer_name.lower()

    if optimizer_name == 'sgd':
        return optim.SGD(model_parameters, lr=lr)
    elif optimizer_name == 'sgd_momentum':
        return optim.SGD(model_parameters, lr=lr, momentum=momentum)
    elif optimizer_name == 'sgd_nesterov':
        return optim.SGD(model_parameters, lr=lr, momentum=momentum, nesterov=nesterov)
    elif optimizer_name == 'sgd_weight_decay':
        return optim.SGD(model_parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif optimizer_name == 'adam':
        return optim.Adam(model_parameters, lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'adamw':
        return optim.AdamW(model_parameters, lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'rmsprop':
        return optim.RMSprop(model_parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        raise ValueError(f"Optimizer '{optimizer_name}' not supported.")


def get_scheduler(optimizer, scheduler_name, **kwargs):
    scheduler_name = scheduler_name.lower()

    if scheduler_name == 'steplr':
        return lr_scheduler.StepLR(optimizer, step_size=kwargs.get('step_size', 10), gamma=kwargs.get('gamma', 0.1))

    elif scheduler_name == 'reducelronplateau':
        mode_str = kwargs.get('mode', 'min')
        if mode_str not in ['min', 'max']:
            raise ValueError("Invalid mode for ReduceLROnPlateau: must be 'min' or 'max'")
        mode = cast(Literal["min", "max"], mode_str)
        return lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=mode,
            factor=kwargs.get('factor', 0.1),
            patience=kwargs.get('patience', 10)
        )

    elif scheduler_name == 'cosineannealinglr':
        return lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=kwargs.get('t_max', 50),
            eta_min=kwargs.get('eta_min', 0)
        )

    elif scheduler_name == 'cosineannealingwarmrestarts':
        return lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=kwargs.get('t_0', 10),
            T_mult=kwargs.get('t_mult', 1),
            eta_min=kwargs.get('eta_min', 0)
        )

    elif scheduler_name == 'exponentiallr':
        return lr_scheduler.ExponentialLR(optimizer, gamma=kwargs.get('gamma', 0.9))

    elif scheduler_name == 'linearlr':
        return lr_scheduler.LinearLR(
            optimizer,
            start_factor=kwargs.get('start_factor', 1.0),
            end_factor=kwargs.get('end_factor', 0.0),
            total_iters=kwargs.get('total_iters', 100)
        )

    elif scheduler_name == 'none':
        return None

    else:
        raise ValueError(f"Scheduler '{scheduler_name}' not supported.")


def early_stopping(current_score, best_score, patience_counter, patience_early_stopping, min_delta=0.0, mode="min"):
    if best_score is None:
        best_score = current_score
        return False, best_score, patience_counter

    if mode == "min":
        improvement = best_score - current_score > min_delta
    elif mode == "max":
        improvement = current_score - best_score > min_delta
    else:
        raise ValueError("Mode should be 'min' or 'max'")

    if improvement:
        best_score = current_score
        patience_counter = 0
    else:
        patience_counter += 1

    early_stop = patience_counter >= patience_early_stopping
    return early_stop, best_score, patience_counter


def validate_model(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct, total = 0, 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss = running_loss / len(test_loader.dataset)
    val_accuracy = 100 * correct / total
    return val_loss, val_accuracy


from torch.amp import autocast, GradScaler


def train_model(model, train_loader, test_loader, device, num_epochs, optimizer, num_classes, scheduler_mode=None,
                scheduler=None, patience_early_stopping=5, min_delta=0.0, early_stop_mode="min", learning_rate=0.1,
                warmup=0, grad_alpha=1.0, use_cutmix=True, use_mixup=True, alpha=1.0):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler(device)

    best_val_score = None
    patience_counter = 0
    best_val_accuracy = 0.0
    wandb.watch(model, log="all", log_freq=10)
    alpha = float(alpha)
    cutmix = v2.CutMix(num_classes=num_classes, alpha=alpha)
    mixup = v2.MixUp(num_classes=num_classes, alpha=alpha)
    cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
    rand= random.randint(1000, 9999)
    file_path = f"/kaggle/working/best_model_{rand}.pth"
    for epoch in range(num_epochs):
        print("Epoch ", epoch)
        model.train()
        running_loss = 0.0
        correct, total = 0, 0

        if epoch < warmup:
            lr_scale = min(1., float(epoch + 1) / warmup)
            for pg in optimizer.param_groups:
                pg['lr'] = learning_rate * lr_scale

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            if use_cutmix and use_mixup:
                inputs, labels = cutmix_or_mixup(inputs, labels)
            elif use_cutmix:
                inputs, labels = cutmix(inputs, labels)
            elif use_mixup:
                inputs, labels = mixup(inputs, labels)


            optimizer.zero_grad()
            with autocast(device):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            scaler.scale(loss).backward()

            #             scaler.unscale_(optimizer)
            #             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_alpha)

            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            if use_cutmix or use_mixup:
                correct += predicted.eq(labels.argmax(dim=1)).sum().item()
            else:
                  correct += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = 100 * correct / total

        val_loss, val_accuracy = validate_model(model, test_loader, criterion, device)
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), file_path)
            print(f"Best model saved with accuracy: {best_val_accuracy:.2f}%")
        print(f"Epoch {epoch + 1}/{num_epochs} - ")
        print(f"Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.2f}% - ")
        print(f"Val Loss: {val_loss:.4f} - Val Accuracy: {val_accuracy:.2f}%")
        wandb.log({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy
        })

        val_score = val_loss if early_stop_mode == "min" else val_accuracy

        if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
            if scheduler_mode == "max":
                scheduler.step(val_accuracy)
            elif scheduler_mode == "min":
                scheduler.step(val_loss)
        elif scheduler:
            scheduler.step()

        early_stop, best_val_score, patience_counter = early_stopping(
            current_score=val_score,
            best_score=best_val_score,
            patience_counter=patience_counter,
            patience_early_stopping=patience_early_stopping,
            min_delta=min_delta,
            mode=early_stop_mode
        )

        if early_stop:
            print("Early stopping triggered. Stopping training.")
            break

    print("Training complete.")


import wandb


def sweep_train():
    wandb.init()
    config = wandb.config
    try:
        dataset_class = datasets.CIFAR100 if config.dataset == "CIFAR100" else datasets.CIFAR10 if config.dataset == "CIFAR10" else datasets.MNIST
        train_loader, test_loader = load_data(
            dataset_class=dataset_class,
            data_dir=config.data_path,
            batch_size=config.batch_size,
            scheme=config.augmentation_scheme,
            use_cutmix=config.use_cutmix,
            use_mixup=config.use_mixup,
            alpha=config.alpha
        )

        model = get_model(
            dataset=config.dataset,
            model_name=config.model_name,
            num_classes=config.num_classes
        )
        optimizer_name = config.optimizer_config["optimizer"]
        learning_rate = config.optimizer_config["learning_rate"]

        optimizer = get_optimizer(
            optimizer_name=optimizer_name,
            model_parameters=model.parameters(),
            lr=learning_rate,
            momentum=config.momentum,
            weight_decay=config.weight_decay,
            nesterov=config.nesterov
        )

        scheduler = get_scheduler(
            optimizer=optimizer,
            scheduler_name=config.scheduler,
            t_max=config.get('t_max', 200),
            eta_min=config.get('eta_min', 0),
            step_size=config.get('step_size', 10),
            gamma=config.get('gamma', 0.1),
            patience=config.get('scheduler_patience', 10),
            factor=config.get('factor', 0.1)
        )

        train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            device=get_device(),
            num_epochs=config.num_epochs,
            optimizer=optimizer,
            scheduler=scheduler,
            patience_early_stopping=config.patience_early_stopping,
            min_delta=config.min_delta,
            early_stop_mode=config.stop_mode,
            learning_rate=learning_rate,
            num_classes=config.num_classes,
            use_cutmix=config.use_cutmix,
            use_mixup=config.use_mixup,
            alpha=config.alpha,
            warmup=config.warmup,
        )
    finally:
        wandb.finish()

def load_config(file_path):
    ext = os.path.splitext(file_path)[-1].lower()
    if ext == ".json":
        with open(file_path, 'r') as f:
            config = json.load(f)
    elif ext in {".yaml", ".yml"}:
        with open(file_path, 'r') as f:
            config = yaml.safe_load(f)
    else:
        raise ValueError("Unsupported file format. Use JSON or YAML.")
    return config




config_file_path = "sweep_config.json"
sweep_config = load_config(config_file_path)


sweep_id = wandb.sweep(sweep_config, project="training-cifar100")
wandb.agent(sweep_id, sweep_train)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: z9ns4r8w
Sweep URL: https://wandb.ai/gheorghitastefana-alexandru-ioan-cuza-university-iasi/training-cifar100/sweeps/z9ns4r8w


[34m[1mwandb[0m: Agent Starting Run: cg9to7ic with config:
[34m[1mwandb[0m: 	alpha: 1
[34m[1mwandb[0m: 	augmentation_scheme: randaugment
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	data_path: ./data
[34m[1mwandb[0m: 	dataset: CIFAR100
[34m[1mwandb[0m: 	eta_min: 1e-05
[34m[1mwandb[0m: 	min_delta: 0.0001
[34m[1mwandb[0m: 	model_name: resnet18_resize
[34m[1mwandb[0m: 	momentum: 0.9
[34m[1mwandb[0m: 	nesterov: True
[34m[1mwandb[0m: 	num_classes: 100
[34m[1mwandb[0m: 	num_epochs: 20
[34m[1mwandb[0m: 	optimizer_config: {'learning_rate': 0.0005, 'optimizer': 'adamw'}
[34m[1mwandb[0m: 	patience: 3
[34m[1mwandb[0m: 	patience_early_stopping: 10
[34m[1mwandb[0m: 	pretrained: True
[34m[1mwandb[0m: 	scheduler: cosineannealinglr
[34m[1mwandb[0m: 	stop_mode: max
[34m[1mwandb[0m: 	t_0: 10
[34m[1mwandb[0m: 	t_max: 100
[34m[1mwandb[0m: 	t_mult: 2
[34m[1mwandb[0m: 	use_cutmix: True
[34m[1mwandb[0m: 	use_mixup: True
[34m[1mwa

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113503688888058, max=1.0…

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:05<00:00, 29583936.81it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Epoch  0
Best model saved with accuracy: 33.45%
Epoch 1/20 - 
Train Loss: 4.2947 - Train Accuracy: 9.22% - 
Val Loss: 2.9145 - Val Accuracy: 33.45%
Epoch  1
Best model saved with accuracy: 62.73%
Epoch 2/20 - 
Train Loss: 3.2017 - Train Accuracy: 34.38% - 
Val Loss: 1.4617 - Val Accuracy: 62.73%
Epoch  2
Best model saved with accuracy: 69.37%
Epoch 3/20 - 
Train Loss: 2.7711 - Train Accuracy: 45.31% - 
Val Loss: 1.2885 - Val Accuracy: 69.37%
Epoch  3
Best model saved with accuracy: 73.10%
Epoch 4/20 - 
Train Loss: 2.5074 - Train Accuracy: 52.06% - 
Val Loss: 1.0508 - Val Accuracy: 73.10%
Epoch  4
Best model saved with accuracy: 75.66%
Epoch 5/20 - 
Train Loss: 2.4448 - Train Accuracy: 54.29% - 
Val Loss: 1.0654 - Val Accuracy: 75.66%
Epoch  5
Best model saved with accuracy: 76.42%
Epoch 6/20 - 
Train Loss: 2.3356 - Train Accuracy: 56.78% - 
Val Loss: 1.0818 - Val Accuracy: 76.42%
Epoch  6
Best model saved with accuracy: 77.44%
Epoch 7/20 - 
Train Loss: 2.2430 - Train Accuracy: 58.65% -

VBox(children=(Label(value='0.175 MB of 0.175 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_accuracy,▁▄▅▆▆▆▇▇▇▇▇▇▇█▇█████
train_loss,█▅▄▃▃▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁
val_accuracy,▁▅▆▇▇▇▇▇▇███████████
val_loss,█▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20.0
train_accuracy,69.732
train_loss,1.8718
val_accuracy,81.88
val_loss,0.7749


[34m[1mwandb[0m: Agent Starting Run: 5fu60cj0 with config:
[34m[1mwandb[0m: 	alpha: 1
[34m[1mwandb[0m: 	augmentation_scheme: randaugment
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	data_path: ./data
[34m[1mwandb[0m: 	dataset: CIFAR100
[34m[1mwandb[0m: 	eta_min: 1e-05
[34m[1mwandb[0m: 	min_delta: 0.0001
[34m[1mwandb[0m: 	model_name: resnet18_resize
[34m[1mwandb[0m: 	momentum: 0.9
[34m[1mwandb[0m: 	nesterov: True
[34m[1mwandb[0m: 	num_classes: 100
[34m[1mwandb[0m: 	num_epochs: 20
[34m[1mwandb[0m: 	optimizer_config: {'learning_rate': 0.001, 'optimizer': 'adamw'}
[34m[1mwandb[0m: 	patience: 3
[34m[1mwandb[0m: 	patience_early_stopping: 10
[34m[1mwandb[0m: 	pretrained: True
[34m[1mwandb[0m: 	scheduler: cosineannealinglr
[34m[1mwandb[0m: 	stop_mode: max
[34m[1mwandb[0m: 	t_0: 10
[34m[1mwandb[0m: 	t_max: 100
[34m[1mwandb[0m: 	t_mult: 2
[34m[1mwandb[0m: 	use_cutmix: True
[34m[1mwandb[0m: 	use_mixup: True
[34m[1mwan

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111366556666932, max=1.0)…

Epoch  0
Best model saved with accuracy: 53.27%
Epoch 1/20 - 
Train Loss: 3.8850 - Train Accuracy: 18.01% - 
Val Loss: 1.8687 - Val Accuracy: 53.27%
Epoch  1
Best model saved with accuracy: 69.15%
Epoch 2/20 - 
Train Loss: 2.8968 - Train Accuracy: 41.68% - 
Val Loss: 1.2508 - Val Accuracy: 69.15%
Epoch  2
Best model saved with accuracy: 72.66%
Epoch 3/20 - 
Train Loss: 2.6161 - Train Accuracy: 49.04% - 
Val Loss: 1.0854 - Val Accuracy: 72.66%
Epoch  3
Best model saved with accuracy: 74.01%
Epoch 4/20 - 
Train Loss: 2.4926 - Train Accuracy: 52.30% - 
Val Loss: 1.0253 - Val Accuracy: 74.01%
Epoch  4
Best model saved with accuracy: 75.14%
Epoch 5/20 - 
Train Loss: 2.4600 - Train Accuracy: 53.72% - 
Val Loss: 1.0467 - Val Accuracy: 75.14%
Epoch  5
Best model saved with accuracy: 76.64%
Epoch 6/20 - 
Train Loss: 2.3495 - Train Accuracy: 56.48% - 
Val Loss: 0.9633 - Val Accuracy: 76.64%
Epoch  6
Best model saved with accuracy: 77.02%
Epoch 7/20 - 
Train Loss: 2.2972 - Train Accuracy: 57.99% 

VBox(children=(Label(value='0.176 MB of 0.176 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_accuracy,▁▄▅▆▆▆▆▇▇▇▇▇▇▇▇█████
train_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁
val_accuracy,▁▅▆▆▇▇▇▇▇▇▇█████████
val_loss,█▄▃▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁

0,1
epoch,20.0
train_accuracy,69.236
train_loss,1.89344
val_accuracy,80.61
val_loss,0.89949


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Sweep Agent: Exiting.
