In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

In [3]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms, utils, models, datasets
from torch.utils.data import Dataset, DataLoader

In [4]:
from synthetic_utils import *

In [5]:
from trainer_dataloader import *
from networks import *
from losses import *

In [6]:
input_size = 96
batch_size = 64
num_workers = 4
num_epochs = 20

In [7]:
cuda = False
pin_memory = False
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    cuda = True
    cudnn.benchmark = True
    pin_memory = True
else:
    device = torch.device("cpu")

print('Device set: {}'.format(device))

Device set: cuda


In [8]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
#             transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
}

In [9]:
DATA_PATH = '/home/var/synthetic_data/dependent_gen/'
TRAIN_PATH = os.path.join(DATA_PATH, 'train')
VAL_PATH = os.path.join(DATA_PATH, 'valid')
TEST_PATH = os.path.join(DATA_PATH, 'test')

In [10]:
train_df = datasets.ImageFolder(root=TRAIN_PATH, transform=data_transforms['train'])
val_df = datasets.ImageFolder(root=VAL_PATH, transform=data_transforms['val'])
test_df = datasets.ImageFolder(root=TEST_PATH, transform=data_transforms['val'])

In [11]:
train_loader = DataLoader(train_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

## LAFTR Training

In [12]:
num_epochs = 1000

In [13]:
from synthetic_dataloader import *
shapegender_train = ShapeGenderDataset(train_df)
shapegender_valid = ShapeGenderDataset(val_df)

In [14]:
laftrtrain_loader = DataLoader(shapegender_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
laftrval_loader = DataLoader(shapegender_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [15]:
laftr_encoder = LeNetBN()
# laftr_encoder = LeNet()
laftr_adversary = ClassNet()
laftr_classifier = ClassNet()

In [16]:
laftr_adv_criterion = AdvDemographicParityLoss()
# laftr_adv_criterion = nn.BCELoss()
laftr_cls_criterion = nn.BCELoss()

In [17]:
laftr_opt_adv = optim.Adam(laftr_adversary.parameters(), lr=0.001, betas=(0.9, 0.999))
laftr_opt_cls = optim.Adam(laftr_classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
laftr_opt_enc = optim.Adam(laftr_encoder.parameters(), lr=0.001, betas=(0.9, 0.999))

# laftr_opt_adv = optim.SGD(laftr_adversary.parameters(), lr=0.001, momentum=0.9)
# laftr_opt_cls = optim.SGD(laftr_classifier.parameters(), lr=0.001, momentum=0.9)
# laftr_opt_enc = optim.SGD(laftr_encoder.parameters(), lr=0.001, momentum=0.9)

laftr_scheduler_adv = lr_scheduler.StepLR(optimizer=laftr_opt_adv, gamma=0.99, step_size=1)
laftr_scheduler_cls = lr_scheduler.StepLR(optimizer=laftr_opt_cls, gamma=0.99, step_size=1)
laftr_scheduler_enc = lr_scheduler.StepLR(optimizer=laftr_opt_enc, gamma=0.99, step_size=1)

In [18]:
clsTrain_losses = []
clsTrain_accs = []
# trainCombined_losses = []
clsTrainCombined_losses = []
advTrain_losses = []
advTrain_accs = []
advTrainCombined_losses = []

combinedVal_losses = []
clsVal_losses = []
clsVal_accs = []
advVal_losses = []
advVal_accs = []

best_acc = 0.7
epoch_time = AverageMeter()


In [19]:
ep_end = time.time()
for epoch in range(0, num_epochs):
#         print('-'*80)
        print('Epoch: {}/{}'.format(epoch, num_epochs))

        laftr_scheduler_adv.step()
        laftr_scheduler_cls.step()
        laftr_scheduler_enc.step()
        
        cls_loss, cls_en_acc, adv_loss, adv_acc, cls_en_combinedLoss, adv_combinedLoss = alfr_train(laftrtrain_loader,
                                                        laftr_encoder, laftr_classifier, laftr_adversary, laftr_opt_enc,
                                                        laftr_opt_cls, laftr_opt_adv, 
                                                        laftr_cls_criterion, laftr_adv_criterion, device)
        
        clsTrain_losses.append(cls_loss)
        clsTrain_accs.append(cls_en_acc)
        clsTrainCombined_losses.append(cls_en_combinedLoss)
        advTrain_losses.append(adv_loss)
        advTrain_accs.append(adv_acc)
#         trainCombined_losses.append(combined_loss)
        advTrainCombined_losses.append(adv_combinedLoss)
        
        print('\nClassifier accuracy: {}\t Adversary Accuracy: {}'.format(cls_en_acc, adv_acc))
        # validate
        print('-'*10)
        
        combinedVal_loss, clsVal_loss, clsVal_acc, advVal_loss, advVal_acc = laftr_validate_dp(laftrval_loader,
                                                        laftr_encoder, laftr_classifier, laftr_adversary, 
                                                        laftr_cls_criterion, laftr_adv_criterion, device)
        
        combinedVal_losses.append(combinedVal_loss)
        clsVal_losses.append(clsVal_loss)
        clsVal_accs.append(clsVal_acc)
        advVal_losses.append(advVal_loss)
        advVal_accs.append(advVal_acc)
        
        print('%'*20)
        print('Classifier validation acc: {:.4f} \t Adv validation acc: {:.4f}'.format(clsVal_acc, advVal_acc))
        
        if clsVal_acc > best_acc and advVal_acc < 0.9:
            best_acc = clsVal_acc
            print('saving weights')
            torch.save(laftr_encoder, './weights/encoder_{}_{}.pth'.format(epoch, clsVal_acc))
            torch.save(laftr_classifier, './weights/cls_{}_{}.pth'.format(epoch, clsVal_acc))
            torch.save(laftr_adversary, './weights/adv_{}_{}.pth'.format(epoch, advVal_acc))

        print('-' * 20)
        epoch_time.update(time.time() - ep_end)
        ep_end = time.time()
        print('Epoch {}/{}\t'
              'Time {epoch_time.val:.3f} sec ({epoch_time.avg:.3f} sec)'.format(epoch, num_epochs, epoch_time=epoch_time))
        print('-'*20)

Epoch: 0/1000
$$*$$$$$$$$$$$$$$$$$$$
Batch: [40/79]	Cls step loss:0.7034 (0.7034)	Adv step loss:-0.7216 (-0.5964)
Cls Acc:0.6719 (0.6307)	Adv Acc:0.4688 (0.4093)
*$$$$$$$$$$$$$$$$*$$$
Classifier accuracy: 0.632	 Adversary Accuracy: 0.3428
----------


  "Please ensure they have the same size.".format(target.size(), input.size()))


%%%%%%%%%%%%%%%%%%%%
Classifier validation acc: 0.6780 	 Adv validation acc: 0.2280
--------------------
Epoch 0/1000	Time 16.161 sec (16.161 sec)
--------------------
Epoch: 1/1000


  "Please ensure they have the same size.".format(target.size(), input.size()))


$$$$$*$$$$$$$$$*$$$*$$$$
Batch: [40/79]	Cls step loss:0.6539 (0.6441)	Adv step loss:-0.5462 (-0.5534)
Cls Acc:0.7344 (0.7069)	Adv Acc:0.4688 (0.4466)
$$$$*$$$$*$$$$$$$$$$$
Classifier accuracy: 0.7036	 Adversary Accuracy: 0.4826
----------
%%%%%%%%%%%%%%%%%%%%
Classifier validation acc: 0.6640 	 Adv validation acc: 0.5740
--------------------
Epoch 1/1000	Time 15.656 sec (15.909 sec)
--------------------
Epoch: 2/1000
$*$$*$$$$$$$$$$$*$$$$$$$
Batch: [40/79]	Cls step loss:0.7663 (0.7085)	Adv step loss:-0.7644 (-0.7037)
Cls Acc:0.6406 (0.6288)	Adv Acc:0.6562 (0.5065)
$$*$$$$$$$$$$$$$$$$$
Classifier accuracy: 0.647	 Adversary Accuracy: 0.51
----------
%%%%%%%%%%%%%%%%%%%%
Classifier validation acc: 0.6720 	 Adv validation acc: 0.5400
--------------------
Epoch 2/1000	Time 16.140 sec (15.986 sec)
--------------------
Epoch: 3/1000
$$$$*$$$$$$$$$$*$$$$$$$
Batch: [40/79]	Cls step loss:0.9008 (0.7615)	Adv step loss:-0.7844 (-0.6929)
Cls Acc:0.5156 (0.6101)	Adv Acc:0.5312 (0.5293)
$$$$$$$*$$$$$

Process Process-190:
Process Process-191:
Process Process-189:
Traceback (most recent call last):
Process Process-192:
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 106, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/var/anaconda3/envs/diss/lib/python3.6/multiproces

Traceback (most recent call last):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-046a6df7f401>", line 29, in <module>
    laftr_cls_criterion, laftr_adv_criterion, device)
  File "/home/var/fairness/synthetic_expr/trainer_dataloader.py", line 439, in laftr_validate_dp
    for batch_idx, (imgs, shape, color) in enumerate(dataloader):
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 330, in __next__
    idx, batch = self._get_batch()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 309, in _get_batch
    return self.data_queue.get()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/queue.py", line 164, in get
    self.not_empty.wait()
  File "/home/var/anaconda3/envs/diss/lib/python3.6/threading.py", line 295, in wait
    

KeyboardInterrupt: 

KeyboardInterrupt


In [None]:
plt.figure(figsize=(20,20))
plt.subplot(431)
plt.title('Combined Val Loss')
# plt.plot(trainCombined_losses, label='Train')
plt.plot(combinedVal_losses, label='Validation')
plt.legend()

plt.subplot(432)
plt.title('Cls-Enc Loss')
plt.plot(clsTrain_losses, label='Train')
plt.plot(clsVal_losses, label='Validation')
plt.legend()

plt.subplot(433)
plt.title('Cls-Enc Accuracy')
plt.plot(clsTrain_accs, label='Train')
plt.plot(clsVal_accs, label='Validation')
plt.legend()

plt.subplot(434)
plt.title('Adv Loss')
plt.plot(advTrain_losses, label='Train')
plt.plot(advVal_losses, label='Validation')
plt.legend()

plt.subplot(435)
plt.title('Adv Accuracy')
plt.plot(advTrain_accs, label='Train')
plt.plot(advVal_accs, label='Validation')
plt.legend()

plt.subplot(436)
plt.title('Step Loss')
plt.plot(advTrainCombined_losses, label='Adversary')
plt.plot(clsTrainCombined_losses, label='Classifier')
plt.legend()

plt.tight_layout()

plt.savefig('./plots/laftr_train_{}{}.pdf'.format(time.localtime().tm_hour, time.localtime().tm_min))
pkl_path = './plots/metrics_{}{}.pkl'.format(time.localtime().tm_hour, time.localtime().tm_min)
with open(pkl_path, 'wb') as f:  # Python 3: open(..., 'wb')
    pickle.dump([advTrainCombined_losses, clsTrainCombined_losses, combinedVal_losses, clsTrain_losses,
                 clsVal_losses, clsTrain_accs, clsVal_accs, advTrain_losses, advVal_losses, advTrain_accs, advVal_accs], f)

In [None]:
gender_train = GenderDataset(train_df)
gender_valid = GenderDataset(val_df)

In [None]:
advtrain_loader = DataLoader(gender_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
advval_loader = DataLoader(gender_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [None]:
adversary = ClassNet()

In [None]:
adv_criterion = nn.BCELoss()
# adv_criterion = AdvDemographicParityLoss()
opt_adv = optim.Adam(adversary.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler_adv = lr_scheduler.StepLR(optimizer=opt_adv, gamma=0.99, step_size=1)

In [None]:
num_epochs = 5
train_losses = []
train_accs = []
val_losses = []
val_accs = []
epoch_time = AverageMeter()
ep_end = time.time()
for epoch in range(0, num_epochs):
        print('Epoch: {}/{}'.format(epoch, num_epochs))
        scheduler_adv.step()
        # train
        train_loss, train_acc = train_classifier_epoch(advtrain_loader, laftr_encoder,
                                adversary, opt_adv, adv_criterion, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        # validate
        print('-'*10)
        val_loss, val_acc = validate_classifier_epoch(advval_loader, laftr_encoder, adversary,
                                 adv_criterion, device)

        print('Avg validation loss: {} \t Accuracy: {}'.format(val_loss, val_acc))
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        print('-' * 20)
        epoch_time.update(time.time() - ep_end)
        ep_end = time.time()
        print('Epoch {}/{}\t'
              'Time {epoch_time.val:.3f} sec ({epoch_time.avg:.3f} sec)'.format(epoch, num_epochs, epoch_time=epoch_time))
        print('-'*20)

In [None]:
plt.subplot(221)
plt.title('training classification loss')
plt.plot(train_losses)
plt.subplot(222)
plt.title('training accuracy')
plt.plot(train_accs)
plt.subplot(223)
plt.title('validation loss')
plt.plot(val_losses)
plt.subplot(224)
plt.title('validation accuracy')
plt.plot(val_accs)