In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
import gc
from dataclasses import dataclass
from typing import List, Tuple, Dict

@dataclass
class MaskingConfig:
    epochs: int = 7
    gumbel_tau: float = 1.0      # Gumbel-Softmax temperature
    sparsity_target: int = 150
    max_k: int = 300
    k_increment: int = 25
    kl_weight: float = 0.001
    lagrange_weight: float = 1.0 # Lagrange term for sparsity
    learning_rate: float = 1e-3
    batch_limit: int = 10
    confidence_threshold: float = 0.01

FEATURE_DIM = 28 * 28
HIDDEN_DIM = 300
HIDDEN2_DIM = 100
OUTPUT_DIM = 10
TARGET_LAYERS = ['fc2', 'fc3']

def gumbel_sigmoid(logits, tau=1.0, hard=False, eps=1e-10):
    """ Differentiable sampling of Bernoulli (binary mask) with Gumbel-Softmax trick. """
    U = torch.rand_like(logits)
    gumbel = -torch.log(-torch.log(U + eps) + eps)
    y = torch.sigmoid((logits + gumbel) / tau)
    if hard:
        y_hard = (y > 0.5).float()
        y = (y_hard - y).detach() + y
    return y

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(FEATURE_DIM, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN2_DIM)
        self.fc3 = nn.Linear(HIDDEN2_DIM, OUTPUT_DIM)

    def forward(self, x):
        x = x.view(-1, FEATURE_DIM)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class GCNMaskGenerator(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 1) # Output: predictive logit for Bernoulli mask

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        logits = self.conv2(x, edge_index).squeeze()
        return logits # Logits for mask sampling

def get_target_parameters(model: nn.Module, target_layers: List[str] = None) -> List[Tuple[str, nn.Parameter]]:
    if target_layers is None:
        target_layers = TARGET_LAYERS
    return [
        (name, param) for name, param in model.named_parameters()
        if any(layer in name for layer in target_layers) and 'weight' in name
    ]

def construct_graph_last_layers(model: nn.Module, input_data: torch.Tensor, target: torch.Tensor) -> Tuple[GeoData, Dict[str,int]]:
    model.zero_grad()
    output = model(input_data)
    loss = F.cross_entropy(output, target)
    loss.backward()

    params, grads, activ_in, activ_out, names, idx_map = [], [], [], [], [], {}
    idx = 0

    for name, param in get_target_parameters(model):
        if param.grad is not None:
            shape = param.shape
            param_data = param.detach().flatten().cpu()
            grad_data = param.grad.detach().flatten().cpu()
            in_feat_val = shape[1] if len(shape) > 1 else 1
            out_feat_val = shape[0]
            in_feat = torch.full((param_data.shape[0], 1), float(in_feat_val))
            out_feat = torch.full((param_data.shape[0], 1), float(out_feat_val))
            for i in range(param_data.shape[0]):
                params.append(param_data[i].unsqueeze(0))
                grads.append(grad_data[i].unsqueeze(0))
                activ_in.append(in_feat[i])
                activ_out.append(out_feat[i])
                names.append(name)
                idx_map[f"{name}_{i}"] = idx
                idx += 1

    x = torch.cat([
        torch.stack(params),
        torch.stack(grads),
        torch.stack(activ_in),
        torch.stack(activ_out)
    ], dim=1)

    edge_index = []
    n = len(names)
    for i in range(n):
        for j in range(i+1, min(i+10, n)):
            if names[i].split(".")[0] == names[j].split(".")[0]:
                edge_index.extend([[i, j], [j, i]])
            else:
                prefix_i = names[i].split(".")[0]
                prefix_j = names[j].split(".")[0]
                if (prefix_i in ['fc2', 'fc3']) and (prefix_j in ['fc2', 'fc3']):
                    edge_index.extend([[i, j], [j, i]])
    if len(edge_index) == 0:
        for i in range(n-1):
            edge_index.extend([[i, i+1], [i+1, i]])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    model.zero_grad()
    return GeoData(x=x, edge_index=edge_index), idx_map

def compute_kl_divergence(q_probs: torch.Tensor, prior_probs: torch.Tensor) -> torch.Tensor:
    eps = 1e-8
    q_probs = q_probs / (q_probs.sum() + eps)
    prior_probs = prior_probs / (prior_probs.sum() + eps)
    q_probs = torch.clamp(q_probs, eps, 1.0)
    prior_probs = torch.clamp(prior_probs, eps, 1.0)
    return (q_probs * (q_probs.log() - prior_probs.log())).sum()

def get_prior_distribution(model):
    weights = []
    for _, param in get_target_parameters(model):
        weights.append(param.detach().flatten().abs().cpu())
    full = torch.cat(weights)
    return full / (full.sum() + 1e-8)

def apply_mask_to_model(model, mask: torch.Tensor, idx_map: Dict[str, int]):
    with torch.no_grad():
        for name, param in get_target_parameters(model):
            param_flat = param.data.view(-1)
            mask_indices = [idx_map.get(f"{name}_{i}", None) for i in range(len(param_flat))]
            mask_tensor = torch.ones_like(param_flat)
            for i, idx in enumerate(mask_indices):
                if idx is not None and idx < len(mask):
                    mask_tensor[i] = 1 - mask[idx]
            param.data.copy_((param_flat * mask_tensor).view(param.shape))

def train_mask_generator(mask_net, classifier, dataset, device, prior_probs, config):
    mask_net.train()
    optimizer = optim.Adam(mask_net.parameters(), lr=config.learning_rate)
    temp_model = MNISTClassifier().to(device)
    lagrange_mult = nn.Parameter(torch.tensor([0.]), requires_grad=True)
    lagrange_optim = optim.Adam([lagrange_mult], lr=1e-2)

    for epoch in range(config.epochs):
        print(f"[MaskNet] Epoch {epoch+1}/{config.epochs}")
        total_loss = 0
        for i in range(min(config.batch_limit, len(dataset))):
            x, y = dataset[i]
            x = x.unsqueeze(0).to(device)
            y = torch.tensor([y], device=device)
            graph, idx_map = construct_graph_last_layers(classifier, x, y)
            graph = graph.to(device)
            logits = mask_net(graph)
            # Sample mask: stochastic, relaxed, with Gumbel noise
            sampled_mask = gumbel_sigmoid(logits, tau=config.gumbel_tau, hard=False)
            # Apply mask to temp copy of classifier
            with torch.no_grad():
                temp_model.load_state_dict(classifier.state_dict())
            apply_mask_to_model(temp_model, sampled_mask, idx_map)
            out = temp_model(x)
            # Red team loss: push prediction away from target (unlearning)
            loss_forget = F.cross_entropy(out, y)
            # Prior regularization (KL)
            kl = compute_kl_divergence(sampled_mask, prior_probs.to(device))
            # Sparsity control ("Lagrange multiplier" penalty)
            cardinality = sampled_mask.sum()
            sparsity_loss = (cardinality - config.sparsity_target) ** 2
            # Full loss
            loss = loss_forget + config.kl_weight * kl + config.lagrange_weight * sparsity_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[MaskNet] Epoch loss: {total_loss/config.batch_limit:.4f}")

def evaluate(model, loader, title="Eval"):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(next(model.parameters()).device), target.to(next(model.parameters()).device)
            pred = model(data).argmax(1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    accuracy = correct / total if total > 0 else 0
    print(f"{title}: Accuracy = {accuracy:.4f}")
    return accuracy

def main():
    config = MaskingConfig()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    transform = transforms.ToTensor()
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    target_idx = 0
    target_x, target_y = dataset[target_idx]
    retain_indices = [i for i in range(len(dataset)) if i != target_idx]
    retain_subset = Subset(dataset, retain_indices[:5000])
    retain_loader = DataLoader(retain_subset, batch_size=128, shuffle=True)

    classifier = MNISTClassifier().to(device)
    optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
    classifier.train()
    for epoch in range(7):
        total_loss = 0
        for data, target in retain_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            loss = F.cross_entropy(classifier(data), target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Avg Loss = {total_loss/len(retain_loader):.4f}")

    print("\nEvaluation before unlearning:")
    evaluate(classifier, retain_loader, "Before Unlearning")

    with torch.no_grad():
        target_x_batch = target_x.unsqueeze(0).to(device)
        target_y_batch = torch.tensor([target_y], device=device)
        out = classifier(target_x_batch)
        orig_conf = F.softmax(out, dim=1)[0, target_y].item()
        pred = out.argmax(dim=1).item()
        print(f"[Before Unlearning] Target sample - Pred: {pred}, Confidence: {orig_conf:.4f}, True: {target_y}")

    prior_probs = get_prior_distribution(classifier)
    mask_net = GCNMaskGenerator(in_channels=4).to(device)

    print("\nTraining probabilistic mask generator (Optimal Probabilistic Masking)...")
    train_mask_generator(mask_net, classifier, dataset, device, prior_probs, config)

    # Final mask application: sample for this datapoint, apply progressively larger masks until confidence drops below
    mask_net.eval()
    data_graph, idx_map = construct_graph_last_layers(classifier, target_x_batch, target_y_batch)
    logits = mask_net(data_graph.to(device)).detach()
    for k in range(config.sparsity_target, config.max_k + 1, config.k_increment):
        # Stochastic sampling, repeated 3 times to reflect distributional effect, take best
        min_conf = 1.0
        for rep in range(3):
            # Relaxed sampling for evaluation; in practice, hard masking for deployment
            mask_prob = gumbel_sigmoid(logits, tau=0.5, hard=True)
            mask = torch.zeros_like(mask_prob)
            topk_idxs = torch.topk(mask_prob, k).indices
            mask[topk_idxs] = 1
            apply_mask_to_model(classifier, mask, idx_map)
            with torch.no_grad():
                output = classifier(target_x_batch)
                conf = F.softmax(output, dim=1)[0, target_y].item()
            min_conf = min(min_conf, conf)
        print(f"k={k}, confidence={min_conf:.4f}")
        if min_conf < config.confidence_threshold:
            break

    print("\nFinal Evaluation after unlearning:")
    evaluate(classifier, retain_loader, "After Unlearning")
    with torch.no_grad():
        out = classifier(target_x_batch)
        pred = out.argmax(dim=1).item()
        conf = F.softmax(out, dim=1)[0, target_y].item()
        print(f"[After Unlearning] Target sample - Pred: {pred}, Confidence: {conf:.4f}, True: {target_y}")

if __name__ == "__main__":
    main()


Using device: cuda
Epoch 1: Avg Loss = 1.2058
Epoch 2: Avg Loss = 0.3810
Epoch 3: Avg Loss = 0.2887
Epoch 4: Avg Loss = 0.2444
Epoch 5: Avg Loss = 0.2222
Epoch 6: Avg Loss = 0.1713
Epoch 7: Avg Loss = 0.1525

Evaluation before unlearning:
Before Unlearning: Accuracy = 0.9618
[Before Unlearning] Target sample - Pred: 5, Confidence: 0.6969, True: 5

Training probabilistic mask generator (Optimal Probabilistic Masking)...
[MaskNet] Epoch 1/7
[MaskNet] Epoch loss: 109357472.6690
[MaskNet] Epoch 2/7
[MaskNet] Epoch loss: 16009.1666
[MaskNet] Epoch 3/7
[MaskNet] Epoch loss: 19726.5471
[MaskNet] Epoch 4/7
[MaskNet] Epoch loss: 20532.4908
[MaskNet] Epoch 5/7
[MaskNet] Epoch loss: 20743.3824
[MaskNet] Epoch 6/7
[MaskNet] Epoch loss: 20882.6840
[MaskNet] Epoch 7/7
[MaskNet] Epoch loss: 20954.3484
k=150, confidence=0.6999
k=175, confidence=0.5078
k=200, confidence=0.4995
k=225, confidence=0.5132
k=250, confidence=0.5000
k=275, confidence=0.5253
k=300, confidence=0.4524

Final Evaluation after unl