In [5]:
!pip install torch torchvision tqdm
!pip install --upgrade ipywidgets
!pip install --upgrade tqdm




You should consider upgrading via the 'C:\Users\zelkh\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.




You should consider upgrading via the 'C:\Users\zelkh\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.




You should consider upgrading via the 'C:\Users\zelkh\AppData\Local\Programs\Python\Python310\python.exe -m pip install --upgrade pip' command.


In [8]:
# MODEL (with training recipe upgrades)

# MODEL + TRAIN (CIFAR-10, poly-softplus, SGD+cosine, MixUp, EMA)

import math, copy, numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR, SequentialLR, LinearLR

# ----------------------------------
# 1) DATA
# ----------------------------------
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

batch_size   = 128
num_workers  = 2
use_cuda     = torch.cuda.is_available()
device       = torch.device('cuda' if use_cuda else 'cpu')
pin_mem_flag = True if use_cuda else False

trainset   = torchvision.datasets.CIFAR10(root='./data', train=True,  download=True, transform=transform)
testset    = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                          num_workers=num_workers, pin_memory=pin_mem_flag)
testloader  = torch.utils.data.DataLoader(testset,  batch_size=batch_size, shuffle=False,
                                          num_workers=num_workers, pin_memory=pin_mem_flag)

# ----------------------------------
# 2) POLY ACT (Softplus-like, degree 4) with learnable affine
# ----------------------------------
class PolyAct4(nn.Module):
    """
    y = s * (A x^4 + B x^3 + C x^2 + D x + E) + b, where x is clamped to [clamp_min, clamp_max].
    A..E are fixed buffers (from your fit); s,b are tiny learnable scalars to absorb residual scale/bias.
    """
    def __init__(self, A, B, C, D, E, clamp_min=-6.0, clamp_max=6.0, init_scale=1.0, init_bias=0.0):
        super().__init__()
        self.register_buffer('A', torch.tensor(float(A)))
        self.register_buffer('B', torch.tensor(float(B)))
        self.register_buffer('C', torch.tensor(float(C)))
        self.register_buffer('D', torch.tensor(float(D)))
        self.register_buffer('E', torch.tensor(float(E)))
        self.clamp_min = clamp_min
        self.clamp_max = clamp_max
        self.scale = nn.Parameter(torch.tensor(float(init_scale)))
        self.bias  = nn.Parameter(torch.tensor(float(init_bias)))

    def forward(self, x):
        x = torch.clamp(x, min=self.clamp_min, max=self.clamp_max)
        y = self.A * x**4 + self.B * x**3 + self.C * x**2 + self.D * x + self.E
        return self.scale * y + self.bias

# ----------------------------------
# 3) MODEL (VGG-ish)
# ----------------------------------
def conv3x3(in_ch, out_ch):
    # BN follows -> bias not needed
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False)

class CIFAR_CNN(nn.Module):
    def __init__(self, activation_fn: nn.Module):
        super().__init__()
        act = activation_fn  # shared instance (keeps HE-compat simplicity)

        self.block1 = nn.Sequential(
            conv3x3(3, 48),  nn.BatchNorm2d(48),  act,
            conv3x3(48, 48), nn.BatchNorm2d(48),  act,
            nn.AvgPool2d(2, 2)   # 32 -> 16
        )

        self.block2 = nn.Sequential(
            conv3x3(48, 96),  nn.BatchNorm2d(96),  act,
            conv3x3(96, 96),  nn.BatchNorm2d(96),  act,
            nn.AvgPool2d(2, 2)   # 16 -> 8
        )

        self.block3 = nn.Sequential(
            conv3x3(96, 192),  nn.BatchNorm2d(192), act,
            conv3x3(192, 192), nn.BatchNorm2d(192), act
        )

        # optional light downsample to stabilize scales
        self.ds3 = nn.AvgPool2d(2, 2)  # 8 -> 4

        self.block4 = nn.Sequential(
            conv3x3(192, 256), nn.BatchNorm2d(256), act,
            conv3x3(256, 256), nn.BatchNorm2d(256), act
        )

        self.gap     = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1     = nn.Linear(256, 512, bias=False)  # BN follows
        self.fc_bn   = nn.BatchNorm1d(512)
        self.head_act= activation_fn
        self.dropout = nn.Dropout(0.3)
        self.fc2     = nn.Linear(512, 10)

        # init
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.ds3(x)
        x = self.block4(x)
        x = self.gap(x).flatten(1)
        x = self.fc1(x)
        x = self.fc_bn(x)
        x = self.head_act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ----------------------------------
# 4) Build model, loss, optimizer, sched
# ----------------------------------
poly4 = PolyAct4(
    A=-2.22734060e-03,
    B=3.70901207e-17,
    C=1.15754806e-01,
    D=5.00000000e-01,   # <-- corrected (0.5)
    E=7.00459235e-01,
    clamp_min=-6.0, clamp_max=6.0,
    init_scale=1.0, init_bias=0.0
)

model = CIFAR_CNN(activation_fn=poly4).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

base_lr        = 0.1
optimizer      = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9,
                           weight_decay=5e-4, nesterov=True)
num_epochs     = 350
warmup_epochs  = 5
warmup_sched   = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs)
cosine_sched   = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs, eta_min=1e-4)
scheduler      = SequentialLR(optimizer, schedulers=[warmup_sched, cosine_sched],
                              milestones=[warmup_epochs])

scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)

# ----------------------------------
# 5) MixUp utilities (set alpha=0.0 to disable)
# ----------------------------------
MIXUP_ALPHA = 0.2

def mixup_data(x, y, alpha=MIXUP_ALPHA):
    if alpha <= 0.0:
        return x, y, 1.0, None
    lam = np.random.beta(alpha, alpha)
    bs  = x.size(0)
    idx = torch.randperm(bs, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, (y_a, y_b), lam, idx

def mixup_criterion(crit, pred, ytuple_lam):
    y_a, y_b, lam = ytuple_lam
    return lam * crit(pred, y_a) + (1 - lam) * crit(pred, y_b)

# ----------------------------------
# 6) EMA (deepcopy to avoid constructor args)
# ----------------------------------
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.ema = copy.deepcopy(model).to(device)
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.decay = decay

    @torch.no_grad()
    def update(self, model):
        d = self.decay
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v.mul_(d).add_((1.0 - d) * msd[k])
            else:
                v.copy_(msd[k])

    def state_dict(self):
        return self.ema.state_dict()

ema = ModelEMA(model, decay=0.999)

# ----------------------------------
# 7) Eval helper (uses given model, e.g., ema.ema)
# ----------------------------------
@torch.no_grad()
def evaluate(eval_model):
    eval_model.eval()
    tot_loss, correct, total = 0.0, 0, 0
    for inputs, labels in testloader:
        inputs = inputs.to(device, non_blocking=pin_mem_flag)
        labels = labels.to(device, non_blocking=pin_mem_flag)
        with torch.cuda.amp.autocast(enabled=use_cuda):
            outputs = eval_model(inputs)
            loss = criterion(outputs, labels)
        tot_loss += loss.item()
        pred = outputs.argmax(dim=1)
        correct += (pred == labels).sum().item()
        total   += labels.size(0)
    return tot_loss / len(testloader), 100.0 * correct / total

# ----------------------------------
# 8) TRAIN
# ----------------------------------
for epoch in range(num_epochs):
    model.train()
    running = 0.0
    loop = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
    for x, y in loop:
        x = x.to(device, non_blocking=pin_mem_flag)
        y = y.to(device, non_blocking=pin_mem_flag)

        x, (ya, yb), lam, _ = mixup_data(x, y, alpha=MIXUP_ALPHA)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_cuda):
            out  = model(x)
            loss = mixup_criterion(criterion, out, (ya, yb, lam))
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        ema.update(model)

        running += loss.item()
        loop.set_postfix(train_loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

    scheduler.step()
    train_loss = running / len(trainloader)

    val_loss, val_acc = evaluate(ema.ema)   # eval EMA weights
    print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

# ----------------------------------
# 9) FINAL EVAL + SAVE (EMA)
# ----------------------------------
final_loss, final_acc = evaluate(ema.ema)
print(f"\nFinal Test Accuracy (EMA): {final_acc:.2f}%")

torch.save(ema.state_dict(), 'test_cifar10_trained_model_ema.pth')
print("Saved: test_cifar10_trained_model_ema.pth ✅")




Files already downloaded and verified
Files already downloaded and verified


  scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)


Epoch 1/350:   0%|          | 0/391 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=use_cuda):
  with torch.cuda.amp.autocast(enabled=use_cuda):


Epoch [1/350] | Train Loss: 2.1761 | Val Loss: 2.3675 | Val Acc: 10.00%


Epoch 2/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [2/350] | Train Loss: 1.8069 | Val Loss: 2.3036 | Val Acc: 10.00%


Epoch 3/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [3/350] | Train Loss: 1.6028 | Val Loss: 2.3051 | Val Acc: 11.13%


Epoch 4/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [4/350] | Train Loss: 1.5010 | Val Loss: 2.3048 | Val Acc: 10.00%


Epoch 5/350:   0%|          | 0/391 [00:00<?, ?it/s]



Epoch [5/350] | Train Loss: 1.4462 | Val Loss: 2.2841 | Val Acc: 10.02%


Epoch 6/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [6/350] | Train Loss: 1.4332 | Val Loss: 2.1638 | Val Acc: 24.13%


Epoch 7/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [7/350] | Train Loss: 1.4140 | Val Loss: 1.5374 | Val Acc: 70.75%


Epoch 8/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [8/350] | Train Loss: 1.3694 | Val Loss: 1.1615 | Val Acc: 73.69%


Epoch 9/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [9/350] | Train Loss: 1.3846 | Val Loss: 1.1385 | Val Acc: 72.35%


Epoch 10/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [10/350] | Train Loss: 1.3525 | Val Loss: 1.2092 | Val Acc: 71.13%


Epoch 11/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [11/350] | Train Loss: 1.3389 | Val Loss: 1.2066 | Val Acc: 71.67%


Epoch 12/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [12/350] | Train Loss: 1.3598 | Val Loss: 1.1953 | Val Acc: 72.56%


Epoch 13/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [13/350] | Train Loss: 1.3365 | Val Loss: 1.1743 | Val Acc: 74.14%


Epoch 14/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [14/350] | Train Loss: 1.3216 | Val Loss: 1.1303 | Val Acc: 77.67%


Epoch 15/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [15/350] | Train Loss: 1.2858 | Val Loss: 1.0820 | Val Acc: 78.81%


Epoch 16/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [16/350] | Train Loss: 1.2919 | Val Loss: 1.0082 | Val Acc: 80.41%


Epoch 17/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [17/350] | Train Loss: 1.3111 | Val Loss: 0.9688 | Val Acc: 80.62%


Epoch 18/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [18/350] | Train Loss: 1.2921 | Val Loss: 0.9538 | Val Acc: 81.46%


Epoch 19/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [19/350] | Train Loss: 1.2829 | Val Loss: 0.9505 | Val Acc: 81.58%


Epoch 20/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [20/350] | Train Loss: 1.2620 | Val Loss: 0.9541 | Val Acc: 81.78%


Epoch 21/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [21/350] | Train Loss: 1.2875 | Val Loss: 0.9489 | Val Acc: 82.01%


Epoch 22/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [22/350] | Train Loss: 1.3018 | Val Loss: 0.9499 | Val Acc: 82.08%


Epoch 23/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [23/350] | Train Loss: 1.2525 | Val Loss: 0.9482 | Val Acc: 82.69%


Epoch 24/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [24/350] | Train Loss: 1.2798 | Val Loss: 0.9473 | Val Acc: 82.81%


Epoch 25/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [25/350] | Train Loss: 1.2591 | Val Loss: 0.9617 | Val Acc: 82.48%


Epoch 26/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [26/350] | Train Loss: 1.2789 | Val Loss: 0.9847 | Val Acc: 82.49%


Epoch 27/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [27/350] | Train Loss: 1.2655 | Val Loss: 0.9626 | Val Acc: 82.72%


Epoch 28/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [28/350] | Train Loss: 1.2704 | Val Loss: 0.9748 | Val Acc: 82.39%


Epoch 29/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [29/350] | Train Loss: 1.2770 | Val Loss: 0.9230 | Val Acc: 83.12%


Epoch 30/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [30/350] | Train Loss: 1.2754 | Val Loss: 0.9226 | Val Acc: 83.01%


Epoch 31/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [31/350] | Train Loss: 1.2843 | Val Loss: 0.9583 | Val Acc: 83.03%


Epoch 32/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [32/350] | Train Loss: 1.2897 | Val Loss: 0.9479 | Val Acc: 82.90%


Epoch 33/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [33/350] | Train Loss: 1.2455 | Val Loss: 0.9289 | Val Acc: 83.07%


Epoch 34/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [34/350] | Train Loss: 1.2444 | Val Loss: 0.9477 | Val Acc: 83.18%


Epoch 35/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [35/350] | Train Loss: 1.2710 | Val Loss: 0.9610 | Val Acc: 83.12%


Epoch 36/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [36/350] | Train Loss: 1.2644 | Val Loss: 0.9055 | Val Acc: 83.55%


Epoch 37/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [37/350] | Train Loss: 1.2677 | Val Loss: 0.9258 | Val Acc: 83.57%


Epoch 38/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [38/350] | Train Loss: 1.2660 | Val Loss: 0.9364 | Val Acc: 83.49%


Epoch 39/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [39/350] | Train Loss: 1.2553 | Val Loss: 0.9277 | Val Acc: 83.29%


Epoch 40/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [40/350] | Train Loss: 1.2483 | Val Loss: 0.9258 | Val Acc: 83.78%


Epoch 41/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [41/350] | Train Loss: 1.2664 | Val Loss: 0.9120 | Val Acc: 83.44%


Epoch 42/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [42/350] | Train Loss: 1.2360 | Val Loss: 0.9345 | Val Acc: 83.67%


Epoch 43/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [43/350] | Train Loss: 1.2365 | Val Loss: 0.9346 | Val Acc: 83.39%


Epoch 44/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [44/350] | Train Loss: 1.2414 | Val Loss: 0.9323 | Val Acc: 83.80%


Epoch 45/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [45/350] | Train Loss: 1.2442 | Val Loss: 0.9507 | Val Acc: 83.64%


Epoch 46/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [46/350] | Train Loss: 1.2362 | Val Loss: 0.9232 | Val Acc: 84.06%


Epoch 47/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [47/350] | Train Loss: 1.2678 | Val Loss: 0.9035 | Val Acc: 84.05%


Epoch 48/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [48/350] | Train Loss: 1.2365 | Val Loss: 0.9005 | Val Acc: 84.00%


Epoch 49/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [49/350] | Train Loss: 1.2641 | Val Loss: 0.9291 | Val Acc: 83.87%


Epoch 50/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [50/350] | Train Loss: 1.2377 | Val Loss: 0.9183 | Val Acc: 84.01%


Epoch 51/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [51/350] | Train Loss: 1.2454 | Val Loss: 0.8939 | Val Acc: 84.27%


Epoch 52/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [52/350] | Train Loss: 1.2631 | Val Loss: 0.9126 | Val Acc: 83.96%


Epoch 53/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [53/350] | Train Loss: 1.2495 | Val Loss: 0.9085 | Val Acc: 84.25%


Epoch 54/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [54/350] | Train Loss: 1.2506 | Val Loss: 0.9153 | Val Acc: 84.34%


Epoch 55/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [55/350] | Train Loss: 1.2706 | Val Loss: 0.9236 | Val Acc: 84.08%


Epoch 56/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [56/350] | Train Loss: 1.2393 | Val Loss: 0.9235 | Val Acc: 84.13%


Epoch 57/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [57/350] | Train Loss: 1.2343 | Val Loss: 0.9345 | Val Acc: 83.89%


Epoch 58/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [58/350] | Train Loss: 1.2478 | Val Loss: 0.9150 | Val Acc: 83.97%


Epoch 59/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [59/350] | Train Loss: 1.2433 | Val Loss: 0.8985 | Val Acc: 84.01%


Epoch 60/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [60/350] | Train Loss: 1.1913 | Val Loss: 0.8914 | Val Acc: 84.31%


Epoch 61/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [61/350] | Train Loss: 1.2399 | Val Loss: 0.8929 | Val Acc: 84.31%


Epoch 62/350:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [62/350] | Train Loss: 1.2249 | Val Loss: 0.9100 | Val Acc: 84.45%


Epoch 63/350:   0%|          | 0/391 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
# CIFAR-10 | PreAct-ResNet-18 + Polynomial Softplus (deg-4)
# Option B: AdamW + cosine, aug schedule (light→strong→light), MixUp ramp→off, EMA, AMP, BN calibration

import os, math, copy, random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm

# =========================
# Config
# =========================
@dataclass
class CFG:
    epochs: int = 350
    batch_size: int = 128
    lr: float = 1e-3
    weight_decay: float = 3e-4
    momentum: float = 0.9
    amp: bool = torch.cuda.is_available()

    # polynomial activation clamp
    act_clip: float = 4.5  # tighter than 6.0 to reduce approx error

    # augmentation schedule
    phase1_epochs: int = 20      # LIGHT aug (no MixUp)
    phase3_ratio: float = 0.10   # last 10% epochs LIGHT aug again
    mixup_target_alpha: float = 0.2
    mixup_ramp_epochs: int = 5   # ramp 0 -> target at start of phase 2

    # EMA
    ema_decay: float = 0.999

    # saving
    save_std: str = "resnet_polyB_last.pth"
    save_ema: str = "resnet_polyB_ema.pth"

    seed: int = 42

cfg = CFG()

# =========================
# Repro & device
# =========================
def set_seed(s=42):
    random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(cfg.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pin_mem = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True

# =========================
# Data (light & strong pipelines)
# =========================
# LIGHT aug (start + end)
transform_train_light = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.2023,0.1994,0.2010)),
])

# STRONG aug (middle)
try:
    from torchvision.transforms import TrivialAugmentWide
    strong_policy = [TrivialAugmentWide()]
except Exception:
    strong_policy = []  # fallback cleanly if not available

transform_train_strong = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    *strong_policy,
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.2023,0.1994,0.2010)),
    transforms.RandomErasing(p=0.25),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),
                         (0.2023,0.1994,0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True,  download=True, transform=transform_train_light)
testset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=cfg.batch_size, shuffle=True,
    num_workers=2, pin_memory=pin_mem, persistent_workers=True
)
testloader  = torch.utils.data.DataLoader(
    testset, batch_size=cfg.batch_size, shuffle=False,
    num_workers=2, pin_memory=pin_mem, persistent_workers=True
)

# =========================
# Polynomial softplus approx (degree-4) — use your 91% coefficients
# =========================
class PolyAct4(nn.Module):
    """
    y = s * (A x^4 + B x^3 + C x^2 + D x + E) + b, with x clamped to [-clip, clip].
    A..E fixed; s,b learnable (can be folded for inference).
    """
    def __init__(self, A, B, C, D, E, clip=6.0, init_scale=1.0, init_bias=0.0):
        super().__init__()
        self.register_buffer('A', torch.tensor(float(A)))
        self.register_buffer('B', torch.tensor(float(B)))
        self.register_buffer('C', torch.tensor(float(C)))
        self.register_buffer('D', torch.tensor(float(D)))
        self.register_buffer('E', torch.tensor(float(E)))
        self.clip = float(clip)
        self.scale = nn.Parameter(torch.tensor(float(init_scale)))
        self.bias  = nn.Parameter(torch.tensor(float(init_bias)))

    def forward(self, x):
        x = torch.clamp(x, -self.clip, self.clip)
        x2 = x * x; x3 = x2 * x; x4 = x2 * x2
        y = self.A * x4 + self.B * x3 + self.C * x2 + self.D * x + self.E
        return self.scale * y + self.bias

def make_poly():
    # per-layer instance (important)
    return PolyAct4(
        A=-0.00068481,
        B=-1.59833239e-17,
        C=0.0887234775,
        D=0.5,
        E=0.738099333,
        clip=cfg.act_clip,
        init_scale=1.0, init_bias=0.0
    )

# =========================
# PreAct-ResNet-18 (BN -> Poly -> Conv), CIFAR stem
# =========================
class PreActBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1, make_act=make_poly):
        super().__init__()
        self.bn1  = nn.BatchNorm2d(in_planes)
        self.act1 = make_act()
        self.conv1= nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

        self.bn2  = nn.BatchNorm2d(planes)
        self.act2 = make_act()
        self.conv2= nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.downsample = None
        if stride != 1 or in_planes != planes*self.expansion:
            self.downsample = nn.Conv2d(in_planes, planes*self.expansion, kernel_size=1, stride=stride, bias=False)

    def forward(self, x):
        out = self.act1(self.bn1(x))
        sc_in = out
        out = self.conv1(out)
        out = self.conv2(self.act2(self.bn2(out)))
        shortcut = x if self.downsample is None else self.downsample(sc_in)
        return out + shortcut

class PreActResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10, make_act=make_poly):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

        self.layer1 = self._make_layer(block, 64,  layers[0], stride=1, make_act=make_act)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, make_act=make_act)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, make_act=make_act)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, make_act=make_act)

        self.bn_last  = nn.BatchNorm2d(512 * block.expansion)
        self.act_last = make_act()

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc      = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))

    def _make_layer(self, block, planes, blocks, stride, make_act):
        layers = [block(self.in_planes, planes, stride=stride, make_act=make_act)]
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes, stride=1, make_act=make_act))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.act_last(self.bn_last(x))
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def PreActResNet18(num_classes=10):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, make_act=make_poly)

model = PreActResNet18(num_classes=10).to(device)

# =========================
# Optimizer: AdamW with WD hygiene (no WD on BN/affine or PolyAct scale/bias)
# =========================
def build_param_groups(m):
    decay, no_decay = [], []
    for n, p in m.named_parameters():
        if not p.requires_grad:
            continue
        if p.ndim == 1 or n.endswith(".bias") or "bn" in n.lower() or n.endswith(".scale"):
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": cfg.weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]

optimizer = optim.AdamW(build_param_groups(model), lr=cfg.lr)

# Cosine LR over full training
scheduler = CosineAnnealingLR(optimizer, T_max=cfg.epochs, eta_min=1e-5)

# AMP (new API)
scaler = torch.amp.GradScaler('cuda', enabled=cfg.amp)

# EMA
class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.ema = copy.deepcopy(model).to(device).eval()
        for p in self.ema.parameters():
            p.requires_grad_(False)
        self.decay = decay
    @torch.no_grad()
    def update(self, model):
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v.copy_(v * self.decay + msd[k] * (1.0 - self.decay))
            else:
                v.copy_(msd[k])
    def state_dict(self):
        return self.ema.state_dict()

ema = ModelEMA(model, decay=cfg.ema_decay)

# =========================
# MixUp utilities
# =========================
def mixup_data(x, y, alpha):
    if alpha <= 0.0:
        return x, y, y, 1.0
    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    bs  = x.size(0)
    idx = torch.randperm(bs, device=x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(crit, pred, y_a, y_b, lam):
    return lam * crit(pred, y_a) + (1 - lam) * crit(pred, y_b)

criterion = nn.CrossEntropyLoss()  # no label smoothing when using MixUp

# =========================
# Eval helper (per-sample mean loss)
# =========================
@torch.no_grad()
def evaluate(eval_model):
    eval_model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for x, y in testloader:
        x = x.to(device, non_blocking=pin_mem)
        y = y.to(device, non_blocking=pin_mem)
        with torch.amp.autocast('cuda', enabled=cfg.amp):
            logits = eval_model(x)
            loss = criterion(logits, y)
        loss_sum += loss.item() * y.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return loss_sum / total, 100.0 * correct / total

# BN calibration (refresh running stats on EMA before final eval)
@torch.no_grad()
def calibrate_bn(model, loader, iters=200):
    model.train()
    n = 0
    for x, _ in loader:
        x = x.to(device, non_blocking=pin_mem)
        model(x)
        n += 1
        if n >= iters:
            break
    model.eval()

# =========================
# Training
# =========================
phase1 = cfg.phase1_epochs
phase3 = max(1, int(cfg.phase3_ratio * cfg.epochs))
phase2 = max(1, cfg.epochs - phase1 - phase3)  # middle

best_acc = 0.0

for epoch in range(1, cfg.epochs + 1):
    # --- Augmentation schedule ---
    if epoch <= phase1:
        trainset.transform = transform_train_light
        mix_alpha = 0.0
        phase_name = "LIGHT"
    elif epoch <= phase1 + phase2:
        trainset.transform = transform_train_strong
        # MixUp ramp during the first part of phase 2
        ramp_ep = min(cfg.mixup_ramp_epochs, phase2)
        e_in_phase2 = epoch - phase1
        if e_in_phase2 <= ramp_ep:
            mix_alpha = cfg.mixup_target_alpha * (e_in_phase2 / ramp_ep)
        else:
            mix_alpha = cfg.mixup_target_alpha
        phase_name = "STRONG"
    else:
        trainset.transform = transform_train_light
        mix_alpha = 0.0
        phase_name = "LIGHT (final)"

    model.train()
    running_loss = 0.0

    loop = tqdm(trainloader, desc=f"Epoch {epoch}/{cfg.epochs} [{phase_name}]",
                leave=False)
    for x, y in loop:
        x = x.to(device, non_blocking=pin_mem)
        y = y.to(device, non_blocking=pin_mem)

        if mix_alpha > 0.0:
            x, y_a, y_b, lam = mixup_data(x, y, mix_alpha)
        else:
            y_a = y_b = y; lam = 1.0

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=cfg.amp):
            logits = model(x)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)

        scaler.scale(loss).backward()
        # (No grad clipping here; remove cap for higher ceiling)
        scaler.step(optimizer)
        scaler.update()

        ema.update(model)

        running_loss += loss.item() * x.size(0)
        loop.set_postfix(loss=float(loss),
                         lr=optimizer.param_groups[0]['lr'],
                         mixup_alpha=round(mix_alpha, 3))

    scheduler.step()
    train_loss = running_loss / len(trainset)

    # Evaluate raw model briefly at the very start, then EMA
    eval_model = model if epoch <= 3 else ema.ema
    val_loss, val_acc = evaluate(eval_model)

    print(f"Epoch [{epoch}/{cfg.epochs}] | Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | MixUp α: {mix_alpha:.3f}")

    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), cfg.save_std)
        torch.save(ema.state_dict(), cfg.save_ema)

# =========================
# Final eval: calibrate BN on EMA, then evaluate
# =========================
if os.path.exists(cfg.save_ema):
    ema_state = torch.load(cfg.save_ema, map_location=device)
    ema.ema.load_state_dict(ema_state)
    calibrate_bn(ema.ema, trainloader, iters=200)
    loss_e, acc_e = evaluate(ema.ema)
    print(f"\nEMA Model (BN-calibrated) -> Acc: {acc_e:.2f}% | Loss: {loss_e:.4f}")
else:
    print("\nNo EMA checkpoint found.")

if os.path.exists(cfg.save_std):
    model.load_state_dict(torch.load(cfg.save_std, map_location=device))
    loss_m, acc_m = evaluate(model)
    print(f"Last-saved Model          -> Acc: {acc_m:.2f}% | Loss: {loss_m:.4f}")
else:
    print("No standard model checkpoint found.")

print(f"\nSaved paths:\n - {cfg.save_std}\n - {cfg.save_ema}")



Files already downloaded and verified
Files already downloaded and verified


Epoch 1/350 [LIGHT]:   0%|          | 0/391 [00:20<?, ?it/s]

Epoch [1/350] | Train Loss: 1.6640 | Val Loss: 1.7936 | Val Acc: 36.33% | MixUp α: 0.000


Epoch 2/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [2/350] | Train Loss: 1.3105 | Val Loss: 1.2515 | Val Acc: 54.42% | MixUp α: 0.000


Epoch 3/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [3/350] | Train Loss: 1.0903 | Val Loss: 1.9050 | Val Acc: 42.44% | MixUp α: 0.000


Epoch 4/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [4/350] | Train Loss: 0.9496 | Val Loss: 2.4728 | Val Acc: 9.99% | MixUp α: 0.000


Epoch 5/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [5/350] | Train Loss: 0.8335 | Val Loss: 2.4952 | Val Acc: 10.59% | MixUp α: 0.000


Epoch 6/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [6/350] | Train Loss: 0.7507 | Val Loss: 2.3532 | Val Acc: 14.33% | MixUp α: 0.000


Epoch 7/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [7/350] | Train Loss: 0.6842 | Val Loss: 1.9519 | Val Acc: 29.08% | MixUp α: 0.000


Epoch 8/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [8/350] | Train Loss: 0.6363 | Val Loss: 1.4485 | Val Acc: 50.07% | MixUp α: 0.000


Epoch 9/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [9/350] | Train Loss: 0.5903 | Val Loss: 1.0347 | Val Acc: 65.72% | MixUp α: 0.000


Epoch 10/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [10/350] | Train Loss: 0.5512 | Val Loss: 0.7697 | Val Acc: 75.29% | MixUp α: 0.000


Epoch 11/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [11/350] | Train Loss: 0.5200 | Val Loss: 0.6153 | Val Acc: 80.07% | MixUp α: 0.000


Epoch 12/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [12/350] | Train Loss: 0.4860 | Val Loss: 0.5278 | Val Acc: 82.72% | MixUp α: 0.000


Epoch 13/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [13/350] | Train Loss: 0.4617 | Val Loss: 0.4722 | Val Acc: 84.38% | MixUp α: 0.000


Epoch 14/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [14/350] | Train Loss: 0.4326 | Val Loss: 0.4323 | Val Acc: 85.61% | MixUp α: 0.000


Epoch 15/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [15/350] | Train Loss: 0.4199 | Val Loss: 0.4049 | Val Acc: 86.32% | MixUp α: 0.000


Epoch 16/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [16/350] | Train Loss: 0.3955 | Val Loss: 0.3827 | Val Acc: 87.24% | MixUp α: 0.000


Epoch 17/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [17/350] | Train Loss: 0.3792 | Val Loss: 0.3648 | Val Acc: 87.87% | MixUp α: 0.000


Epoch 18/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [18/350] | Train Loss: 0.3584 | Val Loss: 0.3505 | Val Acc: 88.18% | MixUp α: 0.000


Epoch 19/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [19/350] | Train Loss: 0.3406 | Val Loss: 0.3381 | Val Acc: 88.66% | MixUp α: 0.000


Epoch 20/350 [LIGHT]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [20/350] | Train Loss: 0.3270 | Val Loss: 0.3272 | Val Acc: 88.99% | MixUp α: 0.000


Epoch 21/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [21/350] | Train Loss: 0.4548 | Val Loss: 0.3151 | Val Acc: 89.15% | MixUp α: 0.040


Epoch 22/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [22/350] | Train Loss: 0.5502 | Val Loss: 0.3051 | Val Acc: 89.65% | MixUp α: 0.080


Epoch 23/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [23/350] | Train Loss: 0.5841 | Val Loss: 0.2993 | Val Acc: 89.87% | MixUp α: 0.120


Epoch 24/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [24/350] | Train Loss: 0.6639 | Val Loss: 0.2959 | Val Acc: 90.28% | MixUp α: 0.160


Epoch 25/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [25/350] | Train Loss: 0.7547 | Val Loss: 0.2963 | Val Acc: 90.45% | MixUp α: 0.200


Epoch 26/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [26/350] | Train Loss: 0.7229 | Val Loss: 0.2952 | Val Acc: 90.69% | MixUp α: 0.200


Epoch 27/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [27/350] | Train Loss: 0.6712 | Val Loss: 0.2949 | Val Acc: 90.73% | MixUp α: 0.200


Epoch 28/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [28/350] | Train Loss: 0.6578 | Val Loss: 0.2936 | Val Acc: 90.96% | MixUp α: 0.200


Epoch 29/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [29/350] | Train Loss: 0.6615 | Val Loss: 0.2945 | Val Acc: 91.04% | MixUp α: 0.200


Epoch 30/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [30/350] | Train Loss: 0.6790 | Val Loss: 0.2948 | Val Acc: 91.13% | MixUp α: 0.200


Epoch 31/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [31/350] | Train Loss: 0.6636 | Val Loss: 0.2926 | Val Acc: 91.25% | MixUp α: 0.200


Epoch 32/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [32/350] | Train Loss: 0.5965 | Val Loss: 0.2857 | Val Acc: 91.63% | MixUp α: 0.200


Epoch 33/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [33/350] | Train Loss: 0.5840 | Val Loss: 0.2790 | Val Acc: 91.71% | MixUp α: 0.200


Epoch 34/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [34/350] | Train Loss: 0.5963 | Val Loss: 0.2748 | Val Acc: 91.80% | MixUp α: 0.200


Epoch 35/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [35/350] | Train Loss: 0.6040 | Val Loss: 0.2723 | Val Acc: 92.00% | MixUp α: 0.200


Epoch 36/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [36/350] | Train Loss: 0.5945 | Val Loss: 0.2703 | Val Acc: 92.12% | MixUp α: 0.200


Epoch 37/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [37/350] | Train Loss: 0.6084 | Val Loss: 0.2688 | Val Acc: 92.37% | MixUp α: 0.200


Epoch 38/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [38/350] | Train Loss: 0.5385 | Val Loss: 0.2625 | Val Acc: 92.55% | MixUp α: 0.200


Epoch 39/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [39/350] | Train Loss: 0.5648 | Val Loss: 0.2595 | Val Acc: 92.47% | MixUp α: 0.200


Epoch 40/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [40/350] | Train Loss: 0.5867 | Val Loss: 0.2589 | Val Acc: 92.60% | MixUp α: 0.200


Epoch 41/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [41/350] | Train Loss: 0.5622 | Val Loss: 0.2562 | Val Acc: 92.65% | MixUp α: 0.200


Epoch 42/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [42/350] | Train Loss: 0.5320 | Val Loss: 0.2526 | Val Acc: 92.76% | MixUp α: 0.200


Epoch 43/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [43/350] | Train Loss: 0.5517 | Val Loss: 0.2526 | Val Acc: 92.87% | MixUp α: 0.200


Epoch 44/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [44/350] | Train Loss: 0.5115 | Val Loss: 0.2492 | Val Acc: 92.91% | MixUp α: 0.200


Epoch 45/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [45/350] | Train Loss: 0.5196 | Val Loss: 0.2472 | Val Acc: 92.96% | MixUp α: 0.200


Epoch 46/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [46/350] | Train Loss: 0.5254 | Val Loss: 0.2459 | Val Acc: 92.98% | MixUp α: 0.200


Epoch 47/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [47/350] | Train Loss: 0.5635 | Val Loss: 0.2467 | Val Acc: 93.01% | MixUp α: 0.200


Epoch 48/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [48/350] | Train Loss: 0.5083 | Val Loss: 0.2444 | Val Acc: 93.09% | MixUp α: 0.200


Epoch 49/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [49/350] | Train Loss: 0.5553 | Val Loss: 0.2455 | Val Acc: 93.24% | MixUp α: 0.200


Epoch 50/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [50/350] | Train Loss: 0.5224 | Val Loss: 0.2451 | Val Acc: 93.39% | MixUp α: 0.200


Epoch 51/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [51/350] | Train Loss: 0.4862 | Val Loss: 0.2423 | Val Acc: 93.38% | MixUp α: 0.200


Epoch 52/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [52/350] | Train Loss: 0.4957 | Val Loss: 0.2397 | Val Acc: 93.39% | MixUp α: 0.200


Epoch 53/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [53/350] | Train Loss: 0.5184 | Val Loss: 0.2397 | Val Acc: 93.30% | MixUp α: 0.200


Epoch 54/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [54/350] | Train Loss: 0.5223 | Val Loss: 0.2411 | Val Acc: 93.27% | MixUp α: 0.200


Epoch 55/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [55/350] | Train Loss: 0.5000 | Val Loss: 0.2405 | Val Acc: 93.36% | MixUp α: 0.200


Epoch 56/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [56/350] | Train Loss: 0.5245 | Val Loss: 0.2416 | Val Acc: 93.40% | MixUp α: 0.200


Epoch 57/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [57/350] | Train Loss: 0.4811 | Val Loss: 0.2392 | Val Acc: 93.45% | MixUp α: 0.200


Epoch 58/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [58/350] | Train Loss: 0.5382 | Val Loss: 0.2407 | Val Acc: 93.51% | MixUp α: 0.200


Epoch 59/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [59/350] | Train Loss: 0.4872 | Val Loss: 0.2406 | Val Acc: 93.49% | MixUp α: 0.200


Epoch 60/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [60/350] | Train Loss: 0.4853 | Val Loss: 0.2396 | Val Acc: 93.59% | MixUp α: 0.200


Epoch 61/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [61/350] | Train Loss: 0.4798 | Val Loss: 0.2383 | Val Acc: 93.71% | MixUp α: 0.200


Epoch 62/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [62/350] | Train Loss: 0.4797 | Val Loss: 0.2375 | Val Acc: 93.71% | MixUp α: 0.200


Epoch 63/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [63/350] | Train Loss: 0.4644 | Val Loss: 0.2366 | Val Acc: 93.74% | MixUp α: 0.200


Epoch 64/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [64/350] | Train Loss: 0.4770 | Val Loss: 0.2364 | Val Acc: 93.66% | MixUp α: 0.200


Epoch 65/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [65/350] | Train Loss: 0.5153 | Val Loss: 0.2397 | Val Acc: 93.57% | MixUp α: 0.200


Epoch 66/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [66/350] | Train Loss: 0.4343 | Val Loss: 0.2365 | Val Acc: 93.59% | MixUp α: 0.200


Epoch 67/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [67/350] | Train Loss: 0.4973 | Val Loss: 0.2378 | Val Acc: 93.68% | MixUp α: 0.200


Epoch 68/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [68/350] | Train Loss: 0.4327 | Val Loss: 0.2352 | Val Acc: 93.71% | MixUp α: 0.200


Epoch 69/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [69/350] | Train Loss: 0.4661 | Val Loss: 0.2362 | Val Acc: 93.69% | MixUp α: 0.200


Epoch 70/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [70/350] | Train Loss: 0.4642 | Val Loss: 0.2360 | Val Acc: 93.77% | MixUp α: 0.200


Epoch 71/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [71/350] | Train Loss: 0.4357 | Val Loss: 0.2342 | Val Acc: 93.81% | MixUp α: 0.200


Epoch 72/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [72/350] | Train Loss: 0.4798 | Val Loss: 0.2357 | Val Acc: 93.84% | MixUp α: 0.200


Epoch 73/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [73/350] | Train Loss: 0.4843 | Val Loss: 0.2376 | Val Acc: 93.87% | MixUp α: 0.200


Epoch 74/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [74/350] | Train Loss: 0.4200 | Val Loss: 0.2336 | Val Acc: 93.90% | MixUp α: 0.200


Epoch 75/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [75/350] | Train Loss: 0.4891 | Val Loss: 0.2357 | Val Acc: 93.87% | MixUp α: 0.200


Epoch 76/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [76/350] | Train Loss: 0.4375 | Val Loss: 0.2345 | Val Acc: 93.78% | MixUp α: 0.200


Epoch 77/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [77/350] | Train Loss: 0.4041 | Val Loss: 0.2318 | Val Acc: 93.69% | MixUp α: 0.200


Epoch 78/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [78/350] | Train Loss: 0.4130 | Val Loss: 0.2302 | Val Acc: 93.74% | MixUp α: 0.200


Epoch 79/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [79/350] | Train Loss: 0.4393 | Val Loss: 0.2296 | Val Acc: 93.76% | MixUp α: 0.200


Epoch 80/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [80/350] | Train Loss: 0.4432 | Val Loss: 0.2312 | Val Acc: 93.76% | MixUp α: 0.200


Epoch 81/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [81/350] | Train Loss: 0.4667 | Val Loss: 0.2334 | Val Acc: 93.86% | MixUp α: 0.200


Epoch 82/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [82/350] | Train Loss: 0.4675 | Val Loss: 0.2334 | Val Acc: 93.89% | MixUp α: 0.200


Epoch 83/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [83/350] | Train Loss: 0.4951 | Val Loss: 0.2362 | Val Acc: 93.90% | MixUp α: 0.200


Epoch 84/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [84/350] | Train Loss: 0.4173 | Val Loss: 0.2341 | Val Acc: 93.96% | MixUp α: 0.200


Epoch 85/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [85/350] | Train Loss: 0.4451 | Val Loss: 0.2343 | Val Acc: 94.00% | MixUp α: 0.200


Epoch 86/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [86/350] | Train Loss: 0.4658 | Val Loss: 0.2354 | Val Acc: 94.01% | MixUp α: 0.200


Epoch 87/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [87/350] | Train Loss: 0.4057 | Val Loss: 0.2334 | Val Acc: 94.07% | MixUp α: 0.200


Epoch 88/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [88/350] | Train Loss: 0.4611 | Val Loss: 0.2342 | Val Acc: 94.04% | MixUp α: 0.200


Epoch 89/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [89/350] | Train Loss: 0.4437 | Val Loss: 0.2345 | Val Acc: 94.07% | MixUp α: 0.200


Epoch 90/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [90/350] | Train Loss: 0.4417 | Val Loss: 0.2342 | Val Acc: 94.14% | MixUp α: 0.200


Epoch 91/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [91/350] | Train Loss: 0.4104 | Val Loss: 0.2324 | Val Acc: 94.06% | MixUp α: 0.200


Epoch 92/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [92/350] | Train Loss: 0.4586 | Val Loss: 0.2346 | Val Acc: 94.08% | MixUp α: 0.200


Epoch 93/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [93/350] | Train Loss: 0.3945 | Val Loss: 0.2326 | Val Acc: 94.19% | MixUp α: 0.200


Epoch 94/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [94/350] | Train Loss: 0.4478 | Val Loss: 0.2317 | Val Acc: 94.09% | MixUp α: 0.200


Epoch 95/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [95/350] | Train Loss: 0.4423 | Val Loss: 0.2322 | Val Acc: 94.17% | MixUp α: 0.200


Epoch 96/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [96/350] | Train Loss: 0.4024 | Val Loss: 0.2313 | Val Acc: 94.14% | MixUp α: 0.200


Epoch 97/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [97/350] | Train Loss: 0.4694 | Val Loss: 0.2344 | Val Acc: 94.04% | MixUp α: 0.200


Epoch 98/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [98/350] | Train Loss: 0.4505 | Val Loss: 0.2360 | Val Acc: 94.07% | MixUp α: 0.200


Epoch 99/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [99/350] | Train Loss: 0.4165 | Val Loss: 0.2342 | Val Acc: 94.11% | MixUp α: 0.200


Epoch 100/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [100/350] | Train Loss: 0.4268 | Val Loss: 0.2331 | Val Acc: 94.08% | MixUp α: 0.200


Epoch 101/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [101/350] | Train Loss: 0.4468 | Val Loss: 0.2336 | Val Acc: 94.18% | MixUp α: 0.200


Epoch 102/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [102/350] | Train Loss: 0.4318 | Val Loss: 0.2334 | Val Acc: 94.19% | MixUp α: 0.200


Epoch 103/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [103/350] | Train Loss: 0.4336 | Val Loss: 0.2335 | Val Acc: 94.18% | MixUp α: 0.200


Epoch 104/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [104/350] | Train Loss: 0.4100 | Val Loss: 0.2329 | Val Acc: 94.20% | MixUp α: 0.200


Epoch 105/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [105/350] | Train Loss: 0.4012 | Val Loss: 0.2319 | Val Acc: 94.25% | MixUp α: 0.200


Epoch 106/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [106/350] | Train Loss: 0.4114 | Val Loss: 0.2316 | Val Acc: 94.23% | MixUp α: 0.200


Epoch 107/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [107/350] | Train Loss: 0.4078 | Val Loss: 0.2301 | Val Acc: 94.25% | MixUp α: 0.200


Epoch 108/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [108/350] | Train Loss: 0.3679 | Val Loss: 0.2276 | Val Acc: 94.28% | MixUp α: 0.200


Epoch 109/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [109/350] | Train Loss: 0.4220 | Val Loss: 0.2289 | Val Acc: 94.34% | MixUp α: 0.200


Epoch 110/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [110/350] | Train Loss: 0.3772 | Val Loss: 0.2261 | Val Acc: 94.33% | MixUp α: 0.200


Epoch 111/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [111/350] | Train Loss: 0.4169 | Val Loss: 0.2263 | Val Acc: 94.25% | MixUp α: 0.200


Epoch 112/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [112/350] | Train Loss: 0.4169 | Val Loss: 0.2278 | Val Acc: 94.23% | MixUp α: 0.200


Epoch 113/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [113/350] | Train Loss: 0.3941 | Val Loss: 0.2270 | Val Acc: 94.29% | MixUp α: 0.200


Epoch 114/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [114/350] | Train Loss: 0.4371 | Val Loss: 0.2286 | Val Acc: 94.43% | MixUp α: 0.200


Epoch 115/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [115/350] | Train Loss: 0.3791 | Val Loss: 0.2266 | Val Acc: 94.33% | MixUp α: 0.200


Epoch 116/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [116/350] | Train Loss: 0.4328 | Val Loss: 0.2275 | Val Acc: 94.31% | MixUp α: 0.200


Epoch 117/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [117/350] | Train Loss: 0.3949 | Val Loss: 0.2271 | Val Acc: 94.38% | MixUp α: 0.200


Epoch 118/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [118/350] | Train Loss: 0.3905 | Val Loss: 0.2264 | Val Acc: 94.42% | MixUp α: 0.200


Epoch 119/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [119/350] | Train Loss: 0.4290 | Val Loss: 0.2263 | Val Acc: 94.35% | MixUp α: 0.200


Epoch 120/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [120/350] | Train Loss: 0.3861 | Val Loss: 0.2247 | Val Acc: 94.32% | MixUp α: 0.200


Epoch 121/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [121/350] | Train Loss: 0.4148 | Val Loss: 0.2252 | Val Acc: 94.39% | MixUp α: 0.200


Epoch 122/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [122/350] | Train Loss: 0.4231 | Val Loss: 0.2262 | Val Acc: 94.39% | MixUp α: 0.200


Epoch 123/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [123/350] | Train Loss: 0.3807 | Val Loss: 0.2250 | Val Acc: 94.43% | MixUp α: 0.200


Epoch 124/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [124/350] | Train Loss: 0.3891 | Val Loss: 0.2250 | Val Acc: 94.50% | MixUp α: 0.200


Epoch 125/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [125/350] | Train Loss: 0.4076 | Val Loss: 0.2253 | Val Acc: 94.35% | MixUp α: 0.200


Epoch 126/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [126/350] | Train Loss: 0.4090 | Val Loss: 0.2258 | Val Acc: 94.37% | MixUp α: 0.200


Epoch 127/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [127/350] | Train Loss: 0.4343 | Val Loss: 0.2267 | Val Acc: 94.45% | MixUp α: 0.200


Epoch 128/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [128/350] | Train Loss: 0.3864 | Val Loss: 0.2257 | Val Acc: 94.38% | MixUp α: 0.200


Epoch 129/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [129/350] | Train Loss: 0.4213 | Val Loss: 0.2265 | Val Acc: 94.34% | MixUp α: 0.200


Epoch 130/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [130/350] | Train Loss: 0.4481 | Val Loss: 0.2286 | Val Acc: 94.48% | MixUp α: 0.200


Epoch 131/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [131/350] | Train Loss: 0.3804 | Val Loss: 0.2265 | Val Acc: 94.48% | MixUp α: 0.200


Epoch 132/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [132/350] | Train Loss: 0.4015 | Val Loss: 0.2260 | Val Acc: 94.44% | MixUp α: 0.200


Epoch 133/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [133/350] | Train Loss: 0.4275 | Val Loss: 0.2269 | Val Acc: 94.45% | MixUp α: 0.200


Epoch 134/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [134/350] | Train Loss: 0.4022 | Val Loss: 0.2269 | Val Acc: 94.43% | MixUp α: 0.200


Epoch 135/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [135/350] | Train Loss: 0.4581 | Val Loss: 0.2275 | Val Acc: 94.41% | MixUp α: 0.200


Epoch 136/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [136/350] | Train Loss: 0.3994 | Val Loss: 0.2266 | Val Acc: 94.30% | MixUp α: 0.200


Epoch 137/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [137/350] | Train Loss: 0.3927 | Val Loss: 0.2272 | Val Acc: 94.28% | MixUp α: 0.200


Epoch 138/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [138/350] | Train Loss: 0.4009 | Val Loss: 0.2274 | Val Acc: 94.26% | MixUp α: 0.200


Epoch 139/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [139/350] | Train Loss: 0.3896 | Val Loss: 0.2263 | Val Acc: 94.30% | MixUp α: 0.200


Epoch 140/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [140/350] | Train Loss: 0.4326 | Val Loss: 0.2271 | Val Acc: 94.34% | MixUp α: 0.200


Epoch 141/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [141/350] | Train Loss: 0.3955 | Val Loss: 0.2275 | Val Acc: 94.21% | MixUp α: 0.200


Epoch 142/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [142/350] | Train Loss: 0.4258 | Val Loss: 0.2280 | Val Acc: 94.18% | MixUp α: 0.200


Epoch 143/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [143/350] | Train Loss: 0.3746 | Val Loss: 0.2265 | Val Acc: 94.25% | MixUp α: 0.200


Epoch 144/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [144/350] | Train Loss: 0.3779 | Val Loss: 0.2245 | Val Acc: 94.38% | MixUp α: 0.200


Epoch 145/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [145/350] | Train Loss: 0.3815 | Val Loss: 0.2235 | Val Acc: 94.34% | MixUp α: 0.200


Epoch 146/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [146/350] | Train Loss: 0.3753 | Val Loss: 0.2233 | Val Acc: 94.40% | MixUp α: 0.200


Epoch 147/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [147/350] | Train Loss: 0.4093 | Val Loss: 0.2248 | Val Acc: 94.39% | MixUp α: 0.200


Epoch 148/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [148/350] | Train Loss: 0.3677 | Val Loss: 0.2236 | Val Acc: 94.39% | MixUp α: 0.200


Epoch 149/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [149/350] | Train Loss: 0.4019 | Val Loss: 0.2236 | Val Acc: 94.45% | MixUp α: 0.200


Epoch 150/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [150/350] | Train Loss: 0.4068 | Val Loss: 0.2236 | Val Acc: 94.49% | MixUp α: 0.200


Epoch 151/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [151/350] | Train Loss: 0.4367 | Val Loss: 0.2247 | Val Acc: 94.42% | MixUp α: 0.200


Epoch 152/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [152/350] | Train Loss: 0.4480 | Val Loss: 0.2261 | Val Acc: 94.48% | MixUp α: 0.200


Epoch 153/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [153/350] | Train Loss: 0.4024 | Val Loss: 0.2238 | Val Acc: 94.50% | MixUp α: 0.200


Epoch 154/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [154/350] | Train Loss: 0.3486 | Val Loss: 0.2211 | Val Acc: 94.57% | MixUp α: 0.200


Epoch 155/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [155/350] | Train Loss: 0.3834 | Val Loss: 0.2205 | Val Acc: 94.57% | MixUp α: 0.200


Epoch 156/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [156/350] | Train Loss: 0.4001 | Val Loss: 0.2205 | Val Acc: 94.46% | MixUp α: 0.200


Epoch 157/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [157/350] | Train Loss: 0.3960 | Val Loss: 0.2204 | Val Acc: 94.42% | MixUp α: 0.200


Epoch 158/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [158/350] | Train Loss: 0.4142 | Val Loss: 0.2200 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 159/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [159/350] | Train Loss: 0.3813 | Val Loss: 0.2179 | Val Acc: 94.62% | MixUp α: 0.200


Epoch 160/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [160/350] | Train Loss: 0.4194 | Val Loss: 0.2197 | Val Acc: 94.67% | MixUp α: 0.200


Epoch 161/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [161/350] | Train Loss: 0.3718 | Val Loss: 0.2185 | Val Acc: 94.63% | MixUp α: 0.200


Epoch 162/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [162/350] | Train Loss: 0.4178 | Val Loss: 0.2202 | Val Acc: 94.58% | MixUp α: 0.200


Epoch 163/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [163/350] | Train Loss: 0.3909 | Val Loss: 0.2204 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 164/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [164/350] | Train Loss: 0.3429 | Val Loss: 0.2195 | Val Acc: 94.60% | MixUp α: 0.200


Epoch 165/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [165/350] | Train Loss: 0.3669 | Val Loss: 0.2187 | Val Acc: 94.64% | MixUp α: 0.200


Epoch 166/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [166/350] | Train Loss: 0.4081 | Val Loss: 0.2207 | Val Acc: 94.53% | MixUp α: 0.200


Epoch 167/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [167/350] | Train Loss: 0.3843 | Val Loss: 0.2207 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 168/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [168/350] | Train Loss: 0.3140 | Val Loss: 0.2179 | Val Acc: 94.46% | MixUp α: 0.200


Epoch 169/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [169/350] | Train Loss: 0.3938 | Val Loss: 0.2204 | Val Acc: 94.46% | MixUp α: 0.200


Epoch 170/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [170/350] | Train Loss: 0.3602 | Val Loss: 0.2200 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 171/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [171/350] | Train Loss: 0.3509 | Val Loss: 0.2188 | Val Acc: 94.47% | MixUp α: 0.200


Epoch 172/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [172/350] | Train Loss: 0.3961 | Val Loss: 0.2184 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 173/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [173/350] | Train Loss: 0.3845 | Val Loss: 0.2192 | Val Acc: 94.55% | MixUp α: 0.200


Epoch 174/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [174/350] | Train Loss: 0.3797 | Val Loss: 0.2192 | Val Acc: 94.45% | MixUp α: 0.200


Epoch 175/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [175/350] | Train Loss: 0.3958 | Val Loss: 0.2197 | Val Acc: 94.42% | MixUp α: 0.200


Epoch 176/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [176/350] | Train Loss: 0.3681 | Val Loss: 0.2191 | Val Acc: 94.45% | MixUp α: 0.200


Epoch 177/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [177/350] | Train Loss: 0.3537 | Val Loss: 0.2175 | Val Acc: 94.39% | MixUp α: 0.200


Epoch 178/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [178/350] | Train Loss: 0.3837 | Val Loss: 0.2174 | Val Acc: 94.44% | MixUp α: 0.200


Epoch 179/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [179/350] | Train Loss: 0.3874 | Val Loss: 0.2177 | Val Acc: 94.47% | MixUp α: 0.200


Epoch 180/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [180/350] | Train Loss: 0.3763 | Val Loss: 0.2167 | Val Acc: 94.59% | MixUp α: 0.200


Epoch 181/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [181/350] | Train Loss: 0.3742 | Val Loss: 0.2178 | Val Acc: 94.56% | MixUp α: 0.200


Epoch 182/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [182/350] | Train Loss: 0.3867 | Val Loss: 0.2182 | Val Acc: 94.57% | MixUp α: 0.200


Epoch 183/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [183/350] | Train Loss: 0.3877 | Val Loss: 0.2185 | Val Acc: 94.50% | MixUp α: 0.200


Epoch 184/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [184/350] | Train Loss: 0.3875 | Val Loss: 0.2178 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 185/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [185/350] | Train Loss: 0.3907 | Val Loss: 0.2180 | Val Acc: 94.48% | MixUp α: 0.200


Epoch 186/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [186/350] | Train Loss: 0.3728 | Val Loss: 0.2165 | Val Acc: 94.57% | MixUp α: 0.200


Epoch 187/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [187/350] | Train Loss: 0.3756 | Val Loss: 0.2182 | Val Acc: 94.47% | MixUp α: 0.200


Epoch 188/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [188/350] | Train Loss: 0.4028 | Val Loss: 0.2188 | Val Acc: 94.58% | MixUp α: 0.200


Epoch 189/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [189/350] | Train Loss: 0.3864 | Val Loss: 0.2181 | Val Acc: 94.60% | MixUp α: 0.200


Epoch 190/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [190/350] | Train Loss: 0.3847 | Val Loss: 0.2174 | Val Acc: 94.64% | MixUp α: 0.200


Epoch 191/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [191/350] | Train Loss: 0.3468 | Val Loss: 0.2164 | Val Acc: 94.66% | MixUp α: 0.200


Epoch 192/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch [192/350] | Train Loss: 0.4004 | Val Loss: 0.2187 | Val Acc: 94.54% | MixUp α: 0.200


Epoch 193/350 [STRONG]:   0%|          | 0/391 [00:00<?, ?it/s]

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x000002DC0EEA69E0>>
Traceback (most recent call last):
  File "C:\Users\zelkh\research\lib\site-packages\ipykernel\ipkernel.py", line 781, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


KeyboardInterrupt: 

In [29]:
import os
import csv

def save_weights_to_csv(model, directory="my_model_weights"):
    # Create the directory if it does not exist
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    # Iterate over each parameter in the model
    for name, param in model.named_parameters():
        # Convert the parameter tensor to a flattened numpy array
        weights = param.data.cpu().numpy().flatten()
        
        # Replace dots in the parameter name with underscores and form the full file path
        filename = f"{directory}/{name.replace('.', '_')}.csv"
        
        # Open the CSV file and write the weights
        with open(filename, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(weights)  # Write weights as a single row
        
        print(f"Weights saved to {filename}")
    
    print("Weight extraction complete!")

# Call this function with your model variable
save_weights_to_csv(model)

Weights saved to my_model_weights/fc1_weight.csv
Weights saved to my_model_weights/fc1_bias.csv
Weights saved to my_model_weights/fc2_weight.csv
Weights saved to my_model_weights/fc2_bias.csv
Weight extraction complete!


In [36]:
import torch
import os
import csv

# Create directory for weights if it doesn't exist
os.makedirs('my_model_weights', exist_ok=True)

# Function to save weights to CSV
def save_to_csv(tensor, filename):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        # Flatten the tensor and write as a single row
        writer.writerow(tensor.flatten().numpy())

# Load the trained model
activation_fn = PolyAct4()
model = CIFAR_CNN(activation_fn=activation_fn)
model.load_state_dict(torch.load('cifar10_trained_model.pth'))

# Method to get layer by name
def get_layer_by_name(model, name):
    modules = name.split('.')
    m = model
    for mod in modules:
        m = getattr(m, mod)
    return m

# Extract and save layer weights
block_layers = [
    'block1.0', 'block1.1', 'block1.3', 'block1.4',
    'block2.0', 'block2.1', 'block2.3', 'block2.4',
    'block3.0', 'block3.1', 'block3.3', 'block3.4',
    'block4.0', 'block4.1', 'block4.3', 'block4.4'
]

for layer_name in block_layers:
    try:
        layer = get_layer_by_name(model, layer_name)
        if isinstance(layer, (nn.Conv2d, nn.BatchNorm2d)):
            if hasattr(layer, 'weight'):
                save_to_csv(layer.weight.data, f'my_model_weights/{layer_name.replace(".", "_")}_weight.csv')
            if hasattr(layer, 'bias'):
                save_to_csv(layer.bias.data, f'my_model_weights/{layer_name.replace(".", "_")}_bias.csv')
            if isinstance(layer, nn.BatchNorm2d):
                save_to_csv(layer.running_mean, f'my_model_weights/{layer_name.replace(".", "_")}_running_mean.csv')
                save_to_csv(layer.running_var, f'my_model_weights/{layer_name.replace(".", "_")}_running_var.csv')
    except Exception as e:
        print(f"Error processing {layer_name}: {e}")

# Extract and save FC layer weights and BatchNorm
fc_layers = ['fc1', 'fc_bn', 'fc2']
for layer_name in fc_layers:
    try:
        layer = get_layer_by_name(model, layer_name)
        if isinstance(layer, nn.Linear):
            save_to_csv(layer.weight.data, f'my_model_weights/{layer_name}_weight.csv')
            save_to_csv(layer.bias.data, f'my_model_weights/{layer_name}_bias.csv')
        elif isinstance(layer, nn.BatchNorm1d):
            save_to_csv(layer.weight.data, f'my_model_weights/{layer_name}_weight.csv')
            save_to_csv(layer.bias.data, f'my_model_weights/{layer_name}_bias.csv')
            save_to_csv(layer.running_mean, f'my_model_weights/{layer_name}_running_mean.csv')
            save_to_csv(layer.running_var, f'my_model_weights/{layer_name}_running_var.csv')
    except Exception as e:
        print(f"Error processing {layer_name}: {e}")

print("Weight extraction complete!")

  model.load_state_dict(torch.load('cifar10_trained_model.pth'))


Weight extraction complete!


In [37]:
import os
import csv
import numpy as np

def preview_csv(filename, max_elements=10):
    with open(filename, 'r') as csvfile:
        reader = csv.reader(csvfile)
        data = list(reader)[0]  # First (and only) row
        data = [float(x) for x in data]
        
        print(f"File: {filename}")
        print(f"Total elements: {len(data)}")
        print("First few elements:", data[:max_elements])
        print("Shape:", len(data))
        print("Min:", min(data))
        print("Max:", max(data))
        print("Mean:", np.mean(data))
        print("Std Dev:", np.std(data))
        print("-" * 50)

# Directory containing the weights
weight_dir = 'my_model_weights'

# List all CSV files
csv_files = [f for f in os.listdir(weight_dir) if f.endswith('.csv')]

# Sort files for easier reading
csv_files.sort()

# Preview each CSV file
for filename in csv_files:
    full_path = os.path.join(weight_dir, filename)
    preview_csv(full_path)

File: my_model_weights/block1_0_bias.csv
Total elements: 48
First few elements: [-0.14158371, 0.086564094, 0.03638618, 0.15626998, 0.0087235095, -0.01983237, 0.077582106, -0.23901737, -0.066613145, 0.156942]
Shape: 48
Min: -0.47316253
Max: 0.2462121
Mean: -0.01763541479375
Std Dev: 0.14122533519858566
--------------------------------------------------
File: my_model_weights/block1_0_weight.csv
Total elements: 1296
First few elements: [-0.408994, -0.27043045, -0.1315946, 0.15104751, 0.6335488, 0.010418372, -0.24804702, 0.08897946, 0.2502869, 0.7613002]
Shape: 1296
Min: -1.820785
Max: 1.0271152
Mean: -0.00042038506412036864
Std Dev: 0.37214929194342206
--------------------------------------------------
File: my_model_weights/block1_1_bias.csv
Total elements: 48
First few elements: [2.239237, -1.6459881, -0.7095814, 2.4580042, -2.64024, -0.07478788, 2.497212, -1.5947889, -0.06567605, 0.6825575]
Shape: 48
Min: -2.7919664
Max: 3.4301686
Mean: -0.23573333333333335
Std Dev: 1.578670837339114


In [23]:
# Assuming `model` is your trained or loaded model
save_weights_to_csv(model, directory="my_model_weights")


Weights saved to my_model_weights/fc1_weight.csv
Weights saved to my_model_weights/fc1_bias.csv
Weights saved to my_model_weights/fc2_weight.csv
Weights saved to my_model_weights/fc2_bias.csv


In [24]:
import torch
import csv
import numpy as np

# Function to save a subset of CIFAR-10 data
def save_cifar10_data_subset(dataloader, percentage=10, filename="cifar10_subset.csv"):
    total_batches = len(dataloader)
    batches_to_save = int((percentage / 100) * total_batches)
    
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        
        for i, (images, labels) in enumerate(dataloader):
            if i >= batches_to_save:
                break
            # Flatten images and convert them to numpy arrays
            images = images.numpy().reshape(images.shape[0], -1)
            labels = labels.numpy()
            
            # Write each image and label pair to the CSV file
            for img, label in zip(images, labels):
                # Optionally, append label to the image data
                img_with_label = np.append(img, label)
                writer.writerow(img_with_label)
    
    print(f"Saved {batches_to_save} batches of data to {filename}")

# Example usage
save_cifar10_data_subset(testloader, 10)


Saved 1000 batches of data to cifar10_subset.csv


In [3]:
# Install TenSEAL if not done
!pip install tenseal




In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Dataset preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)

# Polynomial activation
class PolyAct4(nn.Module):
    def __init__(self, A, B, C, D, E):
        super().__init__()
        self.A = A
        self.B = B
        self.C = C
        self.D = D
        self.E = E

    def forward(self, x):
        x = torch.clamp(x, min=-6, max=6)
        return self.A * x**4 + self.B * x**3 + self.C * x**2 + self.D * x + self.E

poly4 = PolyAct4(
    A=-0.00068481,
    B=0.0,  # optional: skip B term to simplify
    C=0.0887234775,
    D=0.5,
    E=0.738099333
)

# Simple 2-layer model
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3072, 512)
        self.act = poly4
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training
for epoch in range(5):  # Just 5 epochs to test concept
    model.train()
    for images, labels in trainloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/5 done ✅")

# Save weights
torch.save(model.state_dict(), 'simple_model.pth')
print("Simple model saved! ✅")


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/5 done ✅
Epoch 2/5 done ✅
Epoch 3/5 done ✅
Epoch 4/5 done ✅
Epoch 5/5 done ✅
Simple model saved! ✅
