In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
from Models import *
from DataLoaders import *

SyntaxError: invalid syntax (DataLoaders.py, line 299)

In [None]:
train_labels = pd.read_csv('train_data_augmented_balanceddualhistogram.csv')
validation_labels = pd.read_csv('validation_data_augmented_balanceddualhistogram.csv')
train_labels

In [None]:
class TripletFacenetModel(FacenetModel):
    #model to use as a triplet loss
    #will tak in list of three image batchs
    #returns list of tree embeedidng batchs + predictions on first batch of images
    def __init__(self,**kwargs):
        super(TripletFacenetModel,self).__init__(**kwargs)
                               
        self.name_string = 'triplet_' + self.name_string
    
    def embed(self,x):
        x = self.base_model(x)
        x = self.embedding_dropout(x)
        for layer in self.hidden_layers:
            x = layer(x)
        return x
        
    def forward(self,x):
        [xx,anchor,bias] = x
        xx = self.embed(xx)
        anchor = self.embed(anchor)
        bias = self.embed(bias)
        
        x_st = self.apply_layers(xx,self.st_layers)
        x_age = self.apply_layers(xx,self.age_layers)
        x_gender = self.apply_layers(xx,self.gender_layers)
        return [xx,anchor,bias], [x_st,x_age,x_gender]
    
TripletFacenetModel()

In [None]:
def categorical_accuracy(ypred,y):
    #y is index, ypred i s one hot like in loss functions
    predicted = torch.argmax(ypred,1).long()
    correct = torch.mean((y.long() == predicted).float())
    return correct

def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=True,
                embedding_loss_weight = 1,
                classification_loss_weight = 1,
                **kwargs,
               ):
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if upsample:
            save_path += '_balanced'
    if upsample:
        patience = int(patience/5) + 1
    train_loader = TripletFaceGenerator(train_df,Constants.data_root,batch_size=batch_size,upsample=upsample,**kwargs)
    validation_loader = TripletFaceGenerator(validation_df,Constants.data_root,validation=True,batch_size=batch_size,upsample=upsample,**kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    loss_fn = torch.nn.CrossEntropyLoss()
    triplet_loss = torch.nn.TripletMarginLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    format_y = lambda y: y.long().to(device)
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def get_classification_loss(outputs,ytrue):
        losses = [loss_fn(ypred.float(),format_y(y)) for y,ypred in zip(ytrue,outputs)]
        l1 = torch.mul(loss_weights[0],losses[0])
        l2 =  torch.mul(loss_weights[1],losses[1])
        l3 =  torch.mul(loss_weights[2],losses[2])
        total_losses = l1 + l2 + l3
        return total_losses

    def get_loss(m,xbatch,ybatch):
        ouputs = m(format_batch(xbatch))
        [e,ep,en,prediction] = outputs
        embed_loss = triplet_loss(e,ep,en)
        closs = classification_loss(prediction,ybatch)
        loss = torch.mul(closs,classificaton_loss_weight) + torch.mul(embed_loss,embedding_loss_weight)
        return prediction,loss
        
    def train_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            optimizer.zero_grad()
            outputs,total_losses = get_loss(model, x_batch,y_batch)
            total_losses.backward()
            optimizer.step()
            running_loss += total_losses.item()
            print('curr loss',total_losses.item(), 'step',i,' | ',end='\r')
            count += 1
            with torch.no_grad():
                for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                    accuracy = categorical_accuracy(ypred.float(),format_y(y))
                    running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        count = 0
        for i, [x_batch, y_batch] in enumerate(validation_loader):
            outputs,total_losses = get_loss(model,x_batch,y_batch)
            running_loss += total_losses.item()
            count += 1
            for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                accuracy = categorical_accuracy(ypred.float(),format_y(y))
                running_accuracy[i] += accuracy.item()
        return running_loss/count, [a/count for a in running_accuracy]
    
    
    best_val_loss = 100000
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train(True)
        avg_loss, avg_acc = train_epoch()
        print('train loss', avg_loss, 'train accuracy', avg_acc)
        model.train(False)
        val_loss, val_acc = val_epoch()
        print('val loss', val_loss, 'val accuracy', val_acc)
        if best_val_loss > val_loss:
            torch.save(model,save_path)
            best_weights = model.state_dict()
            best_val_loss = val_loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    return model,hist

m,h = train_model(
    TripletFacenetModel(),
    train_labels,
    validation_labels,
    Constants.data_root,
    batch_size=100,
    histogram=False,
    lr=.0001,
)
del m
h