# Experiment: Unlearning Different Data Modalities and Predictive Uncertainty (ResNet-18)

We will compare how unlearning the same proportion (5%) of three different data modalities affects predictive uncertainty (ECE and Brier Score):

- Random instances (5% random training samples)
- Gaussian-noise instances (5% of samples with added Gaussian noise)
- Modified-label instances (5% of samples with randomly flipped labels)

Dataset/model: CIFAR-10 with ResNet-18
Unlearning method: First-order based (single-step gradient removal)
Uncertainty method: Temperature scaling (evaluate ECE and Brier Score)

Runs: 3 repeats per modality. The notebook will produce a results table with mean ± std for ECE and BS.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random
import copy
import time
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader, Subset, Dataset
from torch import optim
from torchvision import transforms, datasets, models
from scipy.special import softmax
from scipy.optimize import minimize
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

if torch.cuda.is_available():
    cudnn.benchmark = True
    device = "cuda:0"
else:
    device = "cpu"
print('Device:', device)

def set_random_seed(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_random_seed(42)

Device: cpu


In [2]:
data_mean = (0.4914, 0.4822, 0.4465)
data_std = (0.2023, 0.1994, 0.2010)
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(data_mean, data_std),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(data_mean, data_std),
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

cali_indices, test_indices = train_test_split(range(len(test_set)), test_size=0.5, stratify=test_set.targets)
cali_data = Subset(test_set, cali_indices)
test_data = Subset(test_set, test_indices)

batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
cali_loader = DataLoader(cali_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

print(f'Train: {len(train_data)}, Calib: {len(cali_data)}, Test: {len(test_data)}')

100.0%


Train: 50000, Calib: 5000, Test: 5000


In [3]:
# ResNet-18 for CIFAR-10
class ResNet18CIFAR(nn.Module):
    def __init__(self, num_classes=10, dropout_rate=0.5):
        super(ResNet18CIFAR, self).__init__()
        self.resnet = models.resnet18(pretrained=False)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()
        self.resnet.fc = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x, dropout=False):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)
        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)
        x = self.resnet.avgpool(x)
        x = x.view(x.size(0), -1)
        if dropout:
            x = self.dropout(x)
        x = self.resnet.fc(x)
        return x

def train(model, train_loader, loss_func, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss, n_batches, total, correct = 0.0, 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n_batches += 1
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        if (epoch+1) % max(1, epochs//5) == 0:
            print(f'Epoch {epoch+1}/{epochs}, loss: {running_loss/n_batches:.4f}, acc: {100*correct/total:.2f}%')

def test(model, data_loader):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# ECE and BS
def one_hot_encode(labels, num_classes=None):
    if num_classes is None:
        num_classes = len(np.unique(labels))
    return np.eye(num_classes)[labels]

def get_calibration_error(probs, labels, bin_upper_bounds, num_bins):
    if np.size(probs) == 0:
        return 0
    bin_indices = np.digitize(probs, bin_upper_bounds)
    sums = np.bincount(bin_indices, weights=probs, minlength=num_bins).astype(np.float64)
    counts = np.bincount(bin_indices, minlength=num_bins) + np.finfo(sums.dtype).eps
    confs = sums / counts
    accs = np.bincount(bin_indices, weights=labels, minlength=num_bins) / counts
    calibration_errors = accs - confs
    weighting = counts / float(len(probs.flatten()))
    weighted_calibration_error = calibration_errors * weighting
    return np.sum(np.abs(weighted_calibration_error))

def ECE(probs, labels, num_bins=10):
    num_classes = probs.shape[1]
    labels_matrix = one_hot_encode(labels, probs.shape[1])
    bin_upper_bounds = np.histogram_bin_edges([], bins=num_bins, range=(0.0, 1.0))[1:]
    labels_matrix = labels_matrix[range(len(probs)), np.argmax(probs, axis=1)]
    probs_matrix = probs[range(len(probs)), np.argmax(probs, axis=1)]
    calibration_error = get_calibration_error(probs_matrix.flatten(), labels_matrix.flatten(), bin_upper_bounds, num_bins)
    return calibration_error

def BS(probs, labels):
    n_samples, n_classes = probs.shape
    labels_matrix = one_hot_encode(labels, n_classes)
    brier_score = np.sum((probs - labels_matrix) ** 2) / n_samples
    return brier_score

# Temperature scaling
class TemperatureScaling():
    def __init__(self, temp=1, maxiter=50, solver="BFGS"):
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver
    def _loss_fun(self, x, probs, true):
        scaled_probs = self.predict(probs, x)
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    def fit(self, logits, true):
        true = true.flatten()
        opt = minimize(self._loss_fun, x0=1.0, args=(logits, true), options={'maxiter': self.maxiter}, method=self.solver)
        self.temp = opt.x[0]
        return opt
    def predict(self, logits, temp=None):
        if temp is None:
            return softmax(logits / self.temp, axis=1)
        else:
            return softmax(logits / temp, axis=1)

def get_outputs(model, data_loader, dropout=False):
    model.eval()
    all_labels, all_logits = [], []
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, dropout=dropout)
            all_labels.append(labels.detach().cpu().numpy())
            all_logits.append(outputs.detach().cpu().numpy())
    all_labels = np.concatenate(all_labels, axis=0)
    all_logits = np.concatenate(all_logits, axis=0)
    return all_labels, all_logits

In [4]:
# First-order unlearning
def get_grad_diff(model, unlearn_loader):
    loss_func = nn.CrossEntropyLoss(reduction="sum")
    model.train()
    grads = []
    for i, (images, labels) in enumerate(unlearn_loader):
        images, labels = images.to(device), labels.to(device)
        result_z = model(images)
        loss_z = loss_func(result_z, labels)
        loss_diff = -loss_z
        differentiable_params = [p for p in model.parameters() if p.requires_grad]
        gradients = torch.autograd.grad(loss_diff, differentiable_params, retain_graph=False)
        grads.append(gradients)
    grads = list(zip(*grads))
    for i in range(len(grads)):
        tmp = grads[i][0]
        for j in range(1, len(grads[i])):
            tmp = torch.add(tmp, grads[i][j])
        grads[i] = tmp
    return grads

def first_order_unlearn(model, unlearn_loader, tau=2e-5):
    net_unlearn = copy.deepcopy(model)
    diff = get_grad_diff(net_unlearn, unlearn_loader)
    d_theta = diff
    net_unlearn.eval()
    with torch.no_grad():
        for p in net_unlearn.parameters():
            if p.requires_grad:
                new_p = p - tau * d_theta.pop(0)
                p.copy_(new_p)
    return net_unlearn

In [None]:
# Modality generators
class NoisyDataset(Dataset):
    """Wraps a dataset and applies Gaussian noise to specified indices."""
    def __init__(self, base_dataset, noise_indices=set(), sigma=0.1):
        self.base = base_dataset
        self.noise_indices = set(noise_indices)
        self.sigma = sigma
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        x, y = self.base[idx]
        if idx in self.noise_indices:
            noise = torch.randn_like(x) * self.sigma
            x = x + noise
            x = torch.clamp(x, -3.0, 3.0)
        return x, y

class LabelFlippedDataset(Dataset):
    """Wraps a dataset and flips labels for specified indices to random labels."""
    def __init__(self, base_dataset, flip_indices=set(), num_classes=10):
        self.base = base_dataset
        self.flip_indices = set(flip_indices)
        self.num_classes = num_classes
    def __len__(self):
        return len(self.base)
    def __getitem__(self, idx):
        x, y = self.base[idx]
        if idx in self.flip_indices:
            new_y = int(random.choice([c for c in range(self.num_classes) if c != y]))
            return x, new_y
        return x, y

def run_modality_experiment(remove_prop=0.05, modality='random', num_runs=3, epochs=3, lr=0.01, tau=2e-5):
    """
    modality: 'random', 'noise', or 'label'
    Returns: dict with lists of ECE and BS (temperature-scaled) for each run
    """
    ece_list = []
    bs_list = []

    num_samples = len(train_data)
    remove_count = int(num_samples * remove_prop)

    for run in range(num_runs):
        print(f"Run {run+1}/{num_runs} — modality: {modality} — removing {remove_count} samples")
        set_random_seed(100 + run)

        # Train base model on full training set
        model = ResNet18CIFAR().to(device)
        loss_func = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        train(model, train_loader, loss_func, optimizer, epochs=epochs)

        # Prepare indices to remove based on modality
        if modality == 'random':
            remove_indices = set(random.sample(range(num_samples), remove_count))
            train_mod = train_data
        elif modality == 'noise':
            noisy_indices = set(random.sample(range(num_samples), int(num_samples * 0.2)))
            remove_indices = set(random.sample(list(noisy_indices), remove_count))
            train_mod = NoisyDataset(train_data, noise_indices=noisy_indices, sigma=0.2)
        elif modality == 'label':
            flipped_indices = set(random.sample(range(num_samples), int(num_samples * 0.2)))
            remove_indices = set(random.sample(list(flipped_indices), remove_count))
            train_mod = LabelFlippedDataset(train_data, flip_indices=flipped_indices, num_classes=10)
        else:
            raise ValueError('Unknown modality')

        # Create DataLoader for unlearn indices
        unlearn_subset = Subset(train_mod, list(remove_indices))
        retain_indices = [i for i in range(num_samples) if i not in remove_indices]
        retain_subset = Subset(train_mod, retain_indices)

        unlearn_loader = DataLoader(unlearn_subset, batch_size=128, shuffle=False)
        retain_loader = DataLoader(retain_subset, batch_size=128, shuffle=True)

        # Apply first-order unlearning
        model_unlearned = first_order_unlearn(model, unlearn_loader, tau=tau)

        # Evaluate using temperature scaling
        cali_labels, cali_logits = get_outputs(model_unlearned, cali_loader, dropout=False)
        test_labels, test_logits = get_outputs(model_unlearned, test_loader, dropout=False)

        ts = TemperatureScaling()
        try:
            ts.fit(cali_logits, cali_labels)
        except Exception as e:
            print('Temperature scaling fit failed:', e)
        test_probs_ts = ts.predict(test_logits)

        ece_val = ECE(test_probs_ts, test_labels)
        bs_val = BS(test_probs_ts, test_labels)

        print(f'  Run result — ECE: {ece_val:.4f}, BS: {bs_val:.4f}')

        ece_list.append(ece_val)
        bs_list.append(bs_val)

    return {'ECE': ece_list, 'BS': bs_list}

smoke_params = {'remove_prop': 0.05, 'num_runs': 1, 'epochs': 1}
print('Finished')

Ready — use run_modality_experiment() to run the full experiments.


In [None]:
modalities = ['random', 'noise', 'label']
all_stats = {}

for mod in modalities:
    print('\n' + '='*80)
    print(f'Running modality: {mod}')
    print('='*80)
    res = run_modality_experiment(remove_prop=0.05, modality=mod, num_runs=3, epochs=3, lr=0.01, tau=2e-5)

    ece_vals = np.array(res['ECE'])
    bs_vals = np.array(res['BS'])

    ece_mean = ece_vals.mean()
    ece_std = ece_vals.std()
    bs_mean = bs_vals.mean()
    bs_std = bs_vals.std()

    all_stats[mod] = {
        'ECE_vals': ece_vals.tolist(),
        'BS_vals': bs_vals.tolist(),
        'ECE_mean': float(ece_mean),
        'ECE_std': float(ece_std),
        'BS_mean': float(bs_mean),
        'BS_std': float(bs_std)
    }

print('\nRESULTS SUMMARY (mean ± std) — remove 5% per modality')
print(f"{'Modality':<12} {'ECE (mean±std)':<25} {'BS (mean±std)':<25}")
print('-'*70)
for mod in modalities:
    e_me, e_st = all_stats[mod]['ECE_mean'], all_stats[mod]['ECE_std']
    b_me, b_st = all_stats[mod]['BS_mean'], all_stats[mod]['BS_std']
    print(f"{mod:<12} {e_me:7.4f} ± {e_st:<7.4f}    {b_me:7.4f} ± {b_st:<7.4f}")

all_stats


Running modality: random
Run 1/3 — modality: random — removing 2500 samples




Epoch 1/8, loss: 1.4203, acc: 47.76%
Epoch 2/8, loss: 0.9031, acc: 68.03%
Epoch 3/8, loss: 0.6437, acc: 77.45%
Epoch 4/8, loss: 0.4496, acc: 84.40%
Epoch 5/8, loss: 0.2906, acc: 90.02%
Epoch 6/8, loss: 0.1822, acc: 93.81%
Epoch 7/8, loss: 0.1315, acc: 95.42%
Epoch 8/8, loss: 0.0813, acc: 97.32%
  Run result — ECE: 0.0169, BS: 0.3552
Run 2/3 — modality: random — removing 2500 samples




Epoch 1/8, loss: 1.4228, acc: 47.59%
Epoch 2/8, loss: 0.8931, acc: 68.20%
Epoch 3/8, loss: 0.6374, acc: 77.65%
Epoch 4/8, loss: 0.4443, acc: 84.58%
Epoch 5/8, loss: 0.2870, acc: 90.12%
Epoch 6/8, loss: 0.1830, acc: 93.65%
Epoch 7/8, loss: 0.1239, acc: 95.72%
Epoch 8/8, loss: 0.0818, acc: 97.29%
  Run result — ECE: 0.0133, BS: 0.3470
Run 3/3 — modality: random — removing 2500 samples




Epoch 1/8, loss: 1.4158, acc: 48.38%
Epoch 2/8, loss: 0.9188, acc: 67.18%
Epoch 3/8, loss: 0.6515, acc: 77.12%
Epoch 4/8, loss: 0.4639, acc: 83.73%
Epoch 5/8, loss: 0.3003, acc: 89.56%
Epoch 6/8, loss: 0.1938, acc: 93.26%
Epoch 7/8, loss: 0.1268, acc: 95.68%
Epoch 8/8, loss: 0.0964, acc: 96.71%
  Run result — ECE: 0.0108, BS: 0.3144

Running modality: noise
Run 1/3 — modality: noise — removing 2500 samples




Epoch 1/8, loss: 1.4203, acc: 47.76%
Epoch 2/8, loss: 0.9031, acc: 68.03%
Epoch 3/8, loss: 0.6437, acc: 77.45%
Epoch 4/8, loss: 0.4496, acc: 84.40%
Epoch 5/8, loss: 0.2906, acc: 90.02%
Epoch 6/8, loss: 0.1822, acc: 93.81%
Epoch 7/8, loss: 0.1315, acc: 95.42%
Epoch 8/8, loss: 0.0813, acc: 97.32%
  Run result — ECE: 0.0144, BS: 0.3812
Run 2/3 — modality: noise — removing 2500 samples




Epoch 1/8, loss: 1.4228, acc: 47.59%
Epoch 2/8, loss: 0.8931, acc: 68.20%
Epoch 3/8, loss: 0.6374, acc: 77.65%
Epoch 4/8, loss: 0.4443, acc: 84.58%
Epoch 5/8, loss: 0.2870, acc: 90.12%
Epoch 6/8, loss: 0.1830, acc: 93.65%
Epoch 7/8, loss: 0.1239, acc: 95.72%
Epoch 8/8, loss: 0.0818, acc: 97.29%
  Run result — ECE: 0.0160, BS: 0.3739
Run 3/3 — modality: noise — removing 2500 samples




Epoch 1/8, loss: 1.4158, acc: 48.38%
Epoch 2/8, loss: 0.9188, acc: 67.18%
Epoch 3/8, loss: 0.6515, acc: 77.12%
Epoch 4/8, loss: 0.4639, acc: 83.73%
Epoch 5/8, loss: 0.3003, acc: 89.56%
Epoch 6/8, loss: 0.1938, acc: 93.26%
Epoch 7/8, loss: 0.1268, acc: 95.68%
Epoch 8/8, loss: 0.0964, acc: 96.71%
  Run result — ECE: 0.0097, BS: 0.3239

Running modality: label
Run 1/3 — modality: label — removing 2500 samples




Epoch 1/8, loss: 1.4203, acc: 47.76%
Epoch 2/8, loss: 0.9031, acc: 68.03%


In [1]:
# Display names mapped to experiment keys
modalities_map = [
    ('Instances', 'random'),
    ('Gaussian Noises', 'noise'),
    ('Modified Labels', 'label'),
]

if 'all_stats' not in globals():
    print("No results available in `all_stats`. Run the experiment cell first to populate results.")
else:
    print('\nRESULTS SUMMARY (remove 5% per modality) — ResNet-18')
    print(f"{'Modality':<20} {'ECE (mean ± std)':<30} {'BS (mean ± std)':<30}")
    print('-' * 80)
    for display_name, key in modalities_map:
        # Look up by canonical key (used when experiments were run), fallback to display name
        s = all_stats.get(key, all_stats.get(display_name))
        if s is None:
            print(f"{display_name:<20} {'N/A':<30} {'N/A':<30}")
            continue
        if 'ECE_mean' in s and 'ECE_std' in s:
            e_me, e_st = s['ECE_mean'], s['ECE_std']
        elif 'ECE' in s and isinstance(s['ECE'], dict):
            e_me, e_st = s['ECE']['mean'], s['ECE']['std']
        else:
            e_vals = np.array(s.get('ECE_vals', s.get('ECE', [])))
            e_me, e_st = (float(np.mean(e_vals)), float(np.std(e_vals))) if len(e_vals)>0 else (None, None)

        if 'BS_mean' in s and 'BS_std' in s:
            b_me, b_st = s['BS_mean'], s['BS_std']
        elif 'BS' in s and isinstance(s['BS'], dict):
            b_me, b_st = s['BS']['mean'], s['BS']['std']
        else:
            b_vals = np.array(s.get('BS_vals', s.get('BS', [])))
            b_me, b_st = (float(np.mean(b_vals)), float(np.std(b_vals))) if len(b_vals)>0 else (None, None)

        if None in (e_me, e_st, b_me, b_st):
            print(f"{display_name:<20} {'N/A':<30} {'N/A':<30}")
        else:
            print(f"{display_name:<20} {e_me:7.4f} ± {e_st:<7.4f}    {b_me:7.4f} ± {b_st:<7.4f}")

No results available in `all_stats`. Run the experiment cell first to populate results.
