In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torch_geometric.data import Data as GeoData
from torch_geometric.nn import GCNConv
import gc
from dataclasses import dataclass
from typing import List, Tuple, Dict

@dataclass
class MaskingConfig:
    epochs_soft: int = 5
    epochs_hard: int = 5
    learning_rate: float = 1e-3
    delta: int = 150   # Increased delta
    kl_weight: float = 0.001
    confidence_threshold: float = 0.01
    max_k: int = 300   # Increased max_k for iterative masking
    k_increment: int = 25
    max_grad_norm: float = 1.0
    batch_limit: int = 10

FEATURE_DIM = 28 * 28
HIDDEN_DIM = 300
HIDDEN2_DIM = 100
OUTPUT_DIM = 10
TARGET_LAYERS = ['fc2', 'fc3']

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(FEATURE_DIM, HIDDEN_DIM)
        self.fc2 = nn.Linear(HIDDEN_DIM, HIDDEN2_DIM)
        self.fc3 = nn.Linear(HIDDEN2_DIM, OUTPUT_DIM)

    def forward(self, x):
        x = x.view(-1, FEATURE_DIM)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class GCNMaskGenerator(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        logits = self.conv2(x, edge_index).squeeze()
        return torch.sigmoid(logits)

def get_target_parameters(model: nn.Module, target_layers: List[str] = None) -> List[Tuple[str, nn.Parameter]]:
    if target_layers is None:
        target_layers = TARGET_LAYERS
    return [
        (name, param) for name, param in model.named_parameters()
        if any(layer in name for layer in target_layers) and 'weight' in name
    ]

def construct_graph_last_layers(model: nn.Module, input_data: torch.Tensor, target: torch.Tensor) -> Tuple[GeoData, Dict[str,int]]:
    model.zero_grad()
    output = model(input_data)
    loss = F.cross_entropy(output, target)
    loss.backward()

    params, grads, activ_in, activ_out, names, idx_map = [], [], [], [], [], {}
    idx = 0

    target_params = get_target_parameters(model, TARGET_LAYERS)

    # Flatten all parameters and collect their grads and activations
    for name, param in target_params:
        if param.grad is not None:
            shape = param.shape
            param_data = param.detach().flatten().cpu()
            grad_data = param.grad.detach().flatten().cpu()

            # Use in/out feature dims per weight shape: For fc layers (out_features, in_features)
            in_feat_val = shape[1] if len(shape) > 1 else 1
            out_feat_val = shape[0]

            in_feat = torch.full((param_data.shape[0], 1), float(in_feat_val))
            out_feat = torch.full((param_data.shape[0], 1), float(out_feat_val))

            for i in range(param_data.shape[0]):
                params.append(param_data[i].unsqueeze(0))
                grads.append(grad_data[i].unsqueeze(0))
                activ_in.append(in_feat[i])
                activ_out.append(out_feat[i])
                names.append(name)
                idx_map[f"{name}_{i}"] = idx
                idx += 1

    x = torch.cat([
        torch.stack(params), 
        torch.stack(grads), 
        torch.stack(activ_in), 
        torch.stack(activ_out)
    ], dim=1)

    # Build edges:
    edge_index = []
    n = len(names)

    # Connect nodes from same param tensors and also across param tensors within last layers
    # Connect nodes if they share layer prefix or in the same layer
    for i in range(n):
        for j in range(i+1, min(i+10, n)):  # connect up to 10 neighbors for better info flow
            # connect if in same param tensor (same layer name)
            if names[i].split(".")[0] == names[j].split(".")[0]:
                edge_index.extend([[i, j], [j, i]])
            else:
                # additionally connect fc2 and fc3 param nodes with edges, cross-layer edges
                # Since both layers are connected in the model, allow cross connections between fc2.weight and fc3.weight nodes
                prefix_i = names[i].split(".")[0]
                prefix_j = names[j].split(".")[0]
                if (prefix_i in ['fc2', 'fc3']) and (prefix_j in ['fc2', 'fc3']):
                    edge_index.extend([[i, j], [j, i]])

    if len(edge_index) == 0:
        for i in range(n - 1):
            edge_index.extend([[i, i+1], [i+1, i]])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    model.zero_grad()
    return GeoData(x=x, edge_index=edge_index), idx_map

def compute_kl_divergence(q_probs: torch.Tensor, prior_probs: torch.Tensor) -> torch.Tensor:
    eps = 1e-8
    q_probs = q_probs / (q_probs.sum() + eps)
    prior_probs = prior_probs / (prior_probs.sum() + eps)
    q_probs = torch.clamp(q_probs, eps, 1.0)
    prior_probs = torch.clamp(prior_probs, eps, 1.0)
    return (q_probs * (q_probs.log() - prior_probs.log())).sum()

def get_prior_distribution(model: nn.Module) -> torch.Tensor:
    weights = []
    target_params = get_target_parameters(model, TARGET_LAYERS)
    for _, param in target_params:
        weights.append(param.detach().flatten().abs().cpu())
    full = torch.cat(weights)
    return full / (full.sum() + 1e-8)

def apply_mask_to_model(model: nn.Module, mask: torch.Tensor, idx_map: Dict[str, int]) -> None:
    with torch.no_grad():
        param_idx = 0
        target_params = get_target_parameters(model, TARGET_LAYERS)
        for name, param in target_params:
            param_flat = param.detach().clone().view(-1)
            param_size = param_flat.shape[0]
            if param_idx + param_size <= len(mask):
                mask_slice = mask[param_idx:param_idx + param_size]
                masked_param = param_flat * (1 - mask_slice.to(param.device))
                param.data.copy_(masked_param.view(param.shape))
            param_idx += param_size

def safe_construct_graph(model: nn.Module, input_data: torch.Tensor, target: torch.Tensor, max_retries=3):
    for attempt in range(max_retries):
        try:
            return construct_graph_last_layers(model, input_data, target)
        except RuntimeError as e:
            if "out of memory" in str(e).lower() and attempt < max_retries - 1:
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                continue
            raise e

def train_mask_generator(mask_net: GCNMaskGenerator, classifier: MNISTClassifier, dataset, device, prior_probs, config: MaskingConfig, soft_mask=False):
    mask_net.train()
    optimizer = optim.Adam(mask_net.parameters(), lr=config.learning_rate)
    temp_model = MNISTClassifier().to(device)

    epochs = config.epochs_soft if soft_mask else config.epochs_hard
    tag = "[Soft Mask]" if soft_mask else "[Mask-net]"

    for epoch in range(epochs):
        print(f"{tag} Epoch {epoch + 1}/{epochs}")
        total_loss = 0
        num_batches = min(config.batch_limit, len(dataset))

        for i in range(num_batches):
            try:
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                gc.collect()

                x, y = dataset[i]
                x = x.unsqueeze(0).to(device)
                y = torch.tensor([y], device=device)

                graph, _ = safe_construct_graph(classifier, x, y)
                graph = graph.to(device)

                q_probs = mask_net(graph)

                if soft_mask:
                    # Soft mask uses q_probs directly as mask
                    mask = q_probs
                else:
                    # Hard mask: top-k mask based on delta
                    k = min(config.delta, len(q_probs))
                    topk = torch.topk(q_probs, k).indices
                    mask = torch.zeros_like(q_probs)
                    mask[topk] = 1

                with torch.no_grad():
                    temp_model.load_state_dict(classifier.state_dict())
                apply_mask_to_model(temp_model, mask, {})

                output = temp_model(x)
                loss_pred = F.cross_entropy(output, y)
                kl = compute_kl_divergence(q_probs, prior_probs.to(device))
                loss = loss_pred + config.kl_weight * kl

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(mask_net.parameters(), max_norm=config.max_grad_norm)
                optimizer.step()

                total_loss += loss.item()

                del x, y, graph, q_probs, mask, output, loss_pred, kl, loss

            except Exception as e:
                print(f"Error in batch {i}: {e}")
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
                gc.collect()
                continue

        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print(f"{tag} Epoch {epoch + 1} avg-loss {avg_loss:.4f}")

        if device.type == 'cuda':
            torch.cuda.empty_cache()
        gc.collect()

def evaluate(model: nn.Module, loader: DataLoader, title: str = "Eval") -> float:
    model.eval()
    correct, total = 0, 0
    device = next(model.parameters()).device
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    accuracy = correct / total if total > 0 else 0
    print(f"{title}: Accuracy = {accuracy:.4f}")
    return accuracy

def main():
    config = MaskingConfig()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if device.type == 'cuda':
        torch.cuda.empty_cache()

    transform = transforms.ToTensor()
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    target_idx = 0
    target_x, target_y = dataset[target_idx]
    retain_indices = [i for i in range(len(dataset)) if i != target_idx]
    retain_subset = Subset(dataset, retain_indices[:5000])
    retain_loader = DataLoader(retain_subset, batch_size=128, shuffle=True)

    # Train classifier longer
    print("Training classifier...")
    classifier = MNISTClassifier().to(device)
    optimizer = optim.Adam(classifier.parameters(), lr=0.001)
    classifier.train()
    for epoch in range(7):  # Increased epochs
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(retain_loader):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = classifier(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(retain_loader)
        print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")

    print("\nEvaluating before unlearning...")
    evaluate(classifier, retain_loader, title="Before Unlearning")

    prior_probs = get_prior_distribution(classifier)
    print(f"Prior distribution size: {len(prior_probs)}")

    print("\nTraining soft mask generator...")
    mask_net = GCNMaskGenerator(in_channels=4).to(device)
    train_mask_generator(mask_net, classifier, dataset, device, prior_probs, config, soft_mask=True)

    print("\nTraining hard mask generator...")
    train_mask_generator(mask_net, classifier, dataset, device, prior_probs, config, soft_mask=False)

    print("\nApplying final mask...")
    mask_net.eval()
    classifier.eval()

    try:
        target_x_batch = target_x.unsqueeze(0).to(device)
        target_y_batch = torch.tensor([target_y], device=device)

        classifier.train()
        data_graph, idx_map = safe_construct_graph(classifier, target_x_batch, target_y_batch)
        classifier.eval()

        q_probs = mask_net(data_graph.to(device))
        print(f"Generated {len(q_probs)} probability scores")

        k = config.delta
        while k <= config.max_k:
            topk_idx = torch.topk(q_probs, k).indices
            final_mask = torch.zeros_like(q_probs)
            final_mask[topk_idx] = 1
            apply_mask_to_model(classifier, final_mask, idx_map)

            with torch.no_grad():
                output = classifier(target_x_batch)
                conf = F.softmax(output, dim=1)[0, target_y].item()

            print(f"k={k}, confidence={conf:.4f}")
            if conf < config.confidence_threshold:
                break
            k += config.k_increment

        print("\nFinal Results:")
        evaluate(classifier, retain_loader, title="After Unlearning")

        with torch.no_grad():
            output = classifier(target_x_batch)
            pred = output.argmax(dim=1).item()
            conf = F.softmax(output, dim=1)[0, target_y].item()
        print(f"Target sample - Prediction: {pred}, Confidence: {conf:.4f}, Original: {target_y}")

        del data_graph, q_probs, final_mask, output
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error during final masking: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()


Using device: cuda
Training classifier...
Epoch 1: Avg Loss = 1.2495
Epoch 2: Avg Loss = 0.3917
Epoch 3: Avg Loss = 0.2951
Epoch 4: Avg Loss = 0.2439
Epoch 5: Avg Loss = 0.2250
Epoch 6: Avg Loss = 0.1757
Epoch 7: Avg Loss = 0.1446

Evaluating before unlearning...
Before Unlearning: Accuracy = 0.9698
Prior distribution size: 31000

Training soft mask generator...
[Soft Mask] Epoch 1/5
[Soft Mask] Epoch 1 avg-loss 0.0529
[Soft Mask] Epoch 2/5
[Soft Mask] Epoch 2 avg-loss 0.0529
[Soft Mask] Epoch 3/5
[Soft Mask] Epoch 3 avg-loss 0.0529
[Soft Mask] Epoch 4/5
[Soft Mask] Epoch 4 avg-loss 0.0529
[Soft Mask] Epoch 5/5
[Soft Mask] Epoch 5 avg-loss 0.0529

Training hard mask generator...
[Mask-net] Epoch 1/5
[Mask-net] Epoch 1 avg-loss 1.4178
[Mask-net] Epoch 2/5
[Mask-net] Epoch 2 avg-loss 1.4176
[Mask-net] Epoch 3/5
[Mask-net] Epoch 3 avg-loss 1.4178
[Mask-net] Epoch 4/5
[Mask-net] Epoch 4 avg-loss 1.4173
[Mask-net] Epoch 5/5
[Mask-net] Epoch 5 avg-loss 1.4178

Applying final mask...
Generate