In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
import random
import numpy as np

# Fix random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)

# --- Config ---
class Config:
    batch_size = 128
    lr = 0.001
    epochs = 2  # classifier epochs
    mask_epochs = 2  # mask training epochs per refinement
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mask_threshold = 0.5
    iterative_refinement_steps = 3
    mask_layers = ['fc2.weight', 'fc3.weight']  # last two layers for masking
    kl_weight = 0.1  # weight of KL divergence regularization

config = Config()

# --- Load MNIST ---
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# Define subset for forget indices (example: 50 points after some index)
forget_indices = list(range(50, 100))
forget_data = Subset(train_dataset, forget_indices)

# --- Classifier ---

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# --- Mask Generator GCN for multiple layers ---

class MaskGenerator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 1)  # output logits per node

    def forward(self, data):
        x, edge_index = data.x, data.edge_index.to(data.x.device)
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        logits = self.conv3(x, edge_index).squeeze(-1)  # shape [num_nodes]
        return logits

# --- Build Param Graph with same output node connections ---

def build_param_graph_same_output(model):
    param_tensors = []
    param_shapes = []
    param_sizes = []
    for layer in config.mask_layers:
        param = dict(model.named_parameters())[layer]
        param_shapes.append(param.shape)
        flat_param = param.view(-1).detach().cpu().unsqueeze(1)
        param_tensors.append(flat_param)
        param_sizes.append(flat_param.size(0))

    x = torch.cat(param_tensors, dim=0)
    cum_sizes = np.cumsum([0] + param_sizes)

    edge_src = []
    edge_dst = []

    # Layer 0: fc2.weight, shape = (128, 256)
    out0, in0 = param_shapes[0]

    # Layer 1: fc3.weight, shape = (10, 128)
    out1, in1 = param_shapes[1]

    # Within fc2.weight: connect weights with same output neuron (same row)
    start0 = cum_sizes[0]
    for out_neuron in range(out0):
        row_start = start0 + out_neuron * in0
        for i in range(in0 - 1):
            src = row_start + i
            dst = row_start + i + 1
            edge_src.append(src)
            edge_dst.append(dst)
            edge_src.append(dst)
            edge_dst.append(src)

    # Within fc3.weight: connect weights with same output neuron (same row)
    start1 = cum_sizes[1]
    for out_neuron in range(out1):
        row_start = start1 + out_neuron * in1
        for i in range(in1 - 1):
            src = row_start + i
            dst = row_start + i + 1
            edge_src.append(src)
            edge_dst.append(dst)
            edge_src.append(dst)
            edge_dst.append(src)

    # Between fc2.weight and fc3.weight:
    # connect weights corresponding to same neuron index at boundary:
    # for each neuron j in fc2 output (0..127),
    # connect all weights in fc2 row j to all weights in fc3 column j

    for j in range(in1):  # in1 == 128
        fc2_row_start = start0 + j * in0
        fc2_nodes = list(range(fc2_row_start, fc2_row_start + in0))

        fc3_nodes = [start1 + i * in1 + j for i in range(out1)]

        for src in fc2_nodes:
            for dst in fc3_nodes:
                edge_src.append(src)
                edge_dst.append(dst)
                edge_src.append(dst)
                edge_dst.append(src)

    edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
    data = GeoData(x=x, edge_index=edge_index)
    return data, cum_sizes

# --- Gumbel Sigmoid for soft/hard mask reparameterization ---

def gumbel_sigmoid(logits, tau=1.0, hard=False, eps=1e-10):
    gumbels = -torch.empty_like(logits).exponential_().log()
    y = (logits + gumbels) / tau
    y = torch.sigmoid(y)
    if hard:
        y_hard = (y > 0.5).float()
        y = (y_hard - y).detach() + y
    return y

# --- KL divergence for Bernoulli (mask logits q vs prior p) ---

def bernoulli_kl(logits, prior=0.5):
    q = torch.sigmoid(logits)
    p = torch.tensor(prior, device=logits.device)
    kl = q * (torch.log(q + 1e-10) - torch.log(p + 1e-10)) + \
         (1 - q) * (torch.log(1 - q + 1e-10) - torch.log(1 - p + 1e-10))
    return kl.sum()

# --- Apply mask logits to classifier parameters ---

def apply_mask_logits_to_classifier(classifier, mask_logits, cum_sizes):
    mask_probs = torch.sigmoid(mask_logits).detach().cpu()
    with torch.no_grad():
        for idx, layer_name in enumerate(config.mask_layers):
            param = dict(classifier.named_parameters())[layer_name]
            start = cum_sizes[idx]
            end = cum_sizes[idx + 1]
            mask_layer = mask_probs[start:end].view(param.shape).to(param.device)
            param.mul_(mask_layer)

# --- Evaluate confidence on a data point ---

def get_confidence(model, data_point, device):
    model.eval()
    x, y = data_point
    x = x.unsqueeze(0).to(device)  # add batch dim
    with torch.no_grad():
        output = model(x)
        probs = F.softmax(output, dim=1)
    confidence = probs[0, y].item()
    return confidence

# --- Evaluate accuracy on a dataloader ---

def evaluate_model(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            preds = output.argmax(dim=1)
            correct += (preds == target).sum().item()
            total += target.size(0)
    acc = correct / total
    print(f"Accuracy: {acc:.4f}")
    return acc

# --- Train classifier ---

def train_classifier(model, loader, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(config.epochs):
        total_loss = 0
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Classifier] Epoch {epoch+1}, Loss={total_loss:.4f}")

# --- Train mask generator ---

def train_mask_generator(mask_net, classifier, forget_data, idx_map, config):
    mask_net.train()
    optimizer = optim.Adam(mask_net.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    prior_prob = 0.5

    for epoch in range(config.mask_epochs):
        total_loss = 0
        for x, y in forget_data:
            x = x.unsqueeze(0).to(config.device)
            y_tensor = torch.tensor([y], dtype=torch.long).to(config.device)

            data_graph, _ = build_param_graph_same_output(classifier)
            data_graph = data_graph.to(config.device)

            logits = mask_net(data_graph)
            mask = gumbel_sigmoid(logits, tau=1.0, hard=False)

            temp_model = MNISTClassifier().to(config.device)
            temp_model.load_state_dict(classifier.state_dict())

            with torch.no_grad():
                start_idx = 0
                for idx, layer_name in enumerate(config.mask_layers):
                    param = dict(temp_model.named_parameters())[layer_name]
                    size = param.numel()
                    layer_mask = mask[start_idx:start_idx + size].view(param.shape)
                    param.mul_(layer_mask)
                    start_idx += size

            temp_model.eval()
            output = temp_model(x)
            ce_loss = criterion(output, y_tensor)

            kl_loss = bernoulli_kl(logits, prior_prob)

            loss = ce_loss + config.kl_weight * kl_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"[MaskNet] Epoch {epoch+1}, Loss={total_loss:.4f}")

# --- Iterative refinement ---

def iterative_refinement(classifier, mask_net, forget_data, idx_map, config):
    best_mask_logits = None
    best_acc = 0

    for iteration in range(config.iterative_refinement_steps):
        print(f"\nIterative Refinement Step {iteration+1}/{config.iterative_refinement_steps}")

        train_mask_generator(mask_net, classifier, forget_data, idx_map, config)

        data_graph, cum_sizes = build_param_graph_same_output(classifier)
        data_graph = data_graph.to(config.device)

        mask_logits = mask_net(data_graph)

        hard_mask = (torch.sigmoid(mask_logits) > config.mask_threshold).float()

        temp_model = MNISTClassifier().to(config.device)
        temp_model.load_state_dict(classifier.state_dict())

        with torch.no_grad():
            start_idx = 0
            for idx, layer_name in enumerate(config.mask_layers):
                param = dict(temp_model.named_parameters())[layer_name]
                size = param.numel()
                layer_mask = hard_mask[start_idx:start_idx+size].view(param.shape).to(param.device)
                param.mul_(layer_mask)
                start_idx += size

        train_acc = evaluate_model(temp_model, train_loader, config.device)

        confs = []
        for data_point in forget_data:
            conf = get_confidence(temp_model, data_point, config.device)
            confs.append(conf)
        mean_conf = np.mean(confs)
        print(f"Mean forget point confidence after masking: {mean_conf:.4f}")

        if train_acc >= best_acc and mean_conf < 0.1:
            best_acc = train_acc
            best_mask_logits = mask_logits.detach()

        if best_mask_logits is not None and best_acc >= 0.95 and mean_conf < 0.1:
            print("Stopping early: good mask found.")
            break

    return best_mask_logits, cum_sizes

# --- Main pipeline ---

def main():
    classifier = MNISTClassifier().to(config.device)
    print("Training classifier...")
    train_classifier(classifier, train_loader, config.device)
    print("Accuracy before unlearning:")
    evaluate_model(classifier, test_loader, config.device)

    # Print forget point confidences BEFORE unlearning
    print("\nForget point confidences BEFORE unlearning:")
    for i, data_point in enumerate(forget_data):
        conf = get_confidence(classifier, data_point, config.device)
        print(f"Forget point {i} confidence before unlearning: {conf:.4f}")

    idx_map = None
    mask_net = MaskGenerator(in_channels=1).to(config.device)

    best_mask_logits, cum_sizes = iterative_refinement(classifier, mask_net, forget_data, idx_map, config)

    if best_mask_logits is None:
        print("No good mask found, using last mask logits.")
        data_graph, cum_sizes = build_param_graph_same_output(classifier)
        best_mask_logits = mask_net(data_graph.to(config.device)).detach()

    print("Applying final mask to classifier...")
    apply_mask_logits_to_classifier(classifier, best_mask_logits, cum_sizes)

    print("Accuracy after unlearning:")
    evaluate_model(classifier, test_loader, config.device)

    print("\nForget point confidences AFTER unlearning:")
    for i, data_point in enumerate(forget_data):
        conf = get_confidence(classifier, data_point, config.device)
        print(f"Forget point {i} confidence after unlearning: {conf:.4f}")

if __name__ == "__main__":
    main()


Training classifier...
[Classifier] Epoch 1, Loss=165.5215
[Classifier] Epoch 2, Loss=62.5292
Accuracy before unlearning:
Accuracy: 0.9678

Forget point confidences BEFORE unlearning:
Forget point 0 confidence before unlearning: 0.9921
Forget point 1 confidence before unlearning: 0.9999
Forget point 2 confidence before unlearning: 0.9997
Forget point 3 confidence before unlearning: 0.9086
Forget point 4 confidence before unlearning: 0.8934
Forget point 5 confidence before unlearning: 0.9430
Forget point 6 confidence before unlearning: 0.9999
Forget point 7 confidence before unlearning: 0.9929
Forget point 8 confidence before unlearning: 0.9998
Forget point 9 confidence before unlearning: 0.9911
Forget point 10 confidence before unlearning: 0.9949
Forget point 11 confidence before unlearning: 0.9973
Forget point 12 confidence before unlearning: 0.9990
Forget point 13 confidence before unlearning: 0.9872
Forget point 14 confidence before unlearning: 0.9887
Forget point 15 confidence befo

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
import numpy as np
import random

# Config
class Config:
    batch_size = 128
    epochs = 2
    lr = 0.01
    mask_lr = 0.005
    mask_epochs = 5  # increased mask training epochs
    kl_coeff = 0.1
    l1_coeff = 1e-3  # L1 regularization weight for mask sparsity
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = Config()

# Seed
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

# Model
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# GCN Mask Generator
class MaskNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        return torch.sigmoid(self.conv2(x, edge_index))

# Dataset
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root=".", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=".", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

# Helpers
def evaluate_model(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, preds = torch.max(outputs, 1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    acc = correct / total
    print(f"Accuracy: {acc:.4f}")
    return acc

def get_forget_point():
    for i, (x, y) in enumerate(train_dataset):
        if y == 2:  # Example target digit
            return i, x.unsqueeze(0), y

# Build param graph for last two layers
layer_shapes = {
    'fc2.weight': (128, 256),
    'fc2.bias': (128,),
    'fc3.weight': (10, 128),
    'fc3.bias': (10,)
}

def build_param_graph(model):
    param_nodes, edge_index, features = [], [], []
    idx_map, reverse_map = {}, {}
    idx = 0
    
    for name, param in model.named_parameters():
        if name in layer_shapes:
            param_nodes.append(param.detach().view(-1))
            idx_map[name] = (idx, idx + param.numel())
            for i in range(param.numel()):
                reverse_map[idx + i] = name
            idx += param.numel()

    param_tensor = torch.cat(param_nodes)
    
    # Connect nodes only if they belong to same output neuron (for weights)
    def connect_same_output(shape, start_idx):
        edges = []
        if len(shape) == 2:
            for i in range(shape[0]):
                base = start_idx + i * shape[1]
                # fully connect weights in the same output neuron row (excluding self loops)
                edges += [[base + j, base + k] for j in range(shape[1]) for k in range(shape[1]) if j != k]
        else:
            # bias: no connections
            pass
        return edges

    edge_index = []
    for name, (start, end) in idx_map.items():
        shape = model.state_dict()[name].shape
        edge_index += connect_same_output(shape, start)

    if len(edge_index) == 0:
        # To avoid error if no edges (e.g., only biases)
        edge_index = torch.empty((2,0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edge_index).t().contiguous()

    return GeoData(x=param_tensor.view(-1, 1), edge_index=edge_index), param_tensor.view(-1, 1).shape[0]

# KL divergence
def kl_divergence(q):
    eps = 1e-8
    p = torch.full_like(q, 0.5)
    q = torch.clamp(q, eps, 1 - eps)
    return (q * torch.log(q / p) + (1 - q) * torch.log((1 - q) / (1 - p))).mean()

def apply_mask_to_weights(model, mask_tensor):
    start = 0
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in layer_shapes:
                length = param.numel()
                mask = mask_tensor[start:start+length].view(param.shape)
                param.mul_(mask)
                start += length

# Clone model utility
def clone_model_with_mask(model, mask):
    new_model = Classifier().to(config.device)
    new_model.load_state_dict(model.state_dict())
    apply_mask_to_weights(new_model, mask)
    return new_model

# Training Mask Generator
def train_mask_generator(mask_net, model, forget_data, device, config):
    optimizer = optim.Adam(mask_net.parameters(), lr=config.mask_lr)
    model.eval()
    data, _ = build_param_graph(model)
    data = data.to(device)

    conf_coeff = 10.0  # added confidence loss weight

    for epoch in range(config.mask_epochs):
        mask_net.train()
        optimizer.zero_grad()
        soft_mask = mask_net(data).view(-1)

        temp_model = clone_model_with_mask(model, soft_mask)

        temp_model.eval()
        x, y = forget_data[0].to(device), torch.tensor([forget_data[1]]).to(device)
        out = temp_model(x)
        conf = F.softmax(out, dim=1)[0, y.item()]

        loss = conf_coeff * conf + config.kl_coeff * kl_divergence(soft_mask) + config.l1_coeff * soft_mask.mean()

        loss.backward()

        # Optional: print grad norm
        grad_norm = mask_net.conv2.weight.grad.norm().item() if mask_net.conv2.weight.grad is not None else 0
        print(f"[MaskNet] Epoch {epoch+1}/{config.mask_epochs} - Loss: {loss.item():.6f} - Confidence: {conf.item():.6f} - GradNorm: {grad_norm:.6f}")

        optimizer.step()

    hard_mask = (soft_mask > 0.5).float().detach()
    return hard_mask


# Main pipeline
if __name__ == "__main__":
    classifier = Classifier().to(config.device)
    optimizer = optim.Adam(classifier.parameters(), lr=config.lr)
    loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

    for epoch in range(config.epochs):
        classifier.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(config.device), y.to(config.device)
            optimizer.zero_grad()
            out = classifier(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Classifier] Epoch {epoch+1}, Loss: {total_loss:.4f}")

    print("Before Unlearning:")
    evaluate_model(classifier, test_loader, config.device)

    forget_idx, forget_x, forget_y = get_forget_point()
    forget_data = (forget_x.to(config.device), forget_y)

    classifier.eval()
    conf_before = F.softmax(classifier(forget_data[0]), dim=1)[0, forget_data[1]].item()
    print(f"Confidence of forget point before unlearning: {conf_before:.4f}")

    mask_net = MaskNet(1, 32, 1).to(config.device)
    final_mask = train_mask_generator(mask_net, classifier, forget_data, config.device, config)
    apply_mask_to_weights(classifier, final_mask)

    print("After Unlearning:")
    evaluate_model(classifier, test_loader, config.device)

    conf_after = F.softmax(classifier(forget_data[0]), dim=1)[0, forget_data[1]].item()
    print(f"Confidence of forget point after unlearning: {conf_after:.4f}")


[Classifier] Epoch 1/10 - Loss: 113.5418
[Classifier] Epoch 2/10 - Loss: 56.1859
[Classifier] Epoch 3/10 - Loss: 46.8765
[Classifier] Epoch 4/10 - Loss: 43.9637
[Classifier] Epoch 5/10 - Loss: 36.4129
[Classifier] Epoch 6/10 - Loss: 31.2952
[Classifier] Epoch 7/10 - Loss: 35.9336
[Classifier] Epoch 8/10 - Loss: 33.2211
[Classifier] Epoch 9/10 - Loss: 29.2238
[Classifier] Epoch 10/10 - Loss: 28.1778
Before Unlearning:
Test Accuracy: 0.9709
Forgetting data index 5, label 2
Confidence of forget point before unlearning: 1.0000
Initial test accuracy: 0.9709
[MaskNet] Epoch 1/30 - Loss: 1.882084 - ForgetConf: 0.848138 - TestAcc: 0.9520
[MaskNet] Epoch 2/30 - Loss: 1.909341 - ForgetConf: 0.817752 - TestAcc: 0.9490
[MaskNet] Epoch 3/30 - Loss: 1.904680 - ForgetConf: 0.820161 - TestAcc: 0.9491
[MaskNet] Epoch 4/30 - Loss: 1.889951 - ForgetConf: 0.828629 - TestAcc: 0.9499
[MaskNet] Epoch 5/30 - Loss: 1.876714 - ForgetConf: 0.836931 - TestAcc: 0.9505
[MaskNet] Epoch 6/30 - Loss: 1.868831 - Forget

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gc
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv
from dataclasses import dataclass
import random

@dataclass
class Config:
    batch_size: int = 128
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    lr_classifier: float = 1e-3
    lr_mask: float = 1e-3
    classifier_epochs: int = 10
    mask_epochs: int = 20
    delta: float = None  # will be set dynamically
    kl_coeff: float = 1e-3
    hidden_channels: int = 32
    overlap_threshold: float = 0.1  # cosine similarity threshold
    max_attempts: int = 50

config = Config()

transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=config.batch_size)

class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class MaskNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return torch.sigmoid(self.conv2(x, edge_index))

def kl_divergence(mask_logits, prior=0.5):
    q = torch.clamp(mask_logits, 1e-6, 1 - 1e-6)
    return (q * torch.log(q / prior) + (1 - q) * torch.log((1 - q) / (1 - prior))).mean()

def get_model_weights(model):
    return torch.cat([param.view(-1) for param in model.parameters()])

def apply_mask_to_weights(model, mask, sizes):
    offset = 0
    with torch.no_grad():
        for param, size in zip(model.parameters(), sizes):
            shape = param.shape
            param.data *= mask[offset:offset + size].view(shape)
            offset += size

def calculate_margin(classifier, x, y):
    classifier.eval()
    with torch.no_grad():
        output = classifier(x)
        logit_y = output[0, y]
        logit_max_others = torch.cat([output[0, :y], output[0, y+1:]]).max()
        margin = (logit_y - logit_max_others).item()
    return margin

def is_unlearning_feasible_margin(classifier, x, y, delta):
    margin = calculate_margin(classifier, x, y)
    print(f"[Feasibility Check] Margin: {margin:.4f}, Threshold: {delta}")
    return margin > delta

def build_graph_from_model(model):
    sizes = [param.numel() for param in model.parameters()]
    total_params = sum(sizes)
    x = torch.ones((total_params, 1)).to(config.device)
    edge_index = []
    idx = 0
    for size in sizes:
        for i in range(size - 1):
            edge_index.append([idx + i, idx + i + 1])
            edge_index.append([idx + i + 1, idx + i])
        idx += size
    edge_index = torch.tensor(edge_index).t().contiguous().to(config.device)
    return x, edge_index, sizes

def train_mask_generator(mask_net, classifier, forget_data, config):
    classifier.eval()
    x_init, edge_index, sizes = build_graph_from_model(classifier)
    optimizer = optim.Adam(mask_net.parameters(), lr=config.lr_mask)
    criterion = nn.CrossEntropyLoss()
    x_forget, y_forget = forget_data

    for epoch in range(config.mask_epochs):
        optimizer.zero_grad()
        mask_logits = mask_net(x_init, edge_index).squeeze()
        temp_model = Classifier().to(config.device)
        temp_model.load_state_dict(classifier.state_dict())
        apply_mask_to_weights(temp_model, mask_logits, sizes)
        output = temp_model(x_forget)
        ce_loss = criterion(output, y_forget.unsqueeze(0))
        conf = F.softmax(output, dim=1)[0, y_forget].item()
        kl = kl_divergence(mask_logits)
        loss = ce_loss + config.kl_coeff * kl
        loss.backward()
        grad_norm = mask_net.conv2.lin.weight.grad.norm().item() if mask_net.conv2.lin.weight.grad is not None else 0
        print(f"[MaskNet] Epoch {epoch+1}/{config.mask_epochs} - Loss: {loss.item():.6f} - Confidence: {conf:.6f} - GradNorm: {grad_norm:.6f}")
        optimizer.step()
        del temp_model
        gc.collect()
    return mask_logits.detach(), sizes

def train_classifier(model, loader, config):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=config.lr_classifier)
    criterion = nn.CrossEntropyLoss()
    for epoch in range(config.classifier_epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(config.device), y.to(config.device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[Classifier] Epoch {epoch+1}, Loss: {total_loss:.4f}")

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(config.device), y.to(config.device)
            output = model(x).argmax(dim=1)
            correct += (output == y).sum().item()
            total += y.size(0)
    return correct / total

def cosine_similarity(a, b):
    return torch.dot(a, b) / (torch.norm(a) * torch.norm(b) + 1e-8)

def main():
    classifier = Classifier().to(config.device)
    train_classifier(classifier, train_loader, config)

    acc_before = evaluate(classifier, test_loader)
    print(f"\nBefore Unlearning:\nAccuracy: {acc_before:.4f}")

    # Calculate margin stats on a subset to set a realistic delta threshold
    margins = []
    num_samples_for_margin = 1000
    print("\nCalculating margin statistics to set delta threshold...")
    for i in range(num_samples_for_margin):
        x_i, y_i = train_set[i]
        x_i = x_i.unsqueeze(0).to(config.device)
        y_i = torch.tensor(y_i).to(config.device)
        margin = calculate_margin(classifier, x_i, y_i)
        margins.append(margin)
    margins_np = np.array(margins)
    mean_margin = margins_np.mean()
    std_margin = margins_np.std()
    suggested_delta = mean_margin - std_margin
    print(f"Margin stats - min: {margins_np.min():.4f}, max: {margins_np.max():.4f}, mean: {mean_margin:.4f}, std: {std_margin:.4f}")
    print(f"Setting feasibility margin threshold (delta) to: {suggested_delta:.4f}")
    config.delta = suggested_delta

    accepted_masks = []

    forget_x, forget_y, forget_idx, forget_mask, sizes = None, None, None, None, None

    for attempt in range(1, config.max_attempts + 1):
        idx = random.randint(0, len(train_set) - 1)
        x_i, y_i = train_set[idx]
        x_i = x_i.unsqueeze(0).to(config.device)
        y_i = torch.tensor(y_i).to(config.device)
        print(f"\nAttempt {attempt}: Testing forget point idx {idx} with label {y_i.item()}")

        if not is_unlearning_feasible_margin(classifier, x_i, y_i, config.delta):
            print("❌ Not feasible by margin, trying another point...")
            continue

        print(f"✅ Forget point idx {idx} is feasible for unlearning.")

        mask_net = MaskNet(1, config.hidden_channels, 1).to(config.device)
        mask_logits, sizes = train_mask_generator(mask_net, classifier, (x_i, y_i), config)

        overlap_ok = True
        for prev_mask in accepted_masks:
            sim = cosine_similarity(mask_logits, prev_mask).item()
            if sim > config.overlap_threshold:
                print(f"❌ Overlap {sim:.4f} exceeds threshold with existing mask, trying next point...")
                overlap_ok = False
                break

        if overlap_ok:
            print(f"✅ Forget point idx {idx} accepted for unlearning after overlap check.")
            accepted_masks.append(mask_logits)
            forget_x, forget_y, forget_idx, forget_mask = x_i, y_i, idx, mask_logits
            break
    else:
        print("❌ No feasible forget point found within max attempts passing overlap checks.")
        return

    output = classifier(forget_x)
    conf_before = F.softmax(output, dim=1)[0, forget_y].item()
    print(f"Confidence of forget point before unlearning: {conf_before:.4f}")

    apply_mask_to_weights(classifier, forget_mask, sizes)

    acc_after = evaluate(classifier, test_loader)
    output = classifier(forget_x)
    conf_after = F.softmax(output, dim=1)[0, forget_y].item()

    print(f"\nAfter Unlearning:\nAccuracy: {acc_after:.4f}")
    print(f"Confidence of forget point after unlearning: {conf_after:.4f}")

if __name__ == "__main__":
    main()
