# üß† EGRR ‚Äî Entropy-Gated Recursive Residual Network

**Target**: >70% Top-1 accuracy on CIFAR-100 with <500K parameters

### Three Core Mechanisms
1. **Symmetric Shared-Weight 1√ó1 Conv** ‚Äî `W = L + L·µÄ` for spectral stability
2. **Entropy-Gated Dynamic Dilation** ‚Äî local variance selects dilation rate `d ‚àà {1, 2, 4}`
3. **Iteration-Specific Normalization** ‚Äî per-recursion affine params `(Œ≥_t, Œ≤_t)`

### Recursive Update Rule
```
h_t = h_{t-1} + ReLU6(IS-Norm_t(DWConv_gated(SymConv(h_{t-1}))))
```

---

## 0. Setup & GPU Check

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
import torchvision
import torchvision.transforms as transforms
import numpy as np
import math
import time
import os

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
    torch.backends.cudnn.benchmark = True

## 1. Configuration

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# Configuration ‚Äî All hyperparameters
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

class Config:
    # Model variant
    MODEL_VARIANT = "base"  # "base" or "deep"
    NUM_CLASSES = 100
    STEM_CHANNELS = 32

    # Architecture stages: (channels, T_recursions, stride)
    STAGES_BASE = [
        (32,  3, 1),   # Stage 1: 32√ó32 ‚Üí 32√ó32
        (64,  4, 2),   # Stage 2: 32√ó32 ‚Üí 16√ó16
        (64,  4, 1),   # Stage 3: 16√ó16
        (128, 5, 2),   # Stage 4: 16√ó16 ‚Üí 8√ó8
        (128, 5, 1),   # Stage 5: 8√ó8
        (128, 5, 1),   # Stage 6: 8√ó8
        (256, 6, 2),   # Stage 7: 8√ó8 ‚Üí 4√ó4
        (256, 6, 1),   # Stage 8: 4√ó4
    ]
    STAGES_DEEP = [
        (32,  6,  1),  (64,  8,  2),  (64,  8,  1),
        (128, 10, 2),  (128, 10, 1),  (128, 10, 1),
        (256, 12, 2),  (256, 12, 1),
    ]
    WIDTH_MULT = 1.22

    # Entropy Gate
    DILATION_RATES = [1, 2, 4]
    GUMBEL_TAU_START = 1.0
    GUMBEL_TAU_END = 0.1

    # Training
    BATCH_SIZE = 128
    EPOCHS = 200
    LEARNING_RATE = 0.1
    MOMENTUM = 0.9
    WEIGHT_DECAY = 5e-4
    LABEL_SMOOTHING = 0.1
    LR_MIN = 0.0
    LR_WARMUP_EPOCHS = 5
    DEPTH_WARMUP_END_EPOCH = 20
    CUTOUT_LENGTH = 8
    USE_AUTOAUGMENT = True

    # Memory optimization
    USE_AMP = True
    USE_GRADIENT_CHECKPOINTING = True

    @property
    def STAGES(self):
        return self.STAGES_DEEP if self.MODEL_VARIANT == "deep" else self.STAGES_BASE

cfg = Config()
print(f"Model variant: {cfg.MODEL_VARIANT}")
print(f"Stages: {len(cfg.STAGES)}, Virtual depth: {sum(T for _, T, _ in cfg.STAGES)}")

---
## 2. Core Modules

### 2.1 Symmetric Shared-Weight 1√ó1 Convolution

Parameterized as `W = L + L·µÄ` where L is lower-triangular.  
**Why?** Symmetric matrices have real eigenvalues ‚Üí stabilizes recursive weight sharing.  
**Bonus:** Only `C(C+1)/2` unique parameters instead of `C¬≤`.

In [None]:
class SymmetricConv1x1(nn.Module):
    """Symmetric 1√ó1 pointwise convolution.

    W = L_lower + L_lower^T guarantees:
    - W is symmetric ‚Üí real eigenvalues
    - Orthogonal init ‚Üí eigenvalues ‚âà 1 at start
    - Fewer unique parameters: C(C+1)/2
    """

    def __init__(self, channels: int):
        super().__init__()
        self.channels = channels
        self.L = nn.Parameter(torch.empty(channels, channels))
        self._init_orthogonal()
        self.register_buffer("tril_mask", torch.tril(torch.ones(channels, channels)))

    def _init_orthogonal(self):
        Q = torch.linalg.qr(torch.randn(self.channels, self.channels))[0]
        with torch.no_grad():
            self.L.copy_(Q / 2.0)
        self._cached_weight = None

    def get_weight(self):
        L_lower = self.L * self.tril_mask
        return L_lower + L_lower.transpose(0, 1)

    def train(self, mode=True):
        super().train(mode)
        if mode:
            self._cached_weight = None
        return self

    def forward(self, x):
        if not self.training and self._cached_weight is not None:
            weight = self._cached_weight
        else:
            weight = self.get_weight().unsqueeze(-1).unsqueeze(-1)
            if not self.training:
                self._cached_weight = weight
        return F.conv2d(x, weight)

    @property
    def unique_params(self):
        return self.channels * (self.channels + 1) // 2


# ‚îÄ‚îÄ Quick test ‚îÄ‚îÄ
sym = SymmetricConv1x1(64)
W = sym.get_weight()
print(f"‚úÖ SymmetricConv1x1: W shape={W.shape}, symmetric={torch.allclose(W, W.T, atol=1e-7)}")
print(f"   Unique params: {sym.unique_params} vs full: {64*64}")

### 2.2 Iteration-Specific Normalization (IS-Norm)

Each recursion step `t` gets its own affine `(Œ≥_t, Œ≤_t)`, but **shares** running mean/var.  
Lets the network "shift gears" at each recursion without full BN cost.  
**Cost**: Only `2 √ó T √ó C` extra params (negligible).

In [None]:
class ISNorm(nn.Module):
    """Iteration-Specific Batch Normalization.

    Uses F.batch_norm for fused, memory-efficient normalization.
    Per-iteration gamma[t] and beta[t] are passed directly.
    """

    def __init__(self, num_features: int, num_iterations: int,
                 eps: float = 1e-5, momentum: float = 0.1):
        super().__init__()
        self.num_features = num_features
        self.num_iterations = num_iterations
        self.eps = eps
        self.momentum = momentum

        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features))
        self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))

        # Per-iteration affine: (T, C)
        self.gamma = nn.Parameter(torch.ones(num_iterations, num_features))
        self.beta = nn.Parameter(torch.zeros(num_iterations, num_features))

    def forward(self, x, t: int):
        assert 0 <= t < self.num_iterations
        if self.training:
            self.num_batches_tracked += 1
        return F.batch_norm(
            x, self.running_mean, self.running_var,
            weight=self.gamma[t], bias=self.beta[t],
            training=self.training, momentum=self.momentum, eps=self.eps,
        )


# ‚îÄ‚îÄ Quick test ‚îÄ‚îÄ
norm = ISNorm(64, num_iterations=4)
x_test = torch.randn(2, 64, 8, 8)
norm.train()
for t in range(4):
    out = norm(x_test, t)
print(f"‚úÖ ISNorm: T=4, output shape={out.shape}, params={sum(p.numel() for p in norm.parameters())}")

### 2.3 Entropy-Gated Dynamic Dilation

Uses **local variance** as a differentiable entropy proxy.  
A lightweight decision head (`1√ó1 Conv ‚Üí ReLU ‚Üí 1√ó1 Conv`) maps variance to 3-way logits.  
- **Training**: Gumbel-Softmax (œÑ annealed from 1.0 ‚Üí 0.1)
- **Inference**: Hard argmax ‚Üí skip unused dilations

In [None]:
class EntropyGate(nn.Module):
    """Entropy-Gated Dynamic Dilation.

    V(i,j) = AvgPool(X¬≤) - (AvgPool(X))¬≤  ‚Üí decision head ‚Üí gate weights
    """

    def __init__(self, channels: int, pool_size: int = 3,
                 num_dilations: int = 3, tau: float = 1.0):
        super().__init__()
        self.channels = channels
        self.num_dilations = num_dilations
        self.tau = tau

        self.avg_pool = nn.AvgPool2d(pool_size, stride=1,
                                     padding=pool_size // 2,
                                     count_include_pad=False)
        self.decision_head = nn.Sequential(
            nn.Conv2d(channels, channels // 4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // 4, num_dilations, 1, bias=True),
        )
        self.global_pool = nn.AdaptiveAvgPool2d(1)

    def compute_local_variance(self, x):
        """V = E[X¬≤] - E[X]¬≤ (memory-optimized)"""
        ex = self.avg_pool(x)
        e_x2 = self.avg_pool(x.square())
        return (e_x2 - ex.square()).clamp_(min=0.0)

    def forward(self, x):
        variance = self.compute_local_variance(x)
        logits = self.decision_head(variance)
        del variance
        logits = self.global_pool(logits).squeeze(-1).squeeze(-1)  # (N, 3)

        if self.training:
            gate_weights = F.gumbel_softmax(logits, tau=self.tau, hard=False, dim=-1)
        else:
            gate_weights = F.one_hot(
                logits.argmax(dim=-1), num_classes=self.num_dilations
            ).float()

        return gate_weights.unsqueeze(-1).unsqueeze(-1)  # (N, 3, 1, 1)

    def set_tau(self, tau):
        self.tau = tau


# ‚îÄ‚îÄ Quick test ‚îÄ‚îÄ
gate = EntropyGate(64)
gate.eval()
with torch.no_grad():
    w = gate(torch.randn(4, 64, 16, 16))
print(f"‚úÖ EntropyGate: weights shape={w.shape}, sum={w.squeeze().sum(dim=-1)}")

### 2.4 Shared Depthwise Conv & EGRR Block

**SharedDepthwiseConv**: Single 3√ó3 depthwise kernel reused at all dilation rates.  
Memory-optimized: accumulates weighted sum in-place, skips unused dilations at inference.

**EGRRBlock**: The core recursive block with optional gradient checkpointing.  
```
h_t = h_{t-1} + ReLU6(IS-Norm_t(DWConv_gated(SymConv(h_{t-1}))))
```

In [None]:
class SharedDepthwiseConv(nn.Module):
    """Single shared depthwise weight, reused for all dilation rates."""

    def __init__(self, channels, kernel_size=3, dilation_rates=None):
        super().__init__()
        self.channels = channels
        self.kernel_size = kernel_size
        self.dilation_rates = dilation_rates or [1, 2, 4]
        self.weight = nn.Parameter(torch.randn(channels, 1, kernel_size, kernel_size) * 0.02)
        self.bias = nn.Parameter(torch.zeros(channels))

    def forward(self, x, gate_weights):
        result = None
        for i, d in enumerate(self.dilation_rates):
            w = gate_weights[:, i:i+1]
            if not self.training and w.sum().item() == 0:
                continue  # Skip unused dilations
            padding = d * (self.kernel_size // 2)
            out = F.conv2d(x, self.weight, self.bias,
                          stride=1, padding=padding, dilation=d,
                          groups=self.channels)
            if result is None:
                result = out * w
            else:
                result = result + out * w
        return result


class EGRRBlock(nn.Module):
    """Entropy-Gated Recursive Residual Block.

    Recursive loop of T iterations with shared weights,
    entropy-gated dilation, and per-iteration normalization.
    """

    def __init__(self, in_channels, out_channels, num_iterations=4,
                 stride=1, kernel_size=3, dilation_rates=None, pool_size=3):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_iterations = num_iterations
        self.stride = stride
        self.use_projection = (in_channels != out_channels) or (stride != 1)

        # Channel projection
        if self.use_projection:
            self.projection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, bias=False),
                nn.BatchNorm2d(out_channels),
            )
            self.downsample = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
        else:
            self.projection = nn.Identity()
            self.downsample = nn.Identity()

        # Three mechanisms
        self.sym_conv = SymmetricConv1x1(out_channels)
        self.entropy_gate = EntropyGate(out_channels, pool_size,
                                        len(dilation_rates or [1, 2, 4]))
        self.shared_dw_conv = SharedDepthwiseConv(out_channels, kernel_size,
                                                  dilation_rates)
        self.is_norm = ISNorm(out_channels, num_iterations)
        self.activation = nn.ReLU6(inplace=True)

        self._active_iterations = num_iterations
        self._use_gradient_checkpointing = False

    @property
    def active_iterations(self):
        return self._active_iterations

    @active_iterations.setter
    def active_iterations(self, value):
        self._active_iterations = min(value, self.num_iterations)

    def _recursive_step(self, h, t):
        z = self.sym_conv(h)
        gate_weights = self.entropy_gate(h)
        z = self.shared_dw_conv(z, gate_weights)
        z = self.is_norm(z, t)
        z = self.activation(z)
        return h + z

    def forward(self, x):
        if self.stride == 2:
            x = self.downsample(x)
        h = self.projection(x)

        T = self._active_iterations
        use_ckpt = self.training and self._use_gradient_checkpointing and T > 1

        for t in range(T):
            if use_ckpt:
                h = grad_checkpoint(self._recursive_step, h, t,
                                    use_reentrant=False)
            else:
                h = self._recursive_step(h, t)
        return h


# ‚îÄ‚îÄ Quick test ‚îÄ‚îÄ
block = EGRRBlock(32, 64, num_iterations=3, stride=2, dilation_rates=[1, 2, 4])
out = block(torch.randn(2, 32, 16, 16))
print(f"‚úÖ EGRRBlock: 32‚Üí64, stride=2, T=3 ‚Üí output={out.shape}")
print(f"   Params: {sum(p.numel() for p in block.parameters()):,}")

### 2.5 Complete EGRR Network

Pyramidal structure: 8 stages with decreasing resolution and increasing channels.  
Stem ‚Üí 8 EGRR Blocks ‚Üí Global AvgPool ‚Üí Dropout ‚Üí Linear(100)

In [None]:
class EGRRNet(nn.Module):
    """Entropy-Gated Recursive Residual Network for CIFAR-100."""

    def __init__(self, num_classes=100, stages=None, stem_channels=32,
                 width_mult=1.5, dilation_rates=None):
        super().__init__()
        self.num_classes = num_classes
        if dilation_rates is None:
            dilation_rates = [1, 2, 4]
        if stages is None:
            stages = cfg.STAGES_BASE

        # Stem
        self.stem = nn.Sequential(
            nn.Conv2d(3, stem_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(stem_channels),
            nn.ReLU6(inplace=True),
        )

        # EGRR Stages
        self.stages = nn.ModuleList()
        in_ch = stem_channels
        for base_ch, T, stride in stages:
            out_ch = self._scale(base_ch, width_mult)
            self.stages.append(EGRRBlock(
                in_ch, out_ch, T, stride, dilation_rates=dilation_rates))
            in_ch = out_ch

        # Head
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Dropout(0.1), nn.Linear(in_ch, num_classes),
        )
        self.last_channels = in_ch

    @staticmethod
    def _scale(base_c, width_mult):
        return max(8, int(round(base_c * width_mult / 8) * 8))

    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return self.head(x)

    def set_active_iterations(self, max_t):
        for s in self.stages:
            s.active_iterations = max_t

    def set_gumbel_tau(self, tau):
        for s in self.stages:
            s.entropy_gate.set_tau(tau)

    def set_gradient_checkpointing(self, enable=True):
        for s in self.stages:
            s._use_gradient_checkpointing = enable

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def parameter_breakdown(self):
        bd = {
            "stem": sum(p.numel() for p in self.stem.parameters()),
            "head": sum(p.numel() for p in self.head.parameters()),
            "stages": {}, "total": self.count_parameters(),
        }
        for i, s in enumerate(self.stages):
            sp = {
                "sym_conv": sum(p.numel() for p in s.sym_conv.parameters()),
                "entropy_gate": sum(p.numel() for p in s.entropy_gate.parameters()),
                "shared_dw_conv": sum(p.numel() for p in s.shared_dw_conv.parameters()),
                "is_norm": sum(p.numel() for p in s.is_norm.parameters()),
                "projection": sum(p.numel() for p in s.projection.parameters()) if s.use_projection else 0,
            }
            sp["subtotal"] = sum(sp.values())
            bd["stages"][f"stage_{i+1}"] = sp
        return bd


# ‚îÄ‚îÄ Build & verify ‚îÄ‚îÄ
model = EGRRNet(
    num_classes=cfg.NUM_CLASSES,
    stages=cfg.STAGES,
    stem_channels=cfg.STEM_CHANNELS,
    width_mult=cfg.WIDTH_MULT,
    dilation_rates=cfg.DILATION_RATES,
).to(device)

total = model.count_parameters()
print(f"\n{'='*60}")
print(f"EGRR Network ‚Äî {total:,} parameters")
print(f"{'='*60}")

bd = model.parameter_breakdown()
print(f"  Stem:   {bd['stem']:>8,}")
for name, info in bd['stages'].items():
    print(f"  {name}: {info['subtotal']:>8,}")
print(f"  Head:   {bd['head']:>8,}")
print(f"  Total:  {bd['total']:>8,}")
print(f"{'='*60}")
print(f"\n‚úÖ Under 500K budget: {total:,} < 500,000 ‚Üí {'PASS' if total < 500_000 else 'FAIL'}")

# Forward pass test
model.eval()
with torch.no_grad():
    out = model(torch.randn(2, 3, 32, 32).to(device))
print(f"‚úÖ Forward: (2, 3, 32, 32) ‚Üí {out.shape}")

---
## 3. Utilities

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# Training Utilities
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

def init_weights(model):
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            if m.bias is not None:
                nn.init.zeros_(m.bias)


class AverageMeter:
    def __init__(self):
        self.reset()
    def reset(self):
        self.val = self.avg = self.sum = 0.0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1, 5)):
    with torch.no_grad():
        maxk = max(topk)
        bs = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        correct = pred.t().eq(target.view(1, -1).expand_as(pred.t()))
        return [correct[:k].reshape(-1).float().sum(0).mul_(100.0/bs).item() for k in topk]


class Cutout:
    def __init__(self, length):
        self.length = length
    def __call__(self, img):
        if self.length <= 0:
            return img
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y, x = np.random.randint(h), np.random.randint(w)
        y1, y2 = np.clip(y - self.length//2, 0, h), np.clip(y + self.length//2, 0, h)
        x1, x2 = np.clip(x - self.length//2, 0, w), np.clip(x + self.length//2, 0, w)
        mask[y1:y2, x1:x2] = 0.0
        return img * torch.from_numpy(mask).expand_as(img)


def get_active_iterations(epoch, warmup_start=0, warmup_end=20, max_T=10):
    if epoch <= warmup_start:
        return 1
    elif epoch >= warmup_end:
        return max_T
    progress = (epoch - warmup_start) / (warmup_end - warmup_start)
    return max(1, int(math.ceil(progress * max_T)))


def get_gumbel_tau(epoch, total_epochs, tau_start=1.0, tau_end=0.1):
    progress = min(1.0, epoch / max(1, total_epochs))
    return max(tau_end, tau_start * (tau_end / tau_start) ** progress)


def cosine_lr(optimizer, epoch, total_epochs, lr_max, lr_min=0.0, warmup_epochs=5):
    if epoch < warmup_epochs:
        lr = lr_max * (epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
    for pg in optimizer.param_groups:
        pg["lr"] = lr
    return lr


print("‚úÖ Utilities loaded")

---
## 4. Data Loaders

In [None]:
def get_dataloaders(batch_size=128, num_workers=2):
    mean = (0.5071, 0.4867, 0.4408)
    std  = (0.2675, 0.2565, 0.2761)

    train_tf = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
    ]
    if cfg.USE_AUTOAUGMENT:
        train_tf.append(transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10))
    train_tf += [transforms.ToTensor(), transforms.Normalize(mean, std)]
    if cfg.CUTOUT_LENGTH > 0:
        train_tf.append(Cutout(cfg.CUTOUT_LENGTH))

    test_tf = [transforms.ToTensor(), transforms.Normalize(mean, std)]

    use_cuda = device.type == 'cuda'
    train_ds = torchvision.datasets.CIFAR100('./data', train=True,  download=True,
                                             transform=transforms.Compose(train_tf))
    test_ds  = torchvision.datasets.CIFAR100('./data', train=False, download=True,
                                             transform=transforms.Compose(test_tf))
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=use_cuda, drop_last=True)
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=use_cuda)
    return train_loader, test_loader


train_loader, test_loader = get_dataloaders(cfg.BATCH_SIZE)
print(f"‚úÖ CIFAR-100: {len(train_loader.dataset)} train, {len(test_loader.dataset)} test")

---
## 5. Training Loop

Features:
- **Depth warm-up**: T increases from 1 ‚Üí max over first 20 epochs
- **Gumbel-œÑ annealing**: 1.0 ‚Üí 0.1 (exponential decay)
- **Cosine LR** with 5-epoch linear warmup
- **AMP** mixed precision on CUDA
- **Gradient checkpointing** for memory efficiency

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device,
                    scaler=None, use_amp=False):
    model.train()
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    amp_type = device.type if device.type in ('cuda', 'cpu') else 'cpu'

    for i, (images, targets) in enumerate(loader):
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        with torch.amp.autocast(device_type=amp_type, enabled=use_amp):
            logits = model(images)
            loss = criterion(logits, targets)

        optimizer.zero_grad(set_to_none=True)
        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        a1, a5 = accuracy(logits.float(), targets)
        losses.update(loss.item(), images.size(0))
        top1.update(a1, images.size(0))
        top5.update(a5, images.size(0))

    return losses.avg, top1.avg, top5.avg


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        logits = model(images)
        loss = criterion(logits, targets)
        a1, a5 = accuracy(logits, targets)
        losses.update(loss.item(), images.size(0))
        top1.update(a1, images.size(0))
        top5.update(a5, images.size(0))
    return losses.avg, top1.avg, top5.avg

print("‚úÖ Training functions defined")

### 5.1 Run Training

In [None]:
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
# Train!
# ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê

# Re-init model
model = EGRRNet(
    num_classes=cfg.NUM_CLASSES,
    stages=cfg.STAGES,
    stem_channels=cfg.STEM_CHANNELS,
    width_mult=cfg.WIDTH_MULT,
    dilation_rates=cfg.DILATION_RATES,
).to(device)
init_weights(model)

if cfg.USE_GRADIENT_CHECKPOINTING:
    model.set_gradient_checkpointing(True)
    print("Gradient checkpointing: ENABLED")

max_T = max(T for _, T, _ in cfg.STAGES)
criterion = nn.CrossEntropyLoss(label_smoothing=cfg.LABEL_SMOOTHING)
optimizer = torch.optim.SGD(model.parameters(), lr=cfg.LEARNING_RATE,
                            momentum=cfg.MOMENTUM, weight_decay=cfg.WEIGHT_DECAY,
                            nesterov=True)

use_amp = cfg.USE_AMP and device.type == 'cuda'
scaler = torch.amp.GradScaler('cuda') if use_amp else None
print(f"AMP: {'ENABLED' if use_amp else 'DISABLED'}")

best_acc = 0.0
history = []

EPOCHS = cfg.EPOCHS  # Change this for faster testing
# EPOCHS = 10  # Uncomment for quick test run

print(f"\nüöÄ Training for {EPOCHS} epochs on {device}...\n")

for epoch in range(EPOCHS):
    t0 = time.time()

    # Schedules
    active_T = get_active_iterations(epoch, 0, cfg.DEPTH_WARMUP_END_EPOCH, max_T)
    model.set_active_iterations(active_T)
    tau = get_gumbel_tau(epoch, EPOCHS, cfg.GUMBEL_TAU_START, cfg.GUMBEL_TAU_END)
    model.set_gumbel_tau(tau)
    lr = cosine_lr(optimizer, epoch, EPOCHS, cfg.LEARNING_RATE,
                   cfg.LR_MIN, cfg.LR_WARMUP_EPOCHS)

    # Train & Eval
    train_loss, train_acc, _ = train_one_epoch(
        model, train_loader, criterion, optimizer, device, scaler, use_amp)
    test_loss, test_acc, test_acc5 = evaluate(
        model, test_loader, criterion, device)

    elapsed = time.time() - t0
    is_best = test_acc > best_acc
    if is_best:
        best_acc = test_acc

    history.append({"epoch": epoch, "train_acc": train_acc,
                    "test_acc": test_acc, "test_acc5": test_acc5})

    if (epoch + 1) % 10 == 0 or is_best or epoch == 0:
        star = " ‚òÖ" if is_best else ""
        print(f"Ep {epoch+1:3d}/{EPOCHS}  T={active_T}  œÑ={tau:.3f}  lr={lr:.5f}  "
              f"Train={train_acc:.1f}%  Test={test_acc:.1f}% (Top5={test_acc5:.1f}%)  "
              f"{elapsed:.1f}s{star}")

    # Save best
    if is_best:
        torch.save({
            "epoch": epoch, "model_state_dict": model.state_dict(),
            "best_acc": best_acc, "stages": cfg.STAGES,
            "width_mult": cfg.WIDTH_MULT,
        }, "best_egrr.pth")

print(f"\nüèÜ Training complete! Best Top-1: {best_acc:.2f}%")

---
## 6. Training Curves

In [None]:
import matplotlib.pyplot as plt

if history:
    epochs = [h['epoch'] for h in history]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    ax1.plot(epochs, [h['train_acc'] for h in history], label='Train', alpha=0.8)
    ax1.plot(epochs, [h['test_acc'] for h in history], label='Test', alpha=0.8)
    ax1.set_xlabel('Epoch'); ax1.set_ylabel('Accuracy (%)')
    ax1.set_title('Top-1 Accuracy'); ax1.legend(); ax1.grid(True, alpha=0.3)

    ax2.plot(epochs, [h['test_acc5'] for h in history], color='green', alpha=0.8)
    ax2.set_xlabel('Epoch'); ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Top-5 Accuracy'); ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
    print(f"Best Top-1: {best_acc:.2f}%")

---
## 7. Architecture Tests

Quick verification of all core invariants.

In [None]:
print("Running architecture verification tests...\n")

# Test 1: Parameter count
m = EGRRNet(stages=cfg.STAGES, width_mult=cfg.WIDTH_MULT,
            stem_channels=cfg.STEM_CHANNELS, dilation_rates=cfg.DILATION_RATES)
total = m.count_parameters()
assert total < 500_000, f"FAIL: {total:,} >= 500K"
print(f"‚úÖ Test 1: Parameter count = {total:,} < 500K")

# Test 2: Forward pass shape
m.eval()
with torch.no_grad():
    out = m(torch.randn(2, 3, 32, 32))
assert out.shape == (2, 100)
assert torch.isfinite(out).all()
print(f"‚úÖ Test 2: Forward (2,3,32,32) ‚Üí {out.shape}")

# Test 3: IS-Norm shapes
n = ISNorm(64, 4); n.train()
x = torch.randn(2, 64, 8, 8)
for t in range(4):
    assert n(x, t).shape == x.shape
print("‚úÖ Test 3: IS-Norm shapes correct for all T")

# Test 4: Entropy gate sums to 1
g = EntropyGate(64); g.eval()
with torch.no_grad():
    w = g(torch.randn(4, 64, 16, 16))
assert w.shape == (4, 3, 1, 1)
assert torch.allclose(w.squeeze().sum(dim=-1), torch.ones(4), atol=1e-5)
print("‚úÖ Test 4: Entropy gate weights sum to 1")

# Test 5: Symmetric kernel
s = SymmetricConv1x1(64)
W = s.get_weight()
assert (W - W.T).abs().max() < 1e-7
print("‚úÖ Test 5: W == W·µÄ (exact symmetry)")

# Test 6: Gradient flow
b = EGRRBlock(32, 32, 3, 1, dilation_rates=[1, 2, 4])
x = torch.randn(2, 32, 8, 8, requires_grad=True)
b(x).sum().backward()
assert x.grad is not None and (x.grad.abs() > 0).any()
assert b.sym_conv.L.grad is not None
print("‚úÖ Test 6: Gradients flow through recursive loop")

# Test 7: Depth warm-up
m.set_active_iterations(1)
assert all(s.active_iterations == 1 for s in m.stages)
m.set_active_iterations(100)
assert all(s.active_iterations == s.num_iterations for s in m.stages)
print("‚úÖ Test 7: Depth warm-up capping works")

print("\nüéâ All tests passed!")