In [4]:
# Full OPM Implementation with Confidence Checks Before/After Unlearning

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
import numpy as np
import copy
import random
import time

# Seed for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

class Config:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 128
        self.test_batch_size = 1000
        self.epochs = 5
        self.lr = 0.01
        self.momentum = 0.9
        self.mask_lr = 0.001
        self.mask_epochs = 50
        self.kl_coeff = 1.0
        self.temperature = 0.5
        self.top_k_ratio = 0.2
        self.num_forget = 10
        self.prior = 0.5
        self.num_params = 0  # filled later
        self.edge_index = None
        self.overlap_threshold = 0.9  # cosine similarity threshold
        self.max_attempts = 50

config = Config()

# Dataset and DataLoader
print("Starting imports and dataset preparation...")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=False)
print("Datasets loaded.")

# Model
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# GCN for mask generation
class GCNMaskNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, 32)
        self.conv2 = GCNConv(32, 1)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = torch.sigmoid(self.conv2(x, edge_index))
        return x

# Helper functions
def copy_model(model):
    return copy.deepcopy(model)

def cosine_similarity(x, y):
    return F.cosine_similarity(x.view(1, -1), y.view(1, -1))

def gumbel_softmax_sample(logits, temperature):
    noise = torch.rand_like(logits)
    gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
    return torch.sigmoid((logits + gumbel) / temperature)

def kl_divergence(p, prior):
    p = torch.clamp(p, 1e-6, 1 - 1e-6)
    prior = torch.full_like(p, prior)
    return (p * (p / prior).log() + (1 - p) * ((1 - p) / (1 - prior)).log()).mean()

def apply_mask_to_model(model, mask, config):
    offset = 0
    for param in model.parameters():
        if param.requires_grad:
            numel = param.numel()
            param.data *= mask[offset:offset+numel].view(param.shape)
            offset += numel
    print("✅ Mask applied to model weights.")

def train_classifier(model, loader, config):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
    for epoch in range(config.epochs):
        total_loss = 0
        for data, target in loader:
            data, target = data.to(config.device), target.to(config.device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Classifier Epoch {epoch+1}] Loss: {total_loss:.4f}")

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    acc = correct / len(loader.dataset)
    print(f"🎯 Evaluation Accuracy: {acc:.4f}")
    return acc

def is_unlearning_feasible_margin(model, x, y, threshold=0.0):
    output = model(x.unsqueeze(0))
    probs = F.softmax(output, dim=1)
    top2 = torch.topk(probs, 2).values.squeeze()
    margin = top2[0] - top2[1]
    print(f"→ Margin: {margin.item():.4f} | Feasible: {'Yes' if margin > threshold else 'No'}")
    return margin > threshold

def check_mask_overlap(new_mask, previous_masks, threshold=0.9):
    if not previous_masks:
        return True
    for i, prev in enumerate(previous_masks):
        sim = cosine_similarity(new_mask, prev).item()
        print(f"→ Cosine similarity with previous mask {i}: {sim:.4f}")
        if sim > threshold:
            print("⛔ Overlap too high. Rejecting mask.\n")
            return False
    print("✅ Mask is unique (overlap below threshold).\n")
    return True

def build_graph_from_model(model):
    sizes = [param.numel() for param in model.parameters()]
    total_params = sum(sizes)
    config.num_params = total_params
    x = torch.ones((total_params, 1)).to(config.device)
    edge_index = []
    idx = 0
    for size in sizes:
        for i in range(size - 1):
            edge_index.append([idx + i, idx + i + 1])
            edge_index.append([idx + i + 1, idx + i])
        idx += size
    edge_index = torch.tensor(edge_index).t().contiguous().to(config.device)
    config.edge_index = edge_index
    print(f"Graph built with {total_params} nodes and {edge_index.shape[1]} edges.")

def train_mask_generator(mask_net, classifier, forget_data, config):
    mask_net.train()
    optimizer = torch.optim.Adam(mask_net.parameters(), lr=config.mask_lr)

    x_forget, y_forget = forget_data
    x_forget = x_forget.to(config.device).unsqueeze(0)
    y_forget = y_forget.to(config.device)

    x_init = torch.ones(config.num_params, 1).to(config.device)
    edge_index = config.edge_index.to(config.device)

    print(f"\n🔧 Starting GCN training to forget data point (label: {y_forget.item()})")
    start_time = time.time()

    for epoch in range(config.mask_epochs):
        optimizer.zero_grad()

        mask_logits = mask_net(x_init, edge_index).squeeze()
        mask_probs = torch.sigmoid(mask_logits)
        mask_sampled = gumbel_softmax_sample(mask_logits, temperature=config.temperature)

        temp_model = copy_model(classifier).to(config.device)
        apply_mask_to_model(temp_model, mask_sampled, config)

        output = temp_model(x_forget)
        ce_loss = F.cross_entropy(output, y_forget.unsqueeze(0))
        kl = kl_divergence(mask_probs, config.prior)

        loss = ce_loss + config.kl_coeff * kl
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"[Epoch {epoch+1:3d}/{config.mask_epochs}] Loss: {loss.item():.4f} | CE: {ce_loss.item():.4f} | KL: {kl.item():.4f}")

    end_time = time.time()
    print(f"✅ GCN training complete in {end_time - start_time:.2f}s\n")
    return mask_logits.detach()

def main():
    print("Starting training process...")
    classifier = SimpleClassifier().to(config.device)
    train_classifier(classifier, train_loader, config)

    acc_before = evaluate_model(classifier, test_loader)
    print(f"\nBefore Unlearning:\nAccuracy: {acc_before:.4f}")

    # Build graph once after training
    build_graph_from_model(classifier)

    accepted_masks = []
    forget_x, forget_y, forget_idx, forget_mask = None, None, None, None

    for attempt in range(1, config.max_attempts + 1):
        idx = random.randint(0, len(train_dataset) - 1)
        x_i, y_i = train_dataset[idx]
        x_i = x_i.to(config.device)
        y_i = torch.tensor(y_i).to(config.device)
        print(f"\nAttempt {attempt}: Trying forget point idx {idx} with label {y_i.item()}")

        if not is_unlearning_feasible_margin(classifier, x_i, y_i):
            print("❌ Not feasible by margin, trying another point...")
            continue

        print("✅ Forget point is feasible for unlearning.")

        mask_net = GCNMaskNet(1).to(config.device)
        mask_logits = train_mask_generator(mask_net, classifier, (x_i, y_i), config)

        if not check_mask_overlap(mask_logits, accepted_masks, config.overlap_threshold):
            print("❌ Mask overlap too high, trying next point...")
            continue

        print(f"✅ Forget point idx {idx} accepted for unlearning.")
        accepted_masks.append(mask_logits)
        forget_x, forget_y, forget_idx, forget_mask = x_i, y_i, idx, mask_logits
        break
    else:
        print("❌ No feasible forget point found within max attempts passing overlap checks.")
        return

    print("\nApplying mask to classifier weights for unlearning...")
    # Confidence before applying mask
    output_before = classifier(forget_x.unsqueeze(0))
    conf_before = F.softmax(output_before, dim=1)[0, forget_y].item()
    print(f"Confidence before unlearning on forget point: {conf_before:.4f}")

    apply_mask_to_model(classifier, forget_mask, config)

    # Confidence after applying mask
    output_after = classifier(forget_x.unsqueeze(0))
    conf_after = F.softmax(output_after, dim=1)[0, forget_y].item()
    print(f"Confidence after unlearning on forget point: {conf_after:.4f}")

    acc_after = evaluate_model(classifier, test_loader)
    print(f"\nAfter Unlearning:\nAccuracy: {acc_after:.4f}")

if __name__ == "__main__":
    main()


Starting imports and dataset preparation...
Datasets loaded.
Starting training process...
[Classifier Epoch 1] Loss: 204.2144
[Classifier Epoch 2] Loss: 72.7412
[Classifier Epoch 3] Loss: 49.0540
[Classifier Epoch 4] Loss: 36.0197
[Classifier Epoch 5] Loss: 27.8276
🎯 Evaluation Accuracy: 0.9776

Before Unlearning:
Accuracy: 0.9776
Graph built with 235146 nodes and 470280 edges.

Attempt 1: Trying forget point idx 80 with label 9
→ Margin: 0.0797 | Feasible: Yes
✅ Forget point is feasible for unlearning.

🔧 Starting GCN training to forget data point (label: 9)
✅ Mask applied to model weights.
[Epoch   1/50] Loss: 3.0776 | CE: 3.0482 | KL: 0.0294
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
[Epoch  10/50] Loss: 2.5271 | CE: 2.4999 | KL: 0.0

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
import numpy as np
import copy
import random
import time

# Seed for reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

class Config:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 128
        self.test_batch_size = 1000
        self.epochs = 5
        self.lr = 0.01
        self.momentum = 0.9
        self.mask_lr = 0.001
        self.mask_epochs = 50
        self.kl_coeff = 1.0
        self.temperature = 0.5
        self.top_k_ratio = 0.2
        self.num_forget = 10
        self.prior = 0.5
        self.num_params = 0  # filled later
        self.edge_index = None
        self.overlap_threshold = 0.9
        self.max_attempts = 50

config = Config()

# Dataset and DataLoader
print("Starting imports and dataset preparation...")
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.test_batch_size, shuffle=False)
print("Datasets loaded.")

# Model definition
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# GCN for mask generation
class GCNMaskNet(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, 32)
        self.conv2 = GCNConv(32, 1)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = torch.sigmoid(self.conv2(x, edge_index))
        return x

# Helper functions
def copy_model(model):
    return copy.deepcopy(model)

def cosine_similarity(x, y):
    return F.cosine_similarity(x.view(1, -1), y.view(1, -1))

def gumbel_softmax_sample(logits, temperature):
    noise = torch.rand_like(logits)
    gumbel = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
    return torch.sigmoid((logits + gumbel) / temperature)

def kl_divergence(p, prior):
    p = torch.clamp(p, 1e-6, 1 - 1e-6)
    prior = torch.full_like(p, prior)
    return (p * (p / prior).log() + (1 - p) * ((1 - p) / (1 - prior)).log()).mean()

def apply_mask_to_model(model, mask, config):
    offset = 0
    for param in model.parameters():
        if param.requires_grad:
            numel = param.numel()
            param.data *= mask[offset:offset+numel].view(param.shape)
            offset += numel
    print("✅ Mask applied to model weights.")

def train_classifier(model, loader, config):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum)
    for epoch in range(config.epochs):
        total_loss = 0
        for data, target in loader:
            data, target = data.to(config.device), target.to(config.device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Classifier Epoch {epoch+1}] Loss: {total_loss:.4f}")

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    acc = correct / len(loader.dataset)
    print(f"🎯 Evaluation Accuracy: {acc:.4f}")
    return acc

def is_unlearning_feasible_margin(model, x, y, threshold=0.0):
    output = model(x.unsqueeze(0))
    probs = F.softmax(output, dim=1)
    top2 = torch.topk(probs, 2).values.squeeze()
    margin = top2[0] - top2[1]
    print(f"→ Margin: {margin.item():.4f} | Feasible: {'Yes' if margin > threshold else 'No'}")
    return margin > threshold

def check_mask_overlap(new_mask, previous_masks, threshold=0.9):
    if not previous_masks:
        return True
    for i, prev in enumerate(previous_masks):
        sim = cosine_similarity(new_mask, prev).item()
        print(f"→ Cosine similarity with previous mask {i}: {sim:.4f}")
        if sim > threshold:
            print("⛔ Overlap too high. Rejecting mask.\n")
            return False
    print("✅ Mask is unique (overlap below threshold).\n")
    return True

def build_graph_from_model(model):
    sizes = [param.numel() for param in model.parameters()]
    total_params = sum(sizes)
    config.num_params = total_params
    x = torch.ones((total_params, 1)).to(config.device)
    edge_index = []
    idx = 0
    for size in sizes:
        for i in range(size - 1):
            edge_index.append([idx + i, idx + i + 1])
            edge_index.append([idx + i + 1, idx + i])
        idx += size
    edge_index = torch.tensor(edge_index).t().contiguous().to(config.device)
    config.edge_index = edge_index
    print(f"Graph built with {total_params} nodes and {edge_index.shape[1]} edges.")

def train_mask_generator(mask_net, classifier, forget_data, config):
    mask_net.train()
    optimizer = torch.optim.Adam(mask_net.parameters(), lr=config.mask_lr)

    x_forget, y_forget = forget_data
    x_forget = x_forget.to(config.device).unsqueeze(0)
    y_forget = y_forget.to(config.device)

    x_init = torch.ones(config.num_params, 1).to(config.device)
    edge_index = config.edge_index.to(config.device)

    print(f"\n🔧 Starting GCN training to forget data point (label: {y_forget.item()})")
    start_time = time.time()

    for epoch in range(config.mask_epochs):
        optimizer.zero_grad()

        mask_logits = mask_net(x_init, edge_index).squeeze()
        mask_probs = torch.sigmoid(mask_logits)
        mask_sampled = gumbel_softmax_sample(mask_logits, temperature=config.temperature)

        temp_model = copy_model(classifier).to(config.device)
        apply_mask_to_model(temp_model, mask_sampled, config)

        output = temp_model(x_forget)
        ce_loss = F.cross_entropy(output, y_forget.unsqueeze(0))
        kl = kl_divergence(mask_probs, config.prior)

        loss = ce_loss + config.kl_coeff * kl
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"[Epoch {epoch+1:3d}/{config.mask_epochs}] Loss: {loss.item():.4f} | CE: {ce_loss.item():.4f} | KL: {kl.item():.4f}")

    end_time = time.time()
    print(f"✅ GCN training complete in {end_time - start_time:.2f}s\n")
    return mask_logits.detach()

def main():
    print("Starting training process...")
    classifier = SimpleClassifier().to(config.device)
    train_classifier(classifier, train_loader, config)

    acc_before = evaluate_model(classifier, test_loader)
    print(f"\nBefore Unlearning:\nAccuracy: {acc_before:.4f}")

    build_graph_from_model(classifier)

    accepted_masks = []
    forget_x, forget_y, forget_idx, forget_mask = None, None, None, None

    for attempt in range(1, config.max_attempts + 1):
        idx = 80  # random.randint(0, len(train_dataset) - 1) for random point
        x_i, y_i = train_dataset[idx]
        x_i = x_i.to(config.device)
        y_i = torch.tensor(y_i).to(config.device)
        print(f"\nAttempt {attempt}: Trying forget point idx {idx} with label {y_i.item()}")

        if not is_unlearning_feasible_margin(classifier, x_i, y_i):
            print("❌ Not feasible by margin, trying another point...")
            continue

        print("✅ Forget point is feasible for unlearning.")

        mask_net = GCNMaskNet(1).to(config.device)
        mask_logits = train_mask_generator(mask_net, classifier, (x_i, y_i), config)

        if not check_mask_overlap(mask_logits, accepted_masks, config.overlap_threshold):
            print("❌ Mask overlap too high, trying next point...")
            continue

        print(f"✅ Forget point idx {idx} accepted for unlearning.")
        accepted_masks.append(mask_logits)
        forget_x, forget_y, forget_idx, forget_mask = x_i, y_i, idx, mask_logits
        break
    else:
        print("❌ No feasible forget point found within max attempts passing overlap checks.")
        return

    print("\nApplying mask to classifier weights for unlearning...")

    # Confidence BEFORE unlearning (per class)
    output_before = classifier(forget_x.unsqueeze(0))
    probs_before = F.softmax(output_before, dim=1).squeeze().tolist()
    print(f"\n📊 Confidence BEFORE unlearning on forget point (label {forget_y.item()}):")
    for cls, prob in enumerate(probs_before):
        print(f"  Class {cls}: {prob:.4f}")
    conf_before = probs_before[forget_y.item()]
    print(f"→ Target Class {forget_y.item()} Confidence: {conf_before:.4f}\n")

    apply_mask_to_model(classifier, forget_mask, config)

    # Confidence AFTER unlearning (per class)
    output_after = classifier(forget_x.unsqueeze(0))
    probs_after = F.softmax(output_after, dim=1).squeeze().tolist()
    print(f"\n📊 Confidence AFTER unlearning on forget point (label {forget_y.item()}):")
    for cls, prob in enumerate(probs_after):
        print(f"  Class {cls}: {prob:.4f}")
    conf_after = probs_after[forget_y.item()]
    print(f"→ Target Class {forget_y.item()} Confidence: {conf_after:.4f}\n")

    acc_after = evaluate_model(classifier, test_loader)
    print(f"\nAfter Unlearning:\nAccuracy: {acc_after:.4f}")

if __name__ == "__main__":
    main()


Starting imports and dataset preparation...
Datasets loaded.
Starting training process...
[Classifier Epoch 1] Loss: 204.2144
[Classifier Epoch 2] Loss: 72.7412
[Classifier Epoch 3] Loss: 49.0540
[Classifier Epoch 4] Loss: 36.0197
[Classifier Epoch 5] Loss: 27.8276
🎯 Evaluation Accuracy: 0.9776

Before Unlearning:
Accuracy: 0.9776
Graph built with 235146 nodes and 470280 edges.

Attempt 1: Trying forget point idx 80 with label 9
→ Margin: 0.0797 | Feasible: Yes
✅ Forget point is feasible for unlearning.

🔧 Starting GCN training to forget data point (label: 9)
✅ Mask applied to model weights.
[Epoch   1/50] Loss: 3.0776 | CE: 3.0482 | KL: 0.0294
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
✅ Mask applied to model weights.
[Epoch  10/50] Loss: 2.5271 | CE: 2.4999 | KL: 0.0