In [None]:
! conda install nb_conda_kernels ipywidgets -c conda-forge -n base -y
! conda env create -f environment.yml
! conda env update -f environment.yml

In [1]:
import torch
import torchvision
import pandas as pd
import PIL
from time import time
from torch import nn, optim
from torchvision import transforms
from variable_width_resnet import resnet10vw
import json
from math import ceil

In [2]:
data_dir = "/nobackup/data/celebA"

label = "Blond_Hair"
spurious = "Male"

batch_size_teacher = 1024
batch_size_student = 2048
max_workers = 64
epochs = 60
max_width = 96
folds = 10
temperature = 3.

device = torch.device("cuda:0") if torch.has_cuda else torch.device("cpu")

In [3]:
attributes = pd.read_csv(f"{data_dir}/list_attr_celeba.csv")

In [4]:
def ReweightedLoss(group_fractions):
    cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")
    
    def loss(logits, labels, groups):
        losses = cross_entropy(logits, labels)
        
        for group_id in range(group_fractions.shape[0]):
            losses[groups == group_id] /= group_fractions[group_id]
            
        return losses.mean()
        
    return loss

def DistillationLossReweighted(temperature, group_fractions):
    cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")
        
    def loss(student_logits, teacher_logits, target, groups):
        last_dim = len(student_logits.shape) - 1
        
        p_t = nn.functional.softmax(teacher_logits/temperature, dim=last_dim)
        log_p_s = nn.functional.log_softmax(student_logits/temperature, dim=last_dim)
        
        losses = cross_entropy(student_logits, target) - (p_t * log_p_s).sum(dim=last_dim)
        
        for group_id in range(group_fractions.shape[0]):
            losses[groups == group_id] /= group_fractions[group_id]
        
        return losses.mean()
    
    return loss

In [5]:
def evaluate(model, loader, name=""):
    total_correct = torch.zeros(2,2)
    total_items = torch.zeros(2,2)
    training = model.training
    model.eval()
    
    with torch.no_grad():
        for data, attrs in iter(loader):
            target = attrs[..., 1]
            data, target = data.to(device), target.to(device)
            
            out = model(data)
            
            correct = (out.argmax(axis=1) == target).to("cpu")
            batch_size = correct.shape[0]
            
            for batch_idx in range(batch_size):
                bin_idx = list(attrs[batch_idx])
                
                total_correct[bin_idx] += correct[batch_idx]
                total_items[bin_idx] += 1
                
    model.train(training)
                
    group_accuracies = total_correct / total_items
                
    print(f"{name} set results:")
    print(f"Average case accuracy: {total_correct.sum()}/{total_items.sum()}={total_correct.sum()/total_items.sum()}\n")
    
    status = lambda x: " not " if x == 0 else " "
    
    def from_flat_idx(x):
        idx0 = 1 if x > 1 else 0
        idx1 = x - 2 * idx0
        
        return idx0, idx1
    
    worst = from_flat_idx(group_accuracies.argmin())
    
    print(f"Worst case group:{status(worst[0])}{spurious.lower()} [spurious] and{status(worst[1])}{label.lower()} [target]")
    print(f"Worst case accuracy: {total_correct[worst]}/{total_items[worst]}={group_accuracies[worst]}\n")
    
#     best = from_flat_idx(group_accuracies.argmax())
#     print(f"Best case group:{status(best[0])}{spurious.lower()} [spurious] and{status(best[1])}{label.lower()} [target]")
#     print(f"Best case accuracy: {total_correct[best]}/{total_items[best]}={group_accuracies[best]}\n")

    return {
        "worst": {
            "correct": total_correct[worst].item(),
            "total": total_items[worst].item(),
        },
        "average": {
            "correct": total_correct.sum().item(),
            "total": total_items.sum().item(),
        }
    }

In [6]:
class CelebAFairnessDataset(torchvision.datasets.VisionDataset):
    def __init__(self, 
                 split,
                 attribute_df=attributes, 
                 is_training=True, 
                 augment=False,
                 spurious=spurious, 
                 label=label,
                 root=f"{data_dir}/img_align_celeba/"):        
        df = attribute_df.loc[split]
        
        self.files = root + df.image_id.values
        
        self.attrs = df[[spurious, label]].values
        
        # Negative cases are labeled 0 instead of -1
        self.attrs[self.attrs < 0] = 0
        
        # Overparameterization paper transform:
        orig_w = 178
        orig_h = 218
        orig_min_dim = min(orig_w, orig_h)
        target_resolution = (224, 224)

        if (not is_training) or (not augment):
            self.transform = transforms.Compose([
                transforms.CenterCrop(orig_min_dim),
                transforms.Resize(target_resolution),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            # Orig aspect ratio is 0.81, so we don't squish it in that direction any more
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    target_resolution,
                    scale=(0.7, 1.0),
                    ratio=(1.0, 1.3333333333333333),
                    interpolation=2),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        
    def __getitem__(self, idx):
        img = PIL.Image.open(self.files[idx])
        attrs = self.attrs[idx]
        
        
        return [
            self.transform(img),
            torch.Tensor(attrs).long()
        ]
        
    def __len__(self):
        return self.files.shape[0]

In [7]:
# permuted_indices = torch.randperm(attributes.index.values.shape[0])
# torch.save(permuted_indices, "permutation_cross_validation.pt")
permuted_indices = torch.load("permutation_cross_validation.pt")

In [8]:
def run_fold(fold, permuted_indices=permuted_indices, augment_teacher=False, augment_student=True, load=False):
    message = f"Starting fold {fold+1} of {folds}"
    print(message)
    print("-" * len(message))
    
    datapoints = permuted_indices.numel()
    indices = torch.arange(datapoints)
    test_size = ceil(datapoints/folds)
    test_end = ceil(datapoints * (fold + 1) / folds)
    test_start = test_end - test_size
    
    test = permuted_indices[test_start:test_end]
    train = permuted_indices[(indices > test_end) | (indices < test_start)]
    training = CelebAFairnessDataset(train, augment=augment_teacher)
    
    teacher_train_loader = torch.utils.data.DataLoader(training,
                                               num_workers=max_workers,
                                               batch_size=batch_size_teacher,
                                               shuffle=True,
                                               pin_memory=True)
    student_train_loader = torch.utils.data.DataLoader(CelebAFairnessDataset(train, augment=augment_student),
                                               num_workers=max_workers,
                                               batch_size=batch_size_student,
                                               shuffle=True,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(CelebAFairnessDataset(test),
                                               num_workers=max_workers,
                                               batch_size=batch_size_student,
                                               shuffle=False,
                                               pin_memory=True)
    
    test_results = lambda model: evaluate(model, test_loader, name="Test")
    
    attrs = torch.Tensor(training.attrs).long()
    group_fractions = torch.bincount(attrs[..., 0] + attrs[...,1]*2).float()
    group_fractions /= group_fractions.sum()
    
    if load:
        start_epoch, is_teacher, last_fold = torch.load("state.pt")
    
    if load and last_fold == fold:
        print("restored state")
        teacher = torch.load("teacher_model.pt").to(device)
        student = torch.load("student_model.pt").to(device)
        student_no_distillation = torch.load("student_no_distillation.pt").to(device)
        
        optimizer_teacher = torch.load("teacher_optimizer.pt")
        optimizer_student = torch.load("student_optimizer.pt")
        optimizer_no_distillation = torch.load("student_no_distillation_optimizer.pt")
    else:
        teacher = nn.DataParallel(resnet10vw(96, num_classes=2)).to(device)
        student =  nn.DataParallel(resnet10vw(10, num_classes=2)).to(device)
        student_no_distillation = nn.DataParallel(resnet10vw(10, num_classes=2)).to(device)

        optimizer_teacher = optim.Adam(teacher.parameters())
        optimizer_student = optim.Adam(student.parameters())
        optimizer_no_distillation = optim.Adam(student_no_distillation.parameters())
        
        start_epoch = -1
        is_teacher = True
    
    def save(epoch, is_teacher):
        torch.save(teacher, "teacher_model.pt")
        torch.save(student, "student_model.pt")
        torch.save(student_no_distillation, "student_no_distillation.pt")
        
        torch.save(optimizer_teacher, "teacher_optimizer.pt")
        torch.save(optimizer_student, "student_optimizer.pt")
        torch.save(optimizer_no_distillation, "student_no_distillation_optimizer.pt")
               
        torch.save([epoch, is_teacher, fold], "state.pt")

    if is_teacher:
        # Train teacher   
        print("\nTraining teacher model")
                                             
        criterion = ReweightedLoss(group_fractions)

        teacher.train()
        
        for epoch in range(start_epoch+1, epochs):
            epoch_loss = 0.
            start = time()

            for data, attrs in iter(teacher_train_loader):
                target = attrs[..., 1] # Spurious label is index 0
                groups = attrs[..., 0] + attrs[..., 1] * 2

                data, target = data.to(device), target.to(device)

                out = teacher(data)
                loss = criterion(out, target, groups)
                loss.backward()

                epoch_loss += loss.item()

                optimizer_teacher.step()      
                optimizer_teacher.zero_grad()

            end = time()

            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}: total loss {epoch_loss} in {end - start} seconds")
                save(epoch, True)
                
        start_epoch = -1

    # Train student
    print("\nTraining student models")
    
    criterion = DistillationLossReweighted(temperature, group_fractions)
    criterion_no_distillation = ReweightedLoss(group_fractions)
            
    teacher.eval()
    student.train()

    for epoch in range(start_epoch+1, epochs):
        epoch_loss = 0.
        epoch_loss_no_distillation = 0.
        start = time()

        for data, attrs in iter(student_train_loader):
            target = attrs[..., 1] # Spurious label is index 0
            groups = attrs[..., 0] + attrs[..., 1] * 2

            data, target = data.to(device), target.to(device)

            out_student = student(data)

            with torch.no_grad():
                out_teacher = teacher(data)

            loss = criterion(out_student, out_teacher, target, groups)
            loss.backward()

            epoch_loss += loss.item()

            optimizer_student.step()
            optimizer_student.zero_grad()
            
            out = student_no_distillation(data)
            loss = criterion_no_distillation(out, target, groups)
            loss.backward()

            epoch_loss_no_distillation += loss.item()

            optimizer_no_distillation.step()      
            optimizer_no_distillation.zero_grad()

        end = time()

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: total loss {epoch_loss} (with distillation) {epoch_loss_no_distillation} (without distillation) in {end - start} seconds")
            save(epoch, False)
            
    # Evaluate
    print("\n Teacher results:")
    teacher_results = test_results(teacher)
    print("\n Student results:")
    student_results = test_results(student)
    print("\n Student results (without distillation):")
    student_no_distillation_results = test_results(student_no_distillation)
    
    return {
        "teacher": teacher_results,
        "student": student_results,
        "student_no_distillation": student_no_distillation_results,
    }

In [None]:
results = []

for i in range(folds):
    results.append(run_fold(i))

    with open("results-unaugmented-teacher-augmented-student.json", "w") as f:
        json.dump(results, f)

Starting fold 1 of 10
---------------------

Training teacher model
Epoch 10: total loss 142.2655338048935 in 87.93975853919983 seconds
Epoch 20: total loss 82.69816225767136 in 85.7453465461731 seconds
Epoch 30: total loss 36.644175469875336 in 87.8168773651123 seconds
Epoch 40: total loss 19.08888928219676 in 84.54312634468079 seconds
Epoch 50: total loss 165.0139730423689 in 85.87446737289429 seconds
Epoch 60: total loss 1.4040055001387373 in 88.7175042629242 seconds

Training student model
Epoch 10: total loss 152.24468314647675 in 89.62184691429138 seconds
Epoch 20: total loss 163.03148370981216 in 84.49458503723145 seconds
Epoch 30: total loss 104.89422810077667 in 90.25111675262451 seconds
Epoch 40: total loss 103.32020539045334 in 94.48962211608887 seconds
Epoch 50: total loss 106.28642398118973 in 87.42717695236206 seconds
Epoch 60: total loss 95.63616496324539 in 90.29020571708679 seconds

 Teacher results:
Test set results:
Average case accuracy: 19177.0/20260.0=0.9465449452

In [20]:
for key in ["teacher", "student"]:
    print(key)

    get = lambda k: sum([i[key][k]["correct"] for i in results])/sum([i[key][k]["total"] for i in results])
    
    print(f"average: {get('average')}")

    print(f"worst: {get('worst')}")

teacher
average: 0.9449160908193485
worst: 0.5363064608347627
student
average: 0.9203849950641658
worst: 0.8201507489743345


### Augmented Student and Teacher

In [None]:
with open("results-augmented-teacher-augmented-student.json", "r") as f:
    results_augmented = json.load(f)
    
# results_augmented = []
folds_complete = len(results_augmented)

for i in range(folds_complete, folds):
    results_augmented.append(run_fold(i, augment_teacher=True, load=(i==folds_complete)))

    with open("results-augmented-teacher-augmented-student.json", "w") as f:
        json.dump(results_augmented, f)

Starting fold 9 of 10
---------------------
restored state

Training teacher model

Training student models
Epoch 10: total loss 504.9422287940979 (with distillation) 278.8648054599762 (without distillation) in 106.14713788032532 seconds
Epoch 20: total loss 499.7467966079712 (with distillation) 277.3459746837616 (without distillation) in 96.11409091949463 seconds


In [16]:
for key in results_augmented[0]:
    print(key)

    get = lambda k: sum([i[key][k]["correct"] for i in results_augmented])/sum([i[key][k]["total"] for i in results_augmented])
    
    print(f"average: {get('average')}")
    print(f"worst: {get('worst')}")

teacher
average: 0.9432329713721619
worst: 0.5946255002858777
student
average: 0.8647087857847976
worst: 0.23943396226415095
student_no_distillation
average: 0.8477986179664363
worst: 0.10624169986719788


In [17]:
results_augmented

[{'teacher': {'worst': {'correct': 129.0, 'total': 169.0},
   'average': {'correct': 18863.0, 'total': 20260.0}},
  'student': {'worst': {'correct': 120.0, 'total': 169.0},
   'average': {'correct': 18869.0, 'total': 20260.0}},
  'student_no_distillation': {'worst': {'correct': 138.0, 'total': 169.0},
   'average': {'correct': 17894.0, 'total': 20260.0}}},
 {'teacher': {'worst': {'correct': 96.0, 'total': 168.0},
   'average': {'correct': 19140.0, 'total': 20260.0}},
  'student': {'worst': {'correct': 121.0, 'total': 168.0},
   'average': {'correct': 18792.0, 'total': 20260.0}},
  'student_no_distillation': {'worst': {'correct': 117.0, 'total': 168.0},
   'average': {'correct': 18940.0, 'total': 20260.0}}},
 {'teacher': {'worst': {'correct': 76.0, 'total': 170.0},
   'average': {'correct': 19229.0, 'total': 20260.0}},
  'student': {'worst': {'correct': 132.0, 'total': 170.0},
   'average': {'correct': 18489.0, 'total': 20260.0}},
  'student_no_distillation': {'worst': {'correct': 119.0

#### Rerunning iteration 9

Iteration 9 appears to have had a bug from incorrectly restoring from a saved state; recalculating the results

In [18]:
results_augmented[8] = run_fold(8, augment_teacher=True)

Starting fold 9 of 10
---------------------

Training teacher model
Epoch 10: total loss 153.7123766541481 in 86.90923738479614 seconds
Epoch 20: total loss 115.20748114585876 in 86.31769490242004 seconds
Epoch 30: total loss 77.30689215660095 in 87.16270160675049 seconds
Epoch 40: total loss 42.02192668616772 in 85.42110061645508 seconds
Epoch 50: total loss 25.61875607073307 in 88.89468908309937 seconds
Epoch 60: total loss 2.5034093619324267 in 86.91546654701233 seconds

Training student models
Epoch 10: total loss 168.15435528755188 (with distillation) 71.89970940351486 (without distillation) in 103.67305898666382 seconds
Epoch 20: total loss 122.86082565784454 (with distillation) 45.19628405570984 (without distillation) in 106.7306900024414 seconds
Epoch 30: total loss 114.55890446901321 (with distillation) 39.20469355583191 (without distillation) in 103.05396866798401 seconds
Epoch 40: total loss 138.7109397649765 (with distillation) 52.43850961327553 (without distillation) in 10

In [20]:
for key in results_augmented[0]:
    print(key)

    get = lambda k: sum([i[key][k]["correct"] for i in results_augmented])/sum([i[key][k]["total"] for i in results_augmented])
    
    print(f"average: {get('average')}")
    print(f"worst: {get('worst')}")

teacher
average: 0.9430700888450148
worst: 0.5946255002858777
student
average: 0.9304689042448173
worst: 0.6861063464837049
student_no_distillation
average: 0.9255972359328727
worst: 0.6740994854202401


In [21]:
with open("results-augmented-teacher-augmented-student.json", "w") as f:
        json.dump(results_augmented, f)

### Unaugmented Student and Teacher

In [9]:
with open("results-unaugmented-teacher-unaugmented-student.json", "r") as f:
    results_unaugmented = json.load(f)

# results_unaugmented = []
folds_complete = len(results_unaugmented)

for i in range(folds_complete, folds):
    results_unaugmented.append(run_fold(i, augment_student=False, load=(i==folds_complete)))

    with open("results-unaugmented-teacher-unaugmented-student.json", "w") as f:
        json.dump(results_unaugmented, f)

Starting fold 4 of 10
---------------------
restored state

Training teacher model

Training student models
Epoch 10: total loss 504.7459120750427 (with distillation) 250.44861817359924 (without distillation) in 105.20854139328003 seconds
Epoch 20: total loss 504.5829710960388 (with distillation) 250.48265051841736 (without distillation) in 106.46074628829956 seconds
Epoch 30: total loss 506.7782826423645 (with distillation) 251.40503478050232 (without distillation) in 100.30206108093262 seconds
Epoch 40: total loss 504.2388153076172 (with distillation) 250.2957627773285 (without distillation) in 102.28609013557434 seconds
Epoch 50: total loss 504.8196077346802 (with distillation) 250.55920815467834 (without distillation) in 104.39407134056091 seconds
Epoch 60: total loss 507.2005515098572 (with distillation) 251.6191053390503 (without distillation) in 113.73196864128113 seconds

 Teacher results:
Test set results:
Average case accuracy: 17621.0/20260.0=0.8697433471679688

Worst case g

KeyboardInterrupt: 