<a href="https://colab.research.google.com/github/surfingtheuniverse/Colab-Mono-Forward/blob/main/Copy_of_Beyond_Backprop_FF_CaFo_MF_Colab_Testbed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Beyond Backprop: FF, CaFo, MF — Colab Testbed

This notebook implements and **tests three BP-free algorithms** from the paper:

- **Forward-Forward (FF)** (Hinton)
- **Cascaded-Forward (CaFo)** (Zhao et al.)
- **Mono-Forward (MF)** (Gong, Li, Abdulla)

…alongside **fair backpropagation (BP) baselines** on identical architectures.

**Datasets supported:** MNIST, Fashion-MNIST, CIFAR-10 (pick in the Config cell). Defaults are chosen for fast demo runs on Colab.

**Notes**
- Training defaults are small for speed. Increase epochs/batch size later for stronger results.
- Optional GPU **energy tracking** via NVML is enabled when possible.
- Each algorithm follows its **native architecture** described in the paper.
- All methods use **early stopping** on a validation split.

If you run into CUDA OOM on Colab, switch to a smaller model or smaller batch size in the Config cell.


In [None]:
#@title Setup
import os, math, random, time
import numpy as np
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as T

try:
    import pynvml
    _HAS_NVML = True
except Exception:
    _HAS_NVML = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Torch:', torch.__version__, 'CUDA:', torch.cuda.is_available(), 'Device:', DEVICE)


Torch: 2.8.0+cu126 CUDA: True Device: cuda


In [None]:
#@title Config
dataset = 'FashionMNIST'  #@param ['MNIST','FashionMNIST','CIFAR10']
algo_to_run = 'ALL'        #@param ['ALL','FF','CaFo','MF','BP']
batch_size = 128           #@param {type:'integer'}
max_epochs = 5             #@param {type:'integer'}
patience = 2               #@param {type:'integer'}
lr_bp = 3e-4               #@param {type:'number'}
lr_ff = 3e-4               #@param {type:'number'}
lr_mf = 3e-4               #@param {type:'number'}
lr_cafo_pred = 3e-4        #@param {type:'number'}
use_energy_tracking = True #@param {type:'boolean'}

print(f'Config: dataset={dataset}, algo={algo_to_run}, bs={batch_size}, epochs={max_epochs}')


Config: dataset=FashionMNIST, algo=ALL, bs=128, epochs=5


In [None]:
#@title Utilities: EarlyStopping + Energy Meter
class EarlyStopper:
    def __init__(self, patience=3, mode='max'):
        self.patience = patience
        self.mode = mode
        self.best = -1e18 if mode=='max' else 1e18
        self.count = 0
        self.stop = False
        self.best_state = None

    def step(self, value, model):
        improved = (value > self.best) if self.mode=='max' else (value < self.best)
        if improved:
            self.best = value
            self.count = 0
            self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.count += 1
            if self.count >= self.patience:
                self.stop = True

class EnergyMeter:
    def __init__(self, enable=True):
        self.enable = enable and _HAS_NVML and torch.cuda.is_available()
        self.energy_joules = 0.0
        self._last_t = None
        self._last_w = None
        if self.enable:
            try:
                pynvml.nvmlInit()
                self.handle = pynvml.nvmlDeviceGetHandleByIndex(0)
                self._last_t = time.time()
                self._last_w = pynvml.nvmlDeviceGetPowerUsage(self.handle)  # milliwatts
            except Exception:
                self.enable = False

    def tick(self):
        if not self.enable:
            return
        t = time.time()
        try:
            w_mw = pynvml.nvmlDeviceGetPowerUsage(self.handle) # milliwatts
        except Exception:
            return
        if self._last_t is not None:
            dt = t - self._last_t
            # integrate trapezoidally (mW to W is /1000)
            self.energy_joules += dt * ( (self._last_w + w_mw) * 0.5 / 1000.0 )
        self._last_t, self._last_w = t, w_mw

    def wh(self):
        # Joules to Wh: 1 Wh = 3600 J
        return self.energy_joules / 3600.0


In [None]:
#@title Data: Loaders + Normalization
def get_data(dataset: str, batch_size: int):
    if dataset in ['MNIST','FashionMNIST']:
        mean, std = ((0.1307,), (0.3081,)) if dataset=='MNIST' else ((0.2860,), (0.3530,))
        tfm = T.Compose([T.ToTensor(), T.Normalize(mean, std)])
        ds_cls = torchvision.datasets.MNIST if dataset=='MNIST' else torchvision.datasets.FashionMNIST
        train = ds_cls(root='./data', train=True, download=True, transform=tfm)
        test  = ds_cls(root='./data', train=False, download=True, transform=tfm)
        n_val = int(0.1 * len(train))
        n_tr  = len(train) - n_val
        train, val = random_split(train, [n_tr, n_val])
        in_ch, ncls, img_hw = 1, 10, 28
    elif dataset=='CIFAR10':
        tfm_train = T.Compose([
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
        ])
        tfm_test = T.Compose([
            T.ToTensor(),
            T.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),
        ])
        train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=tfm_train)
        test  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tfm_test)
        n_val = int(0.1 * len(train))
        n_tr  = len(train) - n_val
        train, val = random_split(train, [n_tr, n_val])
        in_ch, ncls, img_hw = 3, 10, 32
    else:
        raise ValueError('Unsupported dataset')

    dl_train = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    dl_val   = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    dl_test  = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return dl_train, dl_val, dl_test, in_ch, ncls, img_hw

dl_train, dl_val, dl_test, IN_CH, NCLS, IMG = get_data(dataset, batch_size)
IN_CH, NCLS, IMG


100%|██████████| 26.4M/26.4M [00:02<00:00, 10.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 176kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.30MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 14.4MB/s]


(1, 10, 28)

In [None]:
#@title Models: BP baselines (MLP + CNN)
class MLP(nn.Module):
    def __init__(self, in_dim, ncls, hidden: List[int]):
        super().__init__()
        layers = []
        last = in_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.ReLU(inplace=True)]
            last = h
        self.backbone = nn.Sequential(*layers)
        self.head = nn.Linear(last, ncls)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        z = self.backbone(x)
        return self.head(z)

class SimpleCNN3(nn.Module):
    # Matches CaFo native 3-block CNN shape (Conv3x3->ReLU->MaxPool2x2->BN) x3
    def __init__(self, in_ch=1, ncls=10, chans=(32,128,512)):
        super().__init__()
        c1,c2,c3 = chans
        self.b1 = nn.Sequential(nn.Conv2d(in_ch,c1,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c1))
        self.b2 = nn.Sequential(nn.Conv2d(c1,c2,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c2))
        self.b3 = nn.Sequential(nn.Conv2d(c2,c3,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c3))
        self.head = None  # will init after seeing input size

    def forward(self, x):
        x = self.b1(x); x = self.b2(x); x = self.b3(x)
        if self.head is None:
            flat = x.view(x.size(0), -1)
            self.head = nn.Linear(flat.size(1), NCLS).to(x.device)
        x = x.view(x.size(0), -1)
        return self.head(x)


In [None]:
#@title Algorithm 1: Forward-Forward (FF) — MLP
class FFMLP(nn.Module):
    def __init__(self, in_dim, layers: List[int]):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(a, b) for a,b in zip([in_dim]+layers[:-1], layers)])

    @staticmethod
    def length_norm(x, eps=1e-8):
        # normalize each sample vector to unit L2
        return x / (x.norm(dim=1, keepdim=True) + eps)

    def forward_feats(self, x):
        acts = []
        for L in self.layers:
            x = F.relu(L(x))
            acts.append(x)
            x = self.length_norm(x)
        return acts  # list of layer activations (post-ReLU, pre-length-norm next step)

def embed_label(x, y, ncls=10, img_hw=28):
    # Replace first ncls pixels with one-hot label (as in Hinton's demo)
    b = x.size(0)
    flat = x.view(b, -1)
    oh = torch.zeros(b, ncls, device=x.device)
    oh[torch.arange(b), y] = 1.0
    flat[:, :ncls] = oh
    return flat

def ff_goodness(a):
    return (a*a).sum(dim=1)

def train_ff(ff: FFMLP, dl_train, dl_val, in_dim, ncls, img_hw, max_epochs=5, lr=3e-4, patience=2):
    ff = ff.to(DEVICE)
    opt = torch.optim.AdamW(ff.parameters(), lr=lr)
    stopper = EarlyStopper(patience=patience, mode='max')
    energy = EnergyMeter(enable=use_energy_tracking)

    def epoch_pass(dloader, train=True):
        if train: ff.train()
        else: ff.eval()
        tot, correct = 0, 0
        for x,y in dloader:
            energy.tick()
            x,y = x.to(DEVICE), y.to(DEVICE)
            # Build positive and negative batches
            pos = embed_label(x.clone(), y, ncls=ncls, img_hw=img_hw)
            neg_labels = (y + torch.randint_like(y, low=1, high=ncls)) % ncls
            neg = embed_label(x.clone(), neg_labels, ncls=ncls, img_hw=img_hw)
            # Forward acts for pos/neg
            acts_pos = ff.forward_feats(pos.view(pos.size(0), -1))
            acts_neg = ff.forward_feats(neg.view(neg.size(0), -1))
            # Logistic loss per layer on goodness diff
            losses = []
            for ap, an in zip(acts_pos, acts_neg):
                gp, gn = ff_goodness(ap), ff_goodness(an)
                # threshold ~ layer width; encourages separation
                thresh = ap.size(1)
                # pos should be > thresh; neg should be < thresh
                loss = F.softplus(-(gp - thresh)).mean() + F.softplus((gn - thresh)).mean()
                losses.append(loss)
            loss = torch.stack(losses).mean()
            if train:
                opt.zero_grad(); loss.backward(); opt.step()

            # Inference: try each label, pick max total goodness
            with torch.no_grad():
                b = x.size(0)
                scores = torch.zeros(b, ncls, device=DEVICE)
                flat = x.view(b,-1)
                for c in range(ncls):
                    tmp = embed_label(flat.clone(), torch.full_like(y, c), ncls=ncls, img_hw=img_hw)
                    acts = ff.forward_feats(tmp)
                    scores[:, c] = torch.stack([ff_goodness(a) for a in acts], dim=0).sum(dim=0)
                pred = scores.argmax(dim=1)
                correct += (pred==y).sum().item(); tot += b
        return correct/tot, energy.wh()

    best_wh = 0.0
    for ep in range(1, max_epochs+1):
        tr_acc, _ = epoch_pass(dl_train, train=True)
        val_acc, wh = epoch_pass(dl_val, train=False)
        print(f'[FF] epoch {ep:02d}  train_acc={tr_acc:.3f}  val_acc={val_acc:.3f}  energy_Wh~{wh:.3f}')
        stopper.step(val_acc, ff)
        best_wh = wh
        if stopper.stop:
            print('[FF] Early stop!')
            break
    if stopper.best_state:
        ff.load_state_dict(stopper.best_state)
    return ff

def test_model(model, dl):
    model.eval(); tot=0; corr=0
    with torch.no_grad():
        for x,y in dl:
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            pred = logits.argmax(dim=1)
            corr += (pred==y).sum().item(); tot += x.size(0)
    return corr/tot


In [None]:
#@title Algorithm 2: CaFo (Rand-CE variant) — CNN with 3 block predictors
class CaFoPredictor(nn.Module):
    def __init__(self, in_dim, ncls):
        super().__init__()
        self.fc = nn.Linear(in_dim, ncls)
    def forward(self, x):
        return self.fc(x)

class CaFoBackbone(nn.Module):
    def __init__(self, in_ch=1, chans=(32,128,512)):
        super().__init__()
        c1,c2,c3 = chans
        self.b1 = nn.Sequential(nn.Conv2d(in_ch,c1,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c1))
        self.b2 = nn.Sequential(nn.Conv2d(c1,c2,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c2))
        self.b3 = nn.Sequential(nn.Conv2d(c2,c3,3,padding=1), nn.ReLU(True), nn.MaxPool2d(2), nn.BatchNorm2d(c3))
        for p in self.parameters():
            p.requires_grad_(False)  # Rand-CE keeps blocks frozen

    def forward_blocks(self, x):
        f1 = self.b1(x)
        f2 = self.b2(f1)
        f3 = self.b3(f2)
        return f1, f2, f3

def train_cafo_rand(dl_train, dl_val, in_ch, ncls, max_epochs=5, lr=3e-4, patience=2, chans=(32,128,512)):
    bb = CaFoBackbone(in_ch=in_ch, chans=chans).to(DEVICE)
    # Run one batch to infer dims
    xb, yb = next(iter(dl_train))
    xb = xb.to(DEVICE)
    f1,f2,f3 = bb.forward_blocks(xb)
    p1 = CaFoPredictor(f1.view(f1.size(0), -1).size(1), ncls).to(DEVICE)
    p2 = CaFoPredictor(f2.view(f2.size(0), -1).size(1), ncls).to(DEVICE)
    p3 = CaFoPredictor(f3.view(f3.size(0), -1).size(1), ncls).to(DEVICE)
    preds = [p1,p2,p3]
    opts  = [torch.optim.Adam(p.parameters(), lr=lr) for p in preds]
    stoppers = [EarlyStopper(patience=patience, mode='max') for _ in preds]
    energy = EnergyMeter(enable=use_energy_tracking)

    def train_one_predictor(idx):
        P = preds[idx]; Opt = opts[idx]; Stop = stoppers[idx]
        for ep in range(1, max_epochs+1):
            P.train();
            for x,y in dl_train:
                energy.tick()
                x,y = x.to(DEVICE), y.to(DEVICE)
                with torch.no_grad():
                    feats = bb.forward_blocks(x)[idx]
                feats = feats.view(feats.size(0), -1)
                logits = P(feats)
                loss = F.cross_entropy(logits, y)
                Opt.zero_grad(); loss.backward(); Opt.step()
            # val
            P.eval(); correct=0; tot=0
            with torch.no_grad():
                for x,y in dl_val:
                    x,y = x.to(DEVICE), y.to(DEVICE)
                    feats = bb.forward_blocks(x)[idx]
                    feats = feats.view(feats.size(0), -1)
                    logits = P(feats)
                    pred = logits.argmax(dim=1)
                    correct += (pred==y).sum().item(); tot += x.size(0)
            val_acc = correct/tot
            print(f'[CaFo-P{idx+1}] epoch {ep:02d} val_acc={val_acc:.3f}')
            Stop.step(val_acc, P)
            if Stop.stop:
                print(f'[CaFo-P{idx+1}] Early stop!')
                break
        if Stop.best_state:
            preds[idx].load_state_dict(Stop.best_state)

    for i in range(3):
        train_one_predictor(i)
    return bb, preds

def cafo_infer(bb, preds, dl):
    for p in preds: p.eval()
    bb.eval(); correct=0; tot=0
    with torch.no_grad():
        for x,y in dl:
            x,y = x.to(DEVICE), y.to(DEVICE)
            f1,f2,f3 = bb.forward_blocks(x)
            logits = 0
            for feats,P in zip([f1,f2,f3], preds):
                feats = feats.view(feats.size(0), -1)
                logits = logits + P(feats)
            pred = logits.argmax(dim=1)
            correct += (pred==y).sum().item(); tot += x.size(0)
    return correct/tot


In [None]:
#@title Algorithm 3: Mono-Forward (MF) — MLP with projection matrices per hidden layer
class MFBlock(nn.Module):
    def __init__(self, in_dim, out_dim, ncls):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.ncls = ncls
        self.M = nn.Linear(out_dim, ncls, bias=False)  # projects activations -> class scores

    def forward(self, x):
        a = F.relu(self.fc(x))
        g = self.M(a)  # goodness scores per class
        return a, g

class MFNet(nn.Module):
    def __init__(self, in_dim, ncls, layers: List[int]):
        super().__init__()
        dims = [in_dim] + layers
        self.blocks = nn.ModuleList([MFBlock(dims[i], dims[i+1], ncls) for i in range(len(layers))])
        self.ncls = ncls

    def forward_all(self, x):
        # Ensure inputs are flattened for MLP use, regardless of caller
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        acts, scores = [], []
        for b in self.blocks:
            x, g = b(x)
            acts.append(x); scores.append(g)
        return acts, scores

    def forward(self, x):
        # BP-style inference: only final layer's scores
        if x.dim() > 2:
            x = x.view(x.size(0), -1)
        _, scores = self.forward_all(x)
        return scores[-1]

def train_mf(mf: MFNet, dl_train, dl_val, in_dim, max_epochs=5, lr=3e-4, patience=2):
    mf = mf.to(DEVICE)
    opt = torch.optim.Adam([p for p in mf.parameters()], lr=lr)
    stopper = EarlyStopper(patience=patience, mode='max')
    energy = EnergyMeter(enable=use_energy_tracking)

    def run_epoch(dloader, train=True):
        if train: mf.train()
        else: mf.eval()
        tot, corr = 0, 0
        for x,y in dloader:
            energy.tick()
            x,y = x.to(DEVICE), y.to(DEVICE)
            acts, scores = mf.forward_all(x)  # now handles flatten internally
            # local CE loss on each layer's scores
            losses = [F.cross_entropy(g, y) for g in scores]
            loss = torch.stack(losses).mean()
            if train:
                opt.zero_grad(); loss.backward(); opt.step()
            with torch.no_grad():
                pred = scores[-1].argmax(dim=1)
                corr += (pred==y).sum().item(); tot += x.size(0)
        return corr/tot, energy.wh()

    for ep in range(1, max_epochs+1):
        tr_acc, _ = run_epoch(dl_train, train=True)
        val_acc, wh = run_epoch(dl_val, train=False)
        print(f'[MF] epoch {ep:02d}  train_acc={tr_acc:.3f}  val_acc={val_acc:.3f}  energy_Wh~{wh:.3f}')
        stopper.step(val_acc, mf)
        if stopper.stop:
            print('[MF] Early stop!')
            break
    if stopper.best_state:
        mf.load_state_dict(stopper.best_state)
    return mf


In [None]:
#@title Train & Evaluate Selected Algorithm(s)
in_dim = IN_CH * IMG * IMG
results = {}

if algo_to_run in ['BP','ALL']:
    print('== BP baseline ==')
    if dataset in ['MNIST','FashionMNIST']:
        # MF native: 2x1000 for MNIST/FashionMNIST
        model = MLP(in_dim, NCLS, hidden=[1000,1000]).to(DEVICE)
    else:
        # CaFo native: 3-block CNN
        model = SimpleCNN3(IN_CH, NCLS).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=lr_bp)
    stopper = EarlyStopper(patience=patience, mode='max')
    energy = EnergyMeter(enable=use_energy_tracking)
    for ep in range(1, max_epochs+1):
        model.train()
        for x,y in dl_train:
            energy.tick()
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = F.cross_entropy(logits, y)
            opt.zero_grad(); loss.backward(); opt.step()
        val_acc = test_model(model, dl_val)
        print(f'[BP] epoch {ep:02d} val_acc={val_acc:.3f}')
        stopper.step(val_acc, model)
        if stopper.stop:
            print('[BP] Early stop!')
            break
    if stopper.best_state:
        model.load_state_dict(stopper.best_state)
    test_acc = test_model(model, dl_test)
    results['BP'] = test_acc
    print('BP test_acc =', test_acc)

if algo_to_run in ['FF','ALL'] and dataset in ['MNIST','FashionMNIST','CIFAR10']:
    print('\n== Forward-Forward ==')
    # FF native: MLP; 4x2000 for FashionMNIST per paper; for Colab speed we use 3x1000 by default
    layers = [1000,1000,1000] if dataset!='FashionMNIST' else [2000,2000]  # tweak for speed/oom
    ff = FFMLP(in_dim, layers)
    ff = train_ff(ff, dl_train, dl_val, in_dim, NCLS, IMG, max_epochs=max_epochs, lr=lr_ff, patience=patience)
    # FF inference wrapper for uniform test
    class FFWrapper(nn.Module):
        def __init__(self, ff):
            super().__init__(); self.ff=ff
        def forward(self, x):
            b = x.size(0)
            scores = torch.zeros(b, NCLS, device=x.device)
            flat = x.view(b,-1)
            for c in range(NCLS):
                tmp = embed_label(flat.clone(), torch.full((b,), c, device=x.device, dtype=torch.long), ncls=NCLS, img_hw=IMG)
                acts = self.ff.forward_feats(tmp)
                scores[:, c] = torch.stack([ff_goodness(a) for a in acts], dim=0).sum(dim=0)
            return scores
    test_acc = test_model(FFWrapper(ff).to(DEVICE), dl_test)
    results['FF'] = test_acc
    print('FF test_acc =', test_acc)

if algo_to_run in ['CaFo','ALL']:
    print('\n== CaFo (Rand-CE) ==')
    # Native CNN 3 blocks; predictors trained independently, blocks frozen
    chans = (32,128,512) if dataset!='CIFAR10' else (32,128,256)  # lighter for CIFAR10 on Colab
    bb, preds = train_cafo_rand(dl_train, dl_val, IN_CH, NCLS, max_epochs=max_epochs, lr=lr_cafo_pred, patience=patience, chans=chans)
    test_acc = cafo_infer(bb, preds, dl_test)
    results['CaFo'] = test_acc
    print('CaFo test_acc =', test_acc)

if algo_to_run in ['MF','ALL'] and dataset in ['MNIST','FashionMNIST','CIFAR10']:
    print('\n== Mono-Forward ==')
    # Native MLP: 2x1000 for MNIST/FashionMNIST, 3x2000 for CIFAR10 (downsized for Colab)
    layers = [1000,1000] if dataset in ['MNIST','FashionMNIST'] else [1024,1024]
    mf = MFNet(in_dim, NCLS, layers)
    mf = train_mf(mf, dl_train, dl_val, in_dim, max_epochs=max_epochs, lr=lr_mf, patience=patience)
    test_acc = test_model(mf.to(DEVICE), dl_test)
    results['MF'] = test_acc
    print('MF test_acc =', test_acc)

print('\nSummary:', results)


== BP baseline ==
[BP] epoch 01 val_acc=0.860
[BP] epoch 02 val_acc=0.876
[BP] epoch 03 val_acc=0.877
[BP] epoch 04 val_acc=0.881
[BP] epoch 05 val_acc=0.886
BP test_acc = 0.8736

== Forward-Forward ==
[FF] epoch 01  train_acc=0.400  val_acc=0.620  energy_Wh~0.172
[FF] epoch 02  train_acc=0.667  val_acc=0.662  energy_Wh~0.357
[FF] epoch 03  train_acc=0.724  val_acc=0.724  energy_Wh~0.532
[FF] epoch 04  train_acc=0.749  val_acc=0.763  energy_Wh~0.706
[FF] epoch 05  train_acc=0.771  val_acc=0.769  energy_Wh~0.884
FF test_acc = 0.7596

== CaFo (Rand-CE) ==
[CaFo-P1] epoch 01 val_acc=0.877
[CaFo-P1] epoch 02 val_acc=0.886
[CaFo-P1] epoch 03 val_acc=0.891
[CaFo-P1] epoch 04 val_acc=0.893
[CaFo-P1] epoch 05 val_acc=0.893
[CaFo-P2] epoch 01 val_acc=0.884
[CaFo-P2] epoch 02 val_acc=0.890
[CaFo-P2] epoch 03 val_acc=0.899
[CaFo-P2] epoch 04 val_acc=0.902
[CaFo-P2] epoch 05 val_acc=0.907
[CaFo-P3] epoch 01 val_acc=0.882
[CaFo-P3] epoch 02 val_acc=0.887
[CaFo-P3] epoch 03 val_acc=0.892
[CaFo-P3] e

## What to tweak next
- Increase `max_epochs`, `batch_size`, and model width for stronger results.
- Switch dataset to `CIFAR10` for a harder benchmark.
- Enable a Colab GPU (Runtime → Change runtime type → T4/L4/A100), and keep `use_energy_tracking=True` to log approximate GPU energy.
- For CaFo-DFA or Optuna tuning, extend the notebook (this demo uses CaFo Rand-CE for simplicity and runtime).
