0. Install Package

In [None]:
!pip install snntorch torchvision
acc_log = {}

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


1. Rate Encoding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import snntorch as snn
from snntorch import spikegen

# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 128
num_epochs = 30
lr = 2e-3

T = 20                 # number of time steps (increase to 30-50 if you have GPU time)
rate_scale = 1.0       # multiply input intensity before rate sampling (<=1 is safest)
tau = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau)).item()

num_classes = 10

# -----------------------------
# Data (CIFAR-10)
# -----------------------------
# Common CIFAR-10 normalization (helps a lot for conv nets)
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

# -----------------------------
# Rate encoder (Poisson/Bernoulli per timestep)
# -----------------------------
def rate_encode_cifar(x_img, T, rate_scale=1.0):
    """
    x_img: [B, 3, 32, 32] after normalization.
    For rate encoding, we need values in [0,1] as firing probabilities.
    So we map normalized image -> [0,1] using a smooth squash.

    Returns spk_in: [T, B, 3, 32, 32]
    """
    # Map to [0,1] probability. Tanh squash is stable and avoids hard clipping artifacts.
    # (You can also use torch.sigmoid.)
    p = torch.tanh(x_img).add(1).mul(0.5)  # roughly [0,1]
    p = torch.clamp(p * rate_scale, 0.0, 1.0)

    # snnTorch rate encoder expects shape [B, ...] and returns [T, B, ...]
    spk = spikegen.rate(p, num_steps=T)  # [T,B,3,32,32]
    return spk

# -----------------------------
# A more complex Conv SNN for CIFAR-10
# -----------------------------
class ConvSNN_CIFAR10(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()

        # Feature extractor (Conv-BN-LIF-Pool) x3
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.lif1  = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(128)
        self.lif2  = snn.Leaky(beta=beta)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False)
        self.bn3   = nn.BatchNorm2d(256)
        self.lif3  = snn.Leaky(beta=beta)

        self.pool = nn.MaxPool2d(2, 2)
        self.drop = nn.Dropout(p=0.2)

        # Classifier head
        # After 3 pools: 32 -> 16 -> 8 -> 4
        self.fc1  = nn.Linear(256 * 4 * 4, 512, bias=True)
        self.lif4 = snn.Leaky(beta=beta)
        self.fc2  = nn.Linear(512, num_classes, bias=True)
        self.lif5 = snn.Leaky(beta=beta)

    def forward(self, spk_in):
        """
        spk_in: [T, B, 3, 32, 32]
        returns spk_out: [T, B, 10]
        """
        Tsteps, B, _, _, _ = spk_in.shape

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        mem4 = self.lif4.init_leaky()
        mem5 = self.lif5.init_leaky()

        spk_out_rec = []

        for t in range(Tsteps):
            x = spk_in[t]

            x = self.conv1(x)
            x = self.bn1(x)
            spk1, mem1 = self.lif1(x, mem1)
            x = self.pool(spk1)

            x = self.conv2(x)
            x = self.bn2(x)
            spk2, mem2 = self.lif2(x, mem2)
            x = self.pool(spk2)

            x = self.conv3(x)
            x = self.bn3(x)
            spk3, mem3 = self.lif3(x, mem3)
            x = self.pool(spk3)

            x = self.drop(x)
            x = x.view(B, -1)

            x = self.fc1(x)
            spk4, mem4 = self.lif4(x, mem4)

            x = self.fc2(spk4)
            spk5, mem5 = self.lif5(x, mem5)

            spk_out_rec.append(spk5)

        return torch.stack(spk_out_rec, dim=0)

model = ConvSNN_CIFAR10(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# -----------------------------
# Train / Eval
# -----------------------------
@torch.no_grad()
def evaluate():
    model.eval()
    correct, total = 0, 0
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        spk_in = rate_encode_cifar(x, T=T, rate_scale=rate_scale)

        spk_out = model(spk_in)
        logits = spk_out.sum(dim=0)  # spike-count readout
        pred = logits.argmax(dim=1)

        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        spk_in = rate_encode_cifar(x, T=T, rate_scale=rate_scale)
        spk_out = model(spk_in)

        logits = spk_out.sum(dim=0)  # [B,10]
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += x.size(0)

    scheduler.step()
    test_acc = evaluate()

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"test acc {test_acc:.4f}")

print("Done.")


100%|██████████| 170M/170M [00:14<00:00, 11.9MB/s]


Epoch 01 | train loss 2.4983 | train acc 0.2045 | test acc 0.3295
Epoch 02 | train loss 1.9103 | train acc 0.3293 | test acc 0.3630
Epoch 03 | train loss 1.7461 | train acc 0.3814 | test acc 0.4260
Epoch 04 | train loss 1.6328 | train acc 0.4317 | test acc 0.4234
Epoch 05 | train loss 1.5236 | train acc 0.4686 | test acc 0.5193
Epoch 06 | train loss 1.4446 | train acc 0.4998 | test acc 0.5516
Epoch 07 | train loss 1.3614 | train acc 0.5312 | test acc 0.5562
Epoch 08 | train loss 1.2745 | train acc 0.5602 | test acc 0.5828
Epoch 09 | train loss 1.2033 | train acc 0.5850 | test acc 0.5769
Epoch 10 | train loss 1.1669 | train acc 0.6021 | test acc 0.6475
Epoch 11 | train loss 1.0970 | train acc 0.6264 | test acc 0.6095
Epoch 12 | train loss 1.0506 | train acc 0.6389 | test acc 0.6880
Epoch 13 | train loss 1.0007 | train acc 0.6533 | test acc 0.6748
Epoch 14 | train loss 0.9650 | train acc 0.6659 | test acc 0.6865
Epoch 15 | train loss 0.9165 | train acc 0.6837 | test acc 0.6978
Epoch 16 |

2. TTFS Encoding

In [None]:
# Stronger SNN for CIFAR-10 with TTFS (latency) encoding (NON-hard)
# Changes (ALL added as you requested):
# 1) threshold_latency lowered: 0.01 -> 0.001  (reduce tail saturation / late-bin pile-up)
# 2) readout_mode = "mean" (average membrane logits across time to reduce test variance)
# 3) label smoothing in CE loss: label_smoothing=0.1 (reduce overfit, stabilize generalization)
# 4) Save best checkpoint by test accuracy (best_ttfs_resnet_snn.pt)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import snntorch as snn
from snntorch import spikegen

# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 128
num_epochs = 80
lr = 2e-3
weight_decay = 5e-4

T = 30  # encoding window

# CDF equalization settings
num_bins = 512
max_cdf_batches = 300

# Latency settings (affect time spread)
normalize_in_latency = True
linear_latency = True
tau_latency = 8.0
threshold_latency = 0.001  # ✅ CHANGED: was 0.01

# Neuron dynamics
tau_neuron = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau_neuron)).item()

# ✅ CHANGED: use mean membrane logits across time
readout_mode = "mean"  # "last" or "mean"

# ✅ CHANGED: label smoothing to stabilize generalization
label_smoothing = 0.1

num_classes = 10

# ✅ NEW: checkpoint
ckpt_path = "best_ttfs_resnet_snn.pt"

# -----------------------------
# Data (CIFAR-10)
# -----------------------------
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

# -----------------------------
# CDF building + Equalized TTFS encoder (still uses spikegen.latency)
# -----------------------------
_cifar_mean_t = torch.tensor(cifar_mean).view(1, 3, 1, 1)
_cifar_std_t  = torch.tensor(cifar_std).view(1, 3, 1, 1)

@torch.no_grad()
def build_cdf_from_trainloader(train_loader, num_bins=512, max_batches=300, device="cuda"):
    hist = torch.zeros(num_bins, device=device)
    mean = _cifar_mean_t.to(device)
    std  = _cifar_std_t.to(device)

    for i, (x_norm, _) in enumerate(train_loader):
        if max_batches is not None and i >= max_batches:
            break

        x_norm = x_norm.to(device, non_blocking=True)
        x_raw = (x_norm * std + mean).clamp(0.0, 1.0)

        v = x_raw.flatten()
        hist += torch.histc(v, bins=num_bins, min=0.0, max=1.0)

    hist = hist / (hist.sum() + 1e-12)
    cdf = torch.cumsum(hist, dim=0)
    bin_edges = torch.linspace(0.0, 1.0, steps=num_bins + 1, device=device)
    return bin_edges, cdf

@torch.no_grad()
def ttfs_encode_cifar_equalized_latency(
    x_img_norm, T, bin_edges, cdf,
    normalize=True, linear=True, tau=8.0, threshold=0.001
):
    device = x_img_norm.device
    mean = _cifar_mean_t.to(device)
    std  = _cifar_std_t.to(device)

    x_raw = (x_img_norm * std + mean).clamp(0.0, 1.0)

    nb = cdf.numel()
    idx = torch.bucketize(x_raw, bin_edges[1:-1], right=False).clamp(0, nb - 1)
    u = cdf[idx].clamp(1e-4, 1.0 - 1e-4)

    spk = spikegen.latency(
        u,
        num_steps=T,
        normalize=normalize,
        linear=linear,
        tau=tau,
        threshold=threshold
    )
    return spk  # [T,B,3,32,32]

print("Building intensity CDF from training data ...")
bin_edges, cdf = build_cdf_from_trainloader(
    train_loader, num_bins=num_bins, max_batches=max_cdf_batches, device=device
)
print("CDF built.")

# -----------------------------
# Spiking ResNet-style blocks
# -----------------------------
class SpkBasicBlock(nn.Module):
    """
    Residual block:
      x -> Conv-BN -> LIF -> Conv-BN -> (+ shortcut) -> LIF
    """
    def __init__(self, in_ch, out_ch, stride=1, beta=0.95):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.lif1  = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.lif2  = snn.Leaky(beta=beta)

        if stride != 1 or in_ch != out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch),
            )
        else:
            self.short = None

    def init_state(self):
        return self.lif1.init_leaky(), self.lif2.init_leaky()

    def forward_step(self, x_spk, mem1, mem2):
        out = self.conv1(x_spk)
        out = self.bn1(out)
        spk1, mem1 = self.lif1(out, mem1)

        out = self.conv2(spk1)
        out = self.bn2(out)

        skip = x_spk if self.short is None else self.short(x_spk)
        out = out + skip

        spk2, mem2 = self.lif2(out, mem2)
        return spk2, mem1, mem2

class SpikingResNetCIFAR(nn.Module):
    """
    Stem -> [2 blocks @64] -> [2 blocks @128, downsample] -> [2 blocks @256, downsample]
    -> global avgpool -> FC -> output LIF -> logits from membrane (last or mean)
    """
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()

        self.conv0 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.bn0   = nn.BatchNorm2d(64)
        self.lif0  = snn.Leaky(beta=beta)

        self.b1_0 = SpkBasicBlock(64, 64, stride=1, beta=beta)
        self.b1_1 = SpkBasicBlock(64, 64, stride=1, beta=beta)

        self.b2_0 = SpkBasicBlock(64, 128, stride=2, beta=beta)   # 32->16
        self.b2_1 = SpkBasicBlock(128, 128, stride=1, beta=beta)

        self.b3_0 = SpkBasicBlock(128, 256, stride=2, beta=beta)  # 16->8
        self.b3_1 = SpkBasicBlock(256, 256, stride=1, beta=beta)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=0.3)

        self.fc = nn.Linear(256, num_classes)
        self.lif_out = snn.Leaky(beta=beta)

    def forward(self, spk_in, readout_mode="mean"):
        Tsteps, B, _, _, _ = spk_in.shape

        mem0 = self.lif0.init_leaky()

        m10, m11 = self.b1_0.init_state()
        m12, m13 = self.b1_1.init_state()

        m20, m21 = self.b2_0.init_state()
        m22, m23 = self.b2_1.init_state()

        m30, m31 = self.b3_0.init_state()
        m32, m33 = self.b3_1.init_state()

        mem_out = self.lif_out.init_leaky()
        mem_logits_rec = []

        for t in range(Tsteps):
            x = spk_in[t]

            x = self.conv0(x)
            x = self.bn0(x)
            x, mem0 = self.lif0(x, mem0)

            x, m10, m11 = self.b1_0.forward_step(x, m10, m11)
            x, m12, m13 = self.b1_1.forward_step(x, m12, m13)

            x, m20, m21 = self.b2_0.forward_step(x, m20, m21)
            x, m22, m23 = self.b2_1.forward_step(x, m22, m23)

            x, m30, m31 = self.b3_0.forward_step(x, m30, m31)
            x, m32, m33 = self.b3_1.forward_step(x, m32, m33)

            x = self.avgpool(x).view(B, -1)
            x = self.drop(x)
            x = self.fc(x)

            _, mem_out = self.lif_out(x, mem_out)
            mem_logits_rec.append(mem_out)

        mem_logits_rec = torch.stack(mem_logits_rec, dim=0)  # [T,B,10]
        if readout_mode == "last":
            return mem_logits_rec[-1]
        return mem_logits_rec.mean(dim=0)  # ✅ mean by default

# -----------------------------
# Model / Optim
# -----------------------------
model = SpikingResNetCIFAR(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# -----------------------------
# Sanity check: spike count + time histogram
# -----------------------------
@torch.no_grad()
def sanity_check_ttfs():
    x, _ = next(iter(train_loader))
    x = x.to(device)

    spk = ttfs_encode_cifar_equalized_latency(
        x, T=T, bin_edges=bin_edges, cdf=cdf,
        normalize=normalize_in_latency,
        linear=linear_latency,
        tau=tau_latency,
        threshold=threshold_latency
    )

    counts = spk.sum(dim=0)
    print("TTFS spike-count per input (min,max) =", counts.min().item(), counts.max().item())

    t_spike = spk.argmax(dim=0).flatten()
    hist = torch.bincount(t_spike, minlength=T).float()
    hist = hist / hist.sum()
    print("TTFS time histogram (first 10 bins):", [float(h) for h in hist[:10]])
    print("TTFS time histogram (last  10 bins):", [float(h) for h in hist[-10:]])

sanity_check_ttfs()

# -----------------------------
# Train / Eval + Best checkpoint
# -----------------------------
@torch.no_grad()
def evaluate():
    model.eval()
    correct, total = 0, 0
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)

        spk_in = ttfs_encode_cifar_equalized_latency(
            x, T=T, bin_edges=bin_edges, cdf=cdf,
            normalize=normalize_in_latency,
            linear=linear_latency,
            tau=tau_latency,
            threshold=threshold_latency
        )

        logits = model(spk_in, readout_mode=readout_mode)
        pred = logits.argmax(dim=1)

        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

best_acc = -1.0

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        spk_in = ttfs_encode_cifar_equalized_latency(
            x, T=T, bin_edges=bin_edges, cdf=cdf,
            normalize=normalize_in_latency,
            linear=linear_latency,
            tau=tau_latency,
            threshold=threshold_latency
        )

        logits = model(spk_in, readout_mode=readout_mode)

        # ✅ CHANGED: label smoothing
        loss = F.cross_entropy(logits, y, label_smoothing=label_smoothing)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += x.size(0)

    scheduler.step()
    test_acc = evaluate()

    # ✅ NEW: save best
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "test_acc": test_acc,
                "config": {
                    "T": T,
                    "tau_latency": tau_latency,
                    "threshold_latency": threshold_latency,
                    "normalize_in_latency": normalize_in_latency,
                    "linear_latency": linear_latency,
                    "tau_neuron": tau_neuron,
                    "beta": beta,
                    "readout_mode": readout_mode,
                    "label_smoothing": label_smoothing,
                }
            },
            ckpt_path
        )

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"test acc {test_acc:.4f} | "
          f"best {best_acc:.4f}")

print(f"Done. Best test acc = {best_acc:.4f}. Saved to {os.path.abspath(ckpt_path)}")


Building intensity CDF from training data ...
CDF built.
TTFS spike-count per input (min,max) = 1.0 1.0
TTFS time histogram (first 10 bins): [0.014630635268986225, 0.0406595878303051, 0.0410257987678051, 0.0372314453125, 0.0333506278693676, 0.03489939495921135, 0.03621165081858635, 0.036590576171875, 0.0348917655646801, 0.03798167034983635]
TTFS time histogram (last  10 bins): [0.02952066995203495, 0.03033447265625, 0.03035481832921505, 0.02818807028234005, 0.03160349652171135, 0.1603902131319046, 0.0, 0.0, 0.0, 0.0]
Epoch 01 | train loss 1.8287 | train acc 0.3562 | test acc 0.1573 | best 0.1573
Epoch 02 | train loss 1.6114 | train acc 0.4832 | test acc 0.1928 | best 0.1928
Epoch 03 | train loss 1.5061 | train acc 0.5427 | test acc 0.2610 | best 0.2610
Epoch 04 | train loss 1.4415 | train acc 0.5779 | test acc 0.2831 | best 0.2831
Epoch 05 | train loss 1.3876 | train acc 0.6059 | test acc 0.1777 | best 0.2831
Epoch 06 | train loss 1.3381 | train acc 0.6305 | test acc 0.2848 | best 0.28

3. ISI Encoder

In [None]:
# CIFAR-10 Spiking ResNet-style SNN + fixed-K ISI (no-endcaps, strict K)
# Stabilized version:
#   - BatchNorm -> GroupNorm (much more stable for SNN time-unroll)
#   - lr 2e-3 -> 1e-3
#   - Output head: remove lif_out; readout uses mean/last of logits directly

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import snntorch as snn

# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 128
num_epochs = 80
lr = 1e-3               # ✅ smaller
weight_decay = 5e-4

T = 30
K = 4
alpha_max = 2.0
eps_q = 1e-3

tau_neuron = 2.0
beta = torch.exp(torch.tensor(-1.0 / tau_neuron)).item()

readout_mode = "mean"
label_smoothing = 0.1

num_classes = 10
ckpt_path = "best_isi_resnet_snn_GN_logitsreadout.pt"

# -----------------------------
# Data
# -----------------------------
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

_cifar_mean_t = torch.tensor(cifar_mean).view(1, 3, 1, 1)
_cifar_std_t  = torch.tensor(cifar_std).view(1, 3, 1, 1)

@torch.no_grad()
def cifar_to_unit_interval(x_norm: torch.Tensor) -> torch.Tensor:
    mean = _cifar_mean_t.to(x_norm.device)
    std  = _cifar_std_t.to(x_norm.device)
    return (x_norm * std + mean).clamp(0.0, 1.0)

# -----------------------------
# Strict fixed-K encoder (from your working version)
# -----------------------------
@torch.no_grad()
def isi_fixedK_no_endcaps_strict(
    x_img_unit: torch.Tensor, T: int, K: int, alpha_max: float = 2.0, eps: float = 1e-3
) -> torch.Tensor:
    assert T >= 2 and K >= 1
    if K > T:
        raise ValueError(f"K={K} must satisfy K<=T={T}.")

    device = x_img_unit.device
    B = x_img_unit.size(0)
    x = x_img_unit.view(B, -1).clamp(0.0, 1.0)
    N = x.size(1)

    M = T
    j = torch.arange(M, device=device, dtype=torch.float32).view(1, 1, M)
    mid = (M - 1) / 2.0

    alpha = (x * 2.0 - 1.0) * alpha_max
    alpha = alpha.unsqueeze(-1)

    w = torch.exp(alpha * (j - mid))
    w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)
    c = torch.cumsum(w, dim=-1)

    q = torch.linspace(eps, 1.0 - eps, steps=K, device=device, dtype=torch.float32)
    q = q.view(1, 1, K).expand(B, N, K)

    t_idx = torch.searchsorted(c, q).clamp(0, T - 1).long()
    t_idx, _ = torch.sort(t_idx, dim=-1)

    used = torch.zeros(B, N, T, device=device, dtype=torch.bool)
    t_fixed = torch.full_like(t_idx, -1)

    for k in range(K):
        tk = t_idx[..., k]
        free = ~used.gather(dim=2, index=tk.unsqueeze(-1)).squeeze(-1)
        t_fixed[..., k] = torch.where(free, tk, torch.full_like(tk, -1))
        if free.any():
            used[free] |= torch.nn.functional.one_hot(tk[free], num_classes=T).bool()

    for k in range(K):
        need = (t_fixed[..., k] < 0)
        if not need.any():
            continue

        tk = t_idx[..., k].clone()
        avail = ~used
        ar = torch.arange(T, device=device).view(1, 1, T)

        forward_mask = avail & (ar >= tk.unsqueeze(-1))
        fwd_pos = forward_mask.float().argmax(dim=-1)
        fwd_exists = forward_mask.any(dim=-1)

        backward_mask = avail & (ar <= tk.unsqueeze(-1))
        rev = torch.flip(backward_mask, dims=[-1])
        bwd_pos_rev = rev.float().argmax(dim=-1)
        bwd_pos = (T - 1) - bwd_pos_rev
        bwd_exists = backward_mask.any(dim=-1)

        chosen = torch.where(fwd_exists, fwd_pos, bwd_pos).long()
        t_fixed[..., k] = torch.where(need, chosen, t_fixed[..., k])
        used[need] |= torch.nn.functional.one_hot(chosen[need], num_classes=T).bool()

    spk_flat = torch.zeros(T, B, N, device=device, dtype=torch.float32)
    b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, K)
    n_idx = torch.arange(N, device=device).view(1, N, 1).expand(B, N, K)
    spk_flat[t_fixed, b_idx, n_idx] = 1.0

    return spk_flat.view(T, B, *x_img_unit.shape[1:])

# -----------------------------
# GN helper
# -----------------------------
def GN(ch, groups=16):
    g = min(groups, ch)
    while ch % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, ch)

# -----------------------------
# Spiking ResNet-style blocks (BN -> GN)
# -----------------------------
class SpkBasicBlockGN(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, beta=0.95):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.gn1   = GN(out_ch)
        self.lif1  = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.gn2   = GN(out_ch)
        self.lif2  = snn.Leaky(beta=beta)

        if stride != 1 or in_ch != out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                GN(out_ch),
            )
        else:
            self.short = None

    def init_state(self):
        return self.lif1.init_leaky(), self.lif2.init_leaky()

    def forward_step(self, x_spk, mem1, mem2):
        out = self.conv1(x_spk)
        out = self.gn1(out)
        spk1, mem1 = self.lif1(out, mem1)

        out = self.conv2(spk1)
        out = self.gn2(out)

        skip = x_spk if self.short is None else self.short(x_spk)
        out = out + skip

        spk2, mem2 = self.lif2(out, mem2)
        return spk2, mem1, mem2

class SpikingResNetCIFAR_GN(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.gn0   = GN(64)
        self.lif0  = snn.Leaky(beta=beta)

        self.b1_0 = SpkBasicBlockGN(64, 64, stride=1, beta=beta)
        self.b1_1 = SpkBasicBlockGN(64, 64, stride=1, beta=beta)

        self.b2_0 = SpkBasicBlockGN(64, 128, stride=2, beta=beta)
        self.b2_1 = SpkBasicBlockGN(128, 128, stride=1, beta=beta)

        self.b3_0 = SpkBasicBlockGN(128, 256, stride=2, beta=beta)
        self.b3_1 = SpkBasicBlockGN(256, 256, stride=1, beta=beta)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop = nn.Dropout(p=0.3)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, spk_in, readout_mode="mean"):
        Tsteps, B, _, _, _ = spk_in.shape

        mem0 = self.lif0.init_leaky()
        m10, m11 = self.b1_0.init_state()
        m12, m13 = self.b1_1.init_state()
        m20, m21 = self.b2_0.init_state()
        m22, m23 = self.b2_1.init_state()
        m30, m31 = self.b3_0.init_state()
        m32, m33 = self.b3_1.init_state()

        logits_rec = []

        for t in range(Tsteps):
            x = spk_in[t]

            x = self.conv0(x); x = self.gn0(x)
            x, mem0 = self.lif0(x, mem0)

            x, m10, m11 = self.b1_0.forward_step(x, m10, m11)
            x, m12, m13 = self.b1_1.forward_step(x, m12, m13)

            x, m20, m21 = self.b2_0.forward_step(x, m20, m21)
            x, m22, m23 = self.b2_1.forward_step(x, m22, m23)

            x, m30, m31 = self.b3_0.forward_step(x, m30, m31)
            x, m32, m33 = self.b3_1.forward_step(x, m32, m33)

            x = self.avgpool(x).view(B, -1)
            x = self.drop(x)
            logits = self.fc(x)            # ✅ direct logits (no output LIF)
            logits_rec.append(logits)

        logits_rec = torch.stack(logits_rec, dim=0)  # [T,B,10]
        return logits_rec[-1] if readout_mode == "last" else logits_rec.mean(dim=0)

# -----------------------------
# Sanity check
# -----------------------------
@torch.no_grad()
def sanity_check_isi():
    x_norm, _ = next(iter(train_loader))
    x_norm = x_norm.to(device)
    x_unit = cifar_to_unit_interval(x_norm)

    spk = isi_fixedK_no_endcaps_strict(x_unit, T=T, K=K, alpha_max=alpha_max, eps=eps_q)

    counts = spk.sum(dim=0)
    print("ISI spike-count per input (min,max) =", counts.min().item(), counts.max().item())

    t_all = spk.reshape(T, -1).sum(dim=1).float()
    t_all = t_all / (t_all.sum() + 1e-12)
    print("ISI time histogram (first 10 bins):", [float(v) for v in t_all[:10]])
    print("ISI time histogram (last  10 bins):", [float(v) for v in t_all[-10:]])

print(f"ISI encoder sanity check (T={T}, K={K}) ...")
sanity_check_isi()

# -----------------------------
# Train / Eval
# -----------------------------
model = SpikingResNetCIFAR_GN(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

@torch.no_grad()
def evaluate():
    model.eval()
    correct, total = 0, 0
    for x_norm, y in test_loader:
        x_norm, y = x_norm.to(device), y.to(device)
        x_unit = cifar_to_unit_interval(x_norm)
        spk_in = isi_fixedK_no_endcaps_strict(x_unit, T=T, K=K, alpha_max=alpha_max, eps=eps_q)

        logits = model(spk_in, readout_mode=readout_mode)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

best_acc = -1.0

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x_norm, y in train_loader:
        x_norm, y = x_norm.to(device), y.to(device)

        x_unit = cifar_to_unit_interval(x_norm)
        spk_in = isi_fixedK_no_endcaps_strict(x_unit, T=T, K=K, alpha_max=alpha_max, eps=eps_q)

        logits = model(spk_in, readout_mode=readout_mode)
        loss = F.cross_entropy(logits, y, label_smoothing=label_smoothing)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # ✅ slightly tighter
        optimizer.step()

        running_loss += loss.item() * x_norm.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += x_norm.size(0)

    scheduler.step()
    test_acc = evaluate()

    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "test_acc": test_acc,
                "config": {
                    "T": T,
                    "K": K,
                    "alpha_max": alpha_max,
                    "eps_q": eps_q,
                    "beta": beta,
                    "readout_mode": readout_mode,
                    "label_smoothing": label_smoothing,
                    "lr": lr,
                }
            },
            ckpt_path
        )

    print(f"Epoch {epoch:02d} | "
          f"train loss {running_loss/running_total:.4f} | "
          f"train acc {running_correct/running_total:.4f} | "
          f"test acc {test_acc:.4f} | "
          f"best {best_acc:.4f}")

print(f"Done. Best test acc = {best_acc:.4f}. Saved to {os.path.abspath(ckpt_path)}")


ISI encoder sanity check (T=30, K=4) ...
ISI spike-count per input (min,max) = 4.0 4.0
ISI time histogram (first 10 bins): [0.16699981689453125, 0.14074642956256866, 0.13408024609088898, 0.051362354308366776, 0.022551218047738075, 0.0182367954403162, 0.015087127685546875, 0.013635635375976562, 0.009203593246638775, 0.0091705322265625]
ISI time histogram (last  10 bins): [0.008059819228947163, 0.0074615478515625, 0.0105044050142169, 0.01173273753374815, 0.013326008804142475, 0.017501195892691612, 0.017867406830191612, 0.07789293676614761, 0.08326593786478043, 0.10898780822753906]
Epoch 01 | train loss 2.2162 | train acc 0.1514 | test acc 0.1994 | best 0.1994
Epoch 02 | train loss 2.0496 | train acc 0.2236 | test acc 0.1994 | best 0.1994
Epoch 03 | train loss 2.0144 | train acc 0.2481 | test acc 0.2686 | best 0.2686
Epoch 04 | train loss 1.9860 | train acc 0.2704 | test acc 0.2349 | best 0.2686
Epoch 05 | train loss 1.9396 | train acc 0.2994 | test acc 0.1969 | best 0.2686
Epoch 06 | tra

4. TTFS-Phase

In [None]:
# ============================================================
# CIFAR-10 Spiking ResNet-style SNN (GroupNorm) +
# TTFS-Phase Encoder (STRICT SMO maxima-only, 1 spike per pixel)
# Method B: per-image rank-balanced binning (nearly uniform over M maxima bins)
#
# Key properties:
# - One spike per pixel (no increase in spike count)
# - Strict phase-lock: t ∈ {phi0 + k*P}
# - Per image, spike times are distributed ~uniformly across M maxima bins
# - Stable beta=0.95, GN ResNet-style SNN, logits readout mean/last
# ============================================================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import snntorch as snn

# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 128
num_epochs = 80
lr = 1e-3
weight_decay = 5e-4

T = 60
P = 3
phi0 = 0

beta = 0.95

readout_mode = "mean"     # "mean" or "last"
label_smoothing = 0.1
num_classes = 10
ckpt_path = "best_ttfs_phase_rankbalance_resnet_snn_GN_beta095.pt"

# encoder settings
USE_CDF_U = False   # keep False first; rank-balance already equalizes within each image
JITTER = 1e-6       # break ties for identical pixel values

# CDF settings (only used if USE_CDF_U=True)
CDF_BINS = 4096
CDF_MAX_BATCHES = 400

# -----------------------------
# Data
# -----------------------------
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True)

_cifar_mean_t = torch.tensor(cifar_mean).view(1, 3, 1, 1)
_cifar_std_t  = torch.tensor(cifar_std).view(1, 3, 1, 1)

@torch.no_grad()
def cifar_to_unit_interval(x_norm: torch.Tensor) -> torch.Tensor:
    mean = _cifar_mean_t.to(x_norm.device)
    std  = _cifar_std_t.to(x_norm.device)
    return (x_norm * std + mean).clamp(0.0, 1.0)

# -----------------------------
# Optional: build CDF (only if USE_CDF_U=True)
# -----------------------------
@torch.no_grad()
def build_cdf_from_cifar_loader(loader, num_bins=4096, device="cpu", max_batches=400):
    hist = torch.zeros(num_bins, dtype=torch.float64, device=device)
    for bi, (x_norm, _) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break
        x_norm = x_norm.to(device, non_blocking=True)
        x_unit = cifar_to_unit_interval(x_norm)
        x_flat = x_unit.flatten()
        idx = torch.clamp((x_flat * (num_bins - 1)).long(), 0, num_bins - 1)
        hist.scatter_add_(0, idx, torch.ones_like(idx, dtype=torch.float64))
    if hist.sum().item() <= 0:
        raise RuntimeError("CDF build failed: empty histogram.")
    pdf = hist / hist.sum()
    cdf = torch.cumsum(pdf, dim=0)
    cdf = cdf / cdf[-1].clamp_min(1e-12)
    return cdf.to(dtype=torch.float32, device="cpu")

cdf_cpu = None
if USE_CDF_U:
    print(f"Building CIFAR intensity CDF (bins={CDF_BINS}, max_batches={CDF_MAX_BATCHES}) ...")
    cdf_cpu = build_cdf_from_cifar_loader(train_loader, num_bins=CDF_BINS, device="cpu", max_batches=CDF_MAX_BATCHES)
    print("CDF built.")

# ============================================================
# TTFS-Phase Encoder (Method B): per-image rank-balanced binning
# ============================================================
@torch.no_grad()
def ttfs_phase_rank_balance_cifar(
    x_norm: torch.Tensor,          # [B,3,32,32] normalized
    T: int,
    P: int,
    phi0: int,
    use_cdf_u: bool = False,
    cdf_cpu: torch.Tensor | None = None,
    num_bins: int = 4096,
    jitter: float = 1e-6,
) -> torch.Tensor:
    """
    One-spike-per-pixel, STRICT maxima-only (t = phi0 + k*P).
    Per image, assigns bins by rank so that occupancy across M maxima bins is ~uniform.

    Larger intensity -> earlier bins (smaller k).
    """
    assert 0 <= phi0 < P
    device = x_norm.device
    B = x_norm.size(0)

    # undo normalize -> [0,1]
    x_unit = cifar_to_unit_interval(x_norm)          # [B,3,32,32]
    x_flat = x_unit.view(B, -1)                      # [B,N]
    N = x_flat.size(1)

    # choose score used for ranking
    if use_cdf_u:
        assert cdf_cpu is not None
        idx = torch.clamp((x_unit * (num_bins - 1)).long(), 0, num_bins - 1)
        score = cdf_cpu.to(device)[idx].view(B, -1).clamp(0.0, 1.0)
    else:
        score = x_flat

    # break ties (important for many identical pixels after clamp)
    if jitter > 0:
        score = score + jitter * torch.randn_like(score)

    # rank pixels: 0..N-1 (0 = brightest)
    order = torch.argsort(score, dim=1, descending=True)  # [B,N]
    inv_rank = torch.empty_like(order)
    inv_rank.scatter_(1, order, torch.arange(N, device=device).view(1, N).expand(B, N))

    # number of maxima bins
    M = int(((T - 1 - phi0) // P) + 1)  # e.g., T=60,P=3 => 20
    # map rank -> bin k so each bin gets ~N/M pixels
    k = torch.floor(inv_rank.float() * M / float(N)).long().clamp(0, M - 1)  # [B,N]

    # time index: t = phi0 + k*P
    t = (phi0 + k * P).long()  # [B,N]

    spk = torch.zeros(T, B, N, device=device, dtype=torch.float32)
    b_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    n_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    spk[t, b_idx, n_idx] = 1.0

    return spk.view(T, B, *x_norm.shape[1:])  # [T,B,3,32,32]

# -----------------------------
# Sanity check utilities
# -----------------------------
@torch.no_grad()
def phase_lock_ratio(spk_in: torch.Tensor, P: int, phi0: int) -> float:
    Tsteps = spk_in.size(0)
    spk_flat = spk_in.view(Tsteps, -1)
    total = spk_flat.sum()
    if total.item() == 0:
        return 0.0
    t = torch.arange(Tsteps, device=spk_in.device)
    on_grid = ((t - phi0) % P == 0).float()
    on = (spk_flat * on_grid[:, None]).sum()
    return (on / total).item()

@torch.no_grad()
def maxima_bin_hist(spk_in: torch.Tensor, T: int, P: int, phi0: int):
    """
    returns occupancy over maxima bins k=0..M-1 (normalized).
    """
    Tsteps = spk_in.size(0)
    assert Tsteps == T
    M = int(((T - 1 - phi0) // P) + 1)

    # count spikes at each t
    t_count = spk_in.view(T, -1).sum(dim=1)  # [T]
    total = t_count.sum().clamp_min(1.0)

    # collect maxima times
    ks = torch.arange(M, device=spk_in.device)
    t_max = (phi0 + ks * P).long()
    k_count = t_count[t_max] / total
    return k_count.detach().cpu()

@torch.no_grad()
def sanity_check_encoder():
    x_norm, _ = next(iter(train_loader))
    x_norm = x_norm.to(device)

    spk_in = ttfs_phase_rank_balance_cifar(
        x_norm, T=T, P=P, phi0=phi0,
        use_cdf_u=USE_CDF_U, cdf_cpu=cdf_cpu, num_bins=CDF_BINS, jitter=JITTER
    )

    counts = spk_in.sum(dim=0)
    print("TTFS-Phase spike-count per pixel (min,max) =", counts.min().item(), counts.max().item())

    print("Phase-lock ratio (should be 1.0):", phase_lock_ratio(spk_in, P=P, phi0=phi0))

    k_hist = maxima_bin_hist(spk_in, T=T, P=P, phi0=phi0)  # length M
    print(f"Maxima-bin occupancy (M={len(k_hist)}) first 10:", [float(v) for v in k_hist[:10]])
    print(f"Maxima-bin occupancy (M={len(k_hist)}) last  10:", [float(v) for v in k_hist[-10:]])

print(f"TTFS-Phase rank-balance encoder sanity check (T={T}, P={P}, phi0={phi0}) ...")
sanity_check_encoder()

# -----------------------------
# GN helper
# -----------------------------
def GN(ch, groups=16):
    g = min(groups, ch)
    while ch % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, ch)

# -----------------------------
# Spiking ResNet-style blocks (GN)
# -----------------------------
class SpkBasicBlockGN(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, beta=0.95):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.gn1   = GN(out_ch)
        self.lif1  = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.gn2   = GN(out_ch)
        self.lif2  = snn.Leaky(beta=beta)

        if stride != 1 or in_ch != out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                GN(out_ch),
            )
        else:
            self.short = None

    def init_state(self):
        return self.lif1.init_leaky(), self.lif2.init_leaky()

    def forward_step(self, x_spk, mem1, mem2):
        out = self.conv1(x_spk)
        out = self.gn1(out)
        spk1, mem1 = self.lif1(out, mem1)

        out = self.conv2(spk1)
        out = self.gn2(out)

        skip = x_spk if self.short is None else self.short(x_spk)
        out = out + skip

        spk2, mem2 = self.lif2(out, mem2)
        return spk2, mem1, mem2

class SpikingResNetCIFAR_GN(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.gn0   = GN(64)
        self.lif0  = snn.Leaky(beta=beta)

        self.b1_0 = SpkBasicBlockGN(64, 64,  stride=1, beta=beta)
        self.b1_1 = SpkBasicBlockGN(64, 64,  stride=1, beta=beta)

        self.b2_0 = SpkBasicBlockGN(64, 128, stride=2, beta=beta)
        self.b2_1 = SpkBasicBlockGN(128,128, stride=1, beta=beta)

        self.b3_0 = SpkBasicBlockGN(128,256, stride=2, beta=beta)
        self.b3_1 = SpkBasicBlockGN(256,256, stride=1, beta=beta)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop    = nn.Dropout(p=0.3)
        self.fc      = nn.Linear(256, num_classes)

    def forward(self, spk_in, readout_mode="mean"):
        Tsteps, B, _, _, _ = spk_in.shape

        mem0 = self.lif0.init_leaky()
        m10, m11 = self.b1_0.init_state()
        m12, m13 = self.b1_1.init_state()
        m20, m21 = self.b2_0.init_state()
        m22, m23 = self.b2_1.init_state()
        m30, m31 = self.b3_0.init_state()
        m32, m33 = self.b3_1.init_state()

        logits_rec = []
        for t in range(Tsteps):
            x = spk_in[t]

            x = self.conv0(x); x = self.gn0(x)
            x, mem0 = self.lif0(x, mem0)

            x, m10, m11 = self.b1_0.forward_step(x, m10, m11)
            x, m12, m13 = self.b1_1.forward_step(x, m12, m13)

            x, m20, m21 = self.b2_0.forward_step(x, m20, m21)
            x, m22, m23 = self.b2_1.forward_step(x, m22, m23)

            x, m30, m31 = self.b3_0.forward_step(x, m30, m31)
            x, m32, m33 = self.b3_1.forward_step(x, m32, m33)

            x = self.avgpool(x).view(B, -1)
            x = self.drop(x)
            logits = self.fc(x)  # direct logits
            logits_rec.append(logits)

        logits_rec = torch.stack(logits_rec, dim=0)  # [T,B,10]
        return logits_rec[-1] if readout_mode == "last" else logits_rec.mean(dim=0)

# -----------------------------
# Train / Eval
# -----------------------------
model = SpikingResNetCIFAR_GN(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

@torch.no_grad()
def evaluate():
    model.eval()
    correct, total = 0, 0
    for x_norm, y in test_loader:
        x_norm, y = x_norm.to(device), y.to(device)

        spk_in = ttfs_phase_rank_balance_cifar(
            x_norm, T=T, P=P, phi0=phi0,
            use_cdf_u=USE_CDF_U, cdf_cpu=cdf_cpu, num_bins=CDF_BINS, jitter=JITTER
        )

        logits = model(spk_in, readout_mode=readout_mode)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

best_acc = -1.0
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x_norm, y in train_loader:
        x_norm, y = x_norm.to(device), y.to(device)

        spk_in = ttfs_phase_rank_balance_cifar(
            x_norm, T=T, P=P, phi0=phi0,
            use_cdf_u=USE_CDF_U, cdf_cpu=cdf_cpu, num_bins=CDF_BINS, jitter=JITTER
        )

        logits = model(spk_in, readout_mode=readout_mode)
        loss = F.cross_entropy(logits, y, label_smoothing=label_smoothing)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        running_loss += loss.item() * x_norm.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += x_norm.size(0)

    scheduler.step()

    test_acc = evaluate()
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "test_acc": test_acc,
                "config": {
                    "T": T, "P": P, "phi0": phi0,
                    "beta": beta,
                    "readout_mode": readout_mode,
                    "label_smoothing": label_smoothing,
                    "lr": lr,
                    "encoder": "ttfs_phase_rank_balance",
                    "use_cdf_u": USE_CDF_U,
                    "jitter": JITTER,
                },
            },
            ckpt_path,
        )

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {running_loss/running_total:.4f} | "
        f"train acc {running_correct/running_total:.4f} | "
        f"test acc {test_acc:.4f} | "
        f"best {best_acc:.4f}"
    )

print(f"Done. Best test acc = {best_acc:.4f}. Saved to {os.path.abspath(ckpt_path)}")


TTFS-Phase rank-balance encoder sanity check (T=60, P=3, phi0=0) ...
TTFS-Phase spike-count per pixel (min,max) = 1.0 1.0
Phase-lock ratio (should be 1.0): 1.0
Maxima-bin occupancy (M=20) first 10: [0.0501302070915699, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0498046875]
Maxima-bin occupancy (M=20) last  10: [0.0501302070915699, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0501302070915699, 0.0498046875, 0.0501302070915699, 0.0498046875]
Epoch 01 | train loss 2.1903 | train acc 0.1731 | test acc 0.2051 | best 0.2051
Epoch 02 | train loss 2.0934 | train acc 0.2217 | test acc 0.2393 | best 0.2393
Epoch 03 | train loss 2.0402 | train acc 0.2489 | test acc 0.2605 | best 0.2605
Epoch 04 | train loss 1.9953 | train acc 0.2732 | test acc 0.2919 | best 0.2919
Epoch 05 | train loss 1.9685 | train acc 0.2902 | test acc 0.2679 | best 0.2919
Epoch 06 |

5. ISI-Phase

In [None]:
# ============================================================
# CIFAR-10 Spiking ResNet-style SNN (GroupNorm) +
# ISI-Phase Encoder (STRICT phase-locked time grid, fixed-K, no spike-count increase)
#
# Key properties:
# - K spikes per pixel/channel (no increase in spike count)
# - Strict phase-lock: t ∈ {phi0 + k*P}
# - ISI-like distribution over maxima bins (M bins), then mapped to time grid
# - Stable beta=0.95, GN ResNet-style SNN, logits readout mean/last
# ============================================================

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms

import snntorch as snn


# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 128
num_epochs = 80
lr = 1e-3
weight_decay = 5e-4

T = 60              # total unroll steps
P = 3               # phase period
phi0 = 0            # phase offset, must satisfy 0 <= phi0 < P

K = 4               # K spikes per pixel/channel (must satisfy K <= M)
alpha_max = 2.0
eps_q = 1e-3

beta = 0.95

readout_mode = "mean"     # "mean" or "last"
label_smoothing = 0.1
num_classes = 10
ckpt_path = f"best_isiphase_timegrid_resnet_snn_GN_beta{beta}_T{T}_P{P}_phi{phi0}_K{K}.pt"


# -----------------------------
# Data
# -----------------------------
cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std  = (0.2470, 0.2435, 0.2616)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_mean, cifar_std),
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=train_tf)
test_set  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, drop_last=False,
                          num_workers=2, pin_memory=True)

_cifar_mean_t = torch.tensor(cifar_mean).view(1, 3, 1, 1)
_cifar_std_t  = torch.tensor(cifar_std).view(1, 3, 1, 1)

@torch.no_grad()
def cifar_to_unit_interval(x_norm: torch.Tensor) -> torch.Tensor:
    mean = _cifar_mean_t.to(x_norm.device)
    std  = _cifar_std_t.to(x_norm.device)
    return (x_norm * std + mean).clamp(0.0, 1.0)


# ============================================================
# ISI-Phase Encoder (STRICT phase-lock on time grid)
# ============================================================
@torch.no_grad()
def isi_phase_fixedK_strict_timegrid(
    x_img_unit: torch.Tensor,     # [B,3,H,W] in [0,1]
    T: int,
    K: int,
    P: int,
    phi0: int = 0,
    alpha_max: float = 2.0,
    eps: float = 1e-3,
) -> torch.Tensor:
    """
    Return spikes [T,B,3,H,W], with exactly K spikes per pixel/channel.
    All spikes lie on phase-locked grid: t = phi0 + k*P.

    Build ISI-style CDF over M maxima bins (k=0..M-1), then map to time indices.

    Requirements:
      - 0 <= phi0 < P
      - M = floor((T-1-phi0)/P)+1
      - K <= M
    """
    assert 0 <= phi0 < P
    assert T >= 2 and K >= 1 and P >= 1

    device = x_img_unit.device
    B, C, H, W = x_img_unit.shape

    x = x_img_unit.clamp(0.0, 1.0).view(B, -1)  # [B,N]
    N = x.size(1)

    # number of allowed phase-locked bins
    M = int(((T - 1 - phi0) // P) + 1)
    if K > M:
        raise ValueError(
            f"K={K} must satisfy K<=M={M}. (M = floor((T-1-phi0)/P)+1)"
        )

    # Build ISI-like distribution over bins k=0..M-1
    k_grid = torch.arange(M, device=device, dtype=torch.float32).view(1, 1, M)
    mid = (M - 1) / 2.0

    alpha = (x * 2.0 - 1.0) * alpha_max           # [B,N]
    alpha = alpha.unsqueeze(-1)                    # [B,N,1]

    w = torch.exp(alpha * (k_grid - mid))          # [B,N,M]
    w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)
    cdf = torch.cumsum(w, dim=-1)                  # [B,N,M]

    # Quantiles -> pick K bins
    q = torch.linspace(eps, 1.0 - eps, steps=K, device=device, dtype=torch.float32)
    q = q.view(1, 1, K).expand(B, N, K)

    # Avoid non-contiguous warning + slightly faster
    cdf_c = cdf.contiguous()
    q_c   = q.contiguous()

    k_idx = torch.searchsorted(cdf_c, q_c).clamp(0, M - 1).long()  # [B,N,K]
    k_idx, _ = torch.sort(k_idx, dim=-1)

    # Strict uniqueness in bins
    used = torch.zeros(B, N, M, device=device, dtype=torch.bool)
    k_fixed = torch.full_like(k_idx, -1)

    for kk in range(K):
        k0 = k_idx[..., kk]
        free = ~used.gather(dim=2, index=k0.unsqueeze(-1)).squeeze(-1)
        k_fixed[..., kk] = torch.where(free, k0, torch.full_like(k0, -1))
        if free.any():
            used[free] |= F.one_hot(k0[free], num_classes=M).bool()

    # Fill collisions with nearest free bin (forward then backward)
    for kk in range(K):
        need = (k_fixed[..., kk] < 0)
        if not need.any():
            continue

        k0 = k_idx[..., kk].clone()
        avail = ~used
        ar = torch.arange(M, device=device).view(1, 1, M)

        forward_mask = avail & (ar >= k0.unsqueeze(-1))
        fwd_pos = forward_mask.float().argmax(dim=-1)
        fwd_exists = forward_mask.any(dim=-1)

        backward_mask = avail & (ar <= k0.unsqueeze(-1))
        rev = torch.flip(backward_mask, dims=[-1])
        bwd_pos_rev = rev.float().argmax(dim=-1)
        bwd_pos = (M - 1) - bwd_pos_rev
        bwd_exists = backward_mask.any(dim=-1)

        chosen = torch.where(fwd_exists, fwd_pos, bwd_pos).long()
        k_fixed[..., kk] = torch.where(need, chosen, k_fixed[..., kk])
        used[need] |= F.one_hot(chosen[need], num_classes=M).bool()

    # Map bins -> time indices (phase-locked)
    t_idx = (phi0 + k_fixed * P).long()  # [B,N,K] in [0..T-1]

    # Build spike tensor
    spk_flat = torch.zeros(T, B, N, device=device, dtype=torch.float32)
    b_idx = torch.arange(B, device=device).view(B, 1, 1).expand(B, N, K)
    n_idx = torch.arange(N, device=device).view(1, N, 1).expand(B, N, K)
    spk_flat[t_idx, b_idx, n_idx] = 1.0

    return spk_flat.view(T, B, C, H, W)


# -----------------------------
# Sanity check utilities
# -----------------------------
@torch.no_grad()
def phase_lock_ratio(spk_in: torch.Tensor, P: int, phi0: int) -> float:
    Tsteps = spk_in.size(0)
    spk_flat = spk_in.view(Tsteps, -1)
    total = spk_flat.sum().clamp_min(1.0)
    t = torch.arange(Tsteps, device=spk_in.device)
    on_grid = ((t - phi0) % P == 0).float()
    on = (spk_flat * on_grid[:, None]).sum()
    return (on / total).item()

@torch.no_grad()
def maxima_bin_hist(spk_in: torch.Tensor, T: int, P: int, phi0: int):
    """
    returns occupancy over maxima bins k=0..M-1 (normalized).
    """
    Tsteps = spk_in.size(0)
    assert Tsteps == T
    M = int(((T - 1 - phi0) // P) + 1)

    t_count = spk_in.view(T, -1).sum(dim=1)  # [T]
    total = t_count.sum().clamp_min(1.0)

    ks = torch.arange(M, device=spk_in.device)
    t_max = (phi0 + ks * P).long()
    k_count = t_count[t_max] / total
    return k_count.detach().cpu()

@torch.no_grad()
def sanity_check_encoder():
    x_norm, _ = next(iter(train_loader))
    x_norm = x_norm.to(device)
    x_unit = cifar_to_unit_interval(x_norm)

    spk_in = isi_phase_fixedK_strict_timegrid(
        x_unit, T=T, K=K, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps_q
    )

    counts = spk_in.sum(dim=0)  # [B,3,H,W]
    print("ISI-Phase spike-count per pixel (min,max) =", counts.min().item(), counts.max().item())
    print("Phase-lock ratio (should be 1.0):", phase_lock_ratio(spk_in, P=P, phi0=phi0))

    k_hist = maxima_bin_hist(spk_in, T=T, P=P, phi0=phi0)
    print(f"Maxima-bin occupancy (M={len(k_hist)}) first 10:", [float(v) for v in k_hist[:10]])
    print(f"Maxima-bin occupancy (M={len(k_hist)}) last  10:", [float(v) for v in k_hist[-10:]])

print(f"ISI-Phase(timegrid) sanity check (T={T}, K={K}, P={P}, phi0={phi0}) ...")
sanity_check_encoder()


# -----------------------------
# GN helper
# -----------------------------
def GN(ch, groups=16):
    g = min(groups, ch)
    while ch % g != 0 and g > 1:
        g -= 1
    return nn.GroupNorm(g, ch)


# -----------------------------
# Spiking ResNet-style blocks (GN)
# -----------------------------
class SpkBasicBlockGN(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, beta=0.95):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.gn1   = GN(out_ch)
        self.lif1  = snn.Leaky(beta=beta)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.gn2   = GN(out_ch)
        self.lif2  = snn.Leaky(beta=beta)

        if stride != 1 or in_ch != out_ch:
            self.short = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                GN(out_ch),
            )
        else:
            self.short = None

    def init_state(self):
        return self.lif1.init_leaky(), self.lif2.init_leaky()

    def forward_step(self, x_spk, mem1, mem2):
        out = self.conv1(x_spk)
        out = self.gn1(out)
        spk1, mem1 = self.lif1(out, mem1)

        out = self.conv2(spk1)
        out = self.gn2(out)

        skip = x_spk if self.short is None else self.short(x_spk)
        out = out + skip

        spk2, mem2 = self.lif2(out, mem2)
        return spk2, mem1, mem2


class SpikingResNetCIFAR_GN(nn.Module):
    def __init__(self, beta=0.95, num_classes=10):
        super().__init__()
        self.conv0 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.gn0   = GN(64)
        self.lif0  = snn.Leaky(beta=beta)

        self.b1_0 = SpkBasicBlockGN(64, 64,  stride=1, beta=beta)
        self.b1_1 = SpkBasicBlockGN(64, 64,  stride=1, beta=beta)

        self.b2_0 = SpkBasicBlockGN(64, 128, stride=2, beta=beta)
        self.b2_1 = SpkBasicBlockGN(128,128, stride=1, beta=beta)

        self.b3_0 = SpkBasicBlockGN(128,256, stride=2, beta=beta)
        self.b3_1 = SpkBasicBlockGN(256,256, stride=1, beta=beta)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.drop    = nn.Dropout(p=0.3)
        self.fc      = nn.Linear(256, num_classes)

    def forward(self, spk_in, readout_mode="mean"):
        Tsteps, B, _, _, _ = spk_in.shape

        mem0 = self.lif0.init_leaky()
        m10, m11 = self.b1_0.init_state()
        m12, m13 = self.b1_1.init_state()
        m20, m21 = self.b2_0.init_state()
        m22, m23 = self.b2_1.init_state()
        m30, m31 = self.b3_0.init_state()
        m32, m33 = self.b3_1.init_state()

        logits_rec = []
        for t in range(Tsteps):
            x = spk_in[t]

            x = self.conv0(x); x = self.gn0(x)
            x, mem0 = self.lif0(x, mem0)

            x, m10, m11 = self.b1_0.forward_step(x, m10, m11)
            x, m12, m13 = self.b1_1.forward_step(x, m12, m13)

            x, m20, m21 = self.b2_0.forward_step(x, m20, m21)
            x, m22, m23 = self.b2_1.forward_step(x, m22, m23)

            x, m30, m31 = self.b3_0.forward_step(x, m30, m31)
            x, m32, m33 = self.b3_1.forward_step(x, m32, m33)

            x = self.avgpool(x).view(B, -1)
            x = self.drop(x)
            logits = self.fc(x)
            logits_rec.append(logits)

        logits_rec = torch.stack(logits_rec, dim=0)  # [T,B,10]
        return logits_rec[-1] if readout_mode == "last" else logits_rec.mean(dim=0)


# -----------------------------
# Train / Eval
# -----------------------------
model = SpikingResNetCIFAR_GN(beta=beta, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

@torch.no_grad()
def evaluate():
    model.eval()
    correct, total = 0, 0
    for x_norm, y in test_loader:
        x_norm, y = x_norm.to(device), y.to(device)

        x_unit = cifar_to_unit_interval(x_norm)
        spk_in = isi_phase_fixedK_strict_timegrid(
            x_unit, T=T, K=K, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps_q
        )

        logits = model(spk_in, readout_mode=readout_mode)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / max(total, 1)

best_acc = -1.0
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for x_norm, y in train_loader:
        x_norm, y = x_norm.to(device), y.to(device)

        x_unit = cifar_to_unit_interval(x_norm)
        spk_in = isi_phase_fixedK_strict_timegrid(
            x_unit, T=T, K=K, P=P, phi0=phi0, alpha_max=alpha_max, eps=eps_q
        )

        logits = model(spk_in, readout_mode=readout_mode)
        loss = F.cross_entropy(logits, y, label_smoothing=label_smoothing)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        running_loss += loss.item() * x_norm.size(0)
        running_correct += (logits.argmax(dim=1) == y).sum().item()
        running_total += x_norm.size(0)

    scheduler.step()

    test_acc = evaluate()
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "test_acc": test_acc,
                "config": {
                    "T": T, "P": P, "phi0": phi0,
                    "K": K, "alpha_max": alpha_max, "eps_q": eps_q,
                    "beta": beta,
                    "readout_mode": readout_mode,
                    "label_smoothing": label_smoothing,
                    "lr": lr,
                    "encoder": "isi_phase_fixedK_strict_timegrid",
                },
            },
            ckpt_path,
        )

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {running_loss/running_total:.4f} | "
        f"train acc {running_correct/running_total:.4f} | "
        f"test acc {test_acc:.4f} | "
        f"best {best_acc:.4f}"
    )

print(f"Done. Best test acc = {best_acc:.4f}. Saved to {os.path.abspath(ckpt_path)}")


100%|██████████| 170M/170M [00:05<00:00, 29.2MB/s]


ISI-Phase(timegrid) sanity check (T=60, K=4, P=3, phi0=0) ...
ISI-Phase spike-count per pixel (min,max) = 4.0 4.0
Phase-lock ratio (should be 1.0): 1.0
Maxima-bin occupancy (M=20) first 10: [0.17098236083984375, 0.14102935791015625, 0.1348133087158203, 0.0566864013671875, 0.024326324462890625, 0.02075704000890255, 0.01550420094281435, 0.0155398054048419, 0.011348724365234375, 0.012067794799804688]
Maxima-bin occupancy (M=20) last  10: [0.011987686157226562, 0.01106135081499815, 0.013695399276912212, 0.014286041259765625, 0.017553329467773438, 0.018564224243164062, 0.021310806274414062, 0.08214187622070312, 0.08818117529153824, 0.11816278845071793]
Epoch 01 | train loss 2.2075 | train acc 0.1681 | test acc 0.2309 | best 0.2309
Epoch 02 | train loss 2.1110 | train acc 0.2184 | test acc 0.2342 | best 0.2342
Epoch 03 | train loss 2.0508 | train acc 0.2373 | test acc 0.2519 | best 0.2519
Epoch 04 | train loss 2.0108 | train acc 0.2492 | test acc 0.2609 | best 0.2609
Epoch 05 | train loss 1.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# -----------------------------
# Config
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 128
epochs = 20
lr = 1e-3
weight_decay = 5e-4

# -----------------------------
# Data
# -----------------------------
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

train_ds = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)
test_ds = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

train_loader = DataLoader(train_ds, batch_size=batch_size,
                          shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=batch_size,
                         shuffle=False, num_workers=2, pin_memory=True)

# -----------------------------
# Model
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()

        # Convolutional feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),        # 32x32 -> 16x16

            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),        # 16x16 -> 8x8
        )

        # MLP classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


model = SimpleCNN().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),
                       lr=lr,
                       weight_decay=weight_decay)

# -----------------------------
# Evaluation function
# -----------------------------
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss_sum = 0.0

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = criterion(logits, y)

            loss_sum += loss.item() * x.size(0)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    return loss_sum / total, correct / total


# -----------------------------
# Training loop
# -----------------------------
for epoch in range(1, epochs + 1):
    model.train()

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

    train_loss, train_acc = evaluate(model, train_loader)
    test_loss, test_acc = evaluate(model, test_loader)

    print(f"Epoch {epoch:02d} | "
          f"Train Loss {train_loss:.4f} Acc {train_acc*100:.2f}% | "
          f"Test Loss {test_loss:.4f} Acc {test_acc*100:.2f}%")

print("Training finished.")


100%|██████████| 170M/170M [00:13<00:00, 12.9MB/s]


Epoch 01 | Train Loss 1.3424 Acc 51.08% | Test Loss 1.2543 Acc 54.45%
Epoch 02 | Train Loss 1.2559 Acc 54.43% | Test Loss 1.1932 Acc 57.10%
Epoch 03 | Train Loss 0.9955 Acc 64.32% | Test Loss 0.9434 Acc 65.46%
Epoch 04 | Train Loss 0.9080 Acc 68.02% | Test Loss 0.8504 Acc 69.56%
Epoch 05 | Train Loss 0.8134 Acc 72.00% | Test Loss 0.7704 Acc 73.47%
Epoch 06 | Train Loss 0.7496 Acc 73.44% | Test Loss 0.7105 Acc 74.87%
Epoch 07 | Train Loss 0.7348 Acc 74.09% | Test Loss 0.7015 Acc 75.64%
Epoch 08 | Train Loss 0.6965 Acc 75.32% | Test Loss 0.6842 Acc 75.84%
Epoch 09 | Train Loss 0.6916 Acc 75.62% | Test Loss 0.6639 Acc 76.70%
Epoch 10 | Train Loss 0.6759 Acc 76.24% | Test Loss 0.6539 Acc 77.03%
Epoch 11 | Train Loss 0.6245 Acc 78.29% | Test Loss 0.6306 Acc 78.11%
Epoch 12 | Train Loss 0.6250 Acc 78.15% | Test Loss 0.6166 Acc 78.99%
Epoch 13 | Train Loss 0.6047 Acc 78.73% | Test Loss 0.5997 Acc 79.55%
Epoch 14 | Train Loss 0.5895 Acc 79.47% | Test Loss 0.5977 Acc 79.56%
Epoch 15 | Train Los