In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
import copy
from Models import *
from DataLoaders import *

In [3]:
train_labels = pd.read_csv('train_data_augmented_balanceddual.csv')
validation_labels = pd.read_csv('validation_data_augmented_balanceddual.csv')
train_labels.shape, validation_labels.shape

((6842, 166), (1711, 166))

In [4]:
def get_model(file,model=None):
    if model is None:
        model = torch.load(Constants.model_folder + file).to(torch.device('cpu'))
    model.load_state_dict(torch.load(Constants.model_folder + file + '_states'))
    model.eval()
    return model


In [None]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '3.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=False,
                random_upsample=True,
                softmax_upsample_weights=False,
                upsample_validation=False,
                random_upsample_validation=False,
                embedding_loss_weight = 1,
                classification_loss_weight = 1,
                starting_loss=100000,
                bias_weight_transform=None,
                skintone_regression=False,
                **kwargs,
               ):
    pretraining = classification_loss_weight <= .00001
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if random_upsample:
            
            save_path += '_rbalanced'
            if softmax_upsample_weights:
                save_path += 'soft'
        elif upsample:
            save_path += '_balanced'
        if pretraining:
            save_path += '_pretrain'
    if upsample:
        patience = int(patience/5) + 1
        
    if bias_weight_transform is not None:
        weight_cols = [c for c in train_df.columns if '_bias' in c]
        tdf = train_df.copy()
        vdf = validation_df.copy()
        for col in weight_cols:
            tdf[col] = tdf[col].apply(bias_weight_transform)
            vdf[col] = vdf[col].apply(bias_weight_transform)
        train_df = tdf
        validation_df = vdf
    train_loader = TripletFaceGenerator(train_df,Constants.data_root,
                                 batch_size=batch_size,
                                 upsample=upsample,
                                 random_upsample=random_upsample,
                                 softmax=softmax_upsample_weights,
                                 **kwargs)
    validation_loader = TripletFaceGenerator(validation_df,Constants.data_root,
                                      validation=True,
                                      batch_size=batch_size,
                                      upsample=upsample_validation,
                                     random_upsample=random_upsample_validation,
                                     softmax=softmax_upsample_weights,
                                      **kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    
    cel = torch.nn.CrossEntropyLoss()
    format_y = lambda y: y.to(device)
    regression_loss = torch.nn.MSELoss()
    triplet_loss = torch.nn.TripletMarginLoss()
#     embedding_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=lr)
#     decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    
    
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def embedding_step(m,xbatch): 
        base = m.encoder(xbatch[0])
        positive = m.encoder(xbatch[1])
        negative = m.encoder(xbatch[2])
        loss = triplet_loss(base,positive,negative)
        loss = torch.mul(loss,embedding_loss_weight)
        return base,loss
    
    
    def classifier_step(m,embedding,ytrue):
        outputs = m.decoder(embedding)
        if not skintone_regression:
            losses = [cel(ypred.float(),format_y(y.long())) for y,ypred in zip(ytrue,outputs)]
            l1 = torch.mul(loss_weights[0],losses[0])
            l2 =  torch.mul(loss_weights[1],losses[1])
            l3 =  torch.mul(loss_weights[2],losses[2])
        else:
            l1 = torch.mul( regression_loss(outputs[0].float().view(-1),format_y(ytrue[0].float())), loss_weights[0])
            l2 = torch.mul( cel(outputs[1].float(),format_y(ytrue[1].long())), loss_weights[1])
            l3 = torch.mul(cel(outputs[2].float(),format_y(ytrue[2].long())), loss_weights[2])
        total_losses = l1 + l2 + l3
        total_loss = torch.mul(total_losses,classification_loss_weight)
        return outputs,total_losses
        
    def train_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            x_batch = format_batch(x_batch)
            optimizer.zero_grad()
            embedding,embedding_loss = embedding_step(model, x_batch)
            if pretraining:
                classification_loss = embedding_loss - embedding_loss
            else:
                outputs, classification_loss = classifier_step(model,embedding,y_batch)
            total_loss = classification_loss + embedding_loss
            total_loss.backward()
            optimizer.step()
            running_loss += classification_loss.item()
            running_embed_loss += embedding_loss.item()
            print('curr loss class',classification_loss.item(),'embed', embedding_loss.item(), 'step',i,' | ',end='\r')
            count += 1
            if not pretraining:
                with torch.no_grad():
                    for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        running_accuracy[i] += accuracy.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch(model):
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        running_f1 = [0,0,0]
        count = 0
        with torch.no_grad():
            for i, [x_batch, y_batch] in enumerate(validation_loader):
                x_batch = format_batch(x_batch,grad=False)
                embedding,embedding_loss = embedding_step(model, x_batch)
                if pretraining:
                    classification_loss = embedding_loss - embedding_loss
                else:
                    outputs, classification_loss = classifier_step(model,embedding,y_batch)
                running_loss += classification_loss.item()
                running_embed_loss += embedding_loss.item()
                count += 1
                if not pretraining:
                    for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                        accuracy = Utils.categorical_accuracy(ypred.float(),format_y(y))
                        f1, precision, recall = Utils.macro_f1(torch.argmax(ypred.float(),axis=1),format_y(y))
                        running_accuracy[i] += accuracy.item()
                        running_f1[i] += f1.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy], [f/count for f in running_f1]
    shorten = lambda array: [np.round(a, 3) for a in array]
    
    best_val_loss = starting_loss
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train()
        avg_loss,avg_embed_loss, avg_acc = train_epoch(model)
        print('train loss', avg_loss,'train embed loss',avg_embed_loss, 'train accuracy', shorten(avg_acc))
        model.eval()
        val_loss,val_embed_loss, val_acc, val_f1 = val_epoch(model)
        if pretraining:
            val_loss = val_embed_loss
        print('val loss', val_loss, 'val_embed_loss', val_embed_loss, 
              'val accuracy', shorten(val_acc), 'val f1', shorten(val_f1))
        #don't save immediately in case I cancel training
        if best_val_loss > val_loss and epoch > 1:
            torch.save(model,save_path)
            torch.save(model.state_dict(),save_path+'_states')
            print('saving model')
            best_weights = copy.deepcopy(model.state_dict())
            best_val_loss = val_loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    model.load_state_dict(best_weights)
    return model,hist

    
model = TripletModel2(
    encoder=TripletFacenetEncoder(base_model=ResNet18(),base_name='resnet_flatbias',embedding_dropout=.5),
)

m,h = train_model(
    model,
    train_labels,
    validation_labels,
    Constants.data_root,
    batch_size=50,
    embedding_loss_weight=1,
    classification_loss_weight=.5,
    lr=.0001,
    skintone_regression=True,
)
del m
h


(4869, 167)
(1209, 167)
model being saved to ../../data/models/dualencoder_resnet_flatbias_h400_ed5decodewithrregression__st400_a400_g400_std2_ad2_gd2_rbalanced
epoch 0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  return F.mse_loss(input, target, reduction=self.reduction)


curr loss class 22.432090759277344 embed 1.3817708492279053 step 96  | 

  return F.mse_loss(input, target, reduction=self.reduction)


train loss 21.748472165088263 train embed loss 1.6540638755778878 train accuracy [0.056, 0.42, 0.606]
val loss 22.112924270629883 val_embed_loss 1.2606768488883973 val accuracy [0.068, 0.504, 0.708] val f1 [0.062, 0.229, 0.332]
saved history to ../../data/results/history_dualencoder_resnet_flatbias_h400_ed5decodewithrregression__st400_a400_g400_std2_ad2_gd23.csv
epoch 1


  return F.mse_loss(input, target, reduction=self.reduction)


train loss 21.646439153320934 train embed loss 1.4623081854411535 train accuracy [0.056, 0.5, 0.672]
val loss 22.01333869934082 val_embed_loss 1.3007596397399903 val accuracy [0.068, 0.521, 0.747] val f1 [0.062, 0.22, 0.361]
saved history to ../../data/results/history_dualencoder_resnet_flatbias_h400_ed5decodewithrregression__st400_a400_g400_std2_ad2_gd23.csv
epoch 2
curr loss class 19.13224983215332 embed 1.189424753189087 step 78  | | 