In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
from Models import *
from DataLoaders import *

In [None]:
train_labels = pd.read_csv('train_data_augmented_balanceddualhistogram.csv')
validation_labels = pd.read_csv('validation_data_augmented_balanceddualhistogram.csv')
train_labels

In [None]:
class TripletFacenetEncoder(BaseModel):
    #model to use as a triplet loss
    #will tak in list of three image batchs
    #returns list of tree embeedidng batchs + predictions on first batch of images
    def __init__(self,
                 base_model = None,
                 feature_extractor = None,
                 hidden_dims = [400],
                 embedding_dropout=.3,
                 base_name='model',
                 fine_tune=False,
                 **kwargs):
        super(TripletFacenetEncoder,self).__init__()
                               
        if base_model is None:
            base_model = InceptionResnetV1(pretrained='vggface2')
            base_name = 'dualfacenet'
        else:
            base_name = base_model.get_identifier()
        
        
        if feature_extractor is None:
            feature_extractor = InceptionResnetV1(pretrained='vggface2')
        for param in feature_extractor.parameters():
            param.requires_grad = fine_tune
        for param in base_model.parameters():
            param.requires_grad = True
            
        self.base_model = base_model
        self.feature_extractor = feature_extractor
        
        self.embedding_dropout = torch.nn.Dropout(p=embedding_dropout)
        curr_dim = base_model.logits.in_features + feature_extractor.logits.in_features
        hidden_layers = []
        
        for i,size in enumerate(hidden_dims):
            layer = torch.nn.Linear(curr_dim, size)
            curr_dim = size
            hidden_layers.append(layer)
            hidden_layers.append(torch.nn.ReLU())
            
        self.hidden_layers = torch.nn.ModuleList(hidden_layers)
        
        self.embedding_size = hidden_dims[-1]
        self.norm = torch.nn.BatchNorm1d(self.embedding_size)
        
        def add_dims(n,dims,prefix):
            for dim in dims:
                n += '_'+prefix+str(dim)
            return n
        
        name_string = 'dualencoder_' + base_name
        name_string = add_dims(name_string,hidden_dims,'h')
        name_string += '_ed' + str(embedding_dropout).replace('0.','')
        
        
    def forward(self,x):
        xb = self.base_model(x)
        xf = self.feature_extractor(x)
        x = torch.cat((xb,xf),axis=-1)
        x = self.embedding_dropout(x)
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.norm(x)
        return x
    
class TripletFacenetClassifier(BaseModel):
    #model to use as a triplet loss
    #will tak in list of three image batchs
    #returns list of tree embeedidng batchs + predictions on first batch of images
    def __init__(self,
                 input_dim,
                 st_dims = [600],
                 age_dims = [400],
                 gender_dims = [400],
                 st_dropout = .2,
                 age_dropout = .2,
                 gender_dropout = .2,
                 **kwargs):
        super(TripletFacenetClassifier,self).__init__()
                               
        self.st_layers = self.make_output(input_dim,st_dims,10,st_dropout)
        self.age_layers = self.make_output(input_dim,age_dims,4,age_dropout)
        self.gender_layers = self.make_output(input_dim,gender_dims,2,gender_dropout)
        
        def add_dims(n,dims,prefix):
            for dim in dims:
                n += '_'+prefix+str(dim)
            return n
        
        name_string = 'triplet_decoder_'
        name_string = add_dims(name_string,st_dims,'st')
        
        name_string = add_dims(name_string,age_dims,'a')
        name_string = add_dims(name_string,gender_dims,'g')
        name_string += '_std' + str(st_dropout).replace('0.','')
        name_string += '_ad' + str(age_dropout).replace('0.','')
        name_string += '_gd' + str(gender_dropout).replace('0.','')
        self.name_string = name_string
        
    def embed(self,x):
        x = self.base_model(x)
        x = self.embedding_dropout(x)
        for layer in self.hidden_layers:
            x = layer(x)
        return x
        
    def forward(self,x):
        x_st = self.apply_layers(x,self.st_layers)
        x_age = self.apply_layers(x,self.age_layers)
        x_gender = self.apply_layers(x,self.gender_layers)
        return [x_st,x_age,x_gender]
    
class TripletModel(BaseModel):
    
    def __init__(self,encoder=None,decoder=None):
        super(TripletModel,self).__init__()
        if encoder is None:
            encoder = TripletFacenetEncoder()
        if decoder is None:
            decoder = TripletFacenetClassifier(encoder.embedding_size)
        self.encoder = encoder
        self.decoder = decoder
        self.name_string = encoder.get_identifier() + decoder.get_identifier()
    
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
TripletModel()

In [None]:
def categorical_accuracy(ypred,y):
    #y is index, ypred i s one hot like in loss functions
    predicted = torch.argmax(ypred,1).long()
    correct = torch.mean((y.long() == predicted).float())
    return correct

def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                train_df,
                validation_df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                loss_weights = [2,1,.5],
                save_path=None,
                histogram =False,
                upsample=True,
                embedding_loss_weight = 1,
                classification_loss_weight = 1,
                **kwargs,
               ):
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if upsample:
            save_path += '_balanced'
    if upsample:
        patience = int(patience/5) + 1
    train_loader = TripletFaceGenerator(train_df,Constants.data_root,batch_size=batch_size,upsample=upsample,**kwargs)
    validation_loader = TripletFaceGenerator(validation_df,Constants.data_root,validation=True,batch_size=batch_size,upsample=upsample,**kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    
    loss_fn = torch.nn.CrossEntropyLoss()
    triplet_loss = torch.nn.TripletMarginLoss()
#     embedding_optimizer = torch.optim.Adam(model.encoder.parameters(), lr=lr)
#     decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=lr)
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    format_y = lambda y: y.long().to(device)
    
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def embedding_step(m,xbatch): 
        base = m.encoder(xbatch[0])
        positive = m.encoder(xbatch[1])
        negative = m.encoder(xbatch[2])
        loss = triplet_loss(base,positive,negative)
        loss = torch.mul(loss,embedding_loss_weight)
        return base,loss
    
    
    def classifier_step(m,embedding,ytrue):
        outputs = m.decoder(embedding)
        losses = [loss_fn(ypred.float(),format_y(y)) for y,ypred in zip(ytrue,outputs)]
        l1 = torch.mul(loss_weights[0],losses[0])
        l2 =  torch.mul(loss_weights[1],losses[1])
        l3 =  torch.mul(loss_weights[2],losses[2])
        total_losses = l1 + l2 + l3
        total_loss = torch.mul(total_losses,classification_loss_weight)
        return outputs,total_losses
        
    def train_epoch():
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch, y_batch] in enumerate(train_loader):
            x_batch = format_batch(x_batch)
            optimizer.zero_grad()
            embedding,embedding_loss = embedding_step(model, x_batch)
            outputs, classification_loss = classifier_step(model,embedding,y_batch)
            total_loss = classification_loss + embedding_loss
            total_loss.backward()
            optimizer.step()
            running_loss += classification_loss.item()
            running_embed_loss += embedding_loss.item()
            print('curr loss class',classification_loss.item(),'embed', embedding_loss.item(), 'step',i,' | ',end='\r')
            count += 1
            with torch.no_grad():
                for i,(y,ypred) in enumerate(zip(y_batch,outputs)):
                    accuracy = categorical_accuracy(ypred.float(),format_y(y))
                    running_accuracy[i] += accuracy.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy]
    
    def val_epoch():
        running_loss = 0
        running_embed_loss = 0
        running_accuracy = [0,0,0]
        count = 0
        with torch.no_grad():
            for i, [x_batch, y_batch] in enumerate(validation_loader):
                x_batch = format_batch(x_batch,grad=False)
                embedding,embedding_loss = embedding_step(model, x_batch)
                outputs, classification_loss = classifier_step(model,embedding,y_batch)
                
                running_loss += classification_loss.item()
                running_embed_loss += embedding_loss.item()
                count += 1
                for i,(y,ypred) in enumerate(zip(y_batch, outputs)):
                    accuracy = categorical_accuracy(ypred.float(),format_y(y))
                    running_accuracy[i] += accuracy.item()
        return running_loss/count,running_embed_loss/count, [a/count for a in running_accuracy]
    
    
    best_val_loss = 100000
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train(True)
        avg_loss,avg_embed_loss, avg_acc = train_epoch()
        print('train loss', avg_loss,'train embed loss',avg_embed_loss, 'train accuracy', avg_acc)
        model.train(False)
        val_loss,val_embed_loss, val_acc = val_epoch()
        print('val loss', val_loss, 'val_embed_loss', val_embed_loss, 'val accuracy', val_acc)
        #don't save immediately in case I cancel training
        if best_val_loss > val_loss and epoch > 1:
            torch.save(model,save_path)
            best_weights = model.state_dict()
            best_val_loss = val_loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
            'train_acc':avg_acc,
            'val_loss':val_loss,
            'val_acc': val_acc,
            'lr': lr,
            'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    return model,hist

m,h = train_model(
    TripletModel(),
    train_labels,
    validation_labels,
    Constants.data_root,
    batch_size=50,
    histogram=False,
    lr=.0001,
)
del m
h