***
<font size="6"><center><b>
Multi-Level Routing with Hierarchical Capsule Voting
</b></center></font>
***

# Files and Libraries

In [None]:
# Base Libraries
import argparse
import json
import os
import sys
import math
import random
import csv
import numpy as np # type: ignore
import pandas as pd # type: ignore
import matplotlib # type: ignore
import matplotlib.pyplot as plt # type: ignore
from datetime import datetime
from treelib import Tree # type: ignore
import platform

import tensorflow as tf # type: ignore
from tensorflow import keras # type: ignore
from tensorflow.keras.preprocessing.image import ImageDataGenerator # type: ignore
from tensorflow.keras.models import Sequential  # type: ignore
from tensorflow.keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization  # type: ignore
from tensorflow.keras.layers import Conv2D, MaxPooling2D  # type: ignore
from tensorflow.keras import regularizers, optimizers  # type: ignore
from tensorflow.keras import backend as K  # type: ignore
    ## Tensorflow_docs
import tensorflow_docs as tfdocs # type: ignore
import tensorflow_docs.plots # type: ignore
# Supporting Libraries:
# sys.path.append('../../') ### adding system path for src folder
from src import *

# Auto reload local libraries if updated (for development in jupyter)
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
print(sysenv.systeminfo())

# System Arguments


In [None]:
def parse_tuple(arg_string):
    # Remove parentheses if they are included in the string
    arg_string = arg_string.strip('()')
    
    # Split the string by commas and convert each element to an integer (or another required type)
    return tuple(map(int, arg_string.split(',')))

parser = argparse.ArgumentParser(prog='Multi-Level Routing',
                                 description='Multi-Level Routing with Hierarchical Capsule Voting')
args = argparse.Namespace(description='Multi-Level Routing with Hierarchical Capsule Voting')

# System config:
parser.add_argument('--seed', default=42, type=int, help='random seed')
parser.add_argument('--gpus', default='0', type=str, help='gpu id to use')

# Dataset config:
parser.add_argument('--dataset', default='CIFAR10', type=str, help='train dataset', choices=[
                                                                                        'MNIST',
                                                                                        'E_MNIST',
                                                                                        'F_MNIST',
                                                                                        'CIFAR10',
                                                                                        'CIFAR100',
                                                                                        'S_Cars',
                                                                                        'CU_Birds',
                                                                                        'M_Tree',
                                                                                        'M_Tree_L4',
                                                                                        'M_Tree_L3',
                                                                                        'M_Tree_L2',
                                                                                        'M_Tree_L1',
                                                                                        ])

parser.add_argument('--data_path', default=None, type=str, help='train dataset')
parser.add_argument('--data_normalize', default='StandardScaler', type=str, help='data normalization', choices=['MinMaxScaler', 'StandardScaler', 'None'])
parser.add_argument('--data_aug', default='H_MixUp', type=str, help='data augmentation', choices=[
                                                                                                'None', 'MixUp', 'CutMix', 
                                                                                                'MixupAndCutMix', 'MixupORCutMix',
                                                                                                'H_MixUp', 'H_CutMix', 
                                                                                                'H_MixupAndCutMix', 'H_MixupORCutMix',
                                                                                             ]
                                                                                             )
parser.add_argument('--data_aug_alpha', default=0.2, type=float, help='data augmentation alpha value')
parser.add_argument('--input_size', default=(64, 64, 3), type=parse_tuple, help='input image size')


# model configs:
parser.add_argument('--optimizer', default='adam', type=str, help='optimizer', choices=['adam', 'sgd'])
parser.add_argument('--DefaultLrScheduler', action='store_false', help='Use default learning rate scheduler with optimizer')
parser.add_argument('--initial_lr', default=0.001, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('--lr_decay_rate', default=0.9, type=float, help='learning rate decay factor')
parser.add_argument('--lr_decay_exe', default=10, type=int, help='learning rate decay epoch')

parser.add_argument('--LossType', default='margin', type=str, help='loss function', choices=['margin', 'crossentropy'])
parser.add_argument('--LossWeightType', default='Dynamic', type=str, help='Loss Weight Type', choices=['None', 'Dynamic', 'Static'])

parser.add_argument('--metric', default='accuracy', type=str, help='metric', choices=['accuracy','loss'])

parser.add_argument('--PCaps_dim', default=8, type=int, help='feature dimension')
parser.add_argument('--SCaps_dim', default=16, type=int, help='feature dimension')
parser.add_argument('--SCaps_dim_mode', default='same', type=str, help='Change of feature dimension with hierarchical level. Increase or decrease by factor of 2', choices=['same', 'increase', 'decrease'])
parser.add_argument('--Routing_N', default=2, type=int, help='number of routing iterations')

parser.add_argument('--compile_model', action='store_false', help='compile model')
parser.add_argument('--backbone_net', default='custom', type=str, help='backbone network', choices=['custom',*[model for model in dir(keras.applications) if callable(getattr(keras.applications, model))]])
parser.add_argument('--backbone_net_weights', default=None, type=str, help='backbone network weights path. imagenet or path', metavar='PATH', choices=['imagenet', 'None'])

# FOR HTR_CapsNet
parser.add_argument('--HTR_taxonomy_temperature', default=0.5, type=float, help='HTR_CapsNet: Controls the sharpness of sigmoid function for taxonomy')
parser.add_argument('--HTR_mask_threshold_high', default=0.9, type=float, help='HTR_CapsNet:  Upper bound for mask values')
parser.add_argument('--HTR_mask_threshold_low', default=0.1, type=float, help='HTR_CapsNet: Lower bound for mask values')
parser.add_argument('--HTR_mask_temperature', default=0.5, type=float, help='HTR_CapsNet: Temperature for routing softmax (HIG VALUE - more distributed routing; LOW VALUE - more focused routing)')
parser.add_argument('--HTR_mask_center', default=0.5, type=float, help='HTR_CapsNet: Temperature for routing softmax (HIG VALUE - more distributed routing; LOW VALUE - more focused routing)')
parser.add_argument('--HTR_att_num_heads', default=16, type=int, help='HTR_CapsNet: Number of attention heads')
parser.add_argument('--HTR_att_key_dim', default=32, type=int, help='HTR_CapsNet: Key dimension for attention mechanism')

# training configs: 
parser.add_argument('--epochs', default=100, type=int, metavar='epoch', help='number of total epochs to run')
parser.add_argument('--batch_size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)')

parser.add_argument('--NoEarlyStop', action='store_true', help='Enable early stopping based on R@1')
parser.add_argument('--early_stop_tolerance', default=20, type=int, metavar='N', help='Early stop tolerance; number of epochs to wait before stopping')

parser.add_argument('--mode', default='BUH_CapsNet', type=str, help='training mode',
                    choices=['BUH_CapsNet', 
                             'HD_CapsNet',
                             'ML_CapsNet', 
                             'HDR_CapsNet', 
                             'HD_CapsNet_Eff', 
                             'HTR_CapsNet', 
                             'HD_CapsNet_EM'])

parser.add_argument('--logs', default='logs',type=str, help='log directory file name')
parser.add_argument('--Test_only', action='store_true', help='Test the model only')


# Hyperparameters:
parser.add_argument('--m_plus', default=0.9, type=float, help='Margin Loss')
parser.add_argument('--m_minus', default=0.1, type=float, help='Margin Loss')
parser.add_argument('--lambda_val', default=0.5, type=float, help='Down-weighting of the loss for absent digit classes')

# Directories and Unique ID:
parser.add_argument('--dir_uid', default=None, type=str, help='Unique ID for the experiment')

In [None]:
# This cell is tagged with "parameters" in the Jupyter Notebook
args_dict = {
            'dataset' : 'CIFAR10', # 'MNIST', 'E_MNIST', 'F_MNIST', 'CIFAR10', 'CIFAR100', 'S_Cars', 'CU_Birds', 'M_Tree', 'M_Tree_L4', 'M_Tree_L3', 'M_Tree_L2', 'M_Tree_L1'
            'data_aug' : 'MixUp', # 'None', 'MixUp', 'CutMix', 'MixupAndCutMix', 'MixupORCutMix', 'H_MixUp', 'H_CutMix', 'H_MixupAndCutMix', 'H_MixupORCutMix'
            'mode' : 'HTR_CapsNet', # 'BUH_CapsNet', 'HD_CapsNet', 'ML_CapsNet', 'HDR_CapsNet', 'HD_CapsNet_Eff', 'HTR_CapsNet', 'HD_CapsNet_EM'
            'SCaps_dim' : '32', # '8', '16', '32', '64', '128'
            'SCaps_dim_mode' : 'decrease', # 'same', 'increase', 'decrease'
            'epochs' : '15', # '100', '200', '300', '400', '500'
            'gpu' : '0',
            'input_size' : '(64, 64, 3)', # '(28, 28, 1)', '(32, 32, 3)', '(64, 64, 3)', '(128, 128, 3)', '(224, 224, 3)'
            'data_path' : 'P:\ath\to\dataset\folder', # Path to the dataset folder
            'dir_uid' : 'Test',
            # 'Test_only':'BOOL_FLAG', # 'BOOL_FLAG' Special key for on/off flag in argparse
            'HTR_mask_threshold_high' : '0.99', # '0.9', '0.95', '0.99' (HIGH VALUE - more distributed routing; LOW VALUE - more focused routing)
            'HTR_mask_threshold_low' : '0.1', # '0.1', '0.05', '0.01' (HIGH VALUE - more distributed routing; LOW VALUE - more focused routing)
            'HTR_att_num_heads' : '16', # '4', '8', '16', '32', '64' (Number of attention heads)
            'HTR_att_key_dim' : '32', # '8', '16', '32', '64', '128' (Key dimension for attention mechanism)
            'Routing_N' : '3', # '8', '16', '32', '64', '128' (Key dimension for attention mechanism)
            }

In [None]:
if isinstance(args_dict, str):
    args_dict = json.loads(args_dict) # Convert string to dictionary

args = parser.parse_args([item for key, value in args_dict.items() for item in ((f'--{key}',) if value == 'BOOL_FLAG' else (f'--{key}', str(value)))])

args.__dict__.update({k: None if v == 'None' else v for k, v in args.__dict__.items()})

In [None]:
# GPU Growth: For dynamic GPU memory allocation
sysenv.gpugrowth(gpus = args.gpus).memory_growth()

# log directory

In [None]:
if args.dir_uid is not None and (len(args.dir_uid) > 0) and (args.dir_uid.isspace() == False):
    args.log_dir = sysenv.log_dir([args.dataset,
                                args.data_aug,
                                args.mode,
                                f'backbone-{args.backbone_net}-{args.backbone_net_weights}-{args.dir_uid}'                                
                                ])
else:
    args.log_dir = sysenv.log_dir([args.dataset,
                                args.data_aug,
                                args.mode,
                                f'backbone-{args.backbone_net}-{args.backbone_net_weights}',
                                ])

# Import Dataset

In [None]:
dataset = datasets._load_(args.dataset, args= args,
                          batch_size= args.batch_size,
                          data_normalizing = args.data_normalize,
                          data_augmentation = args.data_aug,
                          data_aug_alpha = args.data_aug_alpha,
                          image_size = args.input_size,
                          )

ds_train = dataset.training
ds_val = dataset.validation
ds_test = dataset.test

In [30]:
### Print Dataset Sample ###
# datasets.print_hierarchical_ds_sample(ds_train, print_batch_size = 2, show_images= True)

# Callbacks

In [31]:
tb = keras.callbacks.TensorBoard(os.path.join(args.log_dir,"tb_logs",datetime.now().strftime("%Y%m%d-%H%M%S")))
CSVLogger = keras.callbacks.CSVLogger(os.path.join(args.log_dir,"log.csv"), append=True)
CallBacks = [tb, CSVLogger]

# Model Architecture

In [None]:
# SELECTING LOSS FUNCTION
if args.LossType == 'margin':
    LossFunction = models.capsnet.MarginLoss(m_plus=args.m_plus, m_minus=args.m_minus, lambda_=args.lambda_val)
elif args.LossType == 'crossentropy':
    LossFunction = keras.losses.CategoricalCrossentropy()
else:
    raise ValueError('Invalid LossType')

# SELECTING LOSS WEIGHTS
lw_modifier = models.dynamic_LW_Modifier(num_classes = dataset.num_classes, directory = args.log_dir)
if args.LossWeightType == 'Dynamic':
    LW_Value = lw_modifier.values
    CallBacks = [*CallBacks, lw_modifier]
elif args.LossWeightType == 'Static':
    LW_Value = lw_modifier.initial_lw
elif args.LossWeightType == 'None':
    LW_Value = None
else:
    raise ValueError('Invalid LossWeightType')


# SELECTING OPTIMIZER
if args.optimizer == 'adam':
    Optimizer = keras.optimizers.Adam()
elif args.optimizer == 'sgd':
    Optimizer = keras.optimizers.SGD()

# SELECTING LEARNING RATE SCHEDULER
if not not(args.DefaultLrScheduler):
    print('Using Learning Rate Scheduler')
    LR_Decay = models.LR_ExponentialDecay(initial_LR = args.initial_lr, start_epoch = args.lr_decay_exe, decay_factor = args.lr_decay_rate)
    CallBacks = [*CallBacks, LR_Decay.get_scheduler_callback()]

In [None]:
def get_compiled_model():
    model =  models.get_model(
                        model_name = args.mode,
                        args = args,
                        input_shape = args.input_size,
                        num_classes = dataset.num_classes,
                        taxonomy = dataset.taxonomy,
            )
    model.compile(
                    optimizer=Optimizer, 
                    loss={k: LossFunction for k in model.output_names},
                    # loss_weights=lw_modifier.values,
                    loss_weights=LW_Value,
                    metrics={k: args.metric for k in model.output_names},
                )
    return model

if (len(args.gpus.split(','))) > 1:
    strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) if platform.system() == 'Windows' else tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = get_compiled_model()
else:
    model = get_compiled_model()
    
keras.utils.plot_model(model, to_file = os.path.join(args.log_dir,"Architecture.png"), show_shapes=True,expand_nested=True);
model.summary()

# Model Training

In [12]:
checkpoint = [keras.callbacks.ModelCheckpoint(os.path.join(args.log_dir,f'epoch-best-{i}.weights.h5'),
                                             monitor=f'val_{i}_accuracy',
                                             save_best_only=True, 
                                             save_weights_only=True, 
                                             verbose=1) for i in model.output_names]
CallBacks = [*CallBacks, *checkpoint]

if not args.NoEarlyStop:
    early_stop = keras.callbacks.EarlyStopping(monitor=f'val_{model.output_names[-1]}_accuracy', 
                                            patience=args.early_stop_tolerance,
                                            mode='max',
                                            restore_best_weights=True)
    CallBacks = [*CallBacks, early_stop]

## Train/Test Model

In [None]:
if args.Test_only:
    try:
        model.load_weights(os.path.join(args.log_dir,f"epoch-best-{model.output_names[-1]}.weights.h5"))
        print(f'Model Weights Loaded Successfully from {os.path.join(args.log_dir,f"epoch-best-{model.output_names[-1]}.weights.h5")}')
    except:
        raise ValueError('Model Weights not found') 
else:
    print('Training the model from scratch | Will overwrite the files in the log directory')
    history = model.fit(ds_train,
                        epochs = args.epochs,
                        validation_data = ds_val,
                        callbacks = CallBacks,
                        verbose=1)
    
    print('Training Completed....')
    print(f'loading the best weights from {os.path.join(args.log_dir,f"epoch-best-{model.output_names[-1]}.weights.h5")}')
    model.load_weights(os.path.join(args.log_dir,f"epoch-best-{model.output_names[-1]}.weights.h5"))
    
    pd.DataFrame(history.history).to_csv(os.path.join(args.log_dir,'training_history.csv'), index=False) # Saving training history

    plotter = tfdocs.plots.HistoryPlotter()
    accuracy_metrics = [metric for metric in history.history.keys() if not metric.startswith('val') and metric.endswith('_accuracy')]
    loss_metrics = [metric for metric in history.history.keys() if not metric.startswith('val') and metric.endswith('_loss')]
    for metric in accuracy_metrics:
        plotter.plot({metric.split('_accuracy')[0].capitalize(): history}, metric=metric)

    # Add a title and limit the y-axis
    plt.title("Model Accuracy")
    plt.ylim([0, 1])
    plt.savefig(os.path.join(args.log_dir,f'Model_Accuracy.png'))
    plt.close()

    for metric in loss_metrics:
        plotter.plot({metric.split('_loss')[0].capitalize(): history}, metric=metric)

    # Add a title and limit the y-axis
    plt.title("Model Loss")
    plt.ylim([0, 1])
    plt.savefig(os.path.join(args.log_dir,f'Model_Loss.png'))
    plt.close()

## Model Analysis

In [None]:
results = model.evaluate(ds_test, verbose=1)
print('\n'.join([f"{n+1}. {model.metrics_names[n]} ==> {results[n]}" for n in range(len(results))]))

In [17]:
x_data, y_true, y_pred = models.predict_from_pipeline(model, ds_test,
                                                      return_images=False)

In [None]:
metrics.lvl_wise_metric(y_true=y_true,y_pred=y_pred,savedir=args.log_dir,show_graph=False,show_report=False)

In [None]:
h_measurements,consistency,exact_match, get_performance_report = metrics.hmeasurements(y_true, y_pred, dataset.label_tree)

get_performance_report = {**{'Dataset': dataset.name,
                             'Model': model.name,
                             'Total Parameters': model.count_params(),
                             'Total Trainable Parameters': sum([tf.keras.backend.count_params(w) for w in model.trainable_weights])},
                             **get_performance_report
                             } # merging two dictionaries (adding Dataset and Model name)
performance_metrics = pd.DataFrame(pd.DataFrame(get_performance_report.values(), get_performance_report.keys(), columns = ['Value']))
performance_metrics.to_csv(os.path.join(args.log_dir,'performance_metrics.csv'))
print(performance_metrics)

In [None]:
with open(os.path.join(args.log_dir,'args.json'), 'w') as fid:
    json.dump(args.__dict__, fid, indent=2)

# END