In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
from Models import *
from DataLoaders import *

In [3]:
train_labels = pd.read_csv('train_data_clean.csv')
test_labels = pd.read_csv('test_data_clean.csv')
validation_labels = pd.read_csv('validation_data_clean.csv')
train_labels

Unnamed: 0,name,skin_tone,gender,age,is_face
0,TRAIN0001.png,0,0,1,False
1,TRAIN0002.png,5,1,0,True
2,TRAIN0005.png,1,1,0,False
3,TRAIN0007.png,1,0,1,True
4,TRAIN0009.png,7,0,1,False
...,...,...,...,...,...
6837,TRAIN9992.png,4,0,2,True
6838,TRAIN9993.png,1,1,1,True
6839,TRAIN9995.png,8,0,1,True
6840,TRAIN9998.png,4,1,1,False


In [None]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=False,
                random_upsample=True,
                softmax_upsample_weights=False,
                upsample_validation=False,
                random_upsample_validation=False,
                **kwargs,
               ):
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if upsample:
            save_path += '_balanced'
    if upsample:
        patience = int(patience/5) + 1
    train_loader = FaceGenerator(train_df,Constants.data_root,
                                 batch_size=batch_size,
                                 upsample=upsample,
                                 random_upsample=random_upsample,
                                 softmax=softmax_upsample_weights,
                                 **kwargs)
    validation_loader = FaceGenerator(validation_df,Constants.data_root,
                                      validation=True,
                                      batch_size=batch_size,
                                      upsample=upsample_validation,
                                     random_upsample=random_upsample_validation,
                                     softmax=softmax_upsample_weights,
                                      **kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    format_y = lambda y: y.long().to(device)
    
    def get_loss(m,xin,ytrue):
        outputs = m(xin)
        losses = [loss_fn(ypred.float(),format_y(y)) for y,ypred in zip(ytrue,outputs)]
        l1 = torch.mul(loss_weights[0],losses[0])
        l2 =  torch.mul(loss_weights[1],losses[1])
        l3 =  torch.mul(loss_weights[2],losses[2])
        total_losses = l1 + l2 + l3
        return outputs,total_losses

    
    def train_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            optimizer.zero_grad()
            if histogram:
                xh = torch_color_histogram(torch.clone(x_batch))
                xh = xh.to(device)
                xxb = x_batch.to(device)
                xh.requires_grad_(True)
                xxb.requires_grad_(True)
                xb = [xxb,xh]
            else:
                xb = x_batch.to(device)
                xb.requires_grad_(True)
            outputs,total_losses = get_loss(model, xb,y_batch)
            total_losses.backward()
            optimizer.step()
            running_loss += total_losses.item()
            print('curr loss',total_losses.item(), 'step',i,' | ',end='\r')
            count += 1
            with torch.no_grad():
                for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                    accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                    running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        running_f1 = [0,0,0]
        count = 0
        for i, [x_batch, y_batch] in enumerate(validation_loader):
            if histogram:
                xb = add_batch_histogram(x_batch,device=device,grad=False)
            else:
                xb = x_batch.to(device)
            outputs = model(xb)
            outputs,total_losses = get_loss(model,xb,y_batch)
            running_loss += total_losses.item()
            count += 1
            for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                f1, precision, recall = Utils.macro_f1(torch.argmax(ypred.float(),axis=1),format_y(y))
                running_f1[i] += f1.item()
                running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy], [f/count for f in running_f1]
    
    
    best_val_loss = 100000
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train(True)
        avg_loss, avg_acc = train_epoch()
        print('train loss', avg_loss, 'train accuracy', avg_acc)
        model.eval()
        val_loss, val_acc, val_f1 = val_epoch()
        print('val loss', val_loss, 'val accuracy', val_acc, 'val f1',val_f1)
#         torch.save(model.state_dict(), save_path + '_epoch' + str(epoch))
        if best_val_loss > val_loss:
            best_weights = model.state_dict()
            best_val_loss = val_loss
            steps_since_improvement = 0
            if epoch > 1:
                torch.save(model,save_path)
                torch.save(model.state_dict(),save_path+'_states')
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    return model,hist

m,h = train_model(
    DualFacenetModel(
        hidden_dims = [1000],
        embedding_dropout=.4,
        st_dropout = .2,
        age_dropout = .2,
        gender_dropout = .2,
    ),
    train_labels,
    validation_labels,
    Constants.data_root,
    batch_size=100,
#     histogram=True,
    lr=.0001,
#     lr=10,
#     upsample=False,
)
del m
h

(4869, 5)
(1209, 5)
model being saved to ../../data/models/dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2
epoch 0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train loss 6.265864177626007 train accuracy [0.17110026148813112, 0.4962200437273298, 0.5995149217089828]
val loss 6.054253431466909 val accuracy [0.20324785778155693, 0.5112820451076214, 0.5788034108968881] val f1 [0.15425822654595742, 0.33729202242997974, 0.36569592585930455]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 1
train loss 6.042875338573845 train accuracy [0.19851818422273715, 0.4966282066033811, 0.6299526618451489]
val loss 5.982596544119028 val accuracy [0.19102563537084138, 0.5112820451076214, 0.7028204936247605] val f1 [0.10486833120767887, 0.33729202242997974, 0.33180152682157665]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 2
train loss 6.010139163659543 train accuracy [0.1972227088954984, 0.4966282066033811, 0.6537888062243559]
val loss 5.8998511387751655 val accuracy [0.20324785778155693, 0.5112820451076214, 0.7913675078978906] val f1 [0.

train loss 5.674396067249532 train accuracy [0.3000887261361492, 0.6244838639181487, 0.7838597857222265]
val loss 5.597300896277795 val accuracy [0.29786324271788966, 0.6915384393471938, 0.7982905736336341] val f1 [0.1486168526686155, 0.33689115139154285, 0.39832684856194717]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 22
train loss 5.6758275226670865 train accuracy [0.29314994842422254, 0.6245755541081331, 0.7838804904295473]
val loss 5.5753099001370945 val accuracy [0.29931623488664627, 0.6799145157520587, 0.8783760437598596] val f1 [0.15752129170757073, 0.328329188319353, 0.43675888730929446]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 23
train loss 5.665095679614009 train accuracy [0.2996302812683339, 0.6255752614566258, 0.7880331058891452]
val loss 5.555079496823824 val accuracy [0.3055555488054569, 0.6815384442989643, 0.8814529730723455] val f1 [0.15

train loss 5.53846381635082 train accuracy [0.3558739958977213, 0.6447589312280927, 0.8005944770209643]
val loss 5.542310641362117 val accuracy [0.3325640971844013, 0.6760683564039377, 0.8452136516571045] val f1 [0.17393367680219504, 0.32748963282658505, 0.4206194304502927]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 43
train loss 5.554860309678681 train accuracy [0.35073054566675305, 0.645871031041048, 0.8016444639283784]
val loss 5.530455625974215 val accuracy [0.3279487078006451, 0.6776068210601807, 0.8745298981666565] val f1 [0.17288227952443636, 0.32813676045491147, 0.43512069949736965]
saved history to ../../data/results/history_dual_dualfacenet_h1000_st600_a400_g400_ed4_std2_ad2_gd2.csv
epoch 44
train loss 5.518769380997639 train accuracy [0.3675983335290636, 0.6502484302131497, 0.7999615292159878]
val loss 5.519388785729041 val accuracy [0.3317093929419151, 0.6899999930308416, 0.880683747621683] val f1 [0.1781532

In [None]:
DualHistogramModel()

In [None]:
import torchvision