In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
from Models import *
from DataLoaders import *
import copy

In [3]:
train_labels = pd.read_csv('train_data_clean.csv')
test_labels = pd.read_csv('test_data_clean.csv')
validation_labels = pd.read_csv('validation_data_clean.csv')
train_labels

Unnamed: 0,name,skin_tone,gender,age,is_face
0,TRAIN0001.png,0,0,1,False
1,TRAIN0002.png,5,1,0,True
2,TRAIN0005.png,1,1,0,False
3,TRAIN0007.png,1,0,1,True
4,TRAIN0009.png,7,0,1,False
...,...,...,...,...,...
6837,TRAIN9992.png,4,0,2,True
6838,TRAIN9993.png,1,1,1,True
6839,TRAIN9995.png,8,0,1,True
6840,TRAIN9998.png,4,1,1,False


In [5]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.0001,
                batch_size=200,
                patience = 20,
                loss_weights = [5,2,1],
                save_path=None,
                histogram =False,
                upsample=False,
                random_upsample=True,
                softmax_upsample_weights=False,
                upsample_validation=False,
                random_upsample_validation=False,
                **kwargs,
               ):
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if random_upsample:
            
            save_path += '_rbalanced'
            if softmax_upsample_weights:
                save_path += 'soft'
        elif upsample:
            save_path += '_balanced'
    if upsample:
        patience = int(patience/5) + 1
    train_loader = FaceGenerator(train_df,Constants.data_root,
                                 batch_size=batch_size,
                                 upsample=upsample,
                                 random_upsample=random_upsample,
                                 softmax=softmax_upsample_weights,
                                 **kwargs)
    validation_loader = FaceGenerator(validation_df,Constants.data_root,
                                      validation=True,
                                      batch_size=batch_size,
                                      upsample=upsample_validation,
                                     random_upsample=random_upsample_validation,
                                     softmax=softmax_upsample_weights,
                                      **kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    format_y = lambda y: y.long().to(device)
    
    def get_loss(m,xin,ytrue):
        outputs = m(xin)
        losses = [loss_fn(ypred.float(),format_y(y)) for y,ypred in zip(ytrue,outputs)]
        l1 = torch.mul(loss_weights[0],losses[0])
        l2 =  torch.mul(loss_weights[1],losses[1])
        l3 =  torch.mul(loss_weights[2],losses[2])
        total_losses = l1 + l2 + l3
        return outputs,total_losses

    
    def train_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            optimizer.zero_grad()
            if histogram:
                xh = torch_color_histogram(torch.clone(x_batch))
                xh = xh.to(device)
                xxb = x_batch.to(device)
                xh.requires_grad_(True)
                xxb.requires_grad_(True)
                xb = [xxb,xh]
            else:
                xb = x_batch.to(device)
                xb.requires_grad_(True)
            outputs,total_losses = get_loss(model, xb,y_batch)
            total_losses.backward()
            optimizer.step()
            running_loss += total_losses.item()
            print('curr loss',total_losses.item(), 'step',i,' | ',end='\r')
            count += 1
            with torch.no_grad():
                for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                    accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                    running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        running_f1 = [0,0,0]
        count = 0
        for i, [x_batch, y_batch] in enumerate(validation_loader):
            if histogram:
                xb = add_batch_histogram(x_batch,device=device,grad=False)
            else:
                xb = x_batch.to(device)
            outputs = model(xb)
            outputs,total_losses = get_loss(model,xb,y_batch)
            running_loss += total_losses.item()
            count += 1
            for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                f1, precision, recall = Utils.macro_f1(torch.argmax(ypred.float(),axis=1),format_y(y))
                running_f1[i] += f1.item()
                running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy], [f/count for f in running_f1]
    
    
    best_val_loss = 100000
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train(True)
        avg_loss, avg_acc = train_epoch()
        print('train loss', avg_loss, 'train accuracy', avg_acc)
        model.eval()
        val_loss, val_acc, val_f1 = val_epoch()
        print('val loss', val_loss, 'val accuracy', val_acc, 'val f1',val_f1)
#         torch.save(model.state_dict(), save_path + '_epoch' + str(epoch))
        if best_val_loss > val_loss:
            best_weights = copy.deepcopy(model.state_dict())
            best_val_loss = val_loss
            steps_since_improvement = 0
            if epoch > 1:
                torch.save(model,save_path)
                torch.save(model.state_dict(),save_path+'_states')
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    return model,hist,best_val_loss


# models= [
#     TripletModel(encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnetbase',embedding_dropout=.1)),
#     TripletModel(encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnetbase',embedding_dropout=.5)),
#     TripletModel(encoder=TripletFacenetEncoder(base_model=DenseNet(),base_name='densenetbase',embedding_dropout=.5)),
#     TripletModel(encoder=TripletFacenetEncoder(base_model=MobileNet(),base_name='mobilenetbase',embedding_dropout=.5)),
#     TripletModel(encoder=TripletFacenetEncoder(base_name='facenetbase',embedding_dropout=.5)),
# ]

def get_model(file,model=None):
    if model is None:
        model = torch.load(Constants.model_folder + file).to(torch.device('cpu'))
    model.load_state_dict(torch.load(Constants.model_folder + file + '_states'))
    model.eval()
    return model
pretrain_files = [
    'dualencoder_resnetbase_h400_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_resnetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_densenetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_mobilenetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_dualfacenet_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_resent_flat_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced_pretrain',
    'dualencoder_resnet_flatbias_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2_pretrain'
]
models = [get_model(file) for file in pretrain_files]

res = []
for model in models:
    batch_size=50
    if 'densenet' in model.get_identifier():
        batch_size=5
    m,h,v = train_model(
        model,
        train_labels,
        validation_labels,
        Constants.data_root,
        batch_size=batch_size,
        lr=.0001,
    )
    entry = (m.get_identifier(),v)
    print('________________________________')
    print(entry)
    print('_______________________________')
    res.append(entry)
res

(4869, 5)
(1209, 5)
model being saved to ../../data/models/dualencoder_resnetbase_h400_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced
epoch 0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


train loss 14.800418629938243 train accuracy [0.17446830834509158, 0.46340492671849776, 0.5935875262532916]
val loss 14.452292938232421 val accuracy [0.2039999932050705, 0.508711097240448, 0.6179555368423462] val f1 [0.10830091074109077, 0.33553841829299924, 0.23745435059070588]
saved history to ../../data/results/history_dualencoder_resnetbase_h400_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2.csv
epoch 1
train loss 14.364390801410286 train accuracy [0.20752953129763507, 0.49687431326934267, 0.6096347877565695]
val loss 14.265449829101563 val accuracy [0.19804443895816803, 0.508711097240448, 0.6183110964298248] val f1 [0.11179661691188812, 0.33553841829299924, 0.24367527663707733]
saved history to ../../data/results/history_dualencoder_resnetbase_h400_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2.csv
epoch 2
train loss 14.260106612224968 train accuracy [0.2035016058659067, 0.49687431326934267, 0.6233834453991481]
val loss 14.21980312347412 val accuracy [0.19164443999528885, 0.50

[('dualencoder_resnetbase_h400_ed3triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.458743553161622),
 ('dualencoder_resnetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.504654159545899),
 ('dualencoder_densenetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.629507064819336),
 ('dualencoder_mobilenetbase_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.406404075622559),
 ('dualencoder_dualfacenet_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.069874420166016),
 ('dualencoder_resent_flat_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.221563568115235),
 ('dualencoder_resnet_flatbias_h400_ed5triplet_decoder__st600_a400_g400_std2_ad2_gd2',
  13.33204002380371)]

In [None]:
umodels= [
    TripletModel(encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnetuntrained',embedding_dropout=.1)),
    TripletModel(encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnetuntrained',embedding_dropout=.5)),
    TripletModel(encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnetuntrained',embedding_dropout=.7)),
    TripletModel(encoder=TripletFacenetEncoder(base_model=MobileNet(),base_name='mobilenetuntrained',embedding_dropout=.5)),
    TripletModel(encoder=TripletFacenetEncoder(base_name='facenetuntrained',embedding_dropout=.5)),
]

for model in umodels:
    batch_size=50
    m,h,v = train_model(
        model,
        train_labels,
        validation_labels,
        Constants.data_root,
        batch_size=batch_size,
        lr=.0001,
    )
    entry = (m.get_identifier(),v)
    print('________________________________')
    print(entry)
    print('_______________________________')
    res.append(entry)
res

(4869, 5)
(1209, 5)
model being saved to ../../data/models/dualencoder_resnetuntrained_h400_ed1triplet_decoder__st600_a400_g400_std2_ad2_gd2_rbalanced
epoch 0
train loss 14.396606221490977 train accuracy [0.2007733561022549, 0.4967239398737343, 0.7054242607282133]
val loss 13.776241035461426 val accuracy [0.2383999916911125, 0.6375111031532288, 0.8346666431427002] val f1 [0.12766882598400117, 0.3003450047969818, 0.4127447283267975]
saved history to ../../data/results/history_dualencoder_resnetuntrained_h400_ed1triplet_decoder__st600_a400_g400_std2_ad2_gd2.csv
epoch 1
train loss 13.824594828547264 train accuracy [0.2515144944190979, 0.5931793603361869, 0.7640493907490555]
val loss 13.529690132141113 val accuracy [0.27484443873167036, 0.6495111036300659, 0.8555555272102356] val f1 [0.1499830374121666, 0.3053774571418762, 0.4249933862686157]
saved history to ../../data/results/history_dualencoder_resnetuntrained_h400_ed1triplet_decoder__st600_a400_g400_std2_ad2_gd2.csv
epoch 2
train loss 

In [None]:
import torchvision