# Machine Unlearning via Information Regularization: MNIST

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import random


## Load MNIST

In [2]:
# Check for MPS availability
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Create the data set to unlearn and the remaining data set

In [3]:
# Select 10 data points with label = 1 to unlearn
label_1_indices = [i for i, (_, label) in enumerate(train_dataset) if label == 3]
selected_unlearn_indices = random.sample(label_1_indices, int(len(label_1_indices)*0.75))
remaining_indices = [i for i in range(len(train_dataset)) if i not in selected_unlearn_indices]

# Create datasets for unlearning and remaining data
unlearn_dataset = Subset(train_dataset, selected_unlearn_indices)
remaining_dataset = Subset(train_dataset, remaining_indices)

# Create data loaders
train_unlearn_loader = DataLoader(unlearn_dataset, batch_size=16, shuffle=True)
train_remain_loader = DataLoader(remaining_dataset, batch_size=64, shuffle=True)

## Original: None-regularizaed

In [4]:
# Define classifier
class MNISTClassifier(nn.Module):
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the image
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Training setup
def train_model(model, train_loader, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")

# Evaluation setup
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total:.2f}%')


In [5]:
# Initialize and train model
model_original = MNISTClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_original.parameters(), lr=0.001)

train_model(model_original, train_loader, criterion, optimizer, epochs=100)

Epoch 1/100, Loss: 0.2607
Epoch 2/100, Loss: 0.1091
Epoch 3/100, Loss: 0.0775
Epoch 4/100, Loss: 0.0612
Epoch 5/100, Loss: 0.0500
Epoch 6/100, Loss: 0.0416
Epoch 7/100, Loss: 0.0334
Epoch 8/100, Loss: 0.0309
Epoch 9/100, Loss: 0.0271
Epoch 10/100, Loss: 0.0233
Epoch 11/100, Loss: 0.0217
Epoch 12/100, Loss: 0.0174
Epoch 13/100, Loss: 0.0180
Epoch 14/100, Loss: 0.0191
Epoch 15/100, Loss: 0.0160
Epoch 16/100, Loss: 0.0136
Epoch 17/100, Loss: 0.0149
Epoch 18/100, Loss: 0.0123
Epoch 19/100, Loss: 0.0146
Epoch 20/100, Loss: 0.0115
Epoch 21/100, Loss: 0.0143
Epoch 22/100, Loss: 0.0103
Epoch 23/100, Loss: 0.0114
Epoch 24/100, Loss: 0.0122
Epoch 25/100, Loss: 0.0094
Epoch 26/100, Loss: 0.0093
Epoch 27/100, Loss: 0.0108
Epoch 28/100, Loss: 0.0112
Epoch 29/100, Loss: 0.0079
Epoch 30/100, Loss: 0.0097
Epoch 31/100, Loss: 0.0102
Epoch 32/100, Loss: 0.0079
Epoch 33/100, Loss: 0.0095
Epoch 34/100, Loss: 0.0124
Epoch 35/100, Loss: 0.0079
Epoch 36/100, Loss: 0.0077
Epoch 37/100, Loss: 0.0094
Epoch 38/1

In [49]:
evaluate_model(model_original, train_remain_loader)
evaluate_model(model_original, train_unlearn_loader)
evaluate_model(model_original, test_loader)

Accuracy: 99.89%
Accuracy: 99.98%
Accuracy: 97.81%


In [6]:
# Save trained model
save_path = "/Users/shizhouxu/Desktop/SX_Workspace/Data_Point_Unlearning/Resnet_save/model_original.pth"
torch.save(model_original.state_dict(), save_path)

## Re-Training from Scratch

In [None]:
# Initialize and train model
model_retrain = MNISTClassifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_retrain.parameters(), lr=0.001)

train_model(model_retrain, train_remain_loader, criterion, optimizer, epochs=100)


Epoch 1/100, Loss: 0.2671
Epoch 2/100, Loss: 0.1081
Epoch 3/100, Loss: 0.0767
Epoch 4/100, Loss: 0.0589
Epoch 5/100, Loss: 0.0454
Epoch 6/100, Loss: 0.0384
Epoch 7/100, Loss: 0.0306
Epoch 8/100, Loss: 0.0265
Epoch 9/100, Loss: 0.0240
Epoch 10/100, Loss: 0.0216
Epoch 11/100, Loss: 0.0183
Epoch 12/100, Loss: 0.0188
Epoch 13/100, Loss: 0.0158
Epoch 14/100, Loss: 0.0162
Epoch 15/100, Loss: 0.0145
Epoch 16/100, Loss: 0.0136
Epoch 17/100, Loss: 0.0133
Epoch 18/100, Loss: 0.0128
Epoch 19/100, Loss: 0.0128
Epoch 20/100, Loss: 0.0128
Epoch 21/100, Loss: 0.0118
Epoch 22/100, Loss: 0.0081
Epoch 23/100, Loss: 0.0125
Epoch 24/100, Loss: 0.0071
Epoch 25/100, Loss: 0.0092
Epoch 26/100, Loss: 0.0148
Epoch 27/100, Loss: 0.0070
Epoch 28/100, Loss: 0.0094
Epoch 29/100, Loss: 0.0094
Epoch 30/100, Loss: 0.0089
Epoch 31/100, Loss: 0.0071
Epoch 32/100, Loss: 0.0111
Epoch 33/100, Loss: 0.0064
Epoch 34/100, Loss: 0.0116
Epoch 35/100, Loss: 0.0048
Epoch 36/100, Loss: 0.0067
Epoch 37/100, Loss: 0.0110
Epoch 38/1

In [35]:
evaluate_model(model_retrain, train_remain_loader)
evaluate_model(model_retrain, train_unlearn_loader)
evaluate_model(model_retrain, test_loader)

Accuracy: 99.94%
Accuracy: 94.48%
Accuracy: 97.62%


## Re-Training from Learned Outcome

In [None]:
from copy import deepcopy

model_retrain_continue = deepcopy(model_original)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_retrain_continue.parameters(), lr=0.001)

train_model(model_retrain_continue, train_remain_loader, criterion, optimizer, epochs=100)


Epoch 1/100, Loss: 0.0052
Epoch 2/100, Loss: 0.0083
Epoch 3/100, Loss: 0.0049
Epoch 4/100, Loss: 0.0104
Epoch 5/100, Loss: 0.0047
Epoch 6/100, Loss: 0.0068
Epoch 7/100, Loss: 0.0066
Epoch 8/100, Loss: 0.0032
Epoch 9/100, Loss: 0.0063
Epoch 10/100, Loss: 0.0095
Epoch 11/100, Loss: 0.0037
Epoch 12/100, Loss: 0.0056
Epoch 13/100, Loss: 0.0080
Epoch 14/100, Loss: 0.0042
Epoch 15/100, Loss: 0.0061
Epoch 16/100, Loss: 0.0073
Epoch 17/100, Loss: 0.0046
Epoch 18/100, Loss: 0.0084
Epoch 19/100, Loss: 0.0041
Epoch 20/100, Loss: 0.0045
Epoch 21/100, Loss: 0.0071
Epoch 22/100, Loss: 0.0050
Epoch 23/100, Loss: 0.0032
Epoch 24/100, Loss: 0.0085
Epoch 25/100, Loss: 0.0054
Epoch 26/100, Loss: 0.0064
Epoch 27/100, Loss: 0.0079
Epoch 28/100, Loss: 0.0066
Epoch 29/100, Loss: 0.0060
Epoch 30/100, Loss: 0.0028
Epoch 31/100, Loss: 0.0054
Epoch 32/100, Loss: 0.0068
Epoch 33/100, Loss: 0.0059
Epoch 34/100, Loss: 0.0076
Epoch 35/100, Loss: 0.0028
Epoch 36/100, Loss: 0.0041
Epoch 37/100, Loss: 0.0081
Epoch 38/1

In [34]:
evaluate_model(model_retrain_continue, train_remain_loader)
evaluate_model(model_retrain_continue, train_unlearn_loader)
evaluate_model(model_retrain_continue, test_loader)

Accuracy: 99.96%
Accuracy: 97.24%
Accuracy: 97.97%


## Unlearning, gamma = 1

In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

device = "mps"  # or "cuda" or "cpu"

def train_unlearn_model(model, train_unlearn_loader, train_remain_loader, criterion, optimizer, gamma=1.0, epochs=10):

    model.train()

    mi_history = []
    ce_history = []
    total_loss_history = []

    for epoch in range(epochs):
        epoch_total_loss = 0.0
        epoch_ce = 0.0
        epoch_mi = 0.0
        n_batches = 0

        # Zip remain/unlearn loaders
        for (images_remain, digits_remain), (images_unlearn, digits_unlearn) in zip(train_remain_loader, train_unlearn_loader):
            # Move to device
            images_remain = images_remain.to(device)
            digits_remain = digits_remain.to(device)  # [B], each in 0..9
            images_unlearn = images_unlearn.to(device)
            digits_unlearn = digits_unlearn.to(device)  # [B], also 0..9
            optimizer.zero_grad()

            # -----------------
            # 1) Classification loss on "remain" portion
            # -----------------
            outputs_remain = model(images_remain)  # shape [B,10]
            ce_loss = criterion(outputs_remain, digits_remain)  # standard 10-class CE

            # -----------------
            # 2) Build combined set X_train, Z_train for MI
            #    We'll interpret remain/unlearn => z=0 or z=1
            # -----------------
            X_0, Z_0, X_1, Z_1 = create_train_labels(images_remain, images_unlearn)
            X_train = torch.cat((X_0, X_1), dim=0)   # shape [3B,...]
            Z_train = torch.cat((Z_0, Z_1), dim=0)   # shape [3B]

            # Shuffle them together
            idx = torch.randperm(len(X_train), device=device)
            X_train_shuffled = X_train[idx]
            Z_train_shuffled = Z_train[idx]

            # Forward pass => Y_train_shuffled: [3B,10]
            Y_train_shuffled = model(X_train_shuffled)

            # 3) Compute multi-class MI( Y; Z )
            mi_val = compute_mutual_information_multiclass(Y_train_shuffled, Z_train_shuffled)

            total_loss = ce_loss + gamma*mi_val
            total_loss.backward()
            optimizer.step()

            epoch_ce += ce_loss.item()
            epoch_mi += mi_val.item()
            epoch_total_loss += total_loss.item()
            n_batches += 1

        # End of epoch stats
        avg_ce = epoch_ce / n_batches
        avg_mi = epoch_mi / n_batches
        avg_loss = epoch_total_loss / n_batches

        ce_history.append(avg_ce)
        mi_history.append(avg_mi)
        total_loss_history.append(avg_loss)

        print(f"Epoch {epoch+1}/{epochs} | "
              f"CE: {avg_ce:.4f}, MI: {avg_mi:.4f}, Total Loss: {avg_loss:.4f}")

    return ce_history, mi_history, total_loss_history

# Define binary labels Z for (X_0, X_1)
def create_train_labels(remain_images, unlearn_images):
    X_0 = torch.cat([remain_images, unlearn_images], dim=0)  # shape: [2B, ...]
    Z_0 = torch.zeros(len(X_0), device=X_0.device)           # [2B], label=0
    X_1 = remain_images                                      # shape: [B, ...]
    Z_1 = torch.ones(len(X_1), device=X_1.device)            # [B],  label=1
    return X_0, Z_0, X_1, Z_1

def compute_mutual_information_multiclass(predicted_outputs, labels):
    
    eps = 1e-10
    # 1) Convert logits to probabilities
    probs = torch.softmax(predicted_outputs, dim=1)  # [N, 10]
    probs = probs.clamp(eps, 1.0 - eps)

    N, C = probs.shape  # e.g., C=10

    # 2) p(Z=1), p(Z=0)
    p_z1 = labels.float().mean()  # scalar
    p_z0 = 1.0 - p_z1

    # 3) Compute joint distributions:
    #    p(\hat{y}=c, z=1) = (1/N) * sum_{i=1 to N} [ probs[i, c] * 1{labels[i] = 1} ]
    joint_p_yz1 = (probs * labels.unsqueeze(1)).sum(dim=0) / N  # shape [C]
    joint_p_yz0 = (probs * (1 - labels).unsqueeze(1)).sum(dim=0) / N

    # 4) Marginal p(\hat{Y}=c) = p(\hat{Y}=c, Z=1) + p(\hat{Y}=c, Z=0)
    p_y = joint_p_yz1 + joint_p_yz0  # shape [C]

    # 5) Summation for mutual information
    #    I = sum_c [ p(y=c,z=1)*ln(...)+ p(y=c,z=0)*ln(...) ]
    term_yz1 = joint_p_yz1 * torch.log(joint_p_yz1 / (p_y * p_z1 + eps) + eps)
    term_yz0 = joint_p_yz0 * torch.log(joint_p_yz0 / (p_y * p_z0 + eps) + eps)

    mi = term_yz1.sum() + term_yz0.sum()  # scalar

    return mi

In [47]:
model_unlearn = deepcopy(model_original)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_unlearn.parameters(), lr=0.001)

train_unlearn_model(model_unlearn, train_unlearn_loader, train_remain_loader, criterion, optimizer, gamma=4, epochs=5)
evaluate_model(model_unlearn, train_remain_loader)
evaluate_model(model_unlearn, train_unlearn_loader)
evaluate_model(model_unlearn, test_loader)

Epoch 1/5 | CE: 0.0249, MI: 0.0470, Total Loss: 0.2131
Epoch 2/5 | CE: 0.0103, MI: 0.0476, Total Loss: 0.2008
Epoch 3/5 | CE: 0.0136, MI: 0.0474, Total Loss: 0.2030
Epoch 4/5 | CE: 0.0105, MI: 0.0471, Total Loss: 0.1989
Epoch 5/5 | CE: 0.0240, MI: 0.0448, Total Loss: 0.2033
Accuracy: 99.80%
Accuracy: 94.58%
Accuracy: 97.42%
