In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

In [3]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torchvision import transforms, utils, models, datasets
from torch.utils.data import Dataset, DataLoader

In [4]:
from synthetic_utils import *

In [5]:
from trainer_dataloader import *
from networks import *
from losses import *

In [6]:
input_size = 96
batch_size = 64
num_workers = 4
num_epochs = 20

In [7]:
cuda = False
pin_memory = False
if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    cuda = True
    cudnn.benchmark = True
    pin_memory = True
else:
    device = torch.device("cpu")

print('Device set: {}'.format(device))

Device set: cuda


In [8]:
data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((input_size, input_size)),
#             transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
}

In [9]:
DATA_PATH = '/home/var/synthetic_data/dependent_gen/'
TRAIN_PATH = os.path.join(DATA_PATH, 'train')
VAL_PATH = os.path.join(DATA_PATH, 'valid')
TEST_PATH = os.path.join(DATA_PATH, 'test')

In [10]:
train_df = datasets.ImageFolder(root=TRAIN_PATH, transform=data_transforms['train'])
val_df = datasets.ImageFolder(root=VAL_PATH, transform=data_transforms['val'])
test_df = datasets.ImageFolder(root=TEST_PATH, transform=data_transforms['val'])

In [11]:
train_loader = DataLoader(train_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_df, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

## LAFTR Training

In [12]:
num_epochs = 500

In [13]:
from synthetic_dataloader import *
shapegender_train = ShapeGenderDataset(train_df)
shapegender_valid = ShapeGenderDataset(val_df)

In [14]:
laftrtrain_loader = DataLoader(shapegender_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
laftrval_loader = DataLoader(shapegender_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [15]:
laftr_encoder = LeNet()
laftr_adversary = ClassNet()
laftr_classifier = ClassNet()

In [16]:
# laftr_adv_criterion = AdvDemographicParityLoss()
laftr_adv_criterion = nn.BCELoss()
laftr_cls_criterion = nn.BCELoss()

In [17]:
laftr_opt_adv = optim.Adam(laftr_adversary.parameters(), lr=0.001, betas=(0.9, 0.999))
laftr_opt_cls = optim.Adam(laftr_classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
laftr_opt_enc = optim.Adam(laftr_encoder.parameters(), lr=0.001, betas=(0.9, 0.999))

laftr_scheduler_adv = lr_scheduler.StepLR(optimizer=laftr_opt_adv, gamma=0.99, step_size=1)
laftr_scheduler_cls = lr_scheduler.StepLR(optimizer=laftr_opt_cls, gamma=0.99, step_size=1)
laftr_scheduler_enc = lr_scheduler.StepLR(optimizer=laftr_opt_enc, gamma=0.99, step_size=1)
# laftr_scheduler_adv = lr_scheduler.CosineAnnealingLR(optimizer=laftr_opt_adv, T_max=num_epochs)
# laftr_scheduler_cls = lr_scheduler.CosineAnnealingLR(optimizer=laftr_opt_cls, T_max=num_epochs)
# laftr_scheduler_enc = lr_scheduler.CosineAnnealingLR(optimizer=laftr_opt_enc, T_max=num_epochs)

In [18]:


clsTrain_losses = []
clsTrain_accs = []
trainCombined_losses = []
# clsTrainCombined_losses = []
advTrain_losses = []
advTrain_accs = []
# advTrainCombined_losses = []

combinedVal_losses = []
clsVal_losses = []
clsVal_accs = []
advVal_losses = []
advVal_accs = []

epoch_time = AverageMeter()

In [None]:
ep_end = time.time()
for epoch in range(0, num_epochs):
#         print('-'*80)
        print('Epoch: {}/{}'.format(epoch, num_epochs))

        laftr_scheduler_adv.step()
        laftr_scheduler_cls.step()
        laftr_scheduler_enc.step()
        
        cls_loss, cls_en_acc, adv_loss, adv_acc, _ = alfr_train(laftrtrain_loader,
                                                        laftr_encoder, laftr_classifier, laftr_adversary, laftr_opt_enc,
                                                        laftr_opt_cls, laftr_opt_adv, 
                                                        laftr_cls_criterion, laftr_adv_criterion, device)
        
        clsTrain_losses.append(cls_loss)
        clsTrain_accs.append(cls_en_acc)
#         clsTrainCombined_losses.append(cls_en_combinedLoss)
        advTrain_losses.append(adv_loss)
        advTrain_accs.append(adv_acc)
#         trainCombined_losses.append(combined_loss)
#         advTrainCombined_losses.append(adv_combinedLoss)
        
        print('\nClassifier accuracy: {}\t Adversary Accuracy: {}'.format(cls_en_acc, adv_acc))
        # validate
        print('-'*10)
        
        combinedVal_loss, clsVal_loss, clsVal_acc, advVal_loss, advVal_acc = laftr_validate_dp(laftrval_loader,
                                                        laftr_encoder, laftr_classifier, laftr_adversary, 
                                                        laftr_cls_criterion, laftr_adv_criterion, device)
        
        combinedVal_losses.append(combinedVal_loss)
        clsVal_losses.append(clsVal_loss)
        clsVal_accs.append(clsVal_acc)
        advVal_losses.append(advVal_loss)
        advVal_accs.append(advVal_acc)
        
        print('%'*20)
        print('Classifier validation acc: {:.4f} \t Adv validation acc: {:.4f}'.format(clsVal_acc, advVal_acc))
        
        print('-' * 20)
        epoch_time.update(time.time() - ep_end)
        ep_end = time.time()
        print('Epoch {}/{}\t'
              'Time {epoch_time.val:.3f} sec ({epoch_time.avg:.3f} sec)'.format(epoch, num_epochs, epoch_time=epoch_time))
        print('-'*20)

Epoch: 0/500
$
Batch: [0/79]	Cls step loss:0.0000 (0.0000)	Adv step loss:-0.0124 (-0.0124)
Cls Acc:0.3594 (0.3594)	Adv Acc:0.5781 (0.5781)
$$$$$$$$$*$
Batch: [20/79]	Cls step loss:0.0092 (0.0092)	Adv step loss:0.0040 (-0.0020)
Cls Acc:0.4219 (0.4844)	Adv Acc:0.5312 (0.5126)
$$$$$$$$$$
Batch: [40/79]	Cls step loss:0.0092 (0.0092)	Adv step loss:0.0036 (0.0010)
Cls Acc:0.5000 (0.4889)	Adv Acc:0.5000 (0.4680)
*$$$$$$$$$$
Batch: [60/79]	Cls step loss:0.0033 (0.0062)	Adv step loss:0.0204 (0.0015)
Cls Acc:0.5625 (0.4857)	Adv Acc:0.4219 (0.4810)
$$*$*$$$$*$*$
Classifier accuracy: 0.4846	 Adversary Accuracy: 0.4796
----------


  "Please ensure they have the same size.".format(target.size(), input.size()))


Test batch: [0/8]	Time 0.478 (0.478)
Classifier loss 0.6913 (0.6913)	Adversary loss 0.7046 (0.7046)
Combined Loss 1.3959 (1.3959)	Classifier Accuracy 0.5469 (0.5469)	Adversary Accuracy 0.3594 (0.3594)
%%%%%%%%%%%%%%%%%%%%
Classifier validation acc: 0.5700 	 Adv validation acc: 0.5200
--------------------
Epoch 0/500	Time 7.952 sec (7.952 sec)
--------------------
Epoch: 1/500


  "Please ensure they have the same size.".format(target.size(), input.size()))


$
Batch: [0/79]	Cls step loss:0.0000 (0.0000)	Adv step loss:0.0037 (0.0037)
Cls Acc:0.5469 (0.5469)	Adv Acc:0.5000 (0.5000)
$$$$$$$$$$
Batch: [20/79]	Cls step loss:0.0000 (0.0000)	Adv step loss:0.0067 (0.0010)
Cls Acc:0.5625 (0.5320)	Adv Acc:0.4375 (0.5097)
$$$$$$$$$$
Batch: [40/79]	Cls step loss:0.0000 (0.0000)	Adv step loss:0.0031 (0.0022)
Cls Acc:0.5625 (0.5400)	Adv Acc:0.5156 (0.5000)
$*$$$*$$*$$$$
Batch: [60/79]	Cls step loss:-0.0051 (-0.0032)	Adv step loss:0.0037 (0.0030)
Cls Acc:0.6094 (0.5441)	Adv Acc:0.5000 (0.4949)
$$$$$$$$$
Classifier accuracy: 0.5444	 Adversary Accuracy: 0.4962
----------
Test batch: [0/8]	Time 0.510 (0.510)
Classifier loss 0.6846 (0.6846)	Adversary loss 0.6929 (0.6929)
Combined Loss 1.3775 (1.3775)	Classifier Accuracy 0.5938 (0.5938)	Adversary Accuracy 0.4375 (0.4375)
%%%%%%%%%%%%%%%%%%%%
Classifier validation acc: 0.5520 	 Adv validation acc: 0.5200
--------------------
Epoch 1/500	Time 8.002 sec (7.977 sec)
--------------------
Epoch: 2/500
$
Batch: [0/7

In [None]:
plt.figure(figsize=(20,20))
plt.subplot(431)
plt.title('Combined Loss')
plt.plot(trainCombined_losses, label='Train')
plt.plot(combinedVal_losses, label='Validation')
plt.legend()

plt.subplot(432)
plt.title('Cls-Enc Loss')
plt.plot(clsTrain_losses, label='Train')
plt.plot(clsVal_losses, label='Validation')
plt.legend()

plt.subplot(433)
plt.title('Cls-Enc Accuracy')
plt.plot(clsTrain_accs, label='Train')
plt.plot(clsVal_accs, label='Validation')
plt.legend()

plt.subplot(434)
plt.title('Adv Loss')
plt.plot(advTrain_losses, label='Train')
plt.plot(advVal_losses, label='Validation')
plt.legend()

plt.subplot(435)
plt.title('Adv Accuracy')
plt.plot(advTrain_accs, label='Train')
plt.plot(advVal_accs, label='Validation')
plt.legend()

plt.tight_layout()

In [None]:
gender_train = GenderDataset(train_df)
gender_valid = GenderDataset(val_df)

In [None]:
advtrain_loader = DataLoader(gender_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
advval_loader = DataLoader(gender_valid, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [None]:
adversary = ClassNet()

In [None]:
adv_criterion = nn.BCELoss()
# adv_criterion = AdvDemographicParityLoss()
opt_adv = optim.Adam(adversary.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler_adv = lr_scheduler.StepLR(optimizer=opt_adv, gamma=0.99, step_size=1)

In [None]:
num_epochs = 5
train_losses = []
train_accs = []
val_losses = []
val_accs = []
epoch_time = AverageMeter()
ep_end = time.time()
for epoch in range(0, num_epochs):
        print('Epoch: {}/{}'.format(epoch, num_epochs))
        scheduler_adv.step()
        # train
        train_loss, train_acc = train_classifier_epoch(advtrain_loader, laftr_encoder,
                                adversary, opt_adv, adv_criterion, device)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        # validate
        print('-'*10)
        val_loss, val_acc = validate_classifier_epoch(advval_loader, laftr_encoder, adversary,
                                 adv_criterion, device)

        print('Avg validation loss: {} \t Accuracy: {}'.format(val_loss, val_acc))
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        print('-' * 20)
        epoch_time.update(time.time() - ep_end)
        ep_end = time.time()
        print('Epoch {}/{}\t'
              'Time {epoch_time.val:.3f} sec ({epoch_time.avg:.3f} sec)'.format(epoch, num_epochs, epoch_time=epoch_time))
        print('-'*20)

In [None]:
plt.subplot(221)
plt.title('training classification loss')
plt.plot(train_losses)
plt.subplot(222)
plt.title('training accuracy')
plt.plot(train_accs)
plt.subplot(223)
plt.title('validation loss')
plt.plot(val_losses)
plt.subplot(224)
plt.title('validation accuracy')
plt.plot(val_accs)