In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
import copy
from Models import *
from DataLoaders import *

In [3]:
train_labels = pd.read_csv('train_data_augmented_balanceddual.csv')
validation_labels = pd.read_csv('validation_data_augmented_balanceddual.csv')
train_labels.shape, validation_labels.shape

((6842, 166), (1711, 166))

In [4]:
def get_model(file,model=None):
    if model is None:
        model = torch.load(Constants.model_folder + file).to(torch.device('cpu'))
    model.load_state_dict(torch.load(Constants.model_folder + file + '_states'))
    model.eval()
    return model


In [None]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '3.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def subclass_accuracy(ypred,y,nclass=None):
    if nclass is None:
        nclass = len(torch.unique(y))
        classes = torch.unique(y)
    else:
        classes = torch.Tensor([i for i in range(nclass)]).float()
    results = torch.zeros(nclass)
    if ypred.ndim > 1:
        ypred = torch.argmax(ypred,1).long()
    for i,c in enumerate(torch.unique(y)):
        yy = (y == c)
        yypred = (ypred == c)
        good = torch.logical_and(yy,yypred).float().sum()/yy.float().sum()
        results[i] = good
    return results

def comp_score(accs,disp,csizes):
    score = 0
    for acc,disp,size in zip(accs,disp,csizes):
        score += size*(acc)*(1-disp**(size/2))
    return score

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=False,
                random_upsample=True,
                softmax_upsample_weights=False,
                upsample_validation=False,
                random_upsample_validation=False,
                embedding_loss_weight = 1,
                classification_loss_weight = 1,
                starting_loss=100000,
                bias_weight_transform=None,
                skintone_regression=False,
                noise=.05,
                skintone_patch_anchor_prob=0,
                run_name = '',
                **kwargs,
               ):
    pretraining = classification_loss_weight <= .00001
    if save_path is None:
        save_path = root + 'models/'+ run_name + model.get_identifier() 
        if random_upsample:
            
            save_path += '_rbalanced'
            if softmax_upsample_weights:
                save_path += 'soft'
        elif upsample:
            save_path += '_balanced'
        if pretraining:
            save_path += '_pretrain'
    if skintone_patch_anchor_prob > 0:
        save_path += '_spab'+str(skintone_patch_anchor_prob).replace('.','')
    save_path += '_lw' + '-'.join([str(l) for l in loss_weights])
    save_path += '_noise' + str(noise).replace('0.','').replace('.','')
    if upsample:
        patience = int(patience/5) + 1
    class_sizes = [10,4,2]
    if bias_weight_transform is not None:
        weight_cols = [c for c in train_df.columns if '_bias' in c]
        tdf = train_df.copy()
        vdf = validation_df.copy()
        for col in weight_cols:
            tdf[col] = tdf[col].apply(bias_weight_transform)
            vdf[col] = vdf[col].apply(bias_weight_transform)
        train_df = tdf
        validation_df = vdf
    train_loader = TripletFaceGenerator(train_df,Constants.data_root,
                                 batch_size=batch_size,
                                 upsample=upsample,
                                 random_upsample=random_upsample,
                                 skintone_patch_anchor_prob=skintone_patch_anchor_prob,
                                 softmax=softmax_upsample_weights,
                                 noise_sigma=noise,
                                 **kwargs)
    validation_loader = TripletFaceGenerator(validation_df,Constants.data_root,
                                      validation=True,
                                      batch_size=batch_size,
                                      upsample=upsample_validation,
                                     random_upsample=random_upsample_validation,
                                     softmax=softmax_upsample_weights,
                                     skintone_patch_anchor_prob=0,
                                     noise_sigma=noise,
                                      **kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    
    cel = torch.nn.CrossEntropyLoss()
    format_y = lambda y: y.to(device)
    regression_loss = torch.nn.MSELoss()
    triplet_loss = torch.nn.TripletMarginLoss()
#     embedding_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=lr)
#     decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def embedding_step(m,xbatch): 
        base = m.encoder(xbatch[0])
        positive = m.encoder(xbatch[1])
        negative = m.encoder(xbatch[2])
        loss = triplet_loss(base,positive,negative)
        loss = torch.mul(loss,embedding_loss_weight)
        return base,loss
    
    
    def classifier_step(m,embedding,ytrue):
        outputs = m.decoder(embedding)
        if not skintone_regression:
            losses = [cel(ypred.float(),format_y(y.long())) for y,ypred in zip(ytrue,outputs)]
            l1 = torch.mul(loss_weights[0],losses[0])
            l2 =  torch.mul(loss_weights[1],losses[1])
            l3 =  torch.mul(loss_weights[2],losses[2])
        else:
            l1 = torch.mul( regression_loss(outputs[0].float().view(-1),format_y(ytrue[0].float())), loss_weights[0])
            l2 = torch.mul( cel(outputs[1].float(),format_y(ytrue[1].long())), loss_weights[1])
            l3 = torch.mul(cel(outputs[2].float(),format_y(ytrue[2].long())), loss_weights[2])
        total_losses = l1 + l2 + l3
        total_loss = torch.mul(total_losses,classification_loss_weight)
        return outputs,total_losses
        
    def train_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            x_batch = format_batch(x_batch)
            optimizer.zero_grad()
            embedding,embedding_loss = embedding_step(model, x_batch)
            if pretraining:
                classification_loss = embedding_loss - embedding_loss
            else:
                outputs, classification_loss = classifier_step(model,embedding,y_batch)
            total_loss = classification_loss + embedding_loss
            total_loss.backward()
            optimizer.step()
            running_loss += classification_loss.item()
            running_embed_loss += embedding_loss.item()
            print('curr loss class',classification_loss.item(),'embed', embedding_loss.item(), 'step',i,' | ',end='\r')
            count += 1
            if not pretraining:
                with torch.no_grad():
                    for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        running_accuracy[i] += accuracy.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        running_f1 = [0,0,0]
        running_subclass_acc = [None,None,None]
        count = 0
        with torch.no_grad():
            for i, [x_batch, y_batch] in enumerate(validation_loader):
                x_batch = format_batch(x_batch,grad=False)
                embedding,embedding_loss = embedding_step(model, x_batch)
                if pretraining:
                    classification_loss = embedding_loss - embedding_loss
                else:
                    outputs, classification_loss = classifier_step(model,embedding,y_batch)
                running_loss += classification_loss.item()
                running_embed_loss += embedding_loss.item()
                count += 1
                if not pretraining:
                    for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        f1, precision, recall = Utils.macro_f1(torch.argmax(ypred.float(),axis=1),format_y(y))
                        subclass_acc = subclass_accuracy(ypred.float(),format_y(y),class_sizes[i])
                        if running_subclass_acc[i] is None:
                            running_subclass_acc[i] = subclass_acc
                        else:
                            running_subclass_acc[i] += subclass_acc
                        running_accuracy[i] += accuracy.item()
                        running_f1[i] += f1.item()
        running_subclass_acc = [torch.mul(r,1/count) for r in running_subclass_acc]
        disparities = [(torch.max(v) - torch.min(v)).item() for v in running_subclass_acc]
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy], [f/count for f in running_f1],disparities
    shorten = lambda array: [np.round(a, 3) for a in array]
    
    best_val_res = {}
    best_val_loss=1000000
    best_score = 0
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train()
        avg_loss,avg_embed_loss, avg_acc = train_epoch(model)
        print('train loss', avg_loss,'train embed loss',avg_embed_loss, 'train accuracy', shorten(avg_acc))
        model.eval()
        val_loss,val_embed_loss, val_acc, val_f1, val_disp = val_epoch(model)
        val_score = comp_score(val_acc,val_disp,class_sizes)
        if pretraining:
            val_loss = val_embed_loss
        print('val loss', val_loss, 'val_embed_loss', val_embed_loss, 
              'val accuracy', shorten(val_acc), 'val f1', shorten(val_f1),
                 'disparities',val_disp, 'score', val_score,
             )
        #don't save immediately in case I cancel training
        if best_score < val_score and epoch > 1:
            torch.save(model,save_path)
            torch.save(model.state_dict(),save_path+'_states')
            print('saving model')
            best_weights = copy.deepcopy(model.state_dict())
            best_val_loss = val_loss
            best_score = val_score
            best_val_res = {'accuracy': val_acc, 'f1': val_f1, 'disparity': val_disp,'score':best_score,'loss':val_loss}
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'score': val_score,
            'disp': val_disp,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    model.load_state_dict(best_weights)
    return model,hist,best_val_res,save_path

    
models = [
     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.1),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.1),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch_small'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch_small'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[100],gender_dims=[100]),name='gridsearch_medium'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch_small'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[100],gender_dims=[100]),name='gridsearch_medium'),
    #     lambda : TripletModel2(encoder=TripletFacenetEncoder(),name='gridsearch_baseline'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(base_model=ResNet18()),name='gridsearch_resnet'),
#     lambda : TripletModel2(encoder=SimpleEncoder(),name='gridsearch_simple'),
#     lambda : TripletModel2(encoder=FrozenDualEncoder(),name='gridsearch_frozen'),
#     lambda : TripletModel2(encoder=SimpleEncoder(fine_tune=False),name='gridsearch_simplefrozen'),
]

results = []
for embed_loss_weight in [.1]:
    for stap in [0]:
        for noise in [.01]:
            for loss_weights in [[5,2,1],[3,1,1],[2,1,2]]:
                for model in models:
                    for wt in ['softmax']:
                        if embed_loss_weight == 0 and wt != 'default':
                            continue
                        softmax = (wt == 'softmax')
                        bias_weight_transform=None
                        if wt == 'flat':
                            bias_weight_transform = lambda x: int(x > 0)
                        try:
                            m,h,entry,mname = train_model(
                                model(),
                                train_labels,
                                validation_labels,
                                Constants.data_root,
                                batch_size=50,
                                embedding_loss_weight=embed_loss_weight,
                                classification_loss_weight=1 - embed_loss_weight,
                                lr=.0001,
                                patience=5,
                                softmax_upsample_weights=softmax,
                                loss_weights=loss_weights,
                                skintone_patch_anchor_prob=stap,
                                skintone_regression=False,
                                noise=noise,
                            )
                            entry['noise'] = noise
                            entry['embed_weight'] = embed_loss_weight
                            entry['model'] = mname
                            entry['st_patch_prob'] = stap
                            entry['softmax'] = softmax
                            entry['weight_type'] = wt
                            entry['loss_weights'] = loss_weights
                            results.append(entry)
                            print('___________')
                            print(entry['score'],
                                  entry['accuracy'],
                                  entry['model'],
                                  entry['embed_weight'],
                                  wt,
                                 )
                            print('______________')
                            pd.DataFrame(results).to_csv(Constants.result_folder + '_weight_gridsearch2.csv')
                        except Exception as e:
                            print(e)

(4869, 167)
(1209, 167)
model being saved to ../../data/models/gridsearch_sm0lldualencoder_dualfacenet_h400_ed1triplet_decoder__st20_a8_g4_std2_ad2_gd2_rbalancedsoft_lw5-2-1_noise01
epoch 0
train loss 14.67518156401965 train embed loss 0.12983261204647775 train accuracy [0.161, 0.471, 0.648]
val loss 14.16243236541748 val_embed_loss 0.11980195984244346 val accuracy [0.242, 0.628, 0.77] val f1 [0.108, 0.305, 0.361] disparities [0.6498503088951111, 0.7573860883712769, 0.5373823046684265] score 3.9273209212496174
saved history to ../../data/results/history_gridsearch_sm0lldualencoder_dualfacenet_h400_ed1triplet_decoder__st20_a8_g4_std2_ad2_gd23.csv
epoch 1
train loss 14.243278892672791 train embed loss 0.1634918027082268 train accuracy [0.238, 0.568, 0.743]
val loss 13.839081535339355 val_embed_loss 0.144316383600235 val accuracy [0.269, 0.636, 0.842] val f1 [0.147, 0.318, 0.414] disparities [0.6150000095367432, 0.7797927856445312, 0.17185050249099731] score 4.847568465592514
saved histor

train loss 12.934541955286143 train embed loss 0.14099312052890964 train accuracy [0.402, 0.673, 0.837]
val loss 13.212791557312011 val_embed_loss 0.15812397509813308 val accuracy [0.32, 0.675, 0.909] val f1 [0.157, 0.318, 0.451] disparities [0.46133333444595337, 0.588626503944397, 0.0074236392974853516] score 6.697386010841393
saving model
saved history to ../../data/results/history_gridsearch_sm0lldualencoder_dualfacenet_h400_ed1triplet_decoder__st20_a8_g4_std2_ad2_gd23.csv
epoch 18
train loss 12.9856234180684 train embed loss 0.1393101379959559 train accuracy [0.387, 0.671, 0.834]
val loss 13.291358108520507 val_embed_loss 0.16231618881225585 val accuracy [0.304, 0.667, 0.905] val f1 [0.146, 0.318, 0.45] disparities [0.512666642665863, 0.5891267657279968, 0.007762312889099121] score 6.475416700236092
saved history to ../../data/results/history_gridsearch_sm0lldualencoder_dualfacenet_h400_ed1triplet_decoder__st20_a8_g4_std2_ad2_gd23.csv
epoch 19
train loss 12.878335563503967 train em

In [12]:
pd.DataFrame(results).sort_values('score',ascending=False)

Unnamed: 0,accuracy,f1,disparity,score,loss,noise,embed_weight,model,st_patch_prob,softmax,weight_type,loss_weights
16,"[0.33537777066230773, 0.7186666560173035, 0.90...","[0.17461035192012786, 0.35525026738643645, 0.4...","[0.43634524941444397, 0.23415428400039673, 0.0...",7.81362,5.663364,0.01,0.1,../../data/models/gridsearch_mediumdualencoder...,0,True,softmax,"[2, 1, 1]"
23,"[0.3525333261489868, 0.690666651725769, 0.9063...","[0.18393097698688507, 0.33637497305870057, 0.4...","[0.5077142715454102, 0.20435869693756104, 0.04...",7.780808,3.549181,0.01,0.1,../../data/models/gridsearch_smalldualencoder_...,0,False,flat,"[1, 1, 1]"
15,"[0.3553777664899826, 0.6983110976219177, 0.918...","[0.18819123089313508, 0.3374246382713318, 0.45...","[0.5330951809883118, 0.28069818019866943, 0.02...",7.771407,5.645075,0.01,0.1,../../data/models/gridsearch_mediumdualencoder...,0,False,default,"[2, 1, 1]"
22,"[0.35857776820659637, 0.6990222072601319, 0.91...","[0.17956785917282103, 0.3260649961233139, 0.45...","[0.4861614406108856, 0.3567206561565399, 0.007...",7.741283,3.541026,0.01,0.1,../../data/models/gridsearch_smalldualencoder_...,0,True,softmax,"[1, 1, 1]"
40,"[0.33608887910842894, 0.6803555417060853, 0.90...","[0.17204354524612428, 0.33915401458740235, 0.4...","[0.4580647051334381, 0.12481123208999634, 0.04...",7.694694,3.577917,0.1,0.1,../../data/models/gridsearch_smalldualencoder_...,0,True,softmax,"[1, 1, 1]"
12,"[0.342133327126503, 0.715111095905304, 0.91075...","[0.17218270450830458, 0.34607224583625795, 0.4...","[0.5049999952316284, 0.34833455085754395, 0.00...",7.634543,5.650136,0.01,0.1,../../data/models/gridsearch_smalldualencoder_...,0,False,default,"[2, 1, 1]"
13,"[0.33697776675224306, 0.691555540561676, 0.903...","[0.1631094115972519, 0.33891214430332184, 0.44...","[0.46533331274986267, 0.2836516797542572, 0.02...",7.609616,5.702067,0.01,0.1,../../data/models/gridsearch_smalldualencoder_...,0,True,softmax,"[2, 1, 1]"
4,"[0.34142221629619596, 0.69262220621109, 0.9159...","[0.18092736423015596, 0.3382172554731369, 0.45...","[0.5006271600723267, 0.3165661096572876, 0.014...",7.605159,3.562725,0.01,0.0,../../data/models/gridsearch_mediumdualencoder...,0,False,default,"[1, 1, 1]"
20,"[0.3245333254337311, 0.7074666547775269, 0.904...","[0.17057571351528167, 0.3438539642095566, 0.45...","[0.4639766216278076, 0.2616163492202759, 0.018...",7.586758,5.707014,0.01,0.1,../../data/models/gridsearch_baselinedualencod...,0,False,flat,"[2, 1, 1]"
17,"[0.3586666572093964, 0.70622220993042, 0.91875...","[0.18839144557714463, 0.3360950267314911, 0.45...","[0.503097414970398, 0.44563278555870056, 0.005...",7.562725,5.623144,0.01,0.1,../../data/models/gridsearch_mediumdualencoder...,0,False,flat,"[2, 1, 1]"


In [None]:
test = TripletModel2(encoder=TripletFacenetEncoder(),name='gridsearch_baseline')
test.encoder.embedding_size

In [9]:
print('lol')

lol
