In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import rotate
import Utils
from Utils import Constants
import cv2
from facenet_pytorch import InceptionResnetV1
from Models import *
from DataLoaders import *

In [3]:
def get_all_labels():
    labels = ['train_data_clean.csv','validation_data_clean.csv','test_data_clean.csv']
    dfs = [pd.read_csv(f) for f in labels]
    df = pd.concat(dfs,axis=0).reset_index().drop('index',axis=1)
    return df
all_labels = get_all_labels()
all_labels

Unnamed: 0,name,skin_tone,gender,age,is_face
0,TRAIN0001.png,0,0,1,False
1,TRAIN0002.png,5,1,0,True
2,TRAIN0005.png,1,1,0,False
3,TRAIN0007.png,1,0,1,True
4,TRAIN0009.png,7,0,1,False
...,...,...,...,...,...
11548,TEST2995.png,2,0,2,True
11549,TEST2996.png,4,0,1,True
11550,TEST2997.png,0,1,1,True
11551,TEST2998.png,3,1,1,True


In [11]:
class UnsupervisedTripletEncoder(BaseModel):
    #model to use as a triplet loss
    #will tak in list of three image batchs
    #returns list of tree embeedidng batchs + predictions on first batch of images
    def __init__(self,
                 base_model = None,
                 feature_extractor = None,
                 hidden_dims = [400],
                 embedding_dropout=.3,
                 base_name='model',
                 fine_tune=False,
                 **kwargs):
        super(UnsupervisedTripletEncoder,self).__init__()
                               
        if base_model is None:
            base_model = InceptionResnetV1(pretrained='vggface2')
            base_name = 'dualfacenet'
        else:
            base_name = base_model.get_identifier()
        
        
        if feature_extractor is None:
            feature_extractor = InceptionResnetV1(pretrained='vggface2')
        for param in feature_extractor.parameters():
            param.requires_grad = fine_tune
        for param in base_model.parameters():
            param.requires_grad = True
            
        self.base_model = base_model
        self.feature_extractor = feature_extractor
        
        self.embedding_dropout = torch.nn.Dropout(p=embedding_dropout)
        curr_dim = base_model.logits.in_features + feature_extractor.logits.in_features
        hidden_layers = []
        
        for i,size in enumerate(hidden_dims):
            layer = torch.nn.Linear(curr_dim, size)
            curr_dim = size
            hidden_layers.append(layer)
            hidden_layers.append(torch.nn.ReLU())
            
        self.hidden_layers = torch.nn.ModuleList(hidden_layers)
        
        self.embedding_size = hidden_dims[-1]
        self.norm = torch.nn.BatchNorm1d(self.embedding_size)
        
        def add_dims(n,dims,prefix):
            for dim in dims:
                n += '_'+prefix+str(dim)
            return n
        
        name_string = 'unsupervised_encoder_' + base_name
        name_string = add_dims(name_string,hidden_dims,'h')
        name_string += '_ed' + str(embedding_dropout).replace('0.','')
        
        
    def forward(self,x):
        xb = self.base_model(x)
        xf = self.feature_extractor(x)
        x = torch.cat((xb,xf),axis=-1)
        x = self.embedding_dropout(x)
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.norm(x)
        return x

In [None]:
def save_train_history(model,history,root=''):
    model_name = model.get_identifier()
    
    df = pd.DataFrame(history)
    df['model'] = model_name
    string = root + 'results/history_' + model_name + '.csv'
    df.to_csv(string,index=False)
    print('saved history to',string)
    return df, string

def train_model(model,
                df,
                root,
                epochs=300,
                lr=.001,
                batch_size=200,
                patience = 20,
                save_path=None,
                histogram =False,
                upsample=True,
                **kwargs,
               ):
    if save_path is None:
        save_path = root + 'models/'+ model.get_identifier()
        if upsample:
            save_path += '_balanced'
    if upsample:
        patience = int(patience/5) + 1
    data_loader = UnsupervisedTripletGenerator(df,Constants.data_root,batch_size=batch_size,upsample=upsample,**kwargs)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.train(True)
    
    triplet_loss = torch.nn.TripletMarginLoss()
    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    format_y = lambda y: y.long()#.to(device)
    
    def format_batch(inputs,grad=True):
        xb = []
        for xin in inputs:
            xin = xin.to(device)
            xin.requires_grad_(grad)
            xb.append(xin)
        return xb
    
    def embedding_step(m,xbatch): 
        base = m(xbatch[0])
        positive = m(xbatch[1])
        negative = m(xbatch[2])
        loss = triplet_loss(base,positive,negative)
        return base,loss
    
    def train_epoch():
        running_loss = 0
        running_accuracy = [0,0,0]
        curr_loss = 0
        count = 0
        for i, [x_batch,y_batch] in enumerate(data_loader):
            x_batch = format_batch(x_batch)
            optimizer.zero_grad()
            embedding,embedding_loss = embedding_step(model, x_batch)
            embedding_loss.backward()
            optimizer.step()
            running_loss += embedding_loss.item()
            print('curr loss', embedding_loss.item(), 'step',i,' | ',end='\r')
            count += 1
        return running_loss/count
    
    def val_epoch():
        print('no validation nerd')
    
    best_val_loss = 100000
    steps_since_improvement = 0
    hist = []
    best_weights = model.state_dict()
    print('model being saved to',save_path)
    for epoch in range(epochs):
        print('epoch',epoch)
        model.train(True)
        avg_loss, avg_acc = train_epoch()
        print('train loss', avg_loss, 'train accuracy', avg_acc)
#         model.train(False)
#         val_loss,val_embed_loss, val_acc = val_epoch()
#         print('val loss', val_loss, 'val_embed_loss', val_embed_loss, 'val accuracy', val_acc)

        #don't save immediately in case I cancel training
        if best_val_loss > avg_loss and epoch > 1:
            torch.save(model,save_path)
            best_weights = model.state_dict()
            best_val_loss = avg_loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
        
        hist_entry = {
            'epoch': epoch,
            'train_loss': avg_loss,
#             'train_acc':avg_acc,
            'val_loss':avg_loss,
#             'val_acc': val_acc,
            'lr': lr,
#             'loss_weights': '_'.join([str(l) for l in loss_weights])
        }
        hist.append(hist_entry)
        save_train_history(model,hist,root=root)
        if steps_since_improvement > patience:
            break
    return model,hist

m,h = train_model(
    UnsupervisedTripletEncoder(),
    all_labels,
    Constants.data_root,
    batch_size=50,
    histogram=False,
    lr=.0001,
)
del m
h

model being saved to ../../data/models/abstractmodel_balanced
epoch 0
curr loss 0.2504828870296478 step 169  | | 