<a href="https://colab.research.google.com/github/taweener11/darkSideUnmasked/blob/main/wip_demogpairs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

cores = os.cpu_count() # Count the number of cores in a computer
cores

8

In [3]:

#@title shell pipeline for unzipping! this needs to run every time

!unzip -q "/content/drive/My Drive/Datasets/demogpairs/DemogPairs.zip" -d "/content/demogpairs/"

In [4]:
def read_metadata_file(filepath, gender_label, race_label):
    """
    Read a DemogPairs metadata txt file and collect image paths with labels.

    Args:
      filepath (str): path to the metadata txt file
      gender_label (int): 0 for female, 1 for male
      race_label (str): string label for race, e.g. 'black', 'white', 'asian'

    Returns:
      List of tuples: (image_relative_path, gender_label, race_label)
    """
    samples = []
    with open(filepath, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.strip()
            if not line or line.lower().startswith('db_code'):
                continue
            parts = line.split()
            if len(parts) < 2:
                continue
            # parts[1] is the image path relative to DemogPairs folder
            img_path = parts[1]
            samples.append((img_path, gender_label, race_label))
    return samples


In [6]:
import os

# metadata_dir = '/content/drive/My Drive/demogpairs/Metadata'  # adjust path
metadata_dir = '/content/demogpairs/Metadata'  # edited this to run with the local environment



# Map filenames to gender and race labels
metadata_info = {
    'Black_Females.txt': (0, 'black'),
    'Black_Males.txt': (1, 'black'),
    'White_Females.txt': (0, 'white'),
    'White_Males.txt': (1, 'white'),
    'Asian_Females.txt': (0, 'asian'),
    'Asian_Males.txt': (1, 'asian')
}

all_samples = []

for fname, (gender, race) in metadata_info.items():
    full_path = os.path.join(metadata_dir, fname)
    print(f"Reading {full_path} ...")
    samples = read_metadata_file(full_path, gender, race)
    all_samples.extend(samples)

print(f"Total samples loaded: {len(all_samples)}")

Reading /content/demogpairs/Metadata/Black_Females.txt ...
Reading /content/demogpairs/Metadata/Black_Males.txt ...
Reading /content/demogpairs/Metadata/White_Females.txt ...
Reading /content/demogpairs/Metadata/White_Males.txt ...
Reading /content/demogpairs/Metadata/Asian_Females.txt ...
Reading /content/demogpairs/Metadata/Asian_Males.txt ...
Total samples loaded: 10800


In [16]:
import torch
from torchvision import datasets, transforms

In [29]:
# defining a transform that is smaller per suggestion of rasmus

image_size = 64

transform=transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                          std=[0.5, 0.5, 0.5])
])

In [69]:
#@title preliminary analysis on the dataset -- determine the number of examples for subsetting
demogpairs_root = '/content/demogpairs/DemogPairs'

# Build final dataset list with full paths
dataset = []
for rel_path, gender, race in all_samples:
    img_full_path = os.path.join(demogpairs_root, rel_path)
    if os.path.isfile(img_full_path):
        dataset.append((img_full_path, gender, race))
    else:
        print(f"Missing file: {img_full_path}")

print(f"Final dataset size after filtering missing files: {len(dataset)}")

# Extract identity labels from the image paths
# Assuming identity is the directory name right after demogpairs_root
identity_labels_list = []
for img_path, gender, race in dataset:
    # Split the path and get the second to last element (which should be the identity folder)
    parts = img_path.split(os.sep)
    # Find the index of demogpairs_root in the parts
    try:
        root_index = parts.index('DemogPairs')
        # The identity folder is expected to be the element after 'DemogPairs'
        if root_index + 1 < len(parts):
            identity = parts[root_index + 1]
            identity_labels_list.append(identity)
        else:
            # Handle cases where the path doesn't follow the expected structure
            print(f"Warning: Could not extract identity from path: {img_path}")
            identity_labels_list.append("unknown_identity") # Or handle as appropriate
    except ValueError:
        # Handle cases where 'DemogPairs' is not in the path
        print(f"Warning: 'DemogPairs' not found in path: {img_path}")
        identity_labels_list.append("unknown_identity") # Or handle as appropriate


# Convert to a pandas Series for easier counting
import pandas as pd
# Use the created list of identity labels
identity_series = pd.Series(identity_labels_list)

identity_counts = identity_series.value_counts()
# Select top 1000 identities. Ensure there are at least 1000 unique identities.
if len(identity_counts) >= 1000:
    top_1000_identities = identity_counts.nlargest(1000)
else:
    print(f"Warning: Less than 1000 unique identities found. Using all {len(identity_counts)} identities.")
    top_1000_identities = identity_counts

# Get the indices corresponding to the images belonging to the top 1000 identities
# We need the original indices from the `dataset` list
top_1000_identity_names = top_1000_identities.index.tolist()
top_1000_indices = [i for i, (img_path, gender, race) in enumerate(dataset)
                    if img_path.split(os.sep)[-2] in top_1000_identity_names]


# Create a subset of the dataset containing only the top 1000 identities
from torch.utils.data import Subset
# You can create a Subset using the original list and the selected indices
# Note: Subset is typically used with PyTorch Datasets, not plain Python lists.
# If you intend to use this with PyTorch DataLoader later, you might need to
# convert 'dataset' into a custom PyTorch Dataset first.
# For now, let's just have the list of tuples for the top 1000 identities:
dataset_top_1000 = [dataset[i] for i in top_1000_indices]


min_samples = top_1000_identities.min()
max_samples = top_1000_identities.max()

print(f"Minimum samples per identity: {min_samples}")
print(f"Maximum samples per identity: {max_samples}")
print(f"Number of samples in dataset_top_1000: {len(dataset_top_1000)}")

# printing the number of classes per group
for key, value in metadata_info.items():
    gender, race = value
    gender_str = 'male' if gender ==1 else 'female'
    count = len([s for s in dataset if s[1] == gender and s[2] == race])
    print(f'Count of {race} {gender_str} =' + str(count))


Final dataset size after filtering missing files: 10800
Minimum samples per identity: 18
Maximum samples per identity: 18
Number of samples in dataset_top_1000: 10800
Count of black female =1800
Count of black male =1800
Count of white female =1800
Count of white male =1800
Count of asian female =1800
Count of asian male =1800


In [71]:
#@title accessing the relevant indices and subsetting for the testsets
from sklearn.model_selection import train_test_split


dataset_gender_labels = np.array([s[1] for s in dataset])
dataset_race_labels = np.array([s[2] for s in dataset])

composite_labels = np.array([f"{gender}_{race}" for _, gender, race in dataset])


full_indices = np.arange(len(dataset))

# stratified train/test split by the composite gender-race label
train_indices, test_indices = train_test_split(
    full_indices,
    test_size=0.2,
    random_state=42,
    stratify=composite_labels # Use the composite labels here
)
print(f"Train samples: {len(train_indices)}, Test samples: {len(test_indices)}")

train_dataset = []
# Iterate through the indices in train_indices
for index in train_indices:
    # Get the tuple (img_full_path, gender, race) from the original dataset list using the index
    img_full_path, gender, race = dataset[index]
    # Check if the file exists before adding to the train_dataset (optional, but good practice)
    if os.path.isfile(img_full_path):
        train_dataset.append((img_full_path, gender, race))
    else:
        print(f"Missing file: {img_full_path}")

print(f"Train dataset size after filtering missing files: {len(train_dataset)}")

# Repeat the same logic for the test_indices to create the test_dataset
test_dataset = []
for index in test_indices:
    img_full_path, gender, race = dataset[index]
    if os.path.isfile(img_full_path):
        test_dataset.append((img_full_path, gender, race))
    else:
        print(f"Missing file: {img_full_path}")

print(f"Test dataset size after filtering missing files: {len(test_dataset)}")


Train samples: 8640, Test samples: 2160
Train dataset size after filtering missing files: 8640
Test dataset size after filtering missing files: 2160


In [65]:
#@title subset maker for specified distribution

import numpy as np
import os
from torch.utils.data import Subset
import pandas as pd # Import pandas for value_counts

rng = np.random.default_rng(seed=42)

races = ['black', 'white', 'asian']

def make_train_subsets_from_list(dataset_list, proportions, subgroup = (0, 'asian')):
    """
    Read a DemogPairs metadata txt file and collect image paths with labels.

    Args:
      dataset_list (list): A list of tuples (img_full_path, gender, race)
      proportions (list): A list of proportions for the subgroup.
      subgroup (tuple): (gender, race) for the subgroup to vary.

    Returns:
      Dict of torch.utils.data.Subset: Subsets of the original list, keyed by proportion.
    """
    train_subsets = {}

    # Extract identity labels, gender, and race directly from the input list
    dataset_identity_labels = [img_path.split(os.sep)[-2] for img_full_path, _, _ in dataset_list] # Assuming identity is the folder name
    dataset_gender_labels = [gender for _, gender, _ in dataset_list]
    dataset_race_labels = [race for _, _, race in dataset_list]

    # Calculate base_number based on the minimum samples per identity in this list
    identity_counts = pd.Series(dataset_identity_labels).value_counts()
    base_number = identity_counts.min() if not identity_counts.empty else 0

    # Map original identity names to a numerical label for easier processing
    unique_identities = np.unique(dataset_identity_labels)
    identity_mapping = {name: i for i, name in enumerate(unique_identities)}
    numerical_identity_labels = np.array([identity_mapping[name] for name in dataset_identity_labels])

    for prop in proportions:
        selected_original_indices_for_prop = [] # Collect original indices for the current proportion

        # the indices 'c' here refer to the numerical identity labels
        for c_num in np.unique(numerical_identity_labels):
            # Get the actual identity name
            identity_name = unique_identities[c_num]

            # Indices within the *current dataset_list* that correspond to identity 'c_num'
            indices_for_identity = np.where(numerical_identity_labels == c_num)[0]

            # Separate indices by gender and race *within this identity*
            main_sg_indices_for_identity = [
                idx for idx in indices_for_identity
                if dataset_gender_labels[idx] == subgroup[0] and dataset_race_labels[idx] == subgroup[1]
            ]

            rng.shuffle(main_sg_indices_for_identity)

            # Determine number of samples for the main subgroup
            if len(main_sg_indices_for_identity) < base_number:
                n_main_sg = int(np.floor(len(main_sg_indices_for_identity) * prop))
            else:
                n_main_sg = int(np.floor(base_number * prop))

            # Collect the original indices for the selected main subgroup samples
            selected_original_indices_for_prop.extend(main_sg_indices_for_identity[:n_main_sg])

            # selecting for the non-main subgroups *within this identity*
            for gender in range(2):
                for race in races:
                    if race != subgroup[1]:
                        subgroup_indices_for_identity = [
                            idx for idx in indices_for_identity
                            if dataset_gender_labels[idx] == gender and dataset_race_labels[idx] == race
                        ]
                        rng.shuffle(subgroup_indices_for_identity)

                        # Calculate how many samples from this subgroup to select
                        available_other_subgroups_count = 0
                        for g_other in range(2):
                            for r_other in races:
                                if r_other != subgroup[1]:
                                     if any(dataset_gender_labels[idx] == g_other and dataset_race_labels[idx] == r_other for idx in indices_for_identity):
                                         available_other_subgroups_count += 1

                        if available_other_subgroups_count > 0:
                            target_per_other_subgroup = int(np.floor((base_number * (1-prop)) / available_other_subgroups_count))
                        else:
                             target_per_other_subgroup = 0

                        # Number of samples to select for this specific non-main subgroup
                        n_subgroup = min(len(subgroup_indices_for_identity), target_per_other_subgroup)

                        # Collect the original indices for the selected non-main subgroup samples
                        selected_original_indices_for_prop.extend(subgroup_indices_for_identity[:n_subgroup])

        # Shuffle the collected original indices for the current proportion
        rng.shuffle(selected_original_indices_for_prop)

        # Create the Subset using the original list and the selected original indices
        train_subsets[prop] = Subset(dataset_list, selected_original_indices_for_prop)

    return train_subsets


In [72]:
#@title sanity check
from torch.utils.data import DataLoader

train_subsets = make_train_subsets_from_list(train_dataset, [0.25, 0.5, 0.75], subgroup = (0, 'asian'))

# get the first sample (index 0) from the Subset
first_sample = train_subsets[0.25][0]

# The structure of first_sample depends on your original dataset's __getitem__ method.
# Based on your 'dataset' list and 'make_train_subsets_from_list' function,
# the original dataset is a list of tuples (img_full_path, gender, race).
# A Subset wrapping this list will return the tuple directly.
# The 'transform' passed to the function is not applied by the Subset itself,
# but is intended to be used by a DataLoader or a custom Dataset class.

print("First sample (path, gender, race):")
print(first_sample)

# If you want to see more samples, you can loop
print("\nFirst 3 samples:")
for i in range(min(3, len(train_subsets[0.25]))):
    print(train_subsets[0.25][i])

# If your intention was to see the transformations applied, you would use a DataLoader:
train_loader_0_25 = DataLoader(train_subsets[0.25], batch_size=1, shuffle=False)
first_batch = next(iter(train_loader_0_25))
print("\nFirst batch from DataLoader (transformed image tensor, labels):")
print(first_batch)


First sample (path, gender, race):
('/content/demogpairs/DemogPairs/elizabeth_mitchell/043.jpg', 0, 'white')

First 3 samples:
('/content/demogpairs/DemogPairs/elizabeth_mitchell/043.jpg', 0, 'white')
('/content/demogpairs/DemogPairs/jamie_hector/0246_03.jpg', 1, 'black')
('/content/demogpairs/DemogPairs/cameron_crowe/010.jpg', 1, 'white')

First batch from DataLoader (transformed image tensor, labels):
[('/content/demogpairs/DemogPairs/elizabeth_mitchell/043.jpg',), tensor([0]), ('white',)]


In [44]:
def in_class(predict, label, classes):
    probs = torch.zeros(classes)
    for i in range(classes):
        # in_class_id = torch.tensor(label == i, dtype= torch.float)
        in_class_id = (label == i).clone().detach().float()
        # correct_predict = torch.tensor(predict == label, dtype= torch.float)
        correct_predict = (predict == label).clone().detach().float()
        in_class_correct_predict = (correct_predict) * (in_class_id)
        acc = torch.sum(in_class_correct_predict).item() / torch.sum(in_class_id).item()
        probs[i] = acc

    return probs

In [45]:
#@title pipeline for wandb

import wandb

In [46]:
# creating dataloaders
from torch.utils.data import DataLoader
batch_size = 64

In [47]:
import torch.nn as nn
import torch.nn.functional as F

## Various utility functions (not in utils yet)

In [49]:
def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=8/2550,
                epsilon=8/255,
                perturb_steps=10,
                beta=1.0):
    '''
    Source https://github.com/yaodongyu/TRADES/blob/master/trades.py
    '''
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)

    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                   F.softmax(model(x_natural), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)

    # zero gradient
    optimizer.zero_grad()

    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                    F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

In [50]:
class LinfPGDAttack(nn.Module):
    def __init__(self, model, epsilon, steps=10, step_size=0.003):
        super().__init__()
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        x_adv = x_natural.clone().requires_grad_(True)
        with torch.enable_grad():
            for i in range(self.steps):

                self.model.zero_grad()
                # calculate loss
                output = self.model(x_adv)
                # Selecting the first column of y (assuming it's the identity label)
                loss = nn.CrossEntropyLoss()(output, y[:, 0])

                # gradient
                grad = torch.autograd.grad(loss, x_adv)[0]

                # clipping
                perturbation = torch.clamp(self.step_size * torch.sign(grad), -self.epsilon, self.epsilon)

                # clamping
                x_adv = torch.clamp(x_adv + perturbation, 0, 1)

        return x_adv

    def forward(self, x_natural, y):
        x_adv = self.perturb(x_natural, y)
        return x_adv

In [53]:
#@title initializing a wandb run

# api key: bd1c08839d0c8c49e7c3efe9aabe2d9c644befb6

wandb.init(project="face-adv-fairness", name="demogpairs-demo", config={"learning_rate": 0.001, "epochs": 20})
wandb.finish()

In [56]:
#@title utils: pgd-attack

import torch.nn as nn
import torch.nn.functional as F

class LinfPGDAttack(nn.Module):
    def __init__(self, model, epsilon, steps=10, step_size=0.003):
        super().__init__()
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.step_size = step_size

    def perturb(self, x_natural, y):
        """
        Computes the gradient of the cross-entropy loss with respect to the input
        image `x_adv` and updates the image based on the gradient direction. The
        perturbation is clipped to ensure it stays within a specified epsilon range
        and is finally clamped to ensure pixel values are valid.

        The resulting perturbed image is returned.
        """
        x_adv = x_natural.clone().requires_grad_(True)
        with torch.enable_grad():
            for i in range(self.steps):

                self.model.zero_grad()
                # calculate loss
                output = self.model(x_adv)
                loss = nn.CrossEntropyLoss()(output, y)


                # gradient
                grad = torch.autograd.grad(loss, x_adv)[0]


                # clipping
                perturbation = torch.clamp(self.step_size * torch.sign(grad), -self.epsilon, self.epsilon)

                # clamping
                x_adv = torch.clamp(x_adv + perturbation, 0, 1)


        return x_adv

    def forward(self, x_natural, y):
        x_adv = self.perturb(x_natural, y)
        return x_adv

In [57]:
#@title utils: eval_test, eval_robust

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim



def eval_test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            test_loss += F.cross_entropy(outputs, targets).item()
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(targets.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)

    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy


def eval_robust(model, test_loader, pgd_attack, device):
    model.eval()
    robust_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            adv = pgd_attack(inputs, targets)
            outputs = model(adv)
            robust_loss += F.cross_entropy(outputs, targets).item()
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(targets.view_as(pred)).sum().item()
    robust_loss /= len(test_loader.dataset)

    print('LinfPGD Attack: Average loss: {:.4f}, Robust Accuracy: {}/{} ({:.0f}%)'.format(
        robust_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    robust_accuracy = 100. * correct / len(test_loader.dataset)
    return robust_loss, robust_accuracy


def mixup_data(x, y, mixup_alpha=1.0):
    '''
    Source https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    '''
    lam = np.random.beta(mixup_alpha, mixup_alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    '''
    Source https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
    '''
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=8/255,
                perturb_steps=10,
                beta=1.0):
    '''
    Source https://github.com/yaodongyu/TRADES/blob/master/trades.py
    '''
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)

    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    for _ in range(perturb_steps):
        x_adv.requires_grad_()
        with torch.enable_grad():
            loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                   F.softmax(model(x_natural), dim=1))
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)

    # zero gradient
    optimizer.zero_grad()

    # calculate robust loss
    logits = model(x_natural)
    loss_natural = F.cross_entropy(logits, y)
    loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                                    F.softmax(model(x_natural), dim=1))
    loss = loss_natural + beta * loss_robust
    return loss

In [73]:
def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        labels = targets[:, 0] # the first column is the identity label

        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        elif mode == 'adv_train': # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            adv_x = pgd_attack(inputs, targets)
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            loss = criterion(outputs, targets)

        elif mode == 'adv_train_trades': # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()
            loss = trades_loss(model=model, x_natural=inputs, y=targets, optimizer=optimizer)

        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, targets)
        #     adv_x = pgd_attack(inputs, targets)
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, targets)

        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))



In [59]:
#@title resnet module

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out



class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion * 4, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)


In [60]:
#@title modified train and test functions for celeba

def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Extract the identity label from the multi-dimensional target tensor
        labels = targets[:, 0] # Assuming the first column is the identity label


        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train': # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            # Pass the original multi-dimensional targets to the attack
            adv_x = pgd_attack(inputs, targets) # The attack will extract labels internally
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train_trades': # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()
            loss = trades_loss(model=model, x_natural=inputs, y=labels, optimizer=optimizer)


        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels.
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, labels)
        #     adv_x = pgd_attack(inputs, targets) # Pass original targets to attack
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels from adv_x?
        #     # This part of mixup with adversarial training might need careful consideration of how targets are handled.
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, labels) # Using extracted labels


        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     # Use the extracted 1D labels for criterion
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))

            wandb.log({f"train_loss {train_loader.dataset}": loss.item()}, step=epoch)


def train(model, train_loader, val_loader_f, val_loader_m, pgd_attack,
          mode='natural', epochs=25, batch_size=256, learning_rate=0.001, momentum=0.9, weight_decay=2e-4,
          checkpoint_path='model1.pt'):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Using Adam as in your failing block, but only for model

    best_acc = 0.0 # Keep track of best average accuracy across genders

    for epoch in range(epochs):
        # training
        # Pass the extracted labels in train_ep as modified above
        train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)

        val_acc_f = 0.0
        val_acc_m = 0.0
        val_loss_f = 0.0
        val_loss_m = 0.0

        if val_loader_f and len(val_loader_f.dataset) > 0:
            val_loss_f, val_acc_f = eval_test_celeba(model, val_loader_f, device, name = 'female')
            robust_loss_f, robust_accuracy_f = eval_robust_celeba(model, val_loader_f, pgd, device, name='female', epoch = epoch)


        if val_loader_m and len(val_loader_m.dataset) > 0:
            val_loss_m, val_acc_m = eval_test_celeba(model, val_loader_m, device, name = 'male')
            robust_loss_m, robust_accuracy_m = eval_robust_celeba(model, val_loader_m, pgd, device, name = 'male', epoch = epoch)



        val_acc = (val_acc_f + val_acc_m) / 2

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)


        # save checkpoint if is a new best
        if is_best:
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Average accuracy: {val_acc:.2f}, female: {val_acc_f:.2f}, male: {val_acc_m:.2f}')

        wandb.log({"val_loss_female": val_loss_f, "val_accuracy_female": val_acc_f,
               "val_loss_male": val_loss_m, "val_accuracy_male": val_acc_m,
               "average_val_accuracy": val_acc}, step=epoch)







def eval_test_celeba(model, dataloader, device, name):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            labels = targets[:, 0] # Extract identity label
            outputs = model(inputs)
            test_loss += F.cross_entropy(outputs, labels).item() * inputs.size(0)
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum().item()
            total += inputs.size(0)
    test_loss /= total if total > 0 else 1
    accuracy = 100. * correct / total if total > 0 else 0

    # print(f'Test: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.0f}%)')
    # wandb.log(f"clean_test_loss {name}: {test_loss}", step=epoch)
    # wandb.log(f"clean_test_accuracy {name}: {accuracy}", step=epoch)
    return test_loss, accuracy


# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
  # obtain confidence scores for all classes
  scores = F.softmax(outputs.data, dim=1)
  log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # adding ids based on the order of the images
  _id = 0
  for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
    # add required info to data table:
    # id, image pixels, model's guess, true label, scores for all classes
    img_id = str(_id) + "_" + str(log_counter)
    test_table.add_data(img_id, wandb.Image(i), p, l, *s)
    _id += 1
    if _id == batch_size:
      break


NUM_BATCHES_TO_LOG = 10

def eval_robust_celeba(model, dataloader, pgd_attack, device, name, epoch):
    model.eval()
    robust_loss = 0
    correct = 0
    total = 0

    success_count = 0
    log_counter = 0
    columns=["id", "image", "guess", "truth"]
    for image_id in dataloader.dataset.indices:
      columns.append("score_" + str(image_id))
    test_table = wandb.Table(columns=columns)


    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            labels = targets[:, 0] # extract identity label

            outputs_clean = model(inputs)
            pred_clean = outputs_clean.max(1, keepdim=True)[1]


            adv = pgd_attack(inputs, targets)
            outputs_adv = model(adv)
            robust_loss += F.cross_entropy(outputs_adv, labels).item()
            pred_adv = outputs_adv.max(1, keepdim=True)[1]
            correct += pred_adv.eq(labels.view_as(pred_adv)).sum().item()
            total += inputs.size(0)

            if log_counter < NUM_BATCHES_TO_LOG:
              log_test_predictions(inputs, labels, outputs_adv, pred_adv, test_table, log_counter)
              log_counter += 1

            # keeping track of successful attacks
            mask = pred_clean == labels
            succesful_attacks = (pred_adv != labels) & mask
            success_count += succesful_attacks.sum().item()



    attack_success_rate = success_count / correct if correct > 0 else 0
    print(f'Attack success rate: {attack_success_rate:.2f}%')
    robust_loss /= len(dataloader.dataset) if total > 0 else 1
    robust_accuracy = 100. * correct / total if total > 0 else 0

    print(f'LinfPGD Attack: Average loss: {robust_loss:.4f}, Robust Accuracy: {robust_accuracy:.0f}%)')

    wandb.log({f"robust_loss_{name}": robust_loss}, step=epoch)
    wandb.log({f"robust_accuracy_{name}": robust_accuracy}, step=epoch)
    wandb.log({f"attack_success_rate_{name}": attack_success_rate}, step=epoch)


    # ✨ W&B: Log predictions table to wandb
    wandb.log({"test_predictions" : test_table})

    return robust_loss, robust_accuracy


In [None]:
#@title small sanity check

wandb.init(project="face-adv-fairness", name="celeba-sanity-check", config={"learning_rate": 0.001, "epochs": 1})
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = ResNet18(num_classes=1000).to(device) # ResNet for identity classification
val_loader_f = DataLoader(test_subsets_f[0.25], batch_size=64, shuffle=False) # Shuffle usually False for validation
val_loader_m = DataLoader(test_subsets_m[0.25], batch_size=64, shuffle=False) # Shuffle usually False for validation
pgd = LinfPGDAttack(model, epsilon=8/255, step_size = 2/255, steps = 10)

robust_loss, robust_accuracy = eval_robust_celeba(model, val_loader_f, pgd, device, name = 'female', epoch = 0)

In [None]:
#@title training run: old
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
epsilon = 8/255
training_mode = "adv_train" # Or 'natural' if you want to train naturally
batch_size = 64

proportions = [0.25, 0.5, 0.75]

for proportion in proportions:
    # Re-initialize model and attack for each proportion if needed, otherwise move outside loop
    # If training separately for each proportion, re-initialization is correct.
    model = ResNet18(num_classes=1000).to(device) # ResNet for identity classification
    # Note: number of classes (1000) should match the number of unique identities
    # we filtered initially by top 1000 identitites but this might be limiting perhaps?
    # it gives very few examples on the test set
    # make a new run for each example
    wandb.init(project="face-adv-fairness", name=f"celeba-gender-{proportion}", config={"learning_rate": 0.001, "epochs": 30})


    num_identity_classes = 1000 # Assuming the ResNet18 model is configured for 1000 classes
    model = ResNet18(num_classes=num_identity_classes).to(device)

    pgd = LinfPGDAttack(model, epsilon=epsilon, step_size = epsilon/10, steps = 10)

    # train function definition already includes criterion and optimizer definition.
    # Move best_acc outside the inner epoch loop within the train function.
    # The train function saves checkpoint, so best_acc is managed internally.

    train_loader = DataLoader(train_subsets[proportion], batch_size=batch_size, shuffle=True)

    val_loader_f = None
    val_loader_m = None

    if proportion in test_subsets_f and len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation
    if proportion in test_subsets_m and len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation


    # call the modified train function
    train(model, train_loader=train_loader, mode=training_mode,
          val_loader_f=val_loader_f, val_loader_m=val_loader_m,
          pgd_attack=pgd, learning_rate=0.001,
          checkpoint_path=f'model_adv_prop{int(proportion*100)}.pt', epochs=20) # Save checkpoints with proportion





In [None]:
# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter):
  # obtain confidence scores for all classes
  scores = F.softmax(outputs.data, dim=1)
  log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # adding ids based on the order of the images
  _id = 0
  for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
    # Transpose image dimensions from (C, H, W) to (H, W, C) for wandb.Image
    i_transposed = np.transpose(i, (1, 2, 0))

    # add required info to data table:
    # id, image pixels, model's guess, true label, scores for all classes
    img_id = str(_id) + "_" + str(log_counter)
    # Use the transposed image data
    test_table.add_data(img_id, wandb.Image(i_transposed), p, l, *s)
    _id += 1
    if _id == batch_size:
      break

In [None]:
#@title modified train and test functions for celeba

def train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        # Extract the identity label from the multi-dimensional target tensor
        labels = targets[:, 0] # Assuming the first column is the identity label


        if mode == 'natural':
            optimizer.zero_grad()
            outputs = model(inputs)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train': # [Ref] https://arxiv.org/abs/1706.06083
            model.eval()
            # Pass the original multi-dimensional targets to the attack
            # The attack will extract the identity label internally using targets[:, 0]
            adv_x = pgd_attack(inputs, targets)
            model.train()

            optimizer.zero_grad()
            outputs = model(adv_x)
            # Use the extracted identity labels as the target for CrossEntropyLoss
            loss = criterion(outputs, labels)

        elif mode == 'adv_train_trades': # [Ref] https://arxiv.org/abs/1901.08573
            optimizer.zero_grad()
            # For trades_loss, you need to pass the identity labels as it's used directly in the loss calculation.
            loss = trades_loss(model=model, x_natural=inputs, y=labels, optimizer=optimizer)


        # elif mode == 'adv_train_mixup': # [Ref] https://arxiv.org/abs/1710.09412
        #     model.eval()
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels.
        #     benign_inputs, benign_targets_a, benign_targets_b, benign_lam = mixup_data(inputs, labels)
        #     # Pass original targets to attack
        #     adv_x = pgd_attack(inputs, targets)
        #     # Mixup needs 1D targets. You would need to modify mixup_data to work with the extracted labels from adv_x?
        #     # This part of mixup with adversarial training might need careful consideration of how targets are handled.
        #     adv_inputs, adv_targets_a, adv_targets_b, adv_lam = mixup_data(adv_x, labels) # Using extracted labels


        #     model.train()
        #     optimizer.zero_grad()

        #     benign_outputs = model(benign_inputs)
        #     adv_outputs = model(adv_inputs)
        #     # Use the extracted 1D labels for criterion
        #     loss_1 = mixup_criterion(criterion, benign_outputs, benign_targets_a, benign_targets_b, benign_lam)
        #     loss_2 = mixup_criterion(criterion, adv_outputs, adv_targets_a, adv_targets_b, adv_lam)

        #     loss = (loss_1 + loss_2) / 2

        else:
            print("No training mode specified.")
            raise ValueError()

        loss.backward()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{:05d}/{} ({:.0f}%)]\t Loss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(inputs), len(train_loader) * batch_size,
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))

            # It seems like train_loader.dataset is a Subset, not a string name.
            # Using the proportion from the loop might be better.
            # You might need to pass the proportion to train_ep or handle logging outside this loop.
            # For now, removing the dataset name from log key to avoid issues.
            wandb.log({"train_loss": loss.item()}, step=epoch)


def train(model, train_loader, val_loader_f, val_loader_m, pgd_attack,
          mode='natural', epochs=25, batch_size=256, learning_rate=0.001, momentum=0.9, weight_decay=2e-4,
          checkpoint_path='model1.pt'):

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Using Adam as in your failing block, but only for model

    best_acc = 0.0 # Keep track of best average accuracy across genders

    for epoch in range(epochs):
        # training
        # Pass the extracted labels in train_ep as modified above
        train_ep(model, train_loader, mode, pgd_attack, optimizer, criterion, epoch, batch_size)

        val_acc_f = 0.0
        val_acc_m = 0.0
        val_loss_f = 0.0
        val_loss_m = 0.0

        # Get the number of output classes from the model's linear layer
        num_classes = model.linear.out_features

        if val_loader_f and len(val_loader_f.dataset) > 0:
            val_loss_f, val_acc_f = eval_test_celeba(model, val_loader_f, device, name = 'female', epoch = epoch)
            # Pass num_classes to eval_robust_celeba
            robust_loss_f, robust_accuracy_f = eval_robust_celeba(model, val_loader_f, pgd_attack, device, name='female', epoch = epoch, num_classes=num_classes)


        if val_loader_m and len(val_loader_m.dataset) > 0:
            val_loss_m, val_acc_m = eval_test_celeba(model, val_loader_m, device, name = 'male', epoch = epoch)
            # Pass num_classes to eval_robust_celeba
            robust_loss_m, robust_accuracy_m = eval_robust_celeba(model, val_loader_m, pgd_attack, device, name = 'male', epoch = epoch, num_classes=num_classes)



        val_acc = (val_acc_f + val_acc_m) / 2

        # remember best acc@1 and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)


        # save checkpoint if is a new best
        if is_best:
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Average accuracy: {val_acc:.2f}, female: {val_acc_f:.2f}, male: {val_acc_m:.2f}')

        wandb.log({"val_loss_female": val_loss_f, "val_accuracy_female": val_acc_f,
               "val_loss_male": val_loss_m, "val_accuracy_male": val_acc_m,
               "average_val_accuracy": val_acc}, step=epoch)


def eval_test_celeba(model, dataloader, device, name, epoch): # Added epoch parameter for logging
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            labels = targets[:, 0] # Extract identity label
            outputs = model(inputs)
            test_loss += F.cross_entropy(outputs, labels).item() * inputs.size(0)
            pred = outputs.max(1, keepdim=True)[1]
            correct += pred.eq(labels.view_as(pred)).sum().item()
            total += inputs.size(0)
    test_loss /= total if total > 0 else 1
    accuracy = 100. * correct / total if total > 0 else 0

    print(f'Test {name}: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.0f}%)')
    # Log clean test loss and accuracy
    wandb.log({f"clean_test_loss_{name}": test_loss}, step=epoch)
    wandb.log({f"clean_test_accuracy_{name}": accuracy}, step=epoch)
    return test_loss, accuracy



In [None]:
# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, log_counter, batch_size_for_log):
  # obtain confidence scores for all classes
  scores = F.softmax(outputs.data, dim=1)
  # only log the score with the highest probability as the guess
  # log_scores = scores.cpu().numpy()
  log_images = images.cpu().numpy()
  log_labels = labels.cpu().numpy()
  log_preds = predicted.cpu().numpy()
  # adding ids based on the order of the images
  _id = 0
  for i, l, p in zip(log_images, log_labels, log_preds):
    # Transpose image dimensions from (C, H, W) to (H, W, C) for wandb.Image
    if p != l:
      # Transpose image dimensions from (C, H, W) to (H, W, C) for wandb.Image
      i_transposed = np.transpose(i, (1, 2, 0))

      # add required info to data table:
      # id, image pixels, model's incorrect guess, true label
      img_id = str(_id) + "_" + str(log_counter)
      # Use the transposed image data
      test_table.add_data(img_id, wandb.Image(i_transposed), p, l)


    _id += 1
    # Use the provided batch_size_for_log for comparison
    if _id == batch_size_for_log:
      break


NUM_BATCHES_TO_LOG = 10

# Added num_classes parameter
def eval_robust_celeba(model, dataloader, pgd_attack, device, name, epoch, num_classes):
    model.eval()
    robust_loss = 0
    correct = 0
    total = 0

    success_count = 0
    log_counter = 0
    # Initialize columns based on the number of classes
    columns=["id", "image", "guess", "truth"]
    # took out the other classes as cols bc it was tedious
    test_table = wandb.Table(columns=columns)


    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            labels = targets[:, 0] # extract identity label

            outputs_clean = model(inputs)
            pred_clean = outputs_clean.max(1, keepdim=True)[1]

            # Pass the original targets to the attack
            adv = pgd_attack(inputs, targets)
            outputs_adv = model(adv)
            robust_loss += F.cross_entropy(outputs_adv, labels).item() * inputs.size(0)
            pred_adv = outputs_adv.max(1, keepdim=True)[1]
            correct += pred_adv.eq(labels.view_as(pred_adv)).sum().item()
            total += inputs.size(0)

            if log_counter < NUM_BATCHES_TO_LOG:
                # Pass the actual batch size of the current inputs to log_test_predictions
                log_test_predictions(inputs, labels, outputs_adv, pred_adv, test_table, log_counter, inputs.size(0))
                log_counter += 1


            # keeping track of successful attacks
            # Ensure mask uses the correct comparison (pred_clean vs labels)
            mask = pred_clean.view_as(labels) == labels
            succesful_attacks = (pred_adv.view_as(labels) != labels) & mask
            success_count += succesful_attacks.sum().item()

    robust_loss /= total if total > 0 else 1
    robust_accuracy = 100. * correct / total if total > 0 else 0
    attack_success_rate = success_count / correct if correct > 0 else 0 # Calculate attack success rate based on correct predictions


    print(f'LinfPGD Attack {name}: Average loss: {robust_loss:.4f}, Robust Accuracy: {robust_accuracy:.0f}%)')
    print(f'Attack success rate {name}: {attack_success_rate:.2f}%')

    wandb.log({f"robust_loss_{name}": robust_loss}, step=epoch)
    wandb.log({f"robust_accuracy_{name}": robust_accuracy}, step=epoch)
    wandb.log({f"attack_success_rate_{name}": attack_success_rate}, step=epoch)


    # ✨ W&B: Log predictions table to wandb
    wandb.log({"test_predictions" : test_table}, step=epoch) # Log table at each epoch step

    return robust_loss, robust_accuracy

In [None]:
#@title training run: new, with balanced datasets

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
epsilon = 8/255
training_mode = "adv_train" # Or 'natural' if you want to train naturally
batch_size = 64

proportions = [0.25, 0.5, 0.75]

for proportion in proportions:
    # Re-initialize model and attack for each proportion if needed, otherwise move outside loop
    # If training separately for each proportion, re-initialization is correct.
    model = ResNet18(num_classes=999).to(device) # ResNet for identity classification
    # Note: number of classes (1000) should match the number of unique identities
    # it gives very few examples on the test set


    # make a new run for each example
    wandb.init(project="face-adv-fairness", name=f"celeba-gender-w-{proportion}", config={"learning_rate": 0.001, "epochs": 30})


    num_identity_classes = 999 # Assuming the ResNet18 model is configured for 1000 classes
    model = ResNet18(num_classes=num_identity_classes).to(device)

    pgd = LinfPGDAttack(model, epsilon=epsilon, step_size = epsilon/10, steps = 10)

    # train function definition already includes criterion and optimizer definition.
    # Move best_acc outside the inner epoch loop within the train function.
    # The train function saves checkpoint, so best_acc is managed internally.

    train_loader = DataLoader(train_subsets_new[proportion], batch_size=batch_size, shuffle=True)

    val_loader_f = None
    val_loader_m = None

    if proportion in test_subsets_f and len(test_subsets_f[proportion].indices) > 0:
        val_loader_f = DataLoader(test_subsets_f[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation
    if proportion in test_subsets_m and len(test_subsets_m[proportion].indices) > 0:
        val_loader_m = DataLoader(test_subsets_m[proportion], batch_size=batch_size, shuffle=False) # Shuffle usually False for validation


    # call the modified train function
    train(model, train_loader=train_loader, mode=training_mode,
          val_loader_f=val_loader_f, val_loader_m=val_loader_m,
          pgd_attack=pgd, learning_rate=0.001,
          checkpoint_path=f'model_adv_prop{int(proportion*100)}.pt', epochs=20) # Save checkpoints with proportion







cuda


0,1
attack_success_rate_female,▁▁▁▁▁▁▂█
attack_success_rate_male,▁▁▁▁▁▁▁█
average_val_accuracy,███████▁
clean_test_accuracy_female,███████▁█
clean_test_accuracy_male,███████▁
clean_test_loss_female,▆▆▅▄▄▃▂█▁
clean_test_loss_male,▆▇▅▅▄▂▁█
robust_accuracy_female,███████▁
robust_accuracy_male,███████▁
robust_loss_female,▃▂▁▁▂▁▁█

0,1
attack_success_rate_female,0.16429
attack_success_rate_male,0.18495
average_val_accuracy,85.86093
clean_test_accuracy_female,90.19868
clean_test_accuracy_male,85.82781
clean_test_loss_female,0.23455
clean_test_loss_male,0.3507
robust_accuracy_female,74.17219
robust_accuracy_male,73.04636
robust_loss_female,0.49018


Test female: Average loss: 0.3671, Accuracy: 1364/1510 (90%)
LinfPGD Attack female: Average loss: 0.3810, Robust Accuracy: 90%)
Attack success rate female: 0.00%
Test male: Average loss: 0.3903, Accuracy: 1352/1510 (90%)
LinfPGD Attack male: Average loss: 0.4038, Robust Accuracy: 90%)
Attack success rate male: 0.00%
Average accuracy: 89.93, female: 90.33, male: 89.54
Test female: Average loss: 0.3067, Accuracy: 1364/1510 (90%)
LinfPGD Attack female: Average loss: 0.3323, Robust Accuracy: 90%)
Attack success rate female: 0.00%
Test male: Average loss: 0.3216, Accuracy: 1352/1510 (90%)
LinfPGD Attack male: Average loss: 0.3471, Robust Accuracy: 90%)
Attack success rate male: 0.00%
Average accuracy: 89.93, female: 90.33, male: 89.54
Test female: Average loss: 0.2885, Accuracy: 1364/1510 (90%)
LinfPGD Attack female: Average loss: 0.3157, Robust Accuracy: 90%)
Attack success rate female: 0.00%
Test male: Average loss: 0.3124, Accuracy: 1352/1510 (90%)
LinfPGD Attack male: Average loss: 0.33