In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
import copy
from Models import *
from DataLoaders import *

In [3]:
train_labels = pd.read_csv('train_data_augmented_balanceddual.csv')
validation_labels = pd.read_csv('validation_data_augmented_balanceddual.csv')
train_labels.shape, validation_labels.shape

((6842, 166), (1711, 166))

In [4]:
def get_model(file,model=None):
    if model is None:
        model = torch.load(Constants.model_folder + file).to(torch.device('cpu'))
    model.load_state_dict(torch.load(Constants.model_folder + file + '_states'))
    model.eval()
    return model


In [None]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '3.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def subclass_accuracy(ypred,y,nclass=None):
    if nclass is None:
        nclass = len(torch.unique(y))
        classes = torch.unique(y)
    else:
        classes = torch.Tensor([i for i in range(nclass)]).float()
    results = torch.zeros(nclass)
    if ypred.ndim > 1:
        ypred = torch.argmax(ypred,1).long()
    for i,c in enumerate(torch.unique(y)):
        yy = (y == c)
        yypred = (ypred == c)
        good = torch.logical_and(yy,yypred).float().sum()/yy.float().sum()
        results[i] = good
    return results

def comp_score(accs,disp,csizes):
    score = 0
    for acc,disp,size in zip(accs,disp,csizes):
        score += size*(acc)*(1-disp**(size/2))
    return score

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=False,
                random_upsample=True,
                softmax_upsample_weights=False,
                upsample_validation=False,
                random_upsample_validation=False,
                embedding_loss_weight = 1,
                classification_loss_weight = 1,
                starting_loss=100000,
                bias_weight_transform=None,
                skintone_regression=False,
                noise=.05,
                skintone_patch_anchor_prob=0,
                run_name = '',
                smote_prob=.5,
                rotate_prob=.1,
                flip_prob=.7,
                **kwargs,
               ):
    pretraining = classification_loss_weight <= .00001
    if save_path is None:
        save_path = root + 'models/'+ run_name + model.get_identifier() 
        if random_upsample:
            
            save_path += '_rbalanced'
            if softmax_upsample_weights:
                save_path += 'soft'
        elif upsample:
            save_path += '_balanced'
        if pretraining:
            save_path += '_pretrain'
        else:
            save_path += 'embed' + str(np.round(embedding_loss_weight/classification_loss_weight,2))
    if skintone_patch_anchor_prob > 0:
        save_path += '_spab'+str(skintone_patch_anchor_prob).replace('.','')
    save_path += '_lw' + '-'.join([str(l) for l in loss_weights])
    save_path += '_noise' + str(noise).replace('0.','').replace('.','')
    if smote_prob > 0:
        save_path += '_smote' + str(smote_prob)
    if rotate_prob + flip_prob > 0:
        save_path += '_flprot' + str(flip_prob) + '-' + str(rotate_prob)
    save_path += "_V" + str(np.round(np.random.random(),4))
    if upsample:
        patience = int(patience/5) + 1
    class_sizes = [10,4,2]
    if bias_weight_transform is not None:
        weight_cols = [c for c in train_df.columns if '_bias' in c]
        tdf = train_df.copy()
        vdf = validation_df.copy()
        for col in weight_cols:
            tdf[col] = tdf[col].apply(bias_weight_transform)
            vdf[col] = vdf[col].apply(bias_weight_transform)
        train_df = tdf
        validation_df = vdf
    train_loader = TripletFaceGenerator(train_df,Constants.data_root,
                                 batch_size=batch_size,
                                 upsample=upsample,
                                 random_upsample=random_upsample,
                                 skintone_patch_anchor_prob=skintone_patch_anchor_prob,
                                 softmax=softmax_upsample_weights,
                                 noise_sigma=noise,
                                 smote_prob=.5,
                                 rotate_prob=rotate_prob,
                                 flip_prob=flip_prob,
                                 **kwargs)
    validation_loader = TripletFaceGenerator(validation_df,Constants.data_root,
                                      validation=True,
                                      batch_size=batch_size,
                                      upsample=upsample_validation,
                                     random_upsample=random_upsample_validation,
                                     softmax=softmax_upsample_weights,
                                     skintone_patch_anchor_prob=0,
                                     noise_sigma=noise,
                                      **kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    
    cel = torch.nn.CrossEntropyLoss()
    format_y = lambda y: y.to(device)
    regression_loss = torch.nn.MSELoss()
    triplet_loss = torch.nn.TripletMarginLoss()
#     embedding_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=lr)
#     decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def embedding_step(m,xbatch): 
        base = m.encoder(xbatch[0])
        positive = m.encoder(xbatch[1])
        negative = m.encoder(xbatch[2])
        loss = triplet_loss(base,positive,negative)
        loss = torch.mul(loss,embedding_loss_weight)
        return base,loss
    
    
    def classifier_step(m,embedding,ytrue):
        outputs = m.decoder(embedding)
        if not skintone_regression:
            losses = [cel(ypred.float(),format_y(y.long())) for y,ypred in zip(ytrue,outputs)]
            l1 = torch.mul(loss_weights[0],losses[0])
            l2 =  torch.mul(loss_weights[1],losses[1])
            l3 =  torch.mul(loss_weights[2],losses[2])
        else:
            l1 = torch.mul( regression_loss(outputs[0].float().view(-1),format_y(ytrue[0].float())), loss_weights[0])
            l2 = torch.mul( cel(outputs[1].float(),format_y(ytrue[1].long())), loss_weights[1])
            l3 = torch.mul(cel(outputs[2].float(),format_y(ytrue[2].long())), loss_weights[2])
        total_losses = l1 + l2 + l3
        total_loss = torch.mul(total_losses,classification_loss_weight)
        return outputs,total_losses
        
    def train_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            x_batch = format_batch(x_batch)
            optimizer.zero_grad()
            embedding,embedding_loss = embedding_step(model, x_batch)
            if pretraining:
                classification_loss = embedding_loss - embedding_loss
            else:
                outputs, classification_loss = classifier_step(model,embedding,y_batch)
            total_loss = classification_loss + embedding_loss
            total_loss.backward()
            optimizer.step()
            running_loss += classification_loss.item()
            running_embed_loss += embedding_loss.item()
            print('curr loss class',classification_loss.item(),'embed', embedding_loss.item(), 'step',i,' | ',end='\r')
            count += 1
            if not pretraining:
                with torch.no_grad():
                    for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        running_accuracy[i] += accuracy.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        running_f1 = [0,0,0]
        running_subclass_acc = [None,None,None]
        count = 0
        with torch.no_grad():
            for i, [x_batch, y_batch] in enumerate(validation_loader):
                x_batch = format_batch(x_batch,grad=False)
                embedding,embedding_loss = embedding_step(model, x_batch)
                if pretraining:
                    classification_loss = embedding_loss - embedding_loss
                else:
                    outputs, classification_loss = classifier_step(model,embedding,y_batch)
                running_loss += classification_loss.item()
                running_embed_loss += embedding_loss.item()
                count += 1
                if not pretraining:
                    for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        f1, precision, recall = Utils.macro_f1(torch.argmax(ypred.float(),axis=1),format_y(y))
                        subclass_acc = subclass_accuracy(ypred.float(),format_y(y),class_sizes[i])
                        if running_subclass_acc[i] is None:
                            running_subclass_acc[i] = subclass_acc
                        else:
                            running_subclass_acc[i] += subclass_acc
                        running_accuracy[i] += accuracy.item()
                        running_f1[i] += f1.item()
        running_subclass_acc = [torch.mul(r,1/count) for r in running_subclass_acc]
        disparities = [(torch.max(v) - torch.min(v)).item() for v in running_subclass_acc]
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy], [f/count for f in running_f1],disparities
    shorten = lambda array: [np.round(a, 3) for a in array]
    
    best_val_res = {}
    best_val_loss=1000000
    best_score = 0
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train()
        avg_loss,avg_embed_loss, avg_acc = train_epoch(model)
        print('train loss', avg_loss,'train embed loss',avg_embed_loss, 'train accuracy', shorten(avg_acc))
        model.eval()
        val_loss,val_embed_loss, val_acc, val_f1, val_disp = val_epoch(model)
        val_score = comp_score(val_acc,val_disp,class_sizes)
        if pretraining:
            val_loss = val_embed_loss
        print('val loss', val_loss, 'val_embed_loss', val_embed_loss, 
              'val accuracy', shorten(val_acc), 'val f1', shorten(val_f1),
                 'disparities',val_disp, 'score', val_score,
             )
        #don't save immediately in case I cancel training
        if best_score < val_score and epoch > 1:
            torch.save(model,save_path)
            torch.save(model.state_dict(),save_path+'_states')
            print('saving model')
            best_weights = copy.deepcopy(model.state_dict())
            best_val_loss = val_loss
            best_score = val_score
            best_val_res = {'accuracy': val_acc, 'f1': val_f1, 'disparity': val_disp,'score':best_score,'loss':val_loss}
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'score': val_score,
            'disp': val_disp,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    model.load_state_dict(best_weights)
    return model,hist,best_val_res,save_path

    
models = [
#      lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.1),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.1),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch_small'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
    lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch3_small'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.3),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[100],gender_dims=[100]),name='gridsearch3_medium'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[20],age_dims=[8],gender_dims=[4]),name='gridsearch_sm0ll'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch_small'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(embedding_dropout=.6),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[100],gender_dims=[100]),name='gridsearch_medium'),
#         lambda : TripletModel2(encoder=TripletFacenetEncoder(),name='gridsearch_baseline'),
#     lambda : TripletModel2(encoder=TripletFacenetEncoder(base_model=ResNet18()),decoder=TripletFacenetClassifier(400,st_dims=[100],age_dims=[40],gender_dims=[20]),name='gridsearch3_resnetsmall'),
    lambda : TripletModel2(encoder=SimpleEncoder(),name='gridsearch_simple'),
#     lambda : TripletModel2(encoder=FrozenDualEncoder(),name='gridsearch_frozen'),
    lambda : TripletModel2(encoder=SimpleEncoder(fine_tune=False),name='gridsearch_simplefrozen'),
]

results = []
best_score = 0
best_model = ''
for model in models:
    for embed_loss_weight in [.9]:
        for smote_prob in [.5]:
            for noise in [.01]:
                for loss_weights in [[2,1,2]]:
                    for rotate_prob in [0,.2,.5]:
                        for wt in ['softmax']:
                            if embed_loss_weight == 0 and wt == 'flat':
                                continue
                            softmax = (wt == 'softmax')
                            bias_weight_transform=None
                            if wt == 'flat':
                                bias_weight_transform = lambda x: int(x > 0)
                            try:
                                m,h,entry,mname = train_model(
                                    model(),
                                    train_labels,
                                    validation_labels,
                                    Constants.data_root,
                                    batch_size=50,
                                    embedding_loss_weight=embed_loss_weight,
                                    classification_loss_weight=1 - embed_loss_weight,
                                    lr=.0001,
                                    patience=7,
                                    softmax_upsample_weights=softmax,
                                    loss_weights=loss_weights,
                                    skintone_patch_anchor_prob=0,
                                    skintone_regression=False,
                                    smote_prob=smote_prob,
                                    rotate_prob=rotate_prob,
                                    noise=noise,
                                )
                                entry['noise'] = noise
                                entry['embed_weight'] = embed_loss_weight
                                entry['model'] = mname
                                entry['softmax'] = softmax
                                entry['weight_type'] = wt
                                entry['loss_weights'] = loss_weights
                                if entry['score'] > best_score:
                                    best_score = entry['score']
                                    best_model = entry['model']
                                results.append(entry)
                                print('___________','run',len(results)+1,'_________')
                                print(entry['score'],
                                      entry['accuracy'],
                                      entry['model'],
                                     )
                                print('best',best_score,best_model)
                                print('______________')
                                pd.DataFrame(results).to_csv(Constants.result_folder + '_flip_gridsearch.csv')
                            except Exception as e:
                                print(e)

(4869, 167)
(1209, 167)
model being saved to ../../data/models/gridsearch3_smalldualencoder_dualfacenet_h400_ed3triplet_decoder__st100_a40_g20_std2_ad2_gd2_rbalancedsoftembed9.0_lw2-1-2_noise01_smote0.5_flprot0.7-0_V0.183
epoch 0
train loss 7.158701794488089 train embed loss 1.012501631464277 train accuracy [0.154, 0.377, 0.74]
val loss 6.811118698120117 val_embed_loss 0.7784752404689789 val accuracy [0.226, 0.595, 0.856] val f1 [0.101, 0.294, 0.423] disparities [0.561637818813324, 0.7674609422683716, 0.14224380254745483] score 4.584991576101578
saved history to ../../data/results/history_gridsearch3_smalldualencoder_dualfacenet_h400_ed3triplet_decoder__st100_a40_g20_std2_ad2_gd23.csv
epoch 1
train loss 6.79366396884529 train embed loss 0.9420220429191783 train accuracy [0.233, 0.548, 0.827]
val loss 6.526293048858642 val_embed_loss 0.701758234500885 val accuracy [0.286, 0.594, 0.875] val f1 [0.155, 0.295, 0.434] disparities [0.55692058801651, 0.7606326937675476, 0.0968889594078064] sc

In [10]:
pd.DataFrame(results).sort_values('score',ascending=False).model.iloc[0]

'../../data/models/gridsearch_simpledualencoder_dualfacenet_h500_h500_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalancedsoftembed9.0_lw2-1-2_noise01_smote0.5_flprot0.7-0_V0.8086'

In [7]:
test = TripletModel2(encoder=TripletFacenetEncoder(),name='gridsearch_baseline')
test.encoder.embedding_size

400

In [8]:
print('lol')

lol
