In [6]:
%pip install -q spuco torch torchvision scikit-learn tqdm matplotlib

In [7]:
import os, json, random, time
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

from spuco.datasets import SpuCoMNIST, SpuriousFeatureDifficulty
from spuco.models import model_factory
from spuco.robust_train import ERM
from spuco.evaluate import Evaluator

from torch.utils.data import DataLoader, Subset
from torch.optim import SGD

# -----------------
# Global config
# -----------------
DEFAULT_SEED = 1234
SCORING_BATCH_SIZE = 128
BASE_EPOCHS = 5            # epochs when keep_ratio=1.0
SCALE_EPOCHS_BY_KEEP = True
EARLY_CLUSTER_WARMUP_EPOCHS = 0.5  # fractional epochs for early cluster warmup

# Grids
DIFFICULTIES = [
    SpuriousFeatureDifficulty.MAGNITUDE_SMALL,
    SpuriousFeatureDifficulty.MAGNITUDE_MEDIUM,
    SpuriousFeatureDifficulty.MAGNITUDE_LARGE,
]
STRENGTHS = [0.9, 0.95, 0.995]
KEEP_RATIOS = [0.1, 0.3, 0.5, 0.7, 0.9]
HEURISTICS = [
    "base_line",
    "random",
    "loss",
    "gradnorm",
    "confident",
    "entropy",
    "forgetting",
    "early_cluster",
    "rmi",
]

LOG_DIR = "logs_spuco_grid"
os.makedirs(LOG_DIR, exist_ok=True)


def set_seeds(seed: int = DEFAULT_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seeds(DEFAULT_SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)



Using device: cuda


In [8]:
# -----------------
# Dataset builder
# -----------------

def build_datasets(difficulty, strength):
    classes = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]
    trainset = SpuCoMNIST(
        root="data/dataset",
        split="train",
        classes=classes,
        spurious_feature_difficulty=difficulty,
        spurious_correlation_strength=strength,
    )
    trainset.initialize()
    testset = SpuCoMNIST(
        root="data/dataset",
        split="test",
        classes=classes,
        spurious_feature_difficulty=difficulty,
        spurious_correlation_strength=strength,
    )
    testset.initialize()
    return trainset, testset



In [9]:
# -----------------
# Heuristic utilities
# -----------------

@torch.no_grad()
def compute_losses(model, dataset, device):
    model.eval()
    losses = []
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)
    crit = torch.nn.CrossEntropyLoss(reduction="none")
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss_vals = crit(logits, y)
        losses.extend(loss_vals.cpu().numpy())
    return np.array(losses)


def compute_gradnorms(model, dataset, device):
    model.eval()
    grad_norms = []
    crit = torch.nn.CrossEntropyLoss()
    for i in range(len(dataset)):
        x, y = dataset[i]
        x = x.unsqueeze(0).to(device)
        y = torch.tensor([y]).to(device)
        model.zero_grad()
        out = model(x)
        loss = crit(out, y)
        loss.backward()
        total_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                total_norm += p.grad.detach().norm().item()
        grad_norms.append(total_norm)
    return np.array(grad_norms)


def compute_confidence(model, dataset, device):
    model.eval()
    conf = []
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)
    softmax = torch.nn.Softmax(dim=1)
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            p = softmax(model(x))
            max_conf = p.max(dim=1).values
            conf.extend(max_conf.cpu().numpy())
    return np.array(conf)


def compute_entropy(model, dataset, device):
    model.eval()
    ent = []
    log_softmax = torch.nn.LogSoftmax(dim=1)
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(device)
            logits = model(x)
            logp = log_softmax(logits)
            p = torch.exp(logp)
            entropy_vals = -(p * logp).sum(dim=1)
            ent.extend(entropy_vals.cpu().numpy())
    return np.array(ent)


def compute_forgetting_events(model, dataset, device, warmup_epochs=8):
    n = len(dataset)
    seen_correct = np.zeros(n, dtype=bool)
    forget_count = np.zeros(n)
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=2)
    crit = torch.nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(), lr=1e-2)
    for epoch in range(warmup_epochs):
        for batch_idx, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)
            model.train()
            opt.zero_grad()
            logits = model(x)
            loss = crit(logits, y)
            loss.backward()
            opt.step()
            pred = logits.argmax(dim=1)
            correct = (pred == y).cpu().numpy()
            start = batch_idx * loader.batch_size
            end = start + len(y)
            global_indices = np.arange(start, end)
            for i, idx in enumerate(global_indices):
                if seen_correct[idx] and not correct[i]:
                    forget_count[idx] += 1
                seen_correct[idx] = correct[i]
    return forget_count


@torch.no_grad()
def compute_representations(model, dataset, device, batch_size=256):
    model.eval()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    reps = []
    for x, _ in loader:
        x = x.to(device)
        f = model.forward(x)
        reps.append(f.cpu())
    reps = torch.cat(reps, dim=0).numpy()
    return reps


from sklearn.cluster import KMeans

def fast_craig_select(reps, budget):
    km = KMeans(n_clusters=budget, init='k-means++', n_init=1, max_iter=1, random_state=DEFAULT_SEED)
    km.fit(reps)
    centers = km.cluster_centers_
    idxs = []
    for c in centers:
        d = np.linalg.norm(reps - c, axis=1)
        idxs.append(np.argmin(d))
    return np.array(idxs, dtype=int)


@torch.no_grad()
def compute_feature_centroids(model, dataset, device):
    model.eval()
    loader = DataLoader(dataset, batch_size=SCORING_BATCH_SIZE, shuffle=False, num_workers=2)
    feats_by_class = {}
    for x, y in loader:
        x = x.to(device)
        f = model.forward(x).cpu().numpy()
        y = y.numpy()
        for fi, yi in zip(f, y):
            feats_by_class.setdefault(yi, []).append(fi)
    centroids = {c: np.mean(np.stack(v), axis=0) for c, v in feats_by_class.items()}
    return centroids


def compute_rmi(model, dataset, device):
    """Representation-Margin Influence heuristic."""
    model.eval()
    centroids = compute_feature_centroids(model, dataset, device)
    rmi_scores = []
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
    crit = torch.nn.CrossEntropyLoss()

    for x, y in loader:
        x = x.to(device)
        y = y.to(device)

        model.zero_grad()
        logits = model(x)
        loss = crit(logits, y)
        loss.backward()

        grad_norm_sq = sum(
            (p.grad.detach().norm() ** 2 for p in model.parameters() if p.grad is not None)
        )
        grad_influence = grad_norm_sq.sqrt().item()

        f = model.forward(x).detach().cpu().numpy()[0]
        y_class = int(y.cpu().item())
        margins = []
        for c, mu in centroids.items():
            if c == y_class:
                continue
            margin = np.linalg.norm(f - mu)
            margins.append(margin)
        margin = min(margins) if margins else 1.0

        rmi_scores.append(grad_influence / margin)

    return np.array(rmi_scores)



In [10]:
# -----------------
# Early Loss Cluster warmup + selection + training/eval
# -----------------

def warmup_model_for_scoring(model, dataset, device,
                             epochs=5,
                             batch_size=SCORING_BATCH_SIZE,
                             lr=1e-2):
    model.train()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    num_steps = int(max(1, epochs * len(loader)))  # len(loader) ~ ceil(N / batch_size)
    opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
    crit = torch.nn.CrossEntropyLoss()
    step = 0
    while step < num_steps:
        for x, y in loader:
            if step >= num_steps:
                break
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            logits = model(x)
            loss = crit(logits, y)
            loss.backward()
            opt.step()
            step += 1
    model.eval()
    return model


def early_loss_cluster_prune(trainset, model, device, keep_ratio):
    warmup_model_for_scoring(model, trainset, device)
    losses = compute_losses(model, trainset, device)
    kmeans = KMeans(n_clusters=2, random_state=DEFAULT_SEED)
    labels = kmeans.fit_predict(losses.reshape(-1, 1))
    cluster_means = [losses[labels == k].mean() for k in range(2)]
    low_loss_cluster = np.argmin(cluster_means)
    high_loss_indices = np.where(labels != low_loss_cluster)[0]
    low_loss_indices = np.where(labels == low_loss_cluster)[0]

    topk = int(len(trainset) * keep_ratio)
    chosen = list(high_loss_indices)
    if len(chosen) < topk:
        need = topk - len(chosen)
        fill = np.random.choice(low_loss_indices, need, replace=False)
        chosen = np.concatenate([chosen, fill])
    else:
        chosen = np.random.choice(chosen, topk, replace=False)

    subset = Subset(trainset, chosen)
    subset = attach_subset_metadata(subset, trainset)
    print(f"[early_cluster] keep_ratio={keep_ratio}, target={topk}, high_loss={len(high_loss_indices)}, low_loss={len(low_loss_indices)}, selected={len(chosen)}")
    return subset


def attach_subset_metadata(subset, original_dataset):
    for attr in ["group_weights", "group_partition", "groups"]:
        if hasattr(original_dataset, attr):
            setattr(subset, attr, getattr(original_dataset, attr))
    return subset


def select_data_heuristic(trainset, model, heuristic, keep_ratio):
    n = len(trainset)

    if heuristic == "random":
        idx = np.random.choice(n, int(n * keep_ratio), replace=False)

    elif heuristic == "loss":
        losses = compute_losses(model, trainset, device)
        topk = int(n * keep_ratio)
        idx = np.argsort(losses)[-topk:]

    elif heuristic == "gradnorm":
        gradnorms = compute_gradnorms(model, trainset, device)
        topk = int(n * keep_ratio)
        idx = np.argsort(gradnorms)[-topk:]

    elif heuristic == "confident":
        conf = compute_confidence(model, trainset, device)
        topk = int(n * keep_ratio)
        idx = np.argsort(conf)[:topk]

    elif heuristic == "entropy":
        ent = compute_entropy(model, trainset, device)
        topk = int(n * keep_ratio)
        idx = np.argsort(ent)[-topk:]

    elif heuristic == "forgetting":
        forgets = compute_forgetting_events(model, trainset, device)
        topk = int(n * keep_ratio)
        idx = np.argsort(forgets)[-topk:]

    elif heuristic == "early_cluster":
        return early_loss_cluster_prune(trainset, model, device, keep_ratio)

    elif heuristic == "rmi":
        scores = compute_rmi(model, trainset, device)
        topk = int(len(trainset) * keep_ratio)
        idx = np.argsort(scores)[-topk:]

    elif heuristic == "craig":
        reps = compute_representations(model, trainset, device)
        budget = int(n * keep_ratio)
        idx = fast_craig_select(reps, budget)

    else:
        raise ValueError(f"Unknown heuristic: {heuristic}")

    subset = Subset(trainset, idx)
    subset = attach_subset_metadata(subset, trainset)
    return subset


def train_and_eval(train_subset, testset, name, keep_ratio):
    epochs = BASE_EPOCHS
    if SCALE_EPOCHS_BY_KEEP and keep_ratio > 0:
        epochs = max(1, int(round(BASE_EPOCHS / keep_ratio)))

    model = model_factory("lenet", train_subset[0][0].shape, train_subset.dataset.num_classes).to(device)
    optimizer = SGD(model.parameters(), lr=1e-2, momentum=0.9, nesterov=True)

    erm = ERM(
        trainset=train_subset,
        model=model,
        num_epochs=epochs,
        batch_size=128,
        optimizer=optimizer,
        device=device,
        verbose=False,
    )
    erm.train()

    evaluator = Evaluator(
        testset=testset,
        group_partition=testset.group_partition,
        group_weights=testset.group_weights,
        batch_size=64,
        model=model,
        device=device,
        verbose=False,
    )
    evaluator.evaluate()

    return evaluator.worst_group_accuracy, evaluator.average_accuracy



In [11]:
# -----------------
# Logging utilities
# -----------------
import csv

def init_csv(log_path):
    header = ["seed", "difficulty", "strength", "keep_ratio", "heuristic",
              "worst_group_acc", "average_acc", "timestamp"]
    if not os.path.exists(log_path):
        with open(log_path, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(header)


def append_csv(log_path, row):
    with open(log_path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(row)


def save_barplot(worst_dict, avg_dict, out_path_prefix):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    axes[0].bar(list(worst_dict.keys()), list(worst_dict.values()), color='skyblue')
    axes[0].set_ylabel("Worst Group Accuracy")
    axes[0].set_title("Worst Group Accuracy Across Heuristics")
    axes[0].tick_params(axis='x', rotation=45)

    axes[1].bar(list(avg_dict.keys()), list(avg_dict.values()), color='salmon')
    axes[1].set_ylabel("Average Accuracy")
    axes[1].set_title("Average Accuracy Across Heuristics")
    axes[1].tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig(out_path_prefix + "_bars.png", dpi=200)
    plt.close(fig)


def save_text_log(text, out_path):
    with open(out_path, "w") as f:
        f.write(text)



In [12]:
# -----------------
# Experiment loop (dry-run ready)
# -----------------

log_csv = os.path.join(LOG_DIR, "results.csv")
init_csv(log_csv)

results = []

for difficulty in DIFFICULTIES:
    for strength in STRENGTHS:
        print(f"\n=== Difficulty={difficulty}, Strength={strength} ===")
        trainset, testset = build_datasets(difficulty, strength)

        # Baseline (full data)
        model_lenet = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
        optimizer = SGD(model_lenet.parameters(), lr=1e-2, momentum=0.9, nesterov=True)
        erm_full = ERM(
            trainset=trainset,
            model=model_lenet,
            num_epochs=BASE_EPOCHS,
            batch_size=128,
            optimizer=optimizer,
            device=device,
            verbose=False,
        )
        erm_full.train()

        evaluator = Evaluator(
            testset=testset,
            group_partition=testset.group_partition,
            group_weights=testset.group_weights,
            batch_size=64,
            model=model_lenet,
            device=device,
            verbose=False,
        )
        evaluator.evaluate()
        baseline_wg = evaluator.worst_group_accuracy
        baseline_avg = evaluator.average_accuracy

        baseline_row = {
            "seed": DEFAULT_SEED,
            "difficulty": str(difficulty),
            "strength": strength,
            "keep_ratio": 1.0,
            "heuristic": "base_line",
            "worst_group_acc": baseline_wg[1] if isinstance(baseline_wg, tuple) else baseline_wg,
            "average_acc": baseline_avg,
        }

        results.append(baseline_row)
        append_csv(log_csv, [DEFAULT_SEED, difficulty, strength, 1.0, "base_line", baseline_row["worst_group_acc"], baseline_avg, time.time()])

        # Precompute scoring for forgetting / early_cluster once per (difficulty, strength)
        forgetting_scores = None
        early_losses = None
        early_kmeans = None
        high_loss_indices = None
        low_loss_indices = None

        # Scoring model for forward-only heuristics: reuse trained baseline
        scoring_model_forward = model_lenet

        # For forgetting: train a small warmup model once, compute forgetting counts once
        forget_model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
        forgetting_scores = compute_forgetting_events(forget_model, trainset, device)

        # For early_cluster: warm up once and cache losses/cluster splits
        early_model = model_factory("lenet", trainset[0][0].shape, trainset.num_classes).to(device)
        warmup_model_for_scoring(early_model, trainset, device)
        early_losses = compute_losses(early_model, trainset, device)
        early_kmeans = KMeans(n_clusters=2, random_state=DEFAULT_SEED)
        labels = early_kmeans.fit_predict(early_losses.reshape(-1, 1))
        cluster_means = [early_losses[labels == k].mean() for k in range(2)]
        low_loss_cluster = np.argmin(cluster_means)
        high_loss_indices = np.where(labels != low_loss_cluster)[0]
        low_loss_indices = np.where(labels == low_loss_cluster)[0]

        def select_early_cached(trainset, keep_ratio):
            topk = int(len(trainset) * keep_ratio)
            chosen = list(high_loss_indices)
            if len(chosen) < topk:
                need = topk - len(chosen)
                fill = np.random.choice(low_loss_indices, need, replace=False)
                chosen = np.concatenate([chosen, fill])
            else:
                chosen = np.random.choice(chosen, topk, replace=False)
            subset = Subset(trainset, chosen)
            subset = attach_subset_metadata(subset, trainset)
            print(f"[early_cluster cached] keep_ratio={keep_ratio}, target={topk}, high_loss={len(high_loss_indices)}, low_loss={len(low_loss_indices)}, selected={len(chosen)}")
            return subset

        def select_forgetting_cached(trainset, keep_ratio):
            topk = int(len(trainset) * keep_ratio)
            idx = np.argsort(forgetting_scores)[-topk:]
            subset = Subset(trainset, idx)
            subset = attach_subset_metadata(subset, trainset)
            return subset

        # For each keep_ratio and heuristic
        for keep_ratio in KEEP_RATIOS:
            print(f"-- keep_ratio={keep_ratio} --")
            for heuristic in HEURISTICS:
                print(f"   Heuristic: {heuristic}")
                if heuristic == "base_line":
                    continue

                if heuristic == "forgetting":
                    subset = select_forgetting_cached(trainset, keep_ratio)
                    scoring_model = scoring_model_forward  # irrelevant; subset already chosen
                elif heuristic == "early_cluster":
                    subset = select_early_cached(trainset, keep_ratio)
                    scoring_model = scoring_model_forward
                else:
                    scoring_model = scoring_model_forward
                    subset = select_data_heuristic(trainset, scoring_model, heuristic, keep_ratio)

                wg_acc, avg_acc = train_and_eval(subset, testset, heuristic, keep_ratio)
                print(f"      Done. Worst group: {wg_acc}, Avg: {avg_acc}")

                row = {
                    "seed": DEFAULT_SEED,
                    "difficulty": str(difficulty),
                    "strength": strength,
                    "keep_ratio": keep_ratio,
                    "heuristic": heuristic,
                    "worst_group_acc": wg_acc[1] if isinstance(wg_acc, tuple) else wg_acc,
                    "average_acc": avg_acc,
                }
                results.append(row)
                append_csv(log_csv, [DEFAULT_SEED, difficulty, strength, keep_ratio, heuristic, row["worst_group_acc"], avg_acc, time.time()])

            # Per-config plots
            worst_dict = {}
            avg_dict = {}
            for r in results:
                if r["difficulty"] == str(difficulty) and r["strength"] == strength and r["keep_ratio"] == keep_ratio:
                    worst_dict[r["heuristic"]] = r["worst_group_acc"]
                    avg_dict[r["heuristic"]] = r["average_acc"]
            out_prefix = os.path.join(LOG_DIR, f"diff_{difficulty}_str_{strength}_kr_{keep_ratio}")
            save_barplot(worst_dict, avg_dict, out_prefix)

# Save a JSON summary
with open(os.path.join(LOG_DIR, "results.json"), "w") as f:
    json.dump(results, f, indent=2)

print("Done. Logs saved to", LOG_DIR)




=== Difficulty=SpuriousFeatureDifficulty.MAGNITUDE_SMALL, Strength=0.9 ===


100%|██████████| 9.91M/9.91M [00:00<00:00, 17.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 494kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.47MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 11.5MB/s]
Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.53it/s]


-- keep_ratio=0.1 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  3.11it/s]

      Done. Worst group: ((4, 2), 92.9471032745592), Avg: 97.08999999999999
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:06<00:00,  3.95it/s]


      Done. Worst group: ((4, 2), 82.36775818639798), Avg: 97.39
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.50it/s]

      Done. Worst group: ((4, 2), 61.20906801007557), Avg: 94.39000000000001
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:06<00:00,  3.85it/s]

      Done. Worst group: ((4, 2), 71.28463476070529), Avg: 96.56000000000002
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.34it/s]

      Done. Worst group: ((1, 2), 89.2156862745098), Avg: 96.76000000000002
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.29it/s]

      Done. Worst group: ((4, 2), 88.66498740554157), Avg: 95.18
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.1, target=4800, high_loss=311, low_loss=47693, selected=4800



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:06<00:00,  3.73it/s]

      Done. Worst group: ((2, 4), 92.24598930481284), Avg: 97.00999999999999
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.45it/s]


      Done. Worst group: ((4, 2), 70.52896725440806), Avg: 94.08999999999997
-- keep_ratio=0.3 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.25it/s]

      Done. Worst group: ((2, 4), 96.2566844919786), Avg: 98.10999999999999
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.16it/s]


      Done. Worst group: ((4, 2), 94.20654911838791), Avg: 98.33
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.55it/s]

      Done. Worst group: ((3, 3), 97.48110831234257), Avg: 98.92
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.53it/s]

      Done. Worst group: ((4, 2), 88.66498740554157), Avg: 98.05000000000001
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:06<00:00,  3.59it/s]

      Done. Worst group: ((2, 4), 91.71122994652407), Avg: 98.27
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.56it/s]

      Done. Worst group: ((4, 2), 89.16876574307305), Avg: 97.66
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.3, target=14401, high_loss=311, low_loss=47693, selected=14401



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.56it/s]

      Done. Worst group: ((3, 1), 90.17632241813602), Avg: 97.17999999999999
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.24it/s]


      Done. Worst group: ((4, 2), 83.1234256926952), Avg: 97.78999999999999
-- keep_ratio=0.5 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.24it/s]

      Done. Worst group: ((4, 4), 96.46464646464646), Avg: 98.21
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.23it/s]


      Done. Worst group: ((4, 4), 96.96969696969697), Avg: 98.68999999999997
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.21it/s]

      Done. Worst group: ((2, 4), 96.524064171123), Avg: 98.75000000000003
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.24it/s]

      Done. Worst group: ((2, 4), 96.524064171123), Avg: 98.53
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.57it/s]

      Done. Worst group: ((4, 2), 88.41309823677582), Avg: 97.97999999999998
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.21it/s]

      Done. Worst group: ((2, 4), 93.58288770053476), Avg: 98.24
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.5, target=24002, high_loss=311, low_loss=47693, selected=24002



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.21it/s]

      Done. Worst group: ((3, 3), 95.21410579345088), Avg: 98.13999999999999
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.39it/s]


      Done. Worst group: ((3, 3), 96.72544080604534), Avg: 98.78999999999999
-- keep_ratio=0.7 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.28it/s]

      Done. Worst group: ((2, 4), 95.45454545454545), Avg: 98.08999999999999
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  3.12it/s]


      Done. Worst group: ((2, 4), 96.79144385026738), Avg: 98.84
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.13it/s]

      Done. Worst group: ((2, 4), 94.6524064171123), Avg: 98.49000000000001
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  3.12it/s]

      Done. Worst group: ((4, 4), 96.96969696969697), Avg: 98.65000000000002
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.23it/s]

      Done. Worst group: ((4, 2), 96.4735516372796), Avg: 98.63
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.31it/s]

      Done. Worst group: ((2, 4), 87.16577540106952), Avg: 97.75000000000001
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.7, target=33602, high_loss=311, low_loss=47693, selected=33602



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  3.12it/s]

      Done. Worst group: ((3, 3), 95.96977329974811), Avg: 98.39999999999998
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.47it/s]


      Done. Worst group: ((2, 4), 95.72192513368984), Avg: 98.54000000000002
-- keep_ratio=0.9 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.15it/s]

      Done. Worst group: ((2, 4), 93.85026737967914), Avg: 98.01
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.15it/s]


      Done. Worst group: ((2, 4), 96.524064171123), Avg: 98.63000000000001
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.16it/s]

      Done. Worst group: ((4, 2), 94.7103274559194), Avg: 98.05
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.20it/s]

      Done. Worst group: ((4, 2), 96.97732997481108), Avg: 98.69000000000003
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.49it/s]

      Done. Worst group: ((2, 4), 81.01604278074866), Avg: 97.29
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.37it/s]

      Done. Worst group: ((4, 4), 95.20202020202021), Avg: 98.36999999999996
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.9, target=43203, high_loss=311, low_loss=47693, selected=43203



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.16it/s]

      Done. Worst group: ((3, 1), 95.71788413098237), Avg: 98.14
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:07<00:00,  3.47it/s]


      Done. Worst group: ((2, 4), 96.2566844919786), Avg: 98.53999999999999

=== Difficulty=SpuriousFeatureDifficulty.MAGNITUDE_SMALL, Strength=0.95 ===


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  2.80it/s]


-- keep_ratio=0.1 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  2.80it/s]

      Done. Worst group: ((2, 1), 94.93333333333334), Avg: 97.11
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.62it/s]


      Done. Worst group: ((2, 4), 52.94117647058823), Avg: 93.91000000000003
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.62it/s]

      Done. Worst group: ((2, 4), 28.609625668449198), Avg: 89.23999999999998
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.60it/s]

      Done. Worst group: ((2, 4), 79.14438502673796), Avg: 97.2
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:08<00:00,  2.82it/s]

      Done. Worst group: ((4, 1), 45.84382871536524), Avg: 90.45000000000003
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.61it/s]

      Done. Worst group: ((2, 3), 86.13333333333334), Avg: 95.17999999999999
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.1, target=4800, high_loss=345, low_loss=47659, selected=4800



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.62it/s]

      Done. Worst group: ((4, 0), 93.70277078085643), Avg: 97.18000000000002
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.59it/s]


      Done. Worst group: ((2, 4), 39.839572192513366), Avg: 93.31
-- keep_ratio=0.3 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.61it/s]

      Done. Worst group: ((4, 0), 94.20654911838791), Avg: 97.82000000000001
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.55it/s]


      Done. Worst group: ((2, 4), 92.51336898395722), Avg: 98.53000000000003
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.58it/s]

      Done. Worst group: ((4, 3), 89.39393939393939), Avg: 96.99999999999999
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.55it/s]

      Done. Worst group: ((2, 4), 69.25133689839572), Avg: 95.84
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.66it/s]

      Done. Worst group: ((4, 0), 93.95465994962217), Avg: 98.24000000000001
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.52it/s]

      Done. Worst group: ((2, 2), 95.73333333333333), Avg: 98.32999999999998
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.3, target=14401, high_loss=345, low_loss=47659, selected=14401



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.48it/s]

      Done. Worst group: ((4, 2), 65.74307304785894), Avg: 95.59999999999998
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.53it/s]


      Done. Worst group: ((2, 4), 69.5187165775401), Avg: 96.49
-- keep_ratio=0.5 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.51it/s]

      Done. Worst group: ((3, 4), 95.71788413098237), Avg: 97.97000000000001
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.46it/s]


      Done. Worst group: ((4, 0), 95.46599496221663), Avg: 98.41000000000001
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.47it/s]

      Done. Worst group: ((2, 4), 88.23529411764706), Avg: 97.12
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s]

      Done. Worst group: ((2, 4), 89.3048128342246), Avg: 97.75000000000001
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.46it/s]

      Done. Worst group: ((2, 4), 77.27272727272727), Avg: 96.77000000000001
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.53it/s]

      Done. Worst group: ((2, 2), 94.66666666666667), Avg: 98.35000000000001
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.5, target=24002, high_loss=345, low_loss=47659, selected=24002



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.57it/s]

      Done. Worst group: ((4, 2), 77.32997481108312), Avg: 96.00999999999998
   Heuristic: rmi



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.47it/s]


      Done. Worst group: ((2, 4), 91.97860962566845), Avg: 97.77999999999999
-- keep_ratio=0.7 --
   Heuristic: base_line
   Heuristic: random


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s]

      Done. Worst group: ((3, 3), 97.22921914357683), Avg: 98.37999999999998
   Heuristic: loss



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:10<00:00,  2.44it/s]


      Done. Worst group: ((0, 4), 91.725768321513), Avg: 97.49999999999999
   Heuristic: gradnorm


Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.54it/s]

      Done. Worst group: ((4, 2), 69.52141057934509), Avg: 96.7
   Heuristic: confident



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.65it/s]

      Done. Worst group: ((2, 4), 89.03743315508021), Avg: 97.67999999999999
   Heuristic: entropy



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.61it/s]

      Done. Worst group: ((4, 3), 97.22222222222223), Avg: 98.78
   Heuristic: forgetting



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.66it/s]

      Done. Worst group: ((2, 1), 95.46666666666667), Avg: 98.36999999999999
   Heuristic: early_cluster
[early_cluster cached] keep_ratio=0.7, target=33602, high_loss=345, low_loss=47659, selected=33602



Evaluating group-wise accuracy: 100%|██████████| 25/25 [00:09<00:00,  2.51it/s]

      Done. Worst group: ((2, 4), 85.8288770053476), Avg: 97.5
   Heuristic: rmi





KeyboardInterrupt: 