In [None]:
!pip install --upgrade pip
!pip install --force-reinstall --no-cache-dir numpy==1.23.5
# Install a specific version of torch that is known to have RMSNorm
# Changed PyTorch version to a potentially more recent one with RMSNorm
!pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install --upgrade opacus matplotlib

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting numpy==1.23.5
  Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Downloading numpy-1.23.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.1/17.1 MB[0m [31m176.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.0.2
    Uninstalling numpy-2.0.2:
      Successfully uninstalled numpy-2.0.2
[31mER

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.22.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.7.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m134.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, Subset

# Opacus
from opacus.validators import ModuleValidator
from opacus.grad_sample import GradSampleModule
from opacus.accountants import RDPAccountant

################################
# 0. Hyperparameters
################################
seed = 0
batch_size      = 1000
lr              = 0.1
outer_momentum  = 0.9
inner_momentum  = 0.08
noise_mult      = 1.5
delta           = 1e-5
num_epochs      = 50
self_aug_factor = 3
M               = 10000

DEFAULT_CLIP    = 1     # normal clip
OUTLIER_CLIP    = 0.5     # outlier clip
HIGH_ERR_THRESHOLD  = 0.9 # if error ≥95% => outlier
DROP_AFTER_FRAC     = 0.6  # only apply outlier clip after 60% training

SCHEDULE_MILESTONES = [30, 45]
SCHEDULE_GAMMA       = 0.1

random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

################################
# 1. Dataset
################################
mean, std = [0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]
train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5, scale=(0.02,0.2)),
    transforms.Normalize(mean, std),
])
test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR10(root="./data", train=True,  download=True, transform=train_tf)
test_ds  = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_tf)
n = len(train_ds)

class IndexedSubset(Subset):
    """Wrap CIFAR-10 so each sample returns (x, y, idx)."""
    def __getitem__(self, idx):
        x, y = super().__getitem__(idx)
        return x, y, idx

train_sub = IndexedSubset(train_ds, range(n))
train_full = ConcatDataset([train_sub]*self_aug_factor)
train_loader = DataLoader(train_full, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_ds,   batch_size=batch_size, shuffle=False)

################################
# 2. ResNet20
################################
class BasicBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.gn1   = nn.GroupNorm(8, planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.gn2   = nn.GroupNorm(8, planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, 1, stride, bias=False),
                nn.GroupNorm(8, planes)
            )
    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        # avoid in-place modification on out
        out = out.clone() + self.shortcut(x)
        return F.relu(out)

class ResNet20(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.in_planes = 16
        self.conv1 = nn.Conv2d(3, 16, 3, 1, 1, bias=False)
        self.gn1   = nn.GroupNorm(8,16)
        self.layer1 = self._make_layer(16, 3, stride=1)
        self.layer2 = self._make_layer(32, 3, stride=2)
        self.layer3 = self._make_layer(64, 3, stride=2)
        self.avgpool= nn.AdaptiveAvgPool2d(1)
        self.fc     = nn.Linear(64, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers=[]
        for s in strides:
            layers.append(BasicBlock(self.in_planes, planes, s))
            self.in_planes= planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out,1)
        return self.fc(out)

def evaluate(model, loader):
    was_training = model.training
    model.eval()
    correct = 0
    total   = 0
    with torch.no_grad():
        for X,y in loader:
            X,y = X.to(device), y.to(device)
            preds = model(X).argmax(dim=1)
            correct += (preds==y).sum().item()
            total   += y.size(0)
    if was_training:
        model.train()
    return 100.0 * correct / total

################################
# 3. Build DP ResNet
################################
def build_model():
    net = ResNet20().to(device)
    errs= ModuleValidator.validate(net, strict=False)
    if errs:
        net = ModuleValidator.fix(net).to(device)
    return GradSampleModule(net)

################################
# 4. Sliding Window Momentum
################################
class LRUOrderedDict(OrderedDict):
    def __init__(self, *args, maxsize=10000, **kwargs):
        self.maxsize = maxsize
        super().__init__(*args, **kwargs)

    def __getitem__(self,key):
        val = super().__getitem__(key)
        self.move_to_end(key)
        return val

    def __setitem__(self,key,val):
        if key in self:
            self.move_to_end(key)
        super().__setitem__(key,val)
        if len(self) > self.maxsize:
            self.popitem(last=False)

momentum_dict = LRUOrderedDict(maxsize=M)

#############################################
# 5. Aggregated Grad Computation (inner momentum)
#############################################
def compute_inner_momentum_grads_idxed(dp_model, X, y, idxs):
    """
    - forward/backward => gather per-sample grads
    - average duplicates (from self_aug_factor)
    - apply "inner momentum"
    - return one vector per unique sample_id
    """
    dp_model.zero_grad()
    out = dp_model(X)
    loss= F.cross_entropy(out, y)
    loss.backward()

    bs = X.size(0)
    param_vecs= [None]*bs

    for p in dp_model.parameters():
        gs = getattr(p,"grad_sample",None)
        if gs is None:
            continue
        gs_flat= gs.view(bs,-1).detach()
        for i in range(bs):
            if param_vecs[i] is None:
                param_vecs[i]= gs_flat[i]
            else:
                param_vecs[i]= torch.cat([param_vecs[i], gs_flat[i]], dim=0)
        p.grad_sample= None

    # group by sample_id => average => momentum
    sample_to_vecs = {}
    for i in range(bs):
        sample_id= int(idxs[i].item())
        if sample_id not in sample_to_vecs:
            sample_to_vecs[sample_id] = []
        sample_to_vecs[sample_id].append(param_vecs[i].to(device))

    batch_v=[]
    for sample_id, grads_list in sample_to_vecs.items():
        g_i = torch.stack(grads_list, dim=0).mean(dim=0)

        if sample_id in momentum_dict:
            old_v_half = momentum_dict[sample_id]
            old_v = old_v_half.to(device)
        else:
            old_v = torch.zeros_like(g_i)

        new_v = inner_momentum*old_v + (1.0 - inner_momentum)*g_i
        momentum_dict[sample_id] = new_v.half().cpu()
        batch_v.append(new_v)

    return batch_v, len(sample_to_vecs)

################################
# 6. Single-Sum Clipping & Noise
################################
def clip_and_add_noise_outliers_batch(batch_v, clip_vals):
    """
    1. Sum all vectors in batch_v into 'grad_sum'.
    2. If norm(grad_sum) > sum(clip_vals), rescale it down
       so that norm(grad_sum)= sum(clip_vals).
    3. Add noise with std = noise_mult * max(clip_vals).
    4. Return grad_sum / len(batch_v).
    """
    # batch_v: a list of per-sample momentum vectors
    # clip_vals: matching list of clip thresholds, either 5.0 or 0.5
    device = batch_v[0].device
    count  = len(batch_v)

    # sum them up
    grad_sum = torch.stack(batch_v, dim=0).sum(dim=0)
    norm_ = grad_sum.norm(2)

    # threshold = sum of individual clip_i across the batch
    # so if each sample had c_i, we can hold the sum to sum(clip_vals).
    threshold = sum(clip_vals)
    if norm_ > threshold:
        scale = threshold / (norm_ + 1e-6)
        grad_sum = grad_sum * scale

    # add noise using the largest clip in this batch
    cmax = max(clip_vals)
    noise = torch.randn_like(grad_sum) * (noise_mult * cmax)
    grad_sum += noise

    # average over the number of distinct samples in the batch
    return grad_sum / count

def outer_step_outliers_batch(dp_model, optimizer, batch_v, clip_vals):
    final_grad = clip_and_add_noise_outliers_batch(batch_v, clip_vals)
    idx_start  = 0
    for p in dp_model.parameters():
        numel = p.numel()
        chunk = final_grad[idx_start: idx_start + numel]
        p.grad= chunk.view_as(p)
        idx_start += numel
    optimizer.step()

################################
# 7. Train => HighErr => clip=0.5 (batch-sum style)
################################
def train_clip_outliers_batchsum():
    """
    - For the first 60% of steps => everyone gets clip=5.0
    - After 60%, if a sample's running error >=95% => clip=0.5
      else => clip=5.0
    - Then do single-sum clipping & noise (like budget code)
    """
    global momentum_dict
    momentum_dict.clear()

    dp_net= build_model().to(device)
    optimizer= optim.SGD(dp_net.parameters(), lr=lr, momentum=outer_momentum, weight_decay=5e-4)
    scheduler= optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=SCHEDULE_MILESTONES, gamma=SCHEDULE_GAMMA)
    accountant= RDPAccountant()

    # track correctness stats
    sample_correct_count= np.zeros(n, dtype=int)
    sample_total_count  = np.zeros(n, dtype=int)

    total_steps= num_epochs * len(train_loader)
    step_count= 0

    for epoch in range(1, num_epochs+1):
        dp_net.train()
        losses=[]

        for X,y,idxs in train_loader:
            step_count+= 1
            X = X.to(device)
            y = y.to(device)

            # measure correctness quickly
            with torch.no_grad():
                out= dp_net(X)
                preds= out.argmax(dim=1)
                correct_vec= (preds == y).float()

            # update per-sample correctness stats
            for i, sidx in enumerate(idxs):
                sidx_i= int(sidx.item())
                sample_total_count[sidx_i]+=1
                sample_correct_count[sidx_i]+= int(correct_vec[i].item())

            # gather momentum vectors
            batch_v, unique_count= compute_inner_momentum_grads_idxed(dp_net, X, y, idxs)

            # quick measure loss
            with torch.no_grad():
                loss= F.cross_entropy(out,y)
            losses.append(loss.item())

            # build clip array
            clip_vals=[]
            if step_count <= DROP_AFTER_FRAC * total_steps:
                # 0 to 60% => clip=5.0 for all
                clip_vals= [DEFAULT_CLIP]* len(batch_v)
            else:
                # after 60%, check if error>=95% => clip=0.5
                # else => 5.0
                for i, sidx in enumerate(idxs):
                    sidx_i= int(sidx.item())
                    total_ = sample_total_count[sidx_i]
                    corr_  = sample_correct_count[sidx_i]
                    if total_>0:
                        err_rate= 1.0 - (corr_/ total_)
                        if err_rate <= HIGH_ERR_THRESHOLD:
                            clip_vals.append(OUTLIER_CLIP)  # 0.5
                        else:
                            clip_vals.append(DEFAULT_CLIP)   # 5.0
                    else:
                        # if no stats yet, default clip
                        clip_vals.append(DEFAULT_CLIP)

            # single-sum clipping + noise
            outer_step_outliers_batch(dp_net, optimizer, batch_v, clip_vals)

            sample_rate= unique_count / 50000.0  # for CIFAR-10 base size=50k
            accountant.step(noise_multiplier=noise_mult, sample_rate=sample_rate)

        scheduler.step()
        acc= evaluate(dp_net, test_loader)
        eps= accountant.get_epsilon(delta)
        print(f"[Epoch={epoch:02d}] step={step_count} Loss={np.mean(losses):.3f} "
              f"Acc={acc:.2f}% eps={eps:.2f}")

    final_acc= evaluate(dp_net, test_loader)
    final_eps= accountant.get_epsilon(delta)
    print(f"\nDone. (Batch-sum clipping, outliers=0.5 after 60%). "
          f"Final Acc={final_acc:.2f}% eps={final_eps:.2f}")


if __name__=="__main__":
    train_clip_outliers_batchsum()


100%|██████████| 170M/170M [00:10<00:00, 15.8MB/s]


[Epoch=01] step=150 Loss=2.078 Acc=28.45% eps=0.86
[Epoch=02] step=300 Loss=1.780 Acc=36.36% eps=1.17
[Epoch=03] step=450 Loss=1.609 Acc=40.09% eps=1.43
[Epoch=04] step=600 Loss=1.455 Acc=48.25% eps=1.65
[Epoch=05] step=750 Loss=1.356 Acc=52.17% eps=1.85
[Epoch=06] step=900 Loss=1.248 Acc=52.73% eps=2.03
[Epoch=07] step=1050 Loss=1.171 Acc=57.29% eps=2.20
[Epoch=08] step=1200 Loss=1.099 Acc=61.45% eps=2.36
[Epoch=09] step=1350 Loss=1.023 Acc=60.52% eps=2.51
[Epoch=10] step=1500 Loss=0.972 Acc=66.14% eps=2.65
[Epoch=11] step=1650 Loss=0.951 Acc=65.97% eps=2.79
[Epoch=12] step=1800 Loss=0.882 Acc=68.81% eps=2.92
[Epoch=13] step=1950 Loss=0.839 Acc=68.02% eps=3.05
[Epoch=14] step=2100 Loss=0.814 Acc=69.59% eps=3.18
[Epoch=15] step=2250 Loss=0.771 Acc=70.38% eps=3.30
[Epoch=16] step=2400 Loss=0.759 Acc=66.15% eps=3.42
[Epoch=17] step=2550 Loss=0.751 Acc=69.31% eps=3.53
[Epoch=18] step=2700 Loss=0.721 Acc=73.79% eps=3.64
[Epoch=19] step=2850 Loss=0.712 Acc=73.13% eps=3.75
[Epoch=20] step=30