In [1]:
# label_corrupt_data.py
import os, random, time, csv, itertools
from collections import OrderedDict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

# ----------------------------
# 1. Model: simple 2-layer GCN
# ----------------------------
class GCN(nn.Module):
    def __init__(self, in_ch, hid_ch, out_ch, dropout=0.5):
        super().__init__()
        self.conv1 = GCNConv(in_ch, hid_ch)
        self.conv2 = GCNConv(hid_ch, out_ch)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x  # raw logits

# ----------------------------
# 2. Utilities
# ----------------------------
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def corrupt_labels(y, num_classes, noise_rate=0.0):
    """
    Corrupts a portion of the labels with random incorrect labels.
    """
    if noise_rate == 0.0:
        return y
    
    y_corrupted = y.clone()
    num_nodes = y.size(0)
    
    # Get a list of indices to be corrupted
    corrupt_indices = torch.bernoulli(torch.full((num_nodes,), noise_rate)).bool()
    
    for i in range(num_nodes):
        if corrupt_indices[i]:
            # Generate a new random label
            new_label = torch.randint(0, num_classes, (1,), device=y.device)
            # Ensure the new label is different from the original
            while new_label == y[i]:
                new_label = torch.randint(0, num_classes, (1,), device=y.device)
            y_corrupted[i] = new_label
            
    return y_corrupted

# ----------------------------
# 3. Train one run with label noise
# ----------------------------
def train_one_run(dataset_name="Cora",
                  hidden=64,
                  dropout=0.5,
                  weight_decay=5e-4,
                  label_smoothing=0.0,
                  noise_rate=0.0,
                  lr=0.01,
                  max_epochs=300,
                  patience=50,
                  seed=0,
                  device=None,
                  verbose=False):

    set_seed(seed)
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    dataset = Planetoid(root=os.path.join("data", dataset_name), name=dataset_name)
    data = dataset[0].to(device)
    
    # Corrupt labels only for the training set
    y_train_corrupted = corrupt_labels(data.y[data.train_mask], dataset.num_classes, noise_rate=noise_rate)

    model = GCN(dataset.num_node_features, hidden, dataset.num_classes, dropout=dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    best_val = -1.0
    best_test = 0.0
    best_epoch = -1
    epochs_no_improve = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], y_train_corrupted)
        loss.backward()
        optimizer.step()

        # evaluate with original labels
        model.eval()
        with torch.no_grad():
            out_eval = model(data.x, data.edge_index)
            pred = out_eval.argmax(dim=1)
            val_acc = (pred[data.val_mask] == data.y[data.val_mask]).float().mean().item()
            test_acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean().item()

        if val_acc > best_val:
            best_val = val_acc
            best_test = test_acc
            best_epoch = epoch
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if verbose and epoch % 50 == 0:
            print(f"Seed {seed} | epoch {epoch} | loss {loss.item():.4f} | "
                  f"val {val_acc:.4f} | test {test_acc:.4f}")

        if epochs_no_improve >= patience:
            break

    return {"best_val": best_val, "test_at_best": best_test, "best_epoch": best_epoch}

# ----------------------------
# 4. Grid search
# ----------------------------
def run_grid_search(
    seeds=(0,1,2),
    dropout_list=(0.5, 0.7, 0.9),
    label_smooth_list=(0.1, 0.2),
    weight_decay_list=(5e-4, 5e-3),
    hidden_list=(64,),
    noise_rate_list=(0.0, 0.1, 0.2),
    dataset_name="Cora",
    max_epochs=300,
    patience=50,
    out_csv="label_corruption_results.csv",
    device=None,
    verbose=False
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    combos = list(itertools.product(dropout_list, label_smooth_list, weight_decay_list, hidden_list, noise_rate_list))
    print(f"Running grid: {len(combos)} configs × {len(seeds)} seeds = {len(combos)*len(seeds)} runs")

    # CSV header
    header = ["dropout", "label_smoothing", "weight_decay", "hidden", "noise_rate",
              "seed", "best_val", "test_at_best", "best_epoch"]
    with open(out_csv, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header)

    summary = OrderedDict()
    for (dropout, lab_smooth, wd, hid, noise_rate) in combos:
        cfg_key = (dropout, lab_smooth, wd, hid, noise_rate)
        summary[cfg_key] = []

        for seed in seeds:
            res = train_one_run(
                dataset_name=dataset_name,
                hidden=hid,
                dropout=dropout,
                weight_decay=wd,
                label_smoothing=lab_smooth,
                noise_rate=noise_rate,
                lr=0.01,
                max_epochs=max_epochs,
                patience=patience,
                seed=seed,
                device=device,
                verbose=verbose
            )
            with open(out_csv, "a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([dropout, lab_smooth, wd, hid, noise_rate, seed,
                                 f"{res['best_val']:.6f}", f"{res['test_at_best']:.6f}",
                                 res["best_epoch"]])
            summary[cfg_key].append(res["test_at_best"])

        arr = np.array(summary[cfg_key])
        mean, std = arr.mean(), arr.std()
        print(f"cfg d={dropout}, s={lab_smooth}, wd={wd}, h={hid}, noise_rate={noise_rate} -> mean test {mean:.4f} ± {std:.4f}")

    rows = []
    for k, vals in summary.items():
        dropout, lab_smooth, wd, hid, noise_rate = k
        arr = np.array(vals)
        rows.append({
            "dropout": dropout,
            "label_smoothing": lab_smooth,
            "weight_decay": wd,
            "hidden": hid,
            "noise_rate": noise_rate,
            "mean_test": arr.mean(),
            "std_test": arr.std(),
            "n_runs": len(arr)
        })
    df = pd.DataFrame(rows)
    df = df.sort_values("mean_test", ascending=False).reset_index(drop=True)
    df.to_csv("label_corruption_summary.csv", index=False)
    print("\nSaved detailed runs to", out_csv)
    print("Saved summary to label_corruption_summary.csv")
    return df

# ----------------------------
# 5. Main entry
# ----------------------------
if __name__ == "__main__":
    seeds = (0, 1, 2)
    dropout_list = (0.5, 0.7, 0.9)
    label_smooth_list = (0.1, 0.2)
    weight_decay_list = (5e-4, 5e-3)
    hidden_list = (64,)
    noise_rate_list = (0.0, 0.1, 0.2, 0.3)

    df_summary = run_grid_search(
        seeds=seeds,
        dropout_list=dropout_list,
        label_smooth_list=label_smooth_list,
        weight_decay_list=weight_decay_list,
        hidden_list=hidden_list,
        noise_rate_list=noise_rate_list,
        dataset_name="Cora",
        max_epochs=300,
        patience=50,
        out_csv="label_corruption_runs.csv",
        verbose=False,
    )
    
    print("\n=== Top configs (by mean test acc) ===")
    print

  from .autonotebook import tqdm as notebook_tqdm


Running grid: 48 configs × 3 seeds = 144 runs
cfg d=0.5, s=0.1, wd=0.0005, h=64, noise_rate=0.0 -> mean test 0.8073 ± 0.0058
cfg d=0.5, s=0.1, wd=0.0005, h=64, noise_rate=0.1 -> mean test 0.7913 ± 0.0017
cfg d=0.5, s=0.1, wd=0.0005, h=64, noise_rate=0.2 -> mean test 0.7723 ± 0.0133
cfg d=0.5, s=0.1, wd=0.0005, h=64, noise_rate=0.3 -> mean test 0.7067 ± 0.0196
cfg d=0.5, s=0.1, wd=0.005, h=64, noise_rate=0.0 -> mean test 0.8207 ± 0.0034
cfg d=0.5, s=0.1, wd=0.005, h=64, noise_rate=0.1 -> mean test 0.8027 ± 0.0041
cfg d=0.5, s=0.1, wd=0.005, h=64, noise_rate=0.2 -> mean test 0.7800 ± 0.0120
cfg d=0.5, s=0.1, wd=0.005, h=64, noise_rate=0.3 -> mean test 0.7190 ± 0.0248
cfg d=0.5, s=0.2, wd=0.0005, h=64, noise_rate=0.0 -> mean test 0.8107 ± 0.0021
cfg d=0.5, s=0.2, wd=0.0005, h=64, noise_rate=0.1 -> mean test 0.7947 ± 0.0012
cfg d=0.5, s=0.2, wd=0.0005, h=64, noise_rate=0.2 -> mean test 0.7707 ± 0.0116
cfg d=0.5, s=0.2, wd=0.0005, h=64, noise_rate=0.3 -> mean test 0.7130 ± 0.0139
cfg d=0.5,