In [24]:
! conda install nb_conda_kernels ipywidgets -c conda-forge -n base -y
! conda env create -f environment.yml
! conda env update -f environment.yml


CondaValueError: prefix already exists: /u/n/p/npalumbo/anaconda3/envs/distillation

Collecting package metadata (repodata.json): done
Solving environment: done


  current version: 4.9.0
  latest version: 4.9.2

Please update conda by running

    $ conda update -n base conda



Downloading and Extracting Packages
matplotlib-3.3.3     | 6 KB      | ##################################### | 100% 
expat-2.2.9          | 191 KB    | ##################################### | 100% 
libiconv-1.16        | 1.4 MB    | ##################################### | 100% 
pcre-8.44            | 261 KB    | ##################################### | 100% 
libxml2-2.9.9        | 1.3 MB    | ##################################### | 100% 
fontconfig-2.13.1    | 327 KB    | ##################################### | 100% 
pip-20.3.1           | 1.1 MB    | ##################################### | 100% 
xorg-libxau-1.0.9    | 13 KB     | ##################################### | 100% 
matplotlib-base-3.3. | 6.7 MB    |

In [1]:
import torch
import torchvision
import numpy as np
import pandas as pd
import PIL
from time import time
from torch import nn, optim
from torchvision import transforms
from variable_width_resnet import resnet10vw
import json
from math import ceil

In [2]:
data_dir = "/nobackup/data/celebA"

label = "Blond_Hair"
spurious = "Male"

batch_size_teacher = 1024
batch_size_student = 2048
max_workers = 64
epochs = 50
teacher_width = 96
student_width = 8
augmented=True
augment_student=augmented
augment_teacher=augmented
temperature = 3.

device = torch.device("cuda:0") if torch.has_cuda else torch.device("cpu")

In [3]:
attributes = pd.read_csv(f"{data_dir}/list_attr_celeba.csv")
partitions = pd.read_csv(f"{data_dir}/list_eval_partition.csv")

In [4]:
def ReweightedLoss(group_fractions):
    cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")
    
    def loss(logits, labels, groups):
        losses = cross_entropy(logits, labels)
        
        for group_id in range(group_fractions.shape[0]):
            losses[groups == group_id] /= group_fractions[group_id]
            
        return losses.mean()
        
    return loss

def DistillationLossReweighted(temperature, group_fractions):
    cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")
        
    def loss(student_logits, teacher_logits, target, groups):
        last_dim = len(student_logits.shape) - 1
        
        p_t = nn.functional.softmax(teacher_logits/temperature, dim=last_dim)
        log_p_s = nn.functional.log_softmax(student_logits/temperature, dim=last_dim)
        
        losses = cross_entropy(student_logits, target) - (p_t * log_p_s).sum(dim=last_dim)
        
        for group_id in range(group_fractions.shape[0]):
            losses[groups == group_id] /= group_fractions[group_id]
        
        return losses.mean()
    
    return loss

In [5]:
def evaluate(model, loader, name=""):
    total_correct = torch.zeros(2,2)
    total_items = torch.zeros(2,2)
    training = model.training
    model.eval()
    
    with torch.no_grad():
        for data, attrs in iter(loader):
            target = attrs[..., 1]
            data, target = data.to(device), target.to(device)
            
            out = model(data)
            
            correct = (out.argmax(axis=1) == target).to("cpu")
            batch_size = correct.shape[0]
            
            for batch_idx in range(batch_size):
                bin_idx = list(attrs[batch_idx])
                
                total_correct[bin_idx] += correct[batch_idx]
                total_items[bin_idx] += 1
                
    model.train(training)
                
    group_accuracies = total_correct / total_items
                
    print(f"{name} set results:")
    print(f"Average case accuracy: {total_correct.sum()}/{total_items.sum()}={total_correct.sum()/total_items.sum()}\n")
    
    status = lambda x: " not " if x == 0 else " "
    
    def from_flat_idx(x):
        idx0 = 1 if x > 1 else 0
        idx1 = x - 2 * idx0
        
        return idx0, idx1
    
    worst = from_flat_idx(group_accuracies.argmin())
    
    print(f"Worst case group:{status(worst[0])}{spurious.lower()} [spurious] and{status(worst[1])}{label.lower()} [target]")
    print(f"Worst case accuracy: {total_correct[worst]}/{total_items[worst]}={group_accuracies[worst]}\n")
    
    return {
        "worst": group_accuracies[worst].item(),
        "average": (total_correct.sum()/total_items.sum()).item()
    }

In [6]:
class CelebAFairnessDataset(torchvision.datasets.VisionDataset):
    def __init__(self, 
                 attribute_df=attributes, 
                 partition_df=partitions,
                 split="train", 
                 augment=False,
                 spurious=spurious, 
                 label=label,
                 root=f"{data_dir}/img_align_celeba/"):           
        partition_id = 0 if split == "train" else 1 if split == "valid" else 2
        
        df = attribute_df[partition_df.partition == partition_id]
        
        self.files = root + df.image_id.values
        
        self.attrs = df[[spurious, label]].values
        
        # Negative cases are labeled 0 instead of -1
        self.attrs[self.attrs < 0] = 0
        
        # Overparameterization paper transform:
        train = split == "train"
        orig_w = 178
        orig_h = 218
        orig_min_dim = min(orig_w, orig_h)
        target_resolution = (224, 224)

        if not augment:
            self.transform = transforms.Compose([
                transforms.CenterCrop(orig_min_dim),
                transforms.Resize(target_resolution),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            # Orig aspect ratio is 0.81, so we don't squish it in that direction any more
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(
                    target_resolution,
                    scale=(0.7, 1.0),
                    ratio=(1.0, 1.3333333333333333),
                    interpolation=2),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        
    def __getitem__(self, idx):
        img = PIL.Image.open(self.files[idx])
        attrs = self.attrs[idx]
        
        return [
            self.transform(img),
            torch.Tensor(attrs).long()
        ]
        
    def __len__(self):
        return self.files.shape[0]

In [7]:
training = CelebAFairnessDataset(augment=augment_teacher)

teacher_train_loader = torch.utils.data.DataLoader(training,
                                           num_workers=max_workers,
                                           batch_size=batch_size_teacher,
                                           shuffle=True,
                                           pin_memory=True)
student_train_loader = torch.utils.data.DataLoader(CelebAFairnessDataset(augment=augment_student),
                                           num_workers=max_workers,
                                           batch_size=batch_size_student,
                                           shuffle=True,
                                           pin_memory=True)
validation_loader = torch.utils.data.DataLoader(CelebAFairnessDataset(split="valid"),
                                           num_workers=max_workers,
                                           batch_size=batch_size_student,
                                           shuffle=False,
                                           pin_memory=True)
test_loader = torch.utils.data.DataLoader(CelebAFairnessDataset(split="test"),
                                           num_workers=max_workers,
                                           batch_size=batch_size_student,
                                           shuffle=False,
                                           pin_memory=True)

test_results = lambda model: evaluate(model, test_loader, name="Test")
valid_results = lambda model: evaluate(model, validation_loader, name="Validation")

attrs = torch.Tensor(training.attrs).long()
group_fractions = torch.bincount(attrs[..., 0] + attrs[...,1]*2).float()
group_fractions /= group_fractions.sum()

In [8]:
teacher = nn.DataParallel(resnet10vw(teacher_width, num_classes=2)).to(device)
student =  nn.DataParallel(resnet10vw(student_width, num_classes=2)).to(device)
student_no_distillation = nn.DataParallel(resnet10vw(student_width, num_classes=2)).to(device)

optimizer_teacher = optim.Adam(teacher.parameters())
optimizer_student = optim.Adam(student.parameters())
optimizer_no_distillation = optim.Adam(student_no_distillation.parameters())

In [9]:
teacher_test_results = []
teacher_validation_results  = []

# Train teacher   
print("\nTraining teacher model")

criterion = ReweightedLoss(group_fractions)

teacher.train()

for epoch in range(epochs):
    epoch_loss = 0.
    start = time()

    for data, attrs in iter(teacher_train_loader):
        target = attrs[..., 1] # Spurious label is index 0
        groups = attrs[..., 0] + attrs[..., 1] * 2

        data, target = data.to(device), target.to(device)

        out = teacher(data)
        loss = criterion(out, target, groups)
        loss.backward()

        epoch_loss += loss.item()

        optimizer_teacher.step()      
        optimizer_teacher.zero_grad()

    end = time()
    teacher_test_results.append(test_results(teacher))
    teacher_validation_results.append(valid_results(teacher))

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: total loss {epoch_loss} in {end - start} seconds")

    start_epoch = -1

# Train student
print("\nTraining student models")

criterion = DistillationLossReweighted(temperature, group_fractions)
criterion_no_distillation = ReweightedLoss(group_fractions)

teacher.eval()
student.train()
student_no_distillation.train()

student_test_results = []
student_validation_results = []
student_no_distillation_test_results = []
student_no_distillation_validation_results  = []

for epoch in range(epochs):
    epoch_loss = 0.
    epoch_loss_no_distillation = 0.
    start = time()

    for data, attrs in iter(student_train_loader):
        target = attrs[..., 1] # Spurious label is index 0
        groups = attrs[..., 0] + attrs[..., 1] * 2

        data, target = data.to(device), target.to(device)

        out_student = student(data)

        with torch.no_grad():
            out_teacher = teacher(data)

        loss = criterion(out_student, out_teacher, target, groups)
        loss.backward()

        epoch_loss += loss.item()

        optimizer_student.step()
        optimizer_student.zero_grad()

        out = student_no_distillation(data)
        loss = criterion_no_distillation(out, target, groups)
        loss.backward()

        epoch_loss_no_distillation += loss.item()

        optimizer_no_distillation.step()      
        optimizer_no_distillation.zero_grad()

    end = time()
    student_test_results.append(test_results(student))
    student_validation_results.append(valid_results(student))
    student_no_distillation_test_results.append(test_results(student_no_distillation))
    student_no_distillation_validation_results.append(valid_results(student_no_distillation))

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: total loss {epoch_loss} (with distillation) {epoch_loss_no_distillation} (without distillation) in {end - start} seconds")

results = {
    "teacher_size": teacher_width,
    "student_size": student_width,
    "teacher_augmentation": augment_teacher,
    "student_augmentation": augment_student,
    "teacher_test": teacher_test_results,
    "student_test": student_test_results,
    "student_no_distillation_test": student_no_distillation_test_results,
    "teacher_validation": teacher_validation_results,
    "student_validation": student_validation_results,
    "student_no_distillation_validation": student_no_distillation_validation_results,
}

with open(f"{teacher_width}-to-{student_width}{'-augmented' if augmented else ''}.json", "w") as f:
    json.dump(results, f)


Training teacher model
Test set results:
Average case accuracy: 16711.0/19962.0=0.8371405601501465

Worst case group: male [spurious] and blond_hair [target]
Worst case accuracy: 145.0/180.0=0.8055555820465088

Validation set results:
Average case accuracy: 16253.0/19867.0=0.8180903196334839

Worst case group: not male [spurious] and not blond_hair [target]
Worst case accuracy: 6496.0/8535.0=0.7611013650894165

Test set results:
Average case accuracy: 17230.0/19962.0=0.8631399869918823

Worst case group: not male [spurious] and not blond_hair [target]
Worst case accuracy: 8206.0/9767.0=0.8401761054992676

Validation set results:
Average case accuracy: 16794.0/19867.0=0.8453214168548584

Worst case group: not male [spurious] and not blond_hair [target]
Worst case accuracy: 6833.0/8535.0=0.8005858063697815

Test set results:
Average case accuracy: 17666.0/19962.0=0.8849814534187317

Worst case group: male [spurious] and not blond_hair [target]
Worst case accuracy: 6599.0/7535.0=0.875779