<a href="https://colab.research.google.com/github/taweener11/darkSideUnmasked/blob/main/5_28_intersectional_fixed_demogpairs_dataset_extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title setting up reproducible pipeline that uses our shared colab folder

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
!unzip -q "/content/drive/My Drive/Datasets/demogpairs/DemogPairs.zip" -d "/content/demogpairs/"

In [7]:
#@title defining path


import os


ROOT_DIR = '/content'
os.makedirs(ROOT_DIR, exist_ok=True)
DEMOGPAIRS_FOLDER = os.path.join(ROOT_DIR, 'demogpairs')

In [8]:
def read_metadata_file(filepath, gender_label, race_label):
    """
    Read a DemogPairs metadata txt file and collect image paths with labels.

    Args:
      filepath (str): path to the metadata txt file
      gender_label (int): 0 for female, 1 for male
      race_label (str): string label for race, e.g. 'black', 'white', 'asian'

    Returns:
      List of tuples: (image_relative_path, gender_label, race_label)
    """
    samples = []
    with open(filepath, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if not line or line.lower().startswith('db_code'):
                continue
            parts = line.split()
            if len(parts) < 2:
                continue
            # parts[1] is the image path relative to DemogPairs folder
            img_path = parts[1]
            samples.append((img_path, gender_label, race_label))
    return samples


In [12]:
import os

metadata_dir = '/content/demogpairs/Metadata'  # adjust path

# map filenames to gender and race labels
metadata_info = {
    'Black_Females.txt': (0, 'black'),
    'Black_Males.txt': (1, 'black'),
    'White_Females.txt': (0, 'white'),
    'White_Males.txt': (1, 'white'),
    'Asian_Females.txt': (0, 'asian'),
    'Asian_Males.txt': (1, 'asian')
}

all_samples = []

for fname, (gender, race) in metadata_info.items():
    full_path = os.path.join(metadata_dir, fname)
    print(f"Reading {full_path} ...")
    samples = read_metadata_file(full_path, gender, race)
    all_samples.extend(samples)

print(f"Total samples loaded: {len(all_samples)}")

Reading /content/demogpairs/Metadata/Black_Females.txt ...
Reading /content/demogpairs/Metadata/Black_Males.txt ...
Reading /content/demogpairs/Metadata/White_Females.txt ...
Reading /content/demogpairs/Metadata/White_Males.txt ...
Reading /content/demogpairs/Metadata/Asian_Females.txt ...
Reading /content/demogpairs/Metadata/Asian_Males.txt ...
Total samples loaded: 10800


In [11]:
demogpairs_root = '/content/demogpairs/DemogPairs'

# build final dataset list with full paths
dataset = []
for rel_path, gender, race in all_samples:
    img_full_path = os.path.join(demogpairs_root, rel_path)
    if os.path.isfile(img_full_path):
        dataset.append((img_full_path, gender, race))
    else:
        print(f"Missing file: {img_full_path}")

print(f"Final dataset size after filtering missing files: {len(dataset)}")

Final dataset size after filtering missing files: 10800


In [None]:
import numpy as np

gender_labels = np.array([s[1] for s in dataset])
race_labels = np.array([s[2] for s in dataset])

unique_races = np.unique(race_labels)
balanced_indices = []

rng = np.random.default_rng(seed=42)

# Find minimal count per race-gender group
min_count = min(
    np.sum((gender_labels == gender) & (race_labels == race))
    for race in unique_races
    for gender in [0,1]
)
print(f"Balancing all race-gender groups to {min_count} samples each")

for race in unique_races:
    for gender in [0,1]:
        group_indices = np.where((race_labels == race) & (gender_labels == gender))[0]
        rng.shuffle(group_indices)
        balanced_indices.extend(group_indices[:min_count])

balanced_dataset = [dataset[i] for i in balanced_indices]
print(f"Balanced dataset size (race+gender): {len(balanced_dataset)}")

Balancing all race-gender groups to 1800 samples each
Balanced dataset size (race+gender): 10800


In [None]:
import numpy as np

# Example: training indices subset (can be full train, or a subset)
train_indices = np.arange(len(dataset))  # or your train split indices

# Extract gender and race array for those indices
gender_labels = np.array([dataset[i][1] for i in train_indices])  # 0=female, 1=male
race_labels = np.array([dataset[i][2] for i in train_indices])    # e.g. 'black', 'white', 'asian'

unique_races = np.unique(race_labels)
print(f"Unique races in training data: {unique_races}")

# Group indices by race and gender
indices_by_race_gender = {}
for race in unique_races:
    for gender in [0, 1]:
        mask = (race_labels == race) & (gender_labels == gender)
        group_indices = train_indices[mask]
        indices_by_race_gender[(race, gender)] = group_indices
        print(f"Count for {race} {'female' if gender == 0 else 'male'}: {len(group_indices)}")

# Find minimal count (for balanced sampling across all race+gender groups)
min_count = min(len(idxs) for idxs in indices_by_race_gender.values())
print(f"Balancing all race-gender groups to {min_count} samples each")

rng = np.random.default_rng(seed=42)

train_subsets_f = {}
train_subsets_m = {}
train_subsets = {}

subset_sizes = [100, 500, 1000]  # example subset sizes total (must be divisible by number of groups)

# Number of race-gender groups (race groups * 2 genders)
num_groups = len(unique_races) * 2

for size in subset_sizes:
    # Adjust size to nearest lower multiple of num_groups
    if size % num_groups != 0:
        adjusted_size = (size // num_groups) * num_groups
        print(f"Adjusting subset size {size} -> {adjusted_size} for divisibility by {num_groups}")
        size = adjusted_size

    per_group_n = size // num_groups
    balanced_indices = []

    for (race, gender), group_indices in indices_by_race_gender.items():
        if len(group_indices) < per_group_n:
            raise ValueError(f"Group {race} {gender} has fewer samples ({len(group_indices)}) than requested {per_group_n}")
        shuffled = np.copy(group_indices)
        rng.shuffle(shuffled)
        balanced_indices.extend(shuffled[:per_group_n])

        if gender == 0:
            train_subsets_f.setdefault(size, []).extend(shuffled[:per_group_n])
        else:
            train_subsets_m.setdefault(size, []).extend(shuffled[:per_group_n])

    balanced_indices = np.array(balanced_indices)
    rng.shuffle(balanced_indices)

    train_subsets[size] = balanced_indices

print(f"Females: {len(train_subsets_f[size])}, Males: {len(train_subsets_m[size])}")

Unique races in training data: ['asian' 'black' 'white']
Count for asian female: 1800
Count for asian male: 1800
Count for black female: 1800
Count for black male: 1800
Count for white female: 1800
Count for white male: 1800
Balancing all race-gender groups to 1800 samples each
Adjusting subset size 100 -> 96 for divisibility by 6
Adjusting subset size 500 -> 498 for divisibility by 6
Adjusting subset size 1000 -> 996 for divisibility by 6
Females: 498, Males: 498


In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

# Assuming dataset: list of (img_path, gender, race)
dataset_gender_labels = np.array([s[1] for s in dataset])
dataset_race_labels = np.array([s[2] for s in dataset])

full_indices = np.arange(len(dataset))

# Stratify on gender for train/test split (change as needed)
train_indices, test_indices = train_test_split(
    full_indices, test_size=0.2, random_state=42, stratify=dataset_gender_labels
)
print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")

proportions = [0.25, 0.5, 0.75]
unique_races = np.unique(dataset_race_labels)
rng = np.random.default_rng(seed=42)

# ========== Group 1: Gender partitions (all races combined) ==========
gender_train_subsets = {}
gender_test_subsets_f = {}
gender_test_subsets_m = {}

gender_train_labels = dataset_gender_labels[train_indices]
gender_test_labels = dataset_gender_labels[test_indices]

for p in proportions:
    female_idx_train = np.where(gender_train_labels == 0)[0]
    male_idx_train = np.where(gender_train_labels == 1)[0]
    N_train = min(len(female_idx_train), len(male_idx_train))

    rng.shuffle(female_idx_train)
    rng.shuffle(male_idx_train)

    num_female_train = int(N_train * p)
    num_male_train = N_train - num_female_train

    female_chosen = train_indices[female_idx_train[:num_female_train]]
    male_chosen = train_indices[male_idx_train[:num_male_train]]

    combined_train = np.concatenate([female_chosen, male_chosen])
    rng.shuffle(combined_train)

    gender_train_subsets[p] = Subset(dataset, combined_train)

    # Balanced test set: half female, half male
    female_idx_test = np.where(gender_test_labels == 0)[0]
    male_idx_test = np.where(gender_test_labels == 1)[0]
    N_test = min(len(female_idx_test), len(male_idx_test))

    half_test = N_test // 2
    female_test_chosen = test_indices[female_idx_test[:half_test]]
    male_test_chosen = test_indices[male_idx_test[:half_test]]

    gender_test_subsets_f[p] = Subset(dataset, female_test_chosen)
    gender_test_subsets_m[p] = Subset(dataset, male_test_chosen)

# ========== Group 2: Race partitions (all genders combined) ==========
race_train_subsets = {}
race_test_subsets = {}

test_race_indices = {
    race: test_indices[dataset_race_labels[test_indices] == race] for race in unique_races
}

for race in unique_races:
    race_mask_train = (dataset_race_labels[train_indices] == race)
    race_train_indices = train_indices[race_mask_train]
    nonrace_mask_train = ~race_mask_train
    nonrace_train_indices = train_indices[nonrace_mask_train]

    total_race = len(race_train_indices)
    total_nonrace = len(nonrace_train_indices)

    if total_race == 0 or total_nonrace == 0:
        print(f"Warning: insufficient samples for race {race}, skipping.")
        continue

    for p in proportions:
        num_race_samples = int(total_race * p)
        # Number of non-race samples to keep proportional (approximate)
        num_nonrace_samples = (
            min(total_nonrace, int(num_race_samples * (1 - p) / p)) if p > 0 else total_nonrace
        )

        rng.shuffle(race_train_indices)
        rng.shuffle(nonrace_train_indices)

        chosen_race = race_train_indices[:num_race_samples]
        chosen_nonrace = nonrace_train_indices[:num_nonrace_samples]

        combined = np.concatenate([chosen_race, chosen_nonrace])
        rng.shuffle(combined)

        race_train_subsets[(race, p)] = Subset(dataset, combined)
        race_test_subsets[race] = Subset(dataset, test_race_indices[race])

# ========== Group 3: Intersectional partitions (race + gender) ==========
intersection_train_subsets = {}
intersection_test_subsets = {}

for p in proportions:
    train_indices_per_group = []
    test_indices_per_group = []

    for race in unique_races:
        for gender in [0, 1]:  # 0: female, 1: male
            # Mask for train intersection group
            train_mask = (dataset_race_labels[train_indices] == race) & \
                         (dataset_gender_labels[train_indices] == gender)
            test_mask = (dataset_race_labels[test_indices] == race) & \
                        (dataset_gender_labels[test_indices] == gender)

            train_group_indices = train_indices[train_mask]
            test_group_indices = test_indices[test_mask]

            n_train = len(train_group_indices)
            n_test = len(test_group_indices)

            if n_train == 0 or n_test == 0:
                print(f"Skipping group race={race} gender={gender} due to insufficient samples "
                      f"(train: {n_train}, test: {n_test})")
                continue

            # Number of samples to pick scaled by p (at least 1 if non-empty)
            num_train_samples = max(1, int(n_train * p))
            num_test_samples = max(1, int(n_test * p))

            rng.shuffle(train_group_indices)
            rng.shuffle(test_group_indices)

            chosen_train = train_group_indices[:num_train_samples]
            chosen_test = test_group_indices[:num_test_samples]

            train_indices_per_group.append(chosen_train)
            test_indices_per_group.append(chosen_test)

    if len(train_indices_per_group) == 0 or len(test_indices_per_group) == 0:
        print(f"No valid intersectional groups found for proportion {p}, skipping.")
        continue

    combined_train_indices = np.concatenate(train_indices_per_group)
    combined_test_indices = np.concatenate(test_indices_per_group)
    rng.shuffle(combined_train_indices)
    rng.shuffle(combined_test_indices)

    intersection_train_subsets[p] = Subset(dataset, combined_train_indices)
    intersection_test_subsets[p] = Subset(dataset, combined_test_indices)

# ========== Summary Printing ==========
print("\n=== Gender partitions (all races combined) ===")
for p in proportions:
    train_subset = gender_train_subsets[p]
    train_genders = [dataset[i][1] for i in train_subset.indices]
    total_train = len(train_genders)
    pct_female = np.mean(np.array(train_genders) == 0) if total_train > 0 else 0
    print(f"Prop {int(p*100)}: Train samples = {total_train}, Female ratio = {pct_female*100:.2f}%")
    print(f"  Test Females: {len(gender_test_subsets_f[p].indices)}")
    print(f"  Test Males: {len(gender_test_subsets_m[p].indices)}")

print("\n=== Race partitions (all genders combined) ===")
for (race, p), subset in sorted(race_train_subsets.items()):
    train_size = len(subset.indices)
    print(f"Race {race} Proportion {int(p*100)}: Train samples = {train_size}")
for race in unique_races:
    test_sub = race_test_subsets.get(race, None)
    if test_sub:
        print(f"Race {race} Test samples: {len(test_sub.indices)}")

print("\n=== Intersectional partitions (race + gender) ===")
for p, subset in intersection_train_subsets.items():
    train_genders = [dataset[i][1] for i in subset.indices]
    train_races = [dataset[i][2] for i in subset.indices]
    total_train = len(train_genders)
    pct_female = np.mean(np.array(train_genders) == 0) if total_train > 0 else 0
    print(f"Prop {int(p*100)}: Train samples = {total_train}, Female ratio = {pct_female*100:.2f}%")
    for race in unique_races:
        for gender in [0, 1]:
            count = sum(
                (np.array(train_races) == race) & (np.array(train_genders) == gender)
            )
            gender_str = "Female" if gender == 0 else "Male"
            print(f"  Race {race} {gender_str}: {count} samples")
    test_subset = intersection_test_subsets.get(p, None)
    if test_subset:
        test_genders = [dataset[i][1] for i in test_subset.indices]
        test_races = [dataset[i][2] for i in test_subset.indices]
        total_test = len(test_genders)
        pct_female_test = np.mean(np.array(test_genders) == 0) if total_test > 0 else 0
        print(f"  Test samples = {total_test}, Female ratio = {pct_female_test * 100:.2f}%")

Train samples: 8640, Test samples: 2160

=== Gender partitions (all races combined) ===
Prop 25: Train samples = 4320, Female ratio = 25.00%
  Test Females: 540
  Test Males: 540
Prop 50: Train samples = 4320, Female ratio = 50.00%
  Test Females: 540
  Test Males: 540
Prop 75: Train samples = 4320, Female ratio = 75.00%
  Test Females: 540
  Test Males: 540

=== Race partitions (all genders combined) ===
Race asian Proportion 25: Train samples = 2876
Race asian Proportion 50: Train samples = 2876
Race asian Proportion 75: Train samples = 2876
Race black Proportion 25: Train samples = 2912
Race black Proportion 50: Train samples = 2912
Race black Proportion 75: Train samples = 2912
Race white Proportion 25: Train samples = 2848
Race white Proportion 50: Train samples = 2850
Race white Proportion 75: Train samples = 2850
Race asian Test samples: 724
Race black Test samples: 687
Race white Test samples: 749

=== Intersectional partitions (race + gender) ===
Prop 25: Train samples = 2158,

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, DataLoader

# ----------------------------- #
# Assume dataset is a list of (img_path, gender, race)
dataset_gender_labels = np.array([s[1] for s in dataset])
dataset_race_labels = np.array([s[2] for s in dataset])

# Stratified train/test split by gender (to preserve gender distribution)
full_indices = np.arange(len(dataset))
train_indices, test_indices = train_test_split(
    full_indices,
    test_size=0.2,
    random_state=42,
    stratify=dataset_gender_labels
)
print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")

proportions = [0.25, 0.5, 0.75]  # proportions to create subsets on
unique_races = np.unique(dataset_race_labels)
rng = np.random.default_rng(seed=42)

# --------------------------------- #
# Group 1: Gender partitions (all races combined)
gender_train_subsets = {}
gender_test_subsets_f = {}
gender_test_subsets_m = {}

gender_train_labels = dataset_gender_labels[train_indices]
gender_test_labels = dataset_gender_labels[test_indices]

for p in proportions:
    female_idx_train = np.where(gender_train_labels == 0)[0]
    male_idx_train = np.where(gender_train_labels == 1)[0]
    N_train = min(len(female_idx_train), len(male_idx_train))

    rng.shuffle(female_idx_train)
    rng.shuffle(male_idx_train)

    num_female_train = int(N_train * p)
    num_male_train = N_train - num_female_train

    chosen_female_train_abs = train_indices[female_idx_train[:num_female_train]]
    chosen_male_train_abs = train_indices[male_idx_train[:num_male_train]]

    combined_train = np.concatenate([chosen_female_train_abs, chosen_male_train_abs])
    rng.shuffle(combined_train)

    gender_train_subsets[p] = Subset(dataset, combined_train)

    # Balanced test sets (50% female, 50% male) overall
    female_idx_test = np.where(gender_test_labels == 0)[0]
    male_idx_test = np.where(gender_test_labels == 1)[0]
    N_test = min(len(female_idx_test), len(male_idx_test))

    rng.shuffle(female_idx_test)
    rng.shuffle(male_idx_test)

    half_test = N_test // 2
    chosen_female_test = test_indices[female_idx_test[:half_test]]
    chosen_male_test = test_indices[male_idx_test[:half_test]]

    gender_test_subsets_f[p] = Subset(dataset, chosen_female_test)
    gender_test_subsets_m[p] = Subset(dataset, chosen_male_test)

# --------------------------------- #
# Group 2: Race partitions (all genders; varying race proportions)
race_train_subsets = {}
race_test_subsets = {}

# Precompute test indices by race for reporting + subset creation
test_race_indices = {
    race: test_indices[dataset_race_labels[test_indices] == race] for race in unique_races
}

# Balanced test set for all races combined (same as gender balanced test set)
female_test_idx = np.where(gender_test_labels == 0)[0]
male_test_idx = np.where(gender_test_labels == 1)[0]
N_test = min(len(female_test_idx), len(male_test_idx))
rng.shuffle(female_test_idx)
rng.shuffle(male_test_idx)
half_test = N_test // 2
global_test_females = test_indices[female_test_idx[:half_test]]
global_test_males = test_indices[male_test_idx[:half_test]]
common_test_subset_f = Subset(dataset, global_test_females)
common_test_subset_m = Subset(dataset, global_test_males)

for race in unique_races:
    race_mask_train = (dataset_race_labels[train_indices] == race)
    race_train_indices = train_indices[race_mask_train]
    nonrace_train_indices = train_indices[~race_mask_train]

    total_race = len(race_train_indices)
    total_nonrace = len(nonrace_train_indices)

    for p in proportions:
        num_race_samples = int(total_race * p)
        num_nonrace_samples = (
            min(total_nonrace, int(num_race_samples * (1 - p) / p)) if p > 0 else total_nonrace
        )

        if num_race_samples == 0 or num_nonrace_samples == 0:
            print(f"Warning: insufficient data for race {race} with proportion {p}, skipping.")
            continue

        rng.shuffle(race_train_indices)
        rng.shuffle(nonrace_train_indices)

        chosen_race = race_train_indices[:num_race_samples]
        chosen_nonrace = nonrace_train_indices[:num_nonrace_samples]

        combined = np.concatenate([chosen_race, chosen_nonrace])
        rng.shuffle(combined)

        race_train_subsets[(race, p)] = Subset(dataset, combined)
        # Store test subsets keyed by race only (tests do not depend on p)
        race_test_subsets[race] = Subset(dataset, test_race_indices[race])

# --------------------------------- #
# Group 3: Intersectional partitions (race + gender combined)
intersection_train_subsets = {}
intersection_test_subsets_f = {}
intersection_test_subsets_m = {}

for race in unique_races:
    race_mask_train = (dataset_race_labels[train_indices] == race)
    race_train_indices = train_indices[race_mask_train]
    gender_train_race = dataset_gender_labels[race_train_indices]

    race_mask_test = (dataset_race_labels[test_indices] == race)
    race_test_indices = test_indices[race_mask_test]
    gender_test_race = dataset_gender_labels[race_test_indices]

    female_train_idx = np.where(gender_train_race == 0)[0]
    male_train_idx = np.where(gender_train_race == 1)[0]
    female_test_idx = np.where(gender_test_race == 0)[0]
    male_test_idx = np.where(gender_test_race == 1)[0]

    N_train = min(len(female_train_idx), len(male_train_idx))
    N_test = min(len(female_test_idx), len(male_test_idx))

    if N_train == 0 or N_test == 0:
        print(f"Warning: insufficient data for race {race}, skipping intersectional subsets.")
        continue

    for p in proportions:
        rng.shuffle(female_train_idx)
        rng.shuffle(male_train_idx)
        rng.shuffle(female_test_idx)
        rng.shuffle(male_test_idx)

        num_female_train = int(N_train * p)
        num_male_train = N_train - num_female_train

        chosen_female_train_abs = race_train_indices[female_train_idx[:num_female_train]]
        chosen_male_train_abs = race_train_indices[male_train_idx[:num_male_train]]

        combined_train = np.concatenate([chosen_female_train_abs, chosen_male_train_abs])
        rng.shuffle(combined_train)

        # Unlike train, keep test balanced 50/50 regardless of p because test is fixed
        half_test = N_test // 2
        chosen_female_test_abs = race_test_indices[female_test_idx[:half_test]]
        chosen_male_test_abs = race_test_indices[male_test_idx[:half_test]]

        # Save per (race, p) key to avoid overwriting:
        intersection_train_subsets[(race, p)] = Subset(dataset, combined_train)
        intersection_test_subsets_f[(race, p)] = Subset(dataset, chosen_female_test_abs)
        intersection_test_subsets_m[(race, p)] = Subset(dataset, chosen_male_test_abs)

# --------------- Helpers for DataLoaders ----------------

def get_train_loader(group, key, batch_size=64):
    if group == 'gender':
        dataset_subset = gender_train_subsets[key]
    elif group == 'race':
        dataset_subset = race_train_subsets[key]
    elif group == 'intersectional':
        dataset_subset = intersection_train_subsets[key]
    else:
        raise ValueError(f"Unknown group '{group}'")
    return DataLoader(dataset_subset, batch_size=batch_size, shuffle=True)


def get_test_loaders(group, key, batch_size=64):
    if group == 'gender':
        return (
            DataLoader(gender_test_subsets_f[key], batch_size=batch_size, shuffle=False),
            DataLoader(gender_test_subsets_m[key], batch_size=batch_size, shuffle=False),
        )
    elif group == 'race':
        # key can be race or (race, p); test subsets only saved by race, so get first element if tuple
        race = key if not isinstance(key, tuple) else key[0]
        return (DataLoader(race_test_subsets[race], batch_size=batch_size, shuffle=False), None)
    elif group == 'intersectional':
        # Intersectional test subsets saved with keys (race, p)
        return (
            DataLoader(intersection_test_subsets_f[key], batch_size=batch_size, shuffle=False),
            DataLoader(intersection_test_subsets_m[key], batch_size=batch_size, shuffle=False),
        )
    else:
        raise ValueError(f"Unknown group '{group}'")

# --------------- Example usage ----------------

batch_size = 64
# Gender group loader example (50% females)
train_loader_gender = get_train_loader('gender', 0.5, batch_size)
# Race group loader example (50% proportion of 'black')
train_loader_race = get_train_loader('race', ('black', 0.5), batch_size)
# Intersectional group loader example ('white', 75% females)
train_loader_intersection = get_train_loader('intersectional', ('white', 0.75), batch_size)

print(f"Gender group train loader batches: {len(train_loader_gender)}")
print(f"Race group train loader batches: {len(train_loader_race)}")
print(f"Intersectional group train loader batches: {len(train_loader_intersection)}")

Train samples: 8640, Test samples: 2160
Gender group train loader batches: 68
Race group train loader batches: 46
Intersectional group train loader batches: 22


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
num_epochs = 5

def evaluate(model, arc_head, loader, device):
    model.eval()
    arc_head.eval()
    running_corrects = 0
    running_loss = 0
    total_samples = 0
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.squeeze(1).to(device)
            features = model(images)
            logits = arc_head(features, labels)
            loss = criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)
            running_corrects += (preds == labels).sum().item()
            running_loss += loss.item() * images.size(0)
            total_samples += images.size(0)
    avg_loss = running_loss / total_samples if total_samples > 0 else 0
    avg_acc = running_corrects / total_samples if total_samples > 0 else 0
    return avg_loss, avg_acc

def train_and_eval_group(name, train_subsets, test_subsets_f, test_subsets_m, proportions,
                         backbone, arc_head, batch_size=64, epochs=5):
    if all(isinstance(k, tuple) and len(k) == 3 for k in train_subsets.keys()):
        # Intersectional groups keyed by (race, gender, p)
        for (race, gender, p) in sorted(train_subsets.keys()):
            gender_str = "Female" if gender == 0 else "Male"
            print(f"\n== Group: {name} - {race} {gender_str} - Proportion {p} ==")
            model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
            arc = arc_head.to(device)
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(list(model.parameters()) + list(arc.parameters()), lr=1e-3)
            best_acc = 0

            train_loader = DataLoader(train_subsets[(race, gender, p)], batch_size=batch_size, shuffle=True)
            val_loader = DataLoader(test_subsets_f[(race, gender, p)], batch_size=batch_size, shuffle=False)

            for epoch in range(epochs):
                model.train()
                arc.train()
                total_loss = total_correct = total_samples = 0
                for images, labels in train_loader:
                    images = images.to(device)
                    labels = labels.squeeze(1).to(device)
                    optimizer.zero_grad()
                    features = model(images)
                    logits = arc(features, labels)
                    loss = criterion(logits, labels)
                    loss.backward()
                    optimizer.step()
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == labels).sum().item()
                    total_loss += loss.item() * images.size(0)
                    total_samples += images.size(0)
                train_loss = total_loss / total_samples if total_samples > 0 else 0
                train_acc = total_correct / total_samples if total_samples > 0 else 0
                print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}%")

                val_loss, val_acc = evaluate(model, arc, val_loader, device)
                print(f"Val Loss={val_loss:.4f}, Val Acc={val_acc*100:.2f}%")

                if val_acc > best_acc:
                    best_acc = val_acc
                    print(f"New best val acc at epoch {epoch+1}: {best_acc*100:.2f}%")
        print(f"== Completed training group {name} ==\n")

    else:
        # Non-intersectional groups, keyed by proportion or (race,p)
        for p in proportions:
            print(f"\n== Group: {name} - Proportion {p} ==")
            model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
            arc = arc_head.to(device)
            criterion = nn.CrossEntropyLoss()
            optimizer = torch.optim.Adam(list(model.parameters()) + list(arc.parameters()), lr=1e-3)
            best_acc = 0

            train_loader = DataLoader(train_subsets[p], batch_size=batch_size, shuffle=True)
            val_loader_f = DataLoader(test_subsets_f[p], batch_size=batch_size, shuffle=False)
            val_loader_m = DataLoader(test_subsets_m[p], batch_size=batch_size, shuffle=False)

            for epoch in range(epochs):
                model.train()
                arc.train()
                total_loss = total_correct = total_samples = 0
                for images, labels in train_loader:
                    images = images.to(device)
                    labels = labels.squeeze(1).to(device)
                    optimizer.zero_grad()
                    features = model(images)
                    logits = arc(features, labels)
                    loss = criterion(logits, labels)
                    loss.backward()
                    optimizer.step()
                    preds = torch.argmax(logits, dim=1)
                    total_correct += (preds == labels).sum().item()
                    total_loss += loss.item() * images.size(0)
                    total_samples += images.size(0)
                train_loss = total_loss / total_samples if total_samples > 0 else 0
                train_acc = total_correct / total_samples if total_samples > 0 else 0
                print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc*100:.2f}%")

                val_loss_f, val_acc_f = evaluate(model, arc, val_loader_f, device)
                val_loss_m, val_acc_m = evaluate(model, arc, val_loader_m, device)
                avg_val_acc = (val_acc_f + val_acc_m) / 2
                print(f"Val Acc Female: {val_acc_f*100:.2f}%, Male: {val_acc_m*100:.2f}%, Avg: {avg_val_acc*100:.2f}%")

                if avg_val_acc > best_acc:
                    best_acc = avg_val_acc
                    print(f"New best avg val acc at epoch {epoch+1}: {best_acc*100:.2f}%")
        print(f"== Completed training group {name} ==\n")

# Dummy backbone and ArcHead as before
class DummyArcHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(16, 10)
    def forward(self, x, y):
        return self.fc(x)

backbone = nn.Sequential(
    nn.Conv2d(3, 16, 3, stride=2, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
)

arc_head = DummyArcHead()

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self, size):
        self.size = size
        self.data = torch.randn(size, 3, 64, 64)
        self.labels = torch.randint(0, 10, (size, 1))
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class DummySubset(Dataset):
    def __init__(self, size):
        self.data = DummyDataset(size)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

proportions = [0.25, 0.5, 0.75]
unique_races = ['asian', 'black', 'white']

# Dummy subsets for gender level (all races combined)
gender_train_subsets = {p: DummySubset(80) for p in proportions}
gender_test_subsets_f = {p: DummySubset(10) for p in proportions}
gender_test_subsets_m = {p: DummySubset(10) for p in proportions}

# Dummy subsets for race level (all genders combined)
race_train_subsets = {(race, p): DummySubset(80) for race in unique_races for p in proportions}
race_test_subsets_f = {p: DummySubset(10) for p in proportions}  # reuse female/male for simplicity below
race_test_subsets_m = {p: DummySubset(10) for p in proportions}

# Dummy subsets for intersectional groups (each race × gender × proportion)
intersection_train_subsets = {}
intersection_test_subsets_f = {}
intersection_test_subsets_m = {}
for race in unique_races:
    for gender in [0, 1]:
        for p in proportions:
            intersection_train_subsets[(race, gender, p)] = DummySubset(80)
            intersection_test_subsets_f[(race, gender, p)] = DummySubset(10) if gender == 0 else DummySubset(0)
            intersection_test_subsets_m[(race, gender, p)] = DummySubset(10) if gender == 1 else DummySubset(0)

print("=== Starting training on gender partitions ===")
train_and_eval_group("Gender", gender_train_subsets, gender_test_subsets_f, gender_test_subsets_m,
                     proportions, backbone, arc_head, batch_size=batch_size, epochs=num_epochs)

print("=== Starting training on race partitions ===")
for race in unique_races:
    test_f_filtered = {p: race_test_subsets_f[p] for p in proportions}
    test_m_filtered = {p: race_test_subsets_m[p] for p in proportions}
    train_filt = {p: race_train_subsets[(race, p)] for p in proportions}
    train_and_eval_group(f"Race-{race}", train_filt, test_f_filtered, test_m_filtered,
                         proportions, backbone, arc_head, batch_size=batch_size, epochs=num_epochs)

print("=== Starting training on intersectional partitions ===")
train_and_eval_group("Intersectional", intersection_train_subsets,
                     intersection_test_subsets_f, intersection_test_subsets_m,
                     proportions, backbone, arc_head, batch_size=batch_size, epochs=num_epochs)

print("All training runs complete.")

=== Starting training on gender partitions ===

== Group: Gender - Proportion 0.25 ==
Epoch 1: Train Loss=2.2851, Train Acc=16.25%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
New best avg val acc at epoch 1: 10.00%
Epoch 2: Train Loss=2.2834, Train Acc=16.25%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
Epoch 3: Train Loss=2.2820, Train Acc=16.25%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
Epoch 4: Train Loss=2.2812, Train Acc=16.25%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
Epoch 5: Train Loss=2.2800, Train Acc=16.25%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%

== Group: Gender - Proportion 0.5 ==
Epoch 1: Train Loss=2.2935, Train Acc=13.75%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
New best avg val acc at epoch 1: 10.00%
Epoch 2: Train Loss=2.2915, Train Acc=13.75%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
Epoch 3: Train Loss=2.2899, Train Acc=13.75%
Val Acc Female: 10.00%, Male: 10.00%, Avg: 10.00%
Epoch 4: Train Loss=2.2884, Train Ac

In [None]:
#@title shell pipeline for unzipping! this needs to run every time

!unzip -q "/content/drive/My Drive/Datasets/celeba/img_align_celeba.zip" -d "/content/celeba/"

In [None]:
data_dir = '/content' # setting it to the local environment

In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
# defining a transform that is smaller per suggestion of rasmus

image_size = 64

transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                          std=[0.5, 0.5, 0.5])
])

In [None]:
# transfering files from gdrive to here so that they would work without us uploading manually all the time
# import module
import shutil

shutil.copyfile('/content/drive/My Drive/Datasets/celeba/identity_CelebA.txt', '/content/celeba/identity_CelebA.txt')
shutil.copyfile('/content/drive/My Drive/Datasets/celeba/list_attr_celeba.txt', '/content/celeba/list_attr_celeba.txt')
shutil.copyfile('/content/drive/My Drive/Datasets/celeba/list_bbox_celeba.txt', '/content/celeba/list_bbox_celeba.txt')
shutil.copyfile('/content/drive/My Drive/Datasets/celeba/list_landmarks_align_celeba.txt', '/content/celeba/list_landmarks_align_celeba.txt')
shutil.copyfile('/content/drive/My Drive/Datasets/celeba/list_eval_partition.txt', '/content/celeba/list_eval_partition.txt')



'/content/celeba/list_eval_partition.txt'

In [None]:
from torchvision.datasets import CelebA


# it creates a folder on the go!

try:
    dataset = CelebA(
        root='/content',
        split='train',
        target_type='attr',
        transform=transform,
        download=False # this works now!!!! its just important that it is in the root folder
    )
except Exception as e:
    print("CelebA error:", e)

In [None]:
#@title sanity check

import os

data_dir = '/content/celeba'

print("Root contents:", os.listdir(data_dir))
print("Images folder exists:", os.path.isdir(os.path.join(data_dir, 'img_align_celeba')))
print("Sample images:", os.listdir(os.path.join(data_dir, 'img_align_celeba'))[:3])
print("Has attribute file:", os.path.isfile(os.path.join(data_dir, 'list_attr_celeba.txt')))

Root contents: ['identity_CelebA.txt', 'list_bbox_celeba.txt', 'list_attr_celeba.txt', 'list_eval_partition.txt', 'img_align_celeba', 'list_landmarks_align_celeba.txt']
Images folder exists: True
Sample images: ['085474.jpg', '129511.jpg', '100524.jpg']
Has attribute file: True


In [None]:
#@title sanity check 2 & the moment of truth!!

# adding a dataloader and a basic model

from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
#@title adjusting the training data, different distributions
import numpy as np
from torch.utils.data import Subset


test_dataset = CelebA(
    root='/content',
    split='test',
    target_type='attr',
    transform=transform,
    download=False # i set it to true in case there is some secret metadata?? it is looking for
)


In [None]:
import pandas as pd
# Get the identity information from the training dataset
identity_labels = dataset.identity
# Convert to a pandas Series for easier counting
identity_series = pd.Series(identity_labels.squeeze().numpy())
identity_counts = identity_series.value_counts()
top_1000_identities = identity_counts.nlargest(1000)
# Get the indices corresponding to the top 1000 identities
top_1000_indices = identity_series[identity_series.isin(top_1000_identities.index)].index
# Create a subset of the dataset containing only the top 1000 identities
dataset_top_1000 = Subset(dataset, top_1000_indices)


min_samples = top_1000_identities.min()
max_samples = top_1000_identities.max()

print(f"Minimum samples per identity: {min_samples}")
print(f"Maximum samples per identity: {max_samples}")


Minimum samples per identity: 30
Maximum samples per identity: 35


In [None]:
import numpy as np # Make sure numpy is imported if it hasn't been already

male_idx = test_dataset.attr_names.index('Male')

gender_labels_test_subset = []
for i in top_1000_indices:
  # Note: As discussed before, using training indices on the test dataset
  # might lead to issues or misalignment. Assuming this is intended for now.
  if i < len(test_dataset):
    gender_labels_test_subset.append(test_dataset.attr[i, male_idx])


# Convert the list to a NumPy array
gender_labels_test_subset_np = np.array(gender_labels_test_subset)


# Now use np.where on the NumPy array
# This is the part that fixes the DeprecationWarning
female_test_subset_indices = np.where(gender_labels_test_subset_np == 0)[0]
male_test_subset_indices   = np.where(gender_labels_test_subset_np ==  1)[0]


print(len(female_test_subset_indices))
print(len(male_test_subset_indices))


N_test = min(len(female_test_subset_indices), len(male_test_subset_indices))

rng_test = np.random.default_rng(seed=42)
shuffled_female_test_subset_indices = np.copy(female_test_subset_indices)
shuffled_male_test_subset_indices   = np.copy(male_test_subset_indices)
rng_test.shuffle(shuffled_female_test_subset_indices)
rng_test.shuffle(shuffled_male_test_subset_indices)


test_subsets = {}

# Create training subsets
test_subsets_f = {}
test_subsets_m = {}
# even split for all examples. we can change this later but we want to be able to generalize... we want there to be the same number of examples for men and women and for these to be in the same set...
# we will put this to the loop.







2300
1510


In [None]:
import numpy as np
from torch.utils.data import Subset


# choose smallest n
# proportions = [0, 0.1, 0.25, 0.5, 0.75, 1.0] # changed this bc it doesn't make sense
proportions = [0.25, 0.5, 0.75]
male_idx = test_dataset.attr_names.index('Male')

# You need to create subsets from the test_dataset using test_dataset-specific indices
# The previous code was creating subsets of the training dataset.
# It seems like you want to create training subsets with varying gender proportions
# and test subsets for evaluation (separated by gender).

# For the training subsets (assuming you still want to use indices from the training dataset,
# but with the identity filtering from before):
# You will need to re-calculate the gender labels for the *training* dataset based on top_1000_indices.
male_idx_train = dataset.attr_names.index('Male')
gender_labels_train_subset = dataset.attr[top_1000_indices, male_idx_train] # Use gender from training dataset
female_train_subset_indices = np.where(gender_labels_train_subset == 0)[0]
male_train_subset_indices   = np.where(gender_labels_train_subset ==  1)[0]

N_train = min(len(female_train_subset_indices), len(male_train_subset_indices))

rng_train = np.random.default_rng(seed=42)
shuffled_female_train_subset_indices = np.copy(female_train_subset_indices)
shuffled_male_train_subset_indices   = np.copy(male_train_subset_indices)
rng_train.shuffle(shuffled_female_train_subset_indices)
rng_train.shuffle(shuffled_male_train_subset_indices)


# Create training subsets
train_subsets = {}
for p in proportions:
    num_females_train = int(N_train * p)
    num_males_train = N_train - num_females_train

    q = min(p, 1-p)
    num_females_test = int(N_test * q) # even split for testing
    num_males_test = num_females_test

    chosen_female_train = shuffled_female_train_subset_indices[:num_females_train] if num_females_train > 0 else np.array([], dtype=int)
    chosen_male_train   = shuffled_male_train_subset_indices[:num_males_train]   if num_males_train > 0   else np.array([], dtype=int)

    chosen_female_test = shuffled_female_test_subset_indices[:num_females_test]
    chosen_male_test   = shuffled_male_test_subset_indices[:num_males_test]

    # These indices are relative to the 'dataset_top_1000' subset,
    # so you need to map them back to the original 'dataset' indices if Subset requires it.
    # Since top_1000_indices is the mapping, we can directly use that:
    original_indices_train = np.concatenate([
        top_1000_indices[chosen_female_train],
        top_1000_indices[chosen_male_train]
    ]).astype(int)
    rng_train.shuffle(original_indices_train)
    train_subsets[p] = Subset(dataset, original_indices_train)
    test_subsets_f[p] = Subset(test_dataset, chosen_female_test)
    test_subsets_m[p] = Subset(test_dataset, chosen_male_test)



# Verification as before
for p in proportions:
    # Verification for the training subset
    indices_train = train_subsets[p].indices
    # Need to get genders for these original training indices from the *full* training dataset
    genders_train = dataset.attr[indices_train, male_idx_train]
    percent_female_train = (genders_train == 0).sum()/len(indices_train) if len(indices_train) > 0 else 0
    print(f"Train Subset (Prop {int(p*100)}%): Target {int(p*100)}% -- Actual {percent_female_train*100:.2f}% females, {(genders_train == 0).sum()} samples")


    number_female_test = len(test_subsets_f[p].indices)
    number_male_test = len(test_subsets_m[p].indices)
    print(f"Number of female test samples: {number_female_test}")
    print(f"Number of male test samples: {number_male_test}")




Train Subset (Prop 25%): Target 25% -- Actual 24.99% females, 2480 samples
Number of female test samples: 377
Number of male test samples: 377
Train Subset (Prop 50%): Target 50% -- Actual 50.00% females, 4961 samples
Number of female test samples: 755
Number of male test samples: 755
Train Subset (Prop 75%): Target 75% -- Actual 74.99% females, 7441 samples
Number of female test samples: 377
Number of male test samples: 377


In [None]:
# creating dataloaders
from torch.utils.data import DataLoader

batch_size = 64

train_loader = DataLoader(train_subsets[0.5], batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(test_subsets[0.5], batch_size=batch_size, shuffle=True)


In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
#title this is a more complicated model but used more commonly in FR


class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5, easy_margin=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.easy_margin = easy_margin
        self.cos_m = torch.cos(torch.tensor(self.m))
        self.sin_m = torch.sin(torch.tensor(self.m))
        self.th = torch.cos(torch.tensor(3.14159265 - self.m))
        self.mm = torch.sin(torch.tensor(3.14159265 - self.m)) * self.m

    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.clamp(cosine ** 2, 0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output


In [None]:
def evaluate(model, arc_head, dataloader, device):
    model.eval()
    arc_head.eval()
    total, correct, running_loss = 0, 0, 0.0
    for i, (images, identity_labels) in enumerate(dataloader):
        images, labels = images.to(device), identity_labels[:,0].to(device) # Selecting the first column of identity_labels
        features = model(images)
        logits = arc_head(features, labels) # Using the modified labels for arc_head
        loss = criterion(logits, labels)
        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
    avg_loss = running_loss / total
    acc = correct / total
    print(f"Test set: loss={avg_loss:.4f}, accuracy={acc*100:.2f}, data loader{dataloader}%")
    return avg_loss, acc



In [None]:
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
backbone = models.resnet18(weights=None)
feature_dim = backbone.fc.in_features
backbone.fc = nn.Identity()

n_classes=(dataset.identity.unique())
print(n_classes)
# arc_head = ArcMarginProduct(feature_dim, out_features=n_classes).to(device)
arc_head = ArcMarginProduct(feature_dim, 10177).to(device) # Tracy update



tensor([    1,     2,     3,  ..., 10175, 10176, 10177])


In [None]:
#@title sanity check for the eval code
import torch.nn as nn
import torch.nn.functional as F

model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)



# for epoch in range(1):
#   avg_loss, acc = evaluate(model, arc_head, val_loader, device)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)
best_acc = 0.0


for proportion in proportions:
    model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)
    best_acc = 0.0
    train_loader = DataLoader(train_subsets[proportion], batch_size=batch_size, shuffle=True)

    # Initialize val_loader_f and val_loader_m to None
    val_loader_f = None
    val_loader_m = None

    # Check if the subsets have any samples before creating DataLoaders
    if len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=True)
    if len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=True)

    for epoch in range(50):
        model.train()
        arc_head.train()
        total, correct, running_loss = 0, 0, 0.0
        for i, (images, identity_labels) in enumerate(train_loader):
            images, labels = images.to(device), identity_labels[:,0].to(device) # Selecting the first column of identity_labels
            features = model(images)
            logits = arc_head(features, labels) # Using the modified labels for arc_head
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
            if (i+1) % 50 == 0: print(f"Batch {i+1}/{len(train_loader)} - Loss {loss.item():.4f}")
        print(f"Epoch {epoch+1}: Loss={running_loss/total:.4f}  Accuracy={correct/total*100:.2f}%")

        # Check if the validation loaders are not empty before evaluating

        if val_loader_f:
            val_loss_f, val_acc_f = evaluate(model, arc_head, val_loader_f, device)
        if val_loader_m:
            val_loss_m, val_acc_m = evaluate(model, arc_head, val_loader_m, device)

        if val_loader_m is not None and val_loader_f is not None:
          val_acc = (val_acc_f + val_acc_m) / 2
        elif val_loader_m:
          val_acc = val_acc_m
        elif val_loader_f:
          val_acc = val_acc_f
        else:
          val_acc = 0.0

        if val_acc > best_acc:
          best_acc = val_acc
          torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'arc_head_state_dict': arc_head.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'val_acc': val_acc,
          }, f'model{proportion}_checkpoint.pth')

          if val_loader_m and val_loader_f:
            print(f"New best model saved at epoch {epoch+1} with average acc {val_acc_f*100:.2f}, female acc {val_acc_f*100:.2f}, male acc {val_acc_m*100:.2f}%")
          elif val_loader_m:
            print(f"New best model saved at epoch {epoch+1} with male acc {val_acc_m*100:.2f}%")
          elif val_loader_f:
            print(f"New best model saved at epoch {epoch+1} with female acc {val_acc_f*100:.2f}%")

Batch 50/156 - Loss 3.8559
Batch 100/156 - Loss 2.8617
Batch 150/156 - Loss 3.2255
Epoch 1: Loss=2.9338  Accuracy=71.79%
Test set: loss=1.5582, accuracy=88.59%
Test set: loss=1.5568, accuracy=89.12%
New best model saved at epoch 1 with average acc 88.59, female acc 88.59, male acc 89.12%
Batch 50/156 - Loss 3.2611
Batch 100/156 - Loss 3.1904
Batch 150/156 - Loss 4.8022
Epoch 2: Loss=2.9162  Accuracy=72.49%
Test set: loss=1.6898, accuracy=86.47%
Test set: loss=1.6631, accuracy=88.06%
Batch 50/156 - Loss 2.4079
Batch 100/156 - Loss 1.6515
Batch 150/156 - Loss 2.6151
Epoch 3: Loss=2.7828  Accuracy=75.84%
Test set: loss=1.3345, accuracy=89.39%
Test set: loss=1.4649, accuracy=89.66%
New best model saved at epoch 3 with average acc 89.39, female acc 89.39, male acc 89.66%
Batch 50/156 - Loss 3.0508
Batch 100/156 - Loss 3.8725
Batch 150/156 - Loss 3.3005
Epoch 4: Loss=2.6610  Accuracy=77.30%
Test set: loss=1.6540, accuracy=87.53%
Test set: loss=1.6887, accuracy=87.53%
Batch 50/156 - Loss 2.27

In [None]:
#@title putting in the utils here for easier dev
from torch.autograd import Variable
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader


def eval_robust(model, test_loader, pgd_attack, device):
    model.eval()
    robust_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            adv = pgd_attack(inputs, targets)
            outputs = model(adv)
            robust_loss += F.cross_entropy(outputs, targets).item()
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(targets.view_as(pred)).sum().item()
    robust_loss /= len(test_loader.dataset)

    print('LinfPGD Attack: Average loss: {:.4f}, Robust Accuracy: {}/{} ({:.0f}%)'.format(
        robust_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    robust_accuracy = 100. * correct / len(test_loader.dataset)
    return robust_loss, robust_accuracy



In [None]:
def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=8/2550,
                epsilon=8/255,
                perturb_steps=10,
                beta=1.0):
    '''
    Source https://github.com/yaodongyu/TRADES/blob/master/trades.py
    '''
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)

    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                   F.softmax(model(x_natural), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)

    # zero gradient
    optimizer.zero_grad()

    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                    F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

In [None]:
class LinfPGDAttack(nn.Module):
    def __init__(self, model, epsilon, steps=10, step_size=0.003):
        super().__init__()
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        x_adv = x_natural.clone().requires_grad_(True)
        with torch.enable_grad():
            for i in range(self.steps):

                self.model.zero_grad()
                # calculate loss
                output = self.model(x_adv)
                # Selecting the first column of y (assuming it's the identity label)
                loss = nn.CrossEntropyLoss()(output, y[:, 0])

                # gradient
                grad = torch.autograd.grad(loss, x_adv)[0]

                # clipping
                perturbation = torch.clamp(self.step_size * torch.sign(grad), -self.epsilon, self.epsilon)

                # clamping
                x_adv = torch.clamp(x_adv + perturbation, 0, 1)

        return x_adv

    def forward(self, x_natural, y):
        x_adv = self.perturb(x_natural, y)
        return x_adv

In [None]:
def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Selecting the first column of targets, assuming it represents the identity label
        labels = targets[:, 0]

        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)  # Use labels instead of targets

        elif mode == 'adv_train':  # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            adv_x = pgd_attack(inputs, targets)
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            loss = criterion(outputs, labels)  # Use labels instead of targets

        elif mode == 'adv_train_trades':  # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()

        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, targets)
        #     adv_x = pgd_attack(inputs, targets)
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, targets)

        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))

In [None]:
import torch.optim as optim

def train(model, train_loader, val_loader_f, val_loader_m, pgd_attack,
          mode='natural', epochs=25, batch_size=256, learning_rate=0.001, momentum=0.9, weight_decay=2e-4,
          checkpoint_path='model1.pt'):

    criterion = nn.CrossEntropyLoss()
    # change this to adam!!!
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)

    best_acc = 0.0
    for epoch in range(epochs):
        # training
        train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)


        val_loss_f, val_acc_f = evaluate(model, arc_head, val_loader_f, device)
        val_loss_m, val_acc_m = evaluate(model, arc_head, val_loader_m, device)

        val_acc = (val_acc_f + val_acc_m) / 2

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        # save checkpoint if is a new best
        if is_best:
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Avergae accuracy: {val_acc}, female: {val_acc_f}, male: {val_acc_m}')
        print('================================================================')



In [None]:
# import torch.optim as optim

# def train(model, train_loader, val_loader, pgd_attack,
#           mode='natural', epochs=25, batch_size=256, learning_rate=0.001, momentum=0.9, weight_decay=2e-4,
#           checkpoint_path='model1.pt'):
#     criterion = nn.CrossEntropyLoss()
#     # change this to adam!!!
#     optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

#     best_acc = 0
#     for epoch in range(epochs):
#         # training
#         train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)

#         # evaluate clean accuracy
#         test_loss, test_acc = evaluate(model, arc_head, val_loader, device)

#         # remember best acc@1 and save checkpoint
#         is_best = test_acc > best_acc
#         best_acc = max(test_acc, best_acc)

#         # save checkpoint if is a new best
#         if is_best:
#             torch.save(model.state_dict(), checkpoint_path)
#         print(f'Accuracy: {test_acc}')
#         print('================================================================')

In [None]:
#@title training loop with backbone using ArcFace

# sanity check for attack, rasmus' suggestion


epsilon = 8/255
pgd = LinfPGDAttack(model, epsilon=epsilon, step_size = epsilon/10, steps = 10)  # instantiate the LinfPGDAttack
training_mode = "adv_train"


# note for us!! check the training loop and confirm whether it is similar to the previous one we had


for proportion in proportions:
    model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)
    best_acc = 0.0
    train_loader = DataLoader(train_subsets[proportion], batch_size=batch_size, shuffle=True)

    # Initialize val_loader_f and val_loader_m to None
    val_loader_f = None
    val_loader_m = None

    # Check if the subsets have any samples before creating DataLoaders
    if len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=True)
    if len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=True)



    train(model, train_loader=train_loader, mode=training_mode, val_loader_f=val_loader_f, val_loader_m= val_loader_m, pgd_attack=pgd, learning_rate=0.001, checkpoint_path='model_adv.pt', epochs=70)




KeyboardInterrupt: 

In [None]:
# sanity check for attack, rasmus' suggestion

epsilon = 8/255
pgd = LinfPGDAttack(model, epsilon=epsilon, step_size = epsilon/10, steps = 10)  # instantiate the LinfPGDAttack
training_mode = "adv_train"


# note for us!! check the training loop and confirm whether it is similar to the previous one we had


for proportion in proportions:
    model = nn.Sequential(backbone, nn.Flatten(start_dim=1)).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(list(model.parameters()) + list(arc_head.parameters()), lr=1e-3)
    best_acc = 0.0
    train_loader = DataLoader(train_subsets[proportion], batch_size=batch_size, shuffle=True)

    # Initialize val_loader_f and val_loader_m to None
    val_loader_f = None
    val_loader_m = None

    # Check if the subsets have any samples before creating DataLoaders
    if len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=True)
    if len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=True)



    train(model, train_loader=train_loader, mode=training_mode, val_loader_f=val_loader_f, val_loader_m= val_loader_m, pgd_attack=pgd, learning_rate=0.001, checkpoint_path='model_adv.pt', epochs=70)


In [None]:
#@title adjusting this with a simpler model

import torch.nn as nn
import torch.nn.functional as F

class LinfPGDAttack(nn.Module):
    def __init__(self, model, epsilon, steps=10, step_size=0.003):
        super().__init__()
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        """
        Computes the gradient of the cross-entropy loss with respect to the input
        image `x_adv` and updates the image based on the gradient direction. The
        perturbation is clipped to ensure it stays within a specified epsilon range
        and is finally clamped to ensure pixel values are valid.

        The resulting perturbed image is returned.
        """
        # *********** Your code starts here ***********
        x_adv = x_natural.clone().requires_grad_(True)
        with torch.enable_grad():
            for i in range(self.steps):

                self.model.zero_grad()
                # calculate loss
                output = self.model(x_adv)
                loss = nn.CrossEntropyLoss()(output, y)


                # gradient
                grad = torch.autograd.grad(loss, x_adv)[0]


                # clipping
                perturbation = torch.clamp(self.step_size * torch.sign(grad), -self.epsilon, self.epsilon)

                # clamping
                x_adv = torch.clamp(x_adv + perturbation, 0, 1)





        # *********** Your code ends here *************

        return x_adv

    def forward(self, x_natural, y):
        x_adv = self.perturb(x_natural, y)
        return x_adv

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim



def make_dataloader(data_path, batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train)
    val_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader


def eval_test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            test_loss += F.cross_entropy(outputs, targets).item()
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(targets.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def eval_robust(model, test_loader, pgd_attack, device):
    model.eval()
    robust_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            adv = pgd_attack(inputs, targets)
            outputs = model(adv)
            robust_loss += F.cross_entropy(outputs, targets).item()
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(targets.view_as(pred)).sum().item()
    robust_loss /= len(test_loader.dataset)

    print('LinfPGD Attack: Average loss: {:.4f}, Robust Accuracy: {}/{} ({:.0f}%)'.format(
        robust_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    robust_accuracy = 100. * correct / len(test_loader.dataset)
    return robust_loss, robust_accuracy


def mixup_data(x, y, mixup_alpha=1.0):
    '''
    Source https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    '''
    lam = np.random.beta(mixup_alpha, mixup_alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    '''
    Source https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    '''
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=8/255,
                perturb_steps=10,
                beta=1.0):
    '''
    Source https://github.com/yaodongyu/TRADES/blob/master/trades.py
    '''
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)

    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                   F.softmax(model(x_natural), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)

    # zero gradient
    optimizer.zero_grad()

    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                    F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

In [None]:
def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        labels = targets[:, 0] # the first column is the identity label

        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        elif mode == 'adv_train': # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            adv_x = pgd_attack(inputs, targets)
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            loss = criterion(outputs, targets)

        elif mode == 'adv_train_trades': # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()
            loss = trades_loss(model=model, x_natural=inputs, y=targets, optimizer=optimizer)

        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, targets)
        #     adv_x = pgd_attack(inputs, targets)
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, targets)

        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))



In [None]:
def train(model, train_loader, val_loader, pgd_attack,
          mode='natural', epochs=25, batch_size=256, learning_rate=0.1, momentum=0.9, weight_decay=2e-4,
          checkpoint_path='model1.pt'):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

    best_acc = 0
    for epoch in range(epochs):
        # training
        train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)

        # evaluate clean accuracy
        test_loss, test_acc = eval_test(model, val_loader, device)

        # remember best acc@1 and save checkpoint
        is_best = test_acc > best_acc
        best_acc = max(test_acc, best_acc)

        # save checkpoint if is a new best
        if is_best:
            torch.save(model.state_dict(), checkpoint_path)
        print('================================================================')

In [None]:
#@title resnet module

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion * 4, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)


In [None]:
#@title adjusting this with a simpler model


class LinfPGDAttack(nn.Module):
    def __init__(self, model, epsilon, steps=10, step_size=0.003):
        super().__init__()
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        """
        Computes the gradient of the cross-entropy loss with respect to the input
        image `x_adv` and updates the image based on the gradient direction. The
        perturbation is clipped to ensure it stays within a specified epsilon range
        and is finally clamped to ensure pixel values are valid.

        The resulting perturbed image is returned.
        """
        # *********** Your code starts here ***********
        x_adv = x_natural.clone().requires_grad_(True)
        # Extract the identity label from the multi-dimensional target tensor
        labels = y[:, 0] # Assuming the first column is the identity label
        with torch.enable_grad():
            for i in range(self.steps):

                self.model.zero_grad()
                # calculate loss
                output = self.model(x_adv)
                # Use the extracted identity labels as the target for CrossEntropyLoss
                loss = nn.CrossEntropyLoss()(output, labels)


                # gradient
                grad = torch.autograd.grad(loss, x_adv)[0]


                # clipping
                perturbation = torch.clamp(self.step_size * torch.sign(grad), -self.epsilon, self.epsilon)

                # clamping
                x_adv = torch.clamp(x_adv + perturbation, 0, 1)

        # *********** Your code ends here *************

        return x_adv

    def forward(self, x_natural, y):
        x_adv = self.perturb(x_natural, y)
        return x_adv

# Modified train_ep function to handle multi-dimensional targets from CelebA
def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Extract the identity label from the multi-dimensional target tensor
        labels = targets[:, 0] # Assuming the first column is the identity label


        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train': # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            # Pass the original multi-dimensional targets to the attack
            adv_x = pgd_attack(inputs, targets) # The attack will extract labels internally
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train_trades': # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()
            # Note: The original trades_loss function expects a 1D target tensor.
            # You might need to adapt trades_loss similar to LinfPGDAttack if you plan to use this mode
            # with CelebA's multi-dimensional targets. For now, using extracted labels might work,
            # but the original TRADES formulation uses clean labels for the KL divergence part.
            # This might require further adjustments based on the specific TRADES implementation
            # you are using.
            # Assuming for now that trades_loss can handle the extracted 1D labels.
            loss = trades_loss(model=model, x_natural=inputs, y=labels, optimizer=optimizer)


        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels.
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, labels)
        #     adv_x = pgd_attack(inputs, targets) # Pass original targets to attack
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels from adv_x?
        #     # This part of mixup with adversarial training might need careful consideration of how targets are handled.
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, labels) # Using extracted labels


        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     # Use the extracted 1D labels for criterion
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))

# Keep the rest of the training loop as is, ensuring the train function calls the modified train_ep
# and evaluate functions (which already handle the identity label extraction).
# Make sure the 'train' function signature matches how it's called in the loop:
# train(model, train_loader=train_loader, mode=training_mode, val_loader_f=val_loader_f, val_loader_m= val_loader_m, pgd_attack=pgd, learning_rate=0.001, checkpoint_path='model_adv.pt', epochs=70)
# The 'train' function you provided in the notebook takes 'val_loader', but your calling code
# passes 'val_loader_f'. You should update the train function signature or the call site
# to be consistent. Since you added logic for val_loader_f and val_loader_m in the loop
# where the error occurred, it seems you intend to evaluate on female and male subsets
# separately during training. You should adjust the 'train' function to accept both
# val_loader_f and val_loader_m and call your 'evaluate' function twice for each epoch.


# Here's the updated train function to accept female and male validation loaders
def train(model, train_loader, val_loader_f, val_loader_m, pgd_attack,
          mode='natural', epochs=25, batch_size=256, learning_rate=0.001, momentum=0.9, weight_decay=2e-4,
          checkpoint_path='model1.pt'):

    criterion = nn.CrossEntropyLoss()
    # change this to adam!!!
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    # Note: Your training loop defined the optimizer with arc_head parameters.
    # The train function here needs to accept or define the optimizer appropriately.
    # Let's assume the optimizer is defined outside and passed in, or redefined here
    # to include arc_head parameters if using ArcFace. If using the simple ResNet18
    # without ArcFace for this adversarial training part, the current optimizer definition
    # (SGD or Adam on model.parameters() only) might be correct.
    # Based on the trace, it seems you are using the simple ResNet18 for the PGD attack part.
    # Let's keep the SGD optimizer as in the original train function definition.
    # If you are using ArcFace, you need to adapt this train function as well.

    # Based on the code in the failing cell block, the optimizer was defined *inside* the loop
    # and included arc_head parameters. Let's adjust the train function to match that,
    # assuming you are training with ArcFace in this section.
    # **However**, the eval_test and eval_robust functions called within this 'train' function
    # only take 'model', not 'arc_head'. You will need to adapt eval_test and eval_robust
    # if they are intended to work with the ArcFace setup (which predicts identity classes
    # using features from the backbone and the arc_head).
    # Let's assume for now that the adversarial training is on the simpler ResNet18
    # for a standard classification task (which seems to be what the PGD attack code expects),
    # and the ArcFace training is a separate block of code. If you intend to do adversarial
    # training *with* the ArcFace setup, you'll need more significant modifications to the
    # attack and evaluation functions.

    # Reverting to the optimizer definition as seen in the cell where the error occurred,
    # which includes arc_head parameters, assuming you want to train the entire ArcFace model adversarially.
    # **IMPORTANT:** This requires adapting eval_test and eval_robust to work with the ArcFace model.
    # For simplicity and to fix the immediate error, let's assume the PGD attack is on the
    # simple classification task and the ArcFace part is separate, or you need to pass both
    # model and arc_head to the train_ep and evaluation functions.

    # Let's assume the intention is to train the simple ResNet18 adversarially.
    # In this case, the optimizer should only optimize model.parameters().
    # If you want to adversarially train the ArcFace setup, you need to pass
    # both the model (backbone) and arc_head to train_ep and the attack.

    # Given the structure of the failing code block, it seems you are initializing a simple ResNet18
    # *inside* the loop for each proportion and then calling this `train` function.
    # This suggests adversarial training on the simple ResNet18.
    # The optimizer should then be for the model's parameters only.
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Using Adam as in your failing block, but only for model

    best_acc = 0.0 # Keep track of best average accuracy across genders

    for epoch in range(epochs):
        # training
        # Pass the extracted labels in train_ep as modified above
        train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)


        # Evaluate clean accuracy on both subsets using the eval_test function
        # Note: eval_test currently calculates standard classification accuracy,
        # not accuracy in an embedding space with ArcFace distance.
        # If you want to evaluate with ArcFace, you need a different evaluation function.
        # Assuming standard classification evaluation for now.
        val_acc_f = 0.0
        val_acc_m = 0.0
        val_loss_f = 0.0
        val_loss_m = 0.0

        if val_loader_f and len(val_loader_f.dataset) > 0:
            # eval_test is designed for 1D targets, but the DataLoader yields multi-dimensional targets.
            # You need to adapt eval_test or the DataLoader to yield 1D targets for this evaluation.
            # Or, pass the appropriate labels to eval_test.
            # Let's modify eval_test slightly to extract the identity label.

            # For consistency with the error source, let's assume eval_test needs 1D targets.
            # We'll extract the labels before calling eval_test. This is not ideal as it
            # involves iterating through the dataset again. A better approach would be to
            # modify eval_test to accept multi-dimensional targets and extract the label internally.
            # Given the constraint to fix the immediate error, let's assume eval_test needs 1D targets.
            # However, eval_test takes a DataLoader, so modifying it is better.

            # Let's adapt eval_test to extract the identity label.
            # Need to redefine eval_test to handle the multi-dimensional target.
            # **This requires modifying the eval_test function definition.**
            # Assuming the `evaluate` function defined earlier (which takes arc_head and extracts labels)
            # is the correct evaluation function for your setup, you should use that here.

            # Replacing eval_test calls with the `evaluate` function used previously.
            # This assumes `evaluate` is defined and available in this scope.
            # The `evaluate` function also requires the arc_head, which is not passed to this `train` function.
            # This highlights a mismatch in your code structure.
            # Either the `train` function needs `arc_head` or the adversarial training part is meant for
            # a model without ArcFace.

            # Let's assume you want to adversarially train the simple ResNet18.
            # In that case, the targets should be 1D class labels, and the model should output logits
            # for the number of classes (identities). The ResNet18 is instantiated with num_classes=1000,
            # which aligns with the idea of predicting identity.
            # The `CelebA` dataset yields `identity_labels` which is already a tensor where the first column
            # is the identity.

            # Let's return to modifying `train_ep` and `LinfPGDAttack` to handle the multi-dimensional
            # targets from the DataLoader and extract the first column for the loss calculation.
            # This was done in the code block above this function definition.

            # Now, for evaluation, we need a function that evaluates the model on the identity prediction task.
            # The `evaluate` function defined earlier does this with the ArcFace setup.
            # If you are using the simple ResNet18 directly for classification,
            # you need an evaluation function that takes the model and a DataLoader,
            # extracts the identity label, and calculates accuracy. Let's use `eval_test` but
            # modify it to extract the target.

            # Redefine eval_test here to handle CelebA targets
            def eval_test_celeba(model, dataloader, device):
                model.eval()
                test_loss = 0
                correct = 0
                total = 0
                with torch.no_grad():
                    for inputs, targets in dataloader:
                        inputs, targets = inputs.to(device), targets.to(device)
                        labels = targets[:, 0] # Extract identity label
                        outputs = model(inputs)
                        test_loss += F.cross_entropy(outputs, labels).item() * inputs.size(0)
                        pred = outputs.max(1, keepdim=True)[1]
                        correct += pred.eq(labels.view_as(pred)).sum().item()
                        total += inputs.size(0)
                test_loss /= total if total > 0 else 1
                accuracy = 100. * correct / total if total > 0 else 0

                print(f'Test: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.0f}%)')
                return test_loss, accuracy

            val_loss_f, val_acc_f = eval_test_celeba(model, val_loader_f, device)

        if val_loader_m and len(val_loader_m.dataset) > 0:
            val_loss_m, val_acc_m = eval_test_celeba(model, val_loader_m, device)

        val_acc = (val_acc_f + val_acc_m) / 2

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        # save checkpoint if is a new best
        if is_best:
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Average accuracy: {val_acc:.2f}, female: {val_acc_f:.2f}, male: {val_acc_m:.2f}')



In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
epsilon = 8/255
training_mode = "adv_train" # Or 'natural' if you want to train naturally
batch_size = 64

proportions = [0.25, 0.5, 0.75]

for proportion in proportions:
    # Re-initialize model and attack for each proportion if needed, otherwise move outside loop
    # If training separately for each proportion, re-initialization is correct.
    model = ResNet18(num_classes=1000).to(device) # ResNet for identity classification
    # Note: number of classes (1000) should match the number of unique identities
    # we filtered initially by top 1000 identitites but this might be limiting perhaps?
    # it gives very few examples on the test set


    num_identity_classes = 1000 # Assuming the ResNet18 model is configured for 1000 classes
    model = ResNet18(num_classes=num_identity_classes).to(device)

    pgd = LinfPGDAttack(model, epsilon=epsilon, step_size = epsilon/10, steps = 10)

    # train function definition already includes criterion and optimizer definition.
    # Move best_acc outside the inner epoch loop within the train function.
    # The train function saves checkpoint, so best_acc is managed internally.

    train_loader = DataLoader(train_subsets[proportion], batch_size=batch_size, shuffle=True)

    val_loader_f = None
    val_loader_m = None

    if proportion in test_subsets_f and len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation
    if proportion in test_subsets_m and len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation


    # call the modified train function
    train(model, train_loader=train_loader, mode=training_mode,
          val_loader_f=val_loader_f, val_loader_m=val_loader_m,
          pgd_attack=pgd, learning_rate=0.001,
          checkpoint_path=f'model_adv_prop{int(proportion*100)}.pt', epochs=70) # Save checkpoints with proportion

cuda
Test: Average loss: 0.3262, Accuracy: 346/377 (92%)
Test: Average loss: 0.3451, Accuracy: 341/377 (90%)
Average accuracy: 91.11, female: 91.78, male: 90.45
Test: Average loss: 0.6440, Accuracy: 285/377 (76%)
Test: Average loss: 0.6484, Accuracy: 280/377 (74%)
Average accuracy: 74.93, female: 75.60, male: 74.27
Test: Average loss: 0.4370, Accuracy: 344/377 (91%)
Test: Average loss: 0.4509, Accuracy: 337/377 (89%)
Average accuracy: 90.32, female: 91.25, male: 89.39
Test: Average loss: 0.4701, Accuracy: 346/377 (92%)
Test: Average loss: 0.4788, Accuracy: 341/377 (90%)
Average accuracy: 91.11, female: 91.78, male: 90.45
Test: Average loss: 0.3856, Accuracy: 346/377 (92%)
Test: Average loss: 0.3988, Accuracy: 341/377 (90%)
Average accuracy: 91.11, female: 91.78, male: 90.45
Test: Average loss: 0.3699, Accuracy: 346/377 (92%)
Test: Average loss: 0.3862, Accuracy: 341/377 (90%)
Average accuracy: 91.11, female: 91.78, male: 90.45
