In [1]:
import torch
import os
import torchvision
from dataloader import CelebA
from data_manager.manage_csv import *
import pandas as pd
from evaluation_metrics import eval_scores, print_eval_scores
from dataloader import create_dataset_split
import numpy as np
from torch import nn

In [2]:
""" Define batch size, number of epochs for training and dataset paths """

PATH_TO_IMAGES = 'data/img_align_celeba'
PATH_TO_LABELS = 'data/list_attr_celeba.csv'
PATH_TO_MODELS = 'models/'
PATH_TO_VALIDATION_SCORES = 'metadata/validation_scores.csv'
BATCH_SIZE = 128
NUM_EPOCHS = 2
ADAPTIVE_THRESHOLD_RATE = 1.2


In [3]:
""" Load pretrained model and dataset """ 

#model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
dataset = CelebA(PATH_TO_IMAGES, PATH_TO_LABELS, augment=False)
train_data, val_data, test_data = create_dataset_split(dataset=dataset, batch_size=BATCH_SIZE)

In [4]:
VALIDATION_SIZE = len(val_data) * BATCH_SIZE

In [5]:
class MobileNet(nn.Module):
    def __init__(self, classifier):
        super(MobileNet, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2)
        self.model.classifier = classifier
        

    def forward(self, x):
        return self.model(x)


In [6]:
""" Create classifier for the core model """

classifier_1 = torch.nn.Sequential(
    torch.nn.Dropout(0.3),
    torch.nn.Linear(1280, 128),
    torch.nn.BatchNorm1d(128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 64),
    torch.nn.BatchNorm1d(64),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(64, 40)
)


In [7]:
classifier_2 = torch.nn.Sequential(
    torch.nn.Dropout(0.3),
    torch.nn.Linear(1280, 128),
    torch.nn.BatchNorm1d(128),
    torch.nn.ReLU(),
    torch.nn.Dropout(0.2),
    torch.nn.Linear(128, 40)
)

In [8]:
classifier_3 = torch.nn.Sequential(
    torch.nn.Dropout(0.3),
    torch.nn.Linear(1280, 40)
)

In [9]:
model = MobileNet(classifier=classifier_3)

Using cache found in /Users/peterbrezovcsik/.cache/torch/hub/pytorch_vision_v0.10.0


In [10]:
""" Define loss function and optimizer """

loss = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)


In [12]:
""" Define model metadata to store in model_metadata.csv """

VERSION_NUM = 5

MODEL_ID = 'MOBILE_NET_V2' + '_' + str(VERSION_NUM)
NUM_OF_HEAD_LAYERS = len(model.model.classifier) 
LOSS_NAME = loss.__class__.__name__
OPTI_NAME = optimizer.__class__.__name__
LEARNING_RATE = optimizer.defaults['lr']
WEIGHT_DECAY = optimizer.defaults['weight_decay']
CLASSIFICATION_THRESHOLD = 0.5

CLASSIFICATION_THRESHOLD_VECTOR = torch.tensor([CLASSIFICATION_THRESHOLD]*40)


""" Define model metadata file path and header """

MODEL_METADATA_PATH = 'metadata/'
MODEL_METADATA_FILE = 'model_metadata.csv'
MODEL_METADATA_HEADER = ['model_id', 'num_of_head_layers', 'batch_size', 'num_epochs', 'loss_fn', 'optimizer', 'learning_rate', 'threshold', 'weight_decay']
MODEL_ARGS = [MODEL_ID, NUM_OF_HEAD_LAYERS, BATCH_SIZE, NUM_EPOCHS, LOSS_NAME, OPTI_NAME, LEARNING_RATE, CLASSIFICATION_THRESHOLD, WEIGHT_DECAY]


""" Define class_wise_accuracy file path and header """

CLASS_WISE_ACCURACY_FILE = 'class_wise_accuracy.csv'
CLASS_WISE_ACCURACY_HEADER = dataset.attr_names


""" Define validation scores file path and header """

VALIDATION_SCORES_FILE = 'validation_scores.csv' 
VALIDATION_SCORES_HEADER = ['model_id', 'epoch', 'f1_score', 'recall_score', 'precision_score', 'hamming_loss', 'hamming_score', 'partial_accuracy', 'loss']


""" Define threshold values file path and header """

THRESHOLD_VALUES_FILE = 'threshold_values.csv'
THRESHOLD_VALUES_HEADER = dataset.attr_names


In [13]:
def save_model(model_id, current_accuracy):
    df = pd.read_csv(PATH_TO_VALIDATION_SCORES)
    best_accuracy = df.loc[df['model_id'] == model_id]['partial_accuracy'].max()
    if current_accuracy > best_accuracy or best_accuracy is np.nan:
        print('Saving model...')
        model_path = model_id + '.pth'
        path = os.path.join(PATH_TO_MODELS, model_path)
        torch.save(model.state_dict(), path)

In [14]:
def calculate_fp_fn(y_pred, y_true):

    fp = torch.sum((y_pred == 1) & (y_true == 0), axis=0)
    fn = torch.sum((y_pred == 0) & (y_true == 1), axis=0)
    return fp, fn

In [15]:
def optimize_threshold(y_pred, y_true, current_epoch, current_threshold):
    fp, fn = calculate_fp_fn(y_pred, y_true)
    lambda_ = np.power(ADAPTIVE_THRESHOLD_RATE, current_epoch)
    nominator = lambda_ * (fp - fn)
    new_thres = current_threshold + nominator / VALIDATION_SIZE
    return new_thres

    

In [16]:
metric_saver = MetricSaver(MODEL_METADATA_PATH)

metric_saver.create_model_metadata_csv(MODEL_METADATA_FILE, MODEL_METADATA_HEADER)
metric_saver.create_class_wise_acc_csv(CLASS_WISE_ACCURACY_FILE, *CLASS_WISE_ACCURACY_HEADER)
metric_saver.create_validation_scores_csv(VALIDATION_SCORES_FILE, VALIDATION_SCORES_HEADER)
metric_saver.create_threshold_values_csv(THRESHOLD_VALUES_FILE, *THRESHOLD_VALUES_HEADER)

metric_saver.save_model_metadata(*MODEL_ARGS)

model_metadata.csv.csv already exists.
class_wise_accuracy.csv already exists.
validation_scores.csv already exists.
threshold_values.csv already exists.
Saving model metadata...


In [17]:
""" Validation function for the classifier """

def validate(model, val_data, epoch, model_id, loss_fn, threshold):
    model.eval()
    sigmoid = torch.nn.Sigmoid()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    avg_partial_accuracy = avg_f1_score = avg_recall_score = avg_precision_score = avg_hamming_loss = avg_hamming_score = avg_loss = 0.0
    avg_label_wise_accuracy_score = np.zeros(shape=(40))

    NUM_OF_BATCHES = len(val_data)
    
    current_threshold = threshold

    for i, batch in enumerate(val_data):
        with torch.no_grad():
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            loss = loss_fn(outputs, labels)

            outputs = sigmoid(outputs)

            print(current_threshold)
            result = outputs > current_threshold

            current_threshold = optimize_threshold(result, labels, epoch, current_threshold)
           

            f1_score, recall, precision, hamming_loss, ham_score, partial_accuracy, label_wise_accuracy =\
                        eval_scores(labels.cpu().detach().numpy(), result.cpu().detach().numpy(), loss.item(), print_out=True, epoch=epoch+1, batch=i)
           
            ### Increase average scores by the current batch scores
            avg_f1_score += f1_score
            avg_recall_score += recall
            avg_precision_score += precision
            avg_hamming_loss += hamming_loss
            avg_hamming_score += ham_score
            avg_partial_accuracy += partial_accuracy
            avg_label_wise_accuracy_score += label_wise_accuracy
            avg_loss += loss.item()
            
    ### Calculate average scores
    avg_f1_score /= NUM_OF_BATCHES
    avg_recall_score /= NUM_OF_BATCHES
    avg_precision_score /= NUM_OF_BATCHES
    avg_hamming_loss /= NUM_OF_BATCHES
    avg_hamming_score /= NUM_OF_BATCHES
    avg_partial_accuracy /= NUM_OF_BATCHES
    avg_label_wise_accuracy_score /= NUM_OF_BATCHES
    avg_loss /= NUM_OF_BATCHES
    

    ### Save validation scores to csv file
    save_model(model_id, avg_partial_accuracy)
    metric_saver.save_class_wise_accuracy(model_id, epoch+1, *avg_label_wise_accuracy_score)
    metric_saver.save_validation_scores(model_id, epoch+1, *(avg_f1_score, avg_recall_score, avg_precision_score, avg_hamming_loss, avg_hamming_score, avg_partial_accuracy, avg_loss))
    metric_saver.save_threshold_values(model_id, epoch+1, *(current_threshold.detach().numpy()))

    print_eval_scores(avg_f1_score, avg_recall_score, avg_precision_score, avg_hamming_loss, avg_hamming_score, avg_partial_accuracy, avg_label_wise_accuracy_score, avg_loss)
    print('Threshold values: ', current_threshold)
    return current_threshold


In [36]:
""" Test function for the classifier """

def test(model, val_data, model_id, loss_fn, threshold):
    model.eval()
    sigmoid = torch.nn.Sigmoid()
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    avg_partial_accuracy = avg_f1_score = avg_recall_score = avg_precision_score = avg_hamming_loss = avg_hamming_score = avg_loss = 0.0
    avg_label_wise_accuracy_score = np.zeros(shape=(40))

    NUM_OF_BATCHES = len(val_data)
    
    current_threshold = threshold

    for i, batch in enumerate(val_data):
        with torch.no_grad():
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            loss = loss_fn(outputs, labels)

            outputs = sigmoid(outputs)

            result = outputs > current_threshold

            f1_score, recall, precision, hamming_loss, ham_score, partial_accuracy, label_wise_accuracy =\
                        eval_scores(labels.cpu().detach().numpy(), result.cpu().detach().numpy(), loss.item(), print_out=True, epoch='Test', batch=i)
           
            ### Increase average scores by the current batch scores
            avg_f1_score += f1_score
            avg_recall_score += recall
            avg_precision_score += precision
            avg_hamming_loss += hamming_loss
            avg_hamming_score += ham_score
            avg_partial_accuracy += partial_accuracy
            avg_label_wise_accuracy_score += label_wise_accuracy
            avg_loss += loss.item()
            
    ### Calculate average scores
    avg_f1_score /= NUM_OF_BATCHES
    avg_recall_score /= NUM_OF_BATCHES
    avg_precision_score /= NUM_OF_BATCHES
    avg_hamming_loss /= NUM_OF_BATCHES
    avg_hamming_score /= NUM_OF_BATCHES
    avg_partial_accuracy /= NUM_OF_BATCHES
    avg_label_wise_accuracy_score /= NUM_OF_BATCHES
    avg_loss /= NUM_OF_BATCHES
    
    """ SAVE TEST SCORES """
    # TODO: Create a separate function for saving test scores
    #metric_saver.save_class_wise_accuracy(model_id, epoch+1, *avg_label_wise_accuracy_score)
    #metric_saver.save_validation_scores(model_id, epoch+1, *(avg_f1_score, avg_recall_score, avg_precision_score, avg_hamming_loss, avg_hamming_score, avg_partial_accuracy, avg_loss))
    #metric_saver.save_threshold_values(model_id, epoch+1, *(current_threshold.detach().numpy()))

    print_eval_scores(avg_f1_score, avg_recall_score, avg_precision_score, avg_hamming_loss, avg_hamming_score, avg_partial_accuracy, avg_label_wise_accuracy_score, avg_loss)
    print('Threshold values: ', current_threshold)


In [None]:
""" Training function for the classifier """

def fit(model, train_data, val_data, optimizer, loss_fn, epochs, model_id, threshold):
        model.train(mode=True)
        sigmoid = torch.nn.Sigmoid()
        # Froze feature layers
        for param in model.model.features.parameters():
            param.requires_grad = False

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        for epoch in range(epochs):
            for i, batch in enumerate(train_data):
                images, labels = batch
                images = images.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(images)

                loss = loss_fn(outputs, labels)
                loss.backward()
                optimizer.step()

                outputs = sigmoid(outputs)

                result = outputs > threshold
                # Measure model performance for every batch
                f1_score, recall, precision, hamming_loss, ham_score, partial_accuracy, label_wise_accuracy =\
                      eval_scores(labels.cpu().detach().numpy(), result.cpu().detach().numpy(), loss.item(), print_out=True, epoch=epoch+1, batch=i)
            
            # Evaluation metrics for every epoch
                
            threshold = validate(model, val_data, epoch, model_id, loss_fn, threshold)
        test(model, test_data, model_id, loss_fn, threshold)

In [None]:
""" Train the model """

fit(model, train_data, val_data, optimizer, loss, NUM_EPOCHS, MODEL_ID, CLASSIFICATION_THRESHOLD_VECTOR)