# Gesture Recognition with CAPG DB-a Dataset Using 3D CNN with EMGNet Architecture (one subject for testing)

In this preliminary effort, we will try to perform hand gesture recognition on CAPG DBA dataset.
We will use the EMGNet architecture and training procedure, but instead of CWT, we will use 3D CNN on sequences of 2D images.

In this version:

- EMG data is normalized with the recorded MVC data
- The **EMGNet** architecture will be used, along with the training procedure.
- A **3D CNN** architecture will be adopted into the EMGNet architecture.
- **Raw EMG data** will be used, there will be no preproccessing or feature engineering.
- **Training data:** 17 subjects
- **Test data:** 1 subject
- K-fold cross-validation will be performed.

**NOTE** This code has been tested with:
```
    numpy version:        1.23.5
    scipy version:        1.9.3
    sklearn version:      1.2.0
    seaborn version:      0.12.1
    pandas version:       1.5.2
    torch version:        1.12.1+cu113
    matplotlib version:   3.6.2
    CUDA version:         11.2
```

## 1- Preliminaries

### Imports

In [None]:
import sys, os
direc = os.getcwd()
print("Current Working Directory is: ", direc)
KUACC = False
if "scratch" in direc: # We are using the cluster
    KUACC = True
    homedir = os.path.expanduser("~")
    os.chdir(os.path.join(homedir,"comp541-project/capg_3dcnn/"))
    direc = os.getcwd()
    print("Current Working Directory is now: ", direc)
sys.path.append("../src/")
sys.path.append("../data/")
import torch
import torch.nn as nn
from datasets_torch import *
from models_torch import *
from utils_torch import *
from datetime import datetime
import pandas as pd
import numpy as np
import scipy as sp
import sklearn
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder, StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, classification_report, confusion_matrix, accuracy_score, f1_score
import matplotlib
import matplotlib.pyplot as plt
from copy import deepcopy
import statistics
import json
from IPython.display import display
#from cwt import calculate_wavelet_vector, calculate_wavelet_dataset

# Print versions
print("numpy version:       ", np.__version__)
print("scipy version:       ", sp.__version__)
print("sklearn version:     ", sklearn.__version__)
print("seaborn version:     ", sns.__version__)
print("pandas version:      ", pd.__version__)
print("torch version:       ", torch.__version__)
print("matplotlib version:  ", matplotlib.__version__)


# Checking to see if CUDA is available for us
print("Checking to see if PyTorch recognizes GPU...")
print(torch.cuda.is_available())

# Whether to use latex rendering in plots throughout the notebook
USE_TEX = KUACC 
FONT_SIZE = 12

# Setting matplotlib plotting variables
if USE_TEX:
    plt.rcParams.update({
        "text.usetex": True,
        "font.size": FONT_SIZE,
        "font.family": "serif",
        "font.serif": "Computer Modern"
    })
else:
    plt.rcParams.update({
        "text.usetex": False,
        "font.size": FONT_SIZE,
        "font.family": "serif",
        "font.serif": "Times New Roman"
    })

# Do not plot figures inline (only useful for cluster)
# %matplotlib

## 2- Hyperparameters and Settings

### General settings of the study

In [None]:
k_fold_study = {
    'code':'capg_3dcnn/capg_dba_v004',
    'package':'torch',
    'dataset':'capg',
    'subdataset':'dba',
    "training_accuracies": [],
    "validation_accuracies": [],
    "testset_accuracies": [],
    "history_training_loss": [],
    "history_training_metrics": [],
    "history_validation_loss": [],
    "history_validation_metrics": [],
    "preprocessing":None,
    "feature_engineering":None,
    "k_fold_mode":"1 subject for testing",
    "global_downsampling":10
}

In [None]:
hparams = {
    "model_name": autoname("capg_3dcnn_dba_v004"),
    # General hyperparameters
    "in_features": 128,
    "out_features": 1,
    "input_shape":[1,64,8,16],
    "output_shape":[8],
    # Sequence hyperparameters
    "in_seq_len_sec": 0.16,
    "out_seq_len_sec": 0,
    "data_sampling_rate_Hz": 1000.0,
    "data_downsampling": 5,
    "sequence_downsampling": 1,
    "in_seq_len": 0,
    "out_seq_len": 0,
    # Convolution blocks
    "num_conv_blocks": 4,
    "conv_dim": 3,
    "conv_params": None,
    "conv_channels": [16, 32, 32, 64],
    "conv_kernel_size": 3,
    "conv_padding": "same",
    "conv_stride": 1,
    "conv_dilation": 1,
    "conv_activation": "ReLU",
    "conv_activation_params": None,#{"negative_slope": 0.1},
    "conv_norm_layer_type": "BatchNorm",
    "conv_norm_layer_position": "before",
    "conv_norm_layer_params": {'momentum':0.99, 'eps':1.0e-8},
    "conv_dropout": None,
    "pool_type": [None, None, None, "AdaptiveAvg"],
    "pool_kernel_size": 2,
    "pool_padding": 0,
    "pool_stride": 1,
    "pool_dilation": 1,
    "pool_params": None,
    "min_image_size": 1,
    "adaptive_pool_output_size": [1,1,1],
    # Fully connected blocks
    "dense_width": "auto",
    "dense_depth": 0,
    "dense_activation": "ReLU",
    "dense_activation_params": None,
    "output_activation": None,
    "output_activation_params": None,
    "dense_norm_layer_type": None,
    "dense_norm_layer_position": None,
    "dense_norm_layer_params": {'momentum':0.99, 'eps':1.0e-8},
    "dense_dropout": None,
    # Training procedure
    "l2_reg": 0.0001,
    "batch_size": 512,
    "epochs": 60,
    "validation_data": [0.05,'trainset'],
    "validation_tolerance_epochs": 1000,
    "learning_rate": 0.01,
    "learning_rate_decay_gamma": 0.9,
    "loss_function": "CrossEntropyLoss",
    "optimizer": "Adam",
    "optimizer_params": None
}

## 3- Data Processing

### Load and concatenate data

In [None]:
data_dir = "../data/CAPG/parquet"
def load_single_capg_dataset(data_dir, db_str:str="dba"):
    data_lst = []
    for i,file in enumerate(os.listdir(data_dir)):
        if file.endswith(".parquet") and db_str in file:
            print("Loading file: ", file)
            data_lst.append(pd.read_parquet(os.path.join(data_dir, file)))
    data = pd.concat(data_lst, axis=0, ignore_index=True)
    return data
dba_tot = load_single_capg_dataset(data_dir, db_str="dba")
dba_mvc = dba_tot.loc[dba_tot["gesture"].isin([100, 101])]
dba = dba_tot.loc[~dba_tot["gesture"].isin([100, 101])]
print("dba_tot shape: ", dba_tot.shape)
print("dba_mvc shape: ", dba_mvc.shape)
print("dba shape: ", dba.shape)
print("Columns: ")
print(dba_tot.columns)
print("Description: ")
print(dba.iloc[:,:3].describe())

### Normalize EMG Data

Here the recorded MVC values will be used for normalizaing EMG data

In [None]:
max_mvc = dba_mvc.iloc[:,3:].max(axis=0)
del dba_mvc
# print("max_mvc for 5 first channels: ")
# print(max_mvc[:5])
# print("shape of max_mvc: ", max_mvc.shape)
# print("max of dba before normalization: (first five)")
# print(dba.iloc[:,3:].max(axis=0)[:5])
dba.iloc[:,3:] = dba.iloc[:,3:].div(max_mvc, axis=1)
# print("max of dba_norm after normalization: ")
# print(dba_norm.iloc[:,3:].max(axis=0)[:5])

## 4- Pre-Training

### EMGNet model

In [None]:

class EMGNet(PyTorchSmartModule):
    def __init__(self, hparams):
        super(EMGNet, self).__init__(hparams)
        self.prep_block = nn.Sequential(
            nn.BatchNorm3d(1),
            nn.ReLU()
        )
        self.main_block = Conv_Network(hparams)
    
    def forward(self, x):
        x = self.prep_block(x)
        x = self.main_block(x)
        return x

### Perform Training for 17 subjects

In [None]:

# Define input columns
input_cols = list(dba.iloc[:,3:].columns)

# Hard-code total number of subjects

k = 18
num_subjects = dba['subject'].nunique()

ds = k_fold_study['global_downsampling']


print("\n#################################################################")
print("Using subject %d for testing ..." % (k))
print("#################################################################\n")
subj_for_testing = [k]

# Un-Correct the output feature count (this is buggy behavior and should be fixed)
hparams['out_features'] = 1

# Get processed data cell
# CWT: N x C x L --> N x C x H x L
print("Generating data cell ...")
data_processed = generate_cell_array(
    dba, hparams,
    subjects_column="subject", conditions_column="gesture", trials_column="trial",
    input_cols=input_cols, output_cols=["gesture"], specific_conditions=None,
    input_preprocessor=None,
    output_preprocessor=None,
    # Convert N x L x C data to N x C x L and then to N x C' x D x H x W where C'=1, D=L, H=8, W=16
    input_postprocessor=lambda arr: arr.reshape(arr.shape[0], 1, arr.shape[1], 8, 16),
    output_postprocessor = lambda arr:(arr-1).squeeze(), # torch CrossEntropyLoss needs (N,) array of 0-indexed class labels
    subjects_for_testing=subj_for_testing, 
    trials_for_testing=None,
    input_scaling=False, output_scaling=False, input_forward_facing=True, output_forward_facing=True, 
    data_squeezed=False,
    input_towards_future=False, output_towards_future=False, 
    output_include_current_timestep=True,
    use_filtered_data=False, #lpcutoff=CUTOFF, lporder=FILT_ORDER, lpsamplfreq=SAMPL_FREQ,
    return_data_arrays_orig=False,
    return_data_arrays_processed=False,
    return_train_val_test_arrays=False,
    return_train_val_test_data=True,
    verbosity=1
)

# Correct the output feature count (this is buggy behavior and should be fixed)
hparams['out_features'] = 8

print("Extracting downsampled input and output data from the datacell ...")
# Inputs MUST have correct shape
x_train = data_processed["x_train"][::ds]
x_val = data_processed["x_val"][::ds]
x_test = data_processed["x_test"][::ds]
# Outputs MUST be zero-indexed class labels
y_train = data_processed["y_train"][::ds]
y_val = data_processed["y_val"][::ds]
y_test = data_processed["y_test"][::ds]
print("x_train shape: ", x_train.shape)
print("x_val shape: ", x_val.shape)
print("x_test shape: ", x_test.shape)
print("y_train shape: ", y_train.shape)
print("y_val shape: ", y_val.shape)
print("y_test shape: ", y_test.shape)
del data_processed
# Make datasets from training, validation and test sets
print("Generating the TensorDataset objects ...")
train_set = TensorDataset(torch.from_numpy(x_train).float(), torch.from_numpy(y_train).long())
val_set = TensorDataset(torch.from_numpy(x_val).float(), torch.from_numpy(y_val).long())
test_set = TensorDataset(torch.from_numpy(x_test).float(), torch.from_numpy(y_test).long())

# If it is the first iteration of the loop, save the hyperparameters dictionary in the k-fold study dictionary
k_fold_study['hparams'] = hparams

# Construct model
print("Constructing the model ...")
hparams['input_shape'] = list(x_train.shape[1:])
hparams['output_shape'] = [8]
print("Model input shape: ", hparams['input_shape'])
print("Model output shape: ", hparams['output_shape'])
model = EMGNet(hparams)
print(model)

# Train model
print("Training the model ...")
# history = train_pytorch_model(
#     model, [train_set, val_set], batch_size=1024, loss_str='crossentropy', optimizer_str='adam', 
#     optimizer_params={'weight_decay':0.0001}, loss_function_params=None, learnrate=0.1, 
#     learnrate_decay_gamma=0.95, epochs=200, validation_patience=1000000, 
#     verbose=1, script_before_save=True, saveto=None, num_workers=0)
history = model.train_model([train_set, val_set], verbose=1)    

# Update relevant fields in the k-fold study dictionary
print("Updating the dictinoary for logging ...")
k_fold_study['history_training_loss'].append(history["training_loss"])
k_fold_study["history_validation_loss"].append(history["validation_loss"])
k_fold_study["history_training_metrics"].append(history["training_metrics"])
k_fold_study["history_validation_metrics"].append(history["validation_metrics"])
k_fold_study["training_accuracies"].append(history["training_metrics"][-1])
k_fold_study["validation_accuracies"].append(history["validation_metrics"][-1])

# Evaluate the model on the test set
print("Evaluating the model on the test set ...")
# results = evaluate_pytorch_model(model, test_set, loss_str='crossentropy', loss_function_params=None,
# batch_size=1024, device_str="cuda", verbose=True, num_workers=0)
results = model.evaluate_model(test_set, verbose=True)
print("test set accuracy before TL: ")
print(results["metrics"])
k_fold_study["testset_accuracies"].append(results["metrics"])
torch.save(model, make_path("../models/"+hparams['model_name']+"/capg_dba_v004_PRE_MODEL.pt"))
print("Done with the pre-training.")

### Saving Pre-Trained Model

In [None]:
print("Dumping the JSON file ...")
json.dump(k_fold_study, open(make_path("../results/"+hparams['model_name']+"/k_fold_study.json"), "w"), indent=4)
print("Saved the JSON file.")

### Saving general statistics

In [None]:
print("Saving the general statistics ...")
trn_acc_arr = np.array(k_fold_study["training_accuracies"])
val_acc_arr = np.array(k_fold_study["validation_accuracies"])
tst_acc_arr = np.array(k_fold_study["testset_accuracies"])
general_dict = {"training_accuracy":trn_acc_arr, "validation_accuracy":val_acc_arr, "testset_accuracy":tst_acc_arr}
general_results = pd.DataFrame(general_dict)
print("Description of general results:")
general_results_describe = general_results.describe()
display(general_results_describe)
general_results_describe.to_csv(
    make_path("../results/"+hparams['model_name']+"/general_results.csv"), header=True, index=True)
print("Saved general statistics.")

### Plotting training histories

In [None]:
# import numpy as np
# import json
# import pandas as pd

In [None]:
# k_fold_study = json.load(open("../results/capg_replica_dba_v002_2023_01_07_20_07_25/k_fold_study.json", "r"))

In [None]:
print("Plotting the taining curve ...")
train_loss = np.array(k_fold_study["history_training_loss"])
val_loss = np.array(k_fold_study["history_validation_loss"])
train_acc = np.array(k_fold_study["history_training_metrics"])
val_acc = np.array(k_fold_study["history_validation_metrics"])

print("Shape of train_loss: ", train_loss.shape)

train_loss_mean = np.mean(train_loss, axis=0)
train_loss_std = np.std(train_loss, axis=0)# / 2
val_loss_mean = np.mean(val_loss, axis=0)
val_loss_std = np.std(val_loss, axis=0)# / 2
train_acc_mean = np.mean(train_acc, axis=0)
train_acc_std = np.std(train_acc, axis=0)# / 2
val_acc_mean = np.mean(val_acc, axis=0)
val_acc_std = np.std(val_acc, axis=0)# / 2

print("Shape of train_loss_mean: ", train_loss_mean.shape)
print("Shape of train_loss_std: ", train_loss_std.shape)

epochs = train_loss_mean.shape[0]
epochs = np.arange(1, epochs+1)
plt.figure(figsize=(8,8), dpi=100)
plt.subplot(2,1,1)
plt.grid(True)
plt.plot(epochs, train_loss_mean, label="Training", color="blue")
plt.fill_between(epochs, train_loss_mean-train_loss_std, train_loss_mean+train_loss_std, 
                 color='blue', alpha=0.2)
plt.plot(epochs, val_loss_mean, label="Validation", color="orange")
plt.fill_between(epochs, val_loss_mean-val_loss_std, val_loss_mean+val_loss_std,
                 color='orange', alpha=0.2)
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.subplot(2,1,2)
plt.grid(True)
plt.plot(epochs, train_acc_mean, color="blue")
plt.fill_between(epochs, train_acc_mean-train_acc_std, train_acc_mean+train_acc_std,
                 color='blue', alpha=0.2)
plt.plot(epochs, val_acc_mean, color="orange")
plt.fill_between(epochs, val_acc_mean-val_acc_std, val_acc_mean+val_acc_std,
                 color='orange', alpha=0.2)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.subplots_adjust(hspace=0.2)
plt.savefig(make_path("../results/"+k_fold_study['hparams']['model_name']+"/training_history.png"), dpi=300)

print("Done plotting the training curve.")
print("ALL DONE. GOOD BYE!")

### Taking out the BN parameters before TL

In [None]:
bn_paras_mean = []
bn_paras_var = []
for layer in model.modules():
   if isinstance(layer,(torch.nn.modules.batchnorm.BatchNorm3d,torch.nn.modules.batchnorm.BatchNorm1d)): 
       bn_paras_mean.append(layer.running_mean.cpu().numpy())
       bn_paras_var.append(layer.running_var.cpu().numpy())

# Here is the TL part

## Freezing the non-BatchNorm layers

In [None]:
# model_2 = EMGNet(hparams)
# x = torch.ones((32,1,64,8,16),dtype=torch.float32, requires_grad=False)*100
# y = model_2(x)
# # model_2.train()
# # model_2.cpu()

# for module in model_2.modules():
#     if isinstance(module,(nn.BatchNorm3d,nn.BatchNorm1d)):
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         print(module.running_mean)
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         # y = model_2(torch.ones((32,1,64,8,16),dtype=torch.float32))
#         # print(module.running_mean)
#         break

# for module in model_2.modules():
#     #print("Next module \n")
    
#     if not isinstance(module, (torch.nn.modules.batchnorm.BatchNorm3d,torch.nn.modules.batchnorm.BatchNorm1d)):
#         for param in module.parameters():
#             param.requires_grad_(False)
#         #module.eval()
#     else:
#         module.reset_running_stats()
#         pass
    
    
# x = torch.ones((32,1,64,8,16),dtype=torch.float32,requires_grad=False)*10000
# y = model_2(x)
# # model_2.train()
# # model_2.cpu()
# for module in model_2.modules():
#     if isinstance(module,(nn.BatchNorm3d,nn.BatchNorm1d)):
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         print(module.running_mean)
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         # y = model_2(torch.ones((32,1,64,8,16),dtype=torch.float32))
#         # print(module.running_mean)
#         break



In [None]:
model.train()
for module in model.modules():
    #print("Next module \n")
    
    if not isinstance(module, (torch.nn.modules.batchnorm.BatchNorm3d,torch.nn.modules.batchnorm.BatchNorm1d)):
        # for param in module.parameters():
        #     param.requires_grad = False
        # #module.eval()
        pass
    else:
        module.reset_running_stats()
        

# for module in model.modules():
#     if isinstance(module,(nn.BatchNorm3d,nn.BatchNorm1d)):
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         print(module.running_mean)
#         print("---------------------------------------------------\n")
#         print("---------------------------------------------------\n")
#         y = module(torch.ones((32,1,64,8,16),dtype=torch.float32))
#         print(module.running_mean)
#         break


### Divide the testing data into training and testing

In [None]:
x_train_TL = x_test[:len(x_test)//2]
y_train_TL = y_test[:len(y_test)//2]

train_set_TL = TensorDataset(torch.from_numpy(x_train_TL).float(), torch.from_numpy(y_train_TL).long())


x_val_TL = x_train_TL[:len(x_train_TL)//10]
y_val_TL = y_train_TL[:len(y_train_TL)//10]

val_set_TL = TensorDataset(torch.from_numpy(x_val_TL).float(), torch.from_numpy(y_val_TL).long())


x_test_TL = x_test[len(x_test)//2:]
y_test_TL = y_test[len(y_test)//2:]

test_set_TL = TensorDataset(torch.from_numpy(x_test_TL).float(), torch.from_numpy(y_test_TL).long())


### Necessary Functions for TL

In [None]:
def _update_metrics_for_batch(
    predictions:torch.Tensor, targets:torch.Tensor, loss_str:str, classification:bool, regression:bool, 
    verbose:int, batch_num:int, epoch:int, metric:float, num_logits:int):
    if loss_str == "BCELoss":
        # Output layer already includes sigmoid.
        class_predictions = (predictions > 0.5).int()
    elif loss_str == "BCEWithLogitsLoss":
        # Output layer does not include sigmoid. Sigmoid is a part of the loss function.
        class_predictions = (torch.sigmoid(predictions) > 0.5).int()
    elif loss_str in ["NLLLoss", "CrossEntropyLoss"]:
        # nll -> Output layer already includes log_softmax.
        # crossentropy -> Output layer has no log_softmax. It's implemented as a part of the loss function.
        class_predictions = torch.argmax(predictions, dim=1)
        if predictions.shape == targets.shape: # Targets are one-hot encoded probabilities
            target_predictions = torch.argmax(targets, dim=1)
        else: # Targets are class indices
            target_predictions = targets

    if classification:
        if verbose>=2 and batch_num==0 and epoch ==0: 
            print("Shape of model outputs:     ", predictions.shape)
            print("Shape of class predictions: ", class_predictions.shape)
            print("Shape of targets:           ", targets.shape)
        # Calculate accuracy
        correct = (class_predictions == target_predictions).int().sum().item()
        num_logits += target_predictions.numel()
        metric += correct
        if verbose==3 and epoch==0: 
            print("Number of correct answers (this batch - total): %10d - %10d"%(correct, metric))
        # Calculate F1 score
        # f1 = f1_score(targets.cpu().numpy(), class_predictions.cpu().numpy(), average="macro")
    elif regression:
        if verbose==3 and batch_num==0 and epoch==0: 
            print("Shape of predictions: ", predictions.shape)
            print("Shape of targets:     ", targets.shape)
        # Calculate r2_score
        metric += r2_score(targets.cpu().numpy(), predictions.cpu().numpy())
    
    return metric, num_logits



def _test_shapes(predictions:torch.Tensor, targets:torch.Tensor, classification:bool):
    if classification:
        assert predictions.shape[0] == targets.shape[0], "Batch size of targets and predictions must be the same.\n"+\
            "Target shape: %s, Prediction shape: %s\n"%(str(targets.shape), str(predictions.shape))
        if len(predictions.shape) == 1:
            assert targets.shape == predictions.shape, "For 1D predictions, the targets must also be 1D.\n"+\
                "Predictions shape: %s, Targets shape: %s\n"%(str(predictions.shape), str(targets.shape))
        if len(predictions.shape) == 2:
            assert len(targets.shape)==1 or targets.shape == predictions.shape, \
                "For 2D predictions, the targets must be 1D class indices are 2D [N x K] one-hot encoded array, with the same shape as the predictions.\n"+\
                "Predictions shape: %s, Targets shape: %s\n"%(str(predictions.shape), str(targets.shape))
        if len(predictions.shape) > 2:
            assert len(predictions.shape)==len(targets.shape) or len(predictions.shape)==len(targets.shape)+1, \
                "Target dimensionality must be equal to or one less than the prediction dimensionality.\n"+\
                "Target shape: %s, Prediction shape: %s\n"%(str(targets.shape), str(predictions.shape))+\
                "If targets are class indices, they must be of shape (N,), or (N, d1, ..., dm). "+\
                "Otherwise, they must be (N, K) or (N, K, d1, ..., dm) arrays of one-hot encoded probabilities. "+\
                "Predictions must in any case be (N, K) or (N, K, d1, ..., dm).\n"+\
                "N is batch size, K is number of classes and d1 to dm are other dimensionalities of classification, if any."
            if len(predictions.shape) == len(targets.shape):
                assert predictions.shape == targets.shape, "If predictions and targets have the same dimensionality, they must be the same shape.\n"+\
                    "Target shape: %s, Prediction shape: %s\n"%(str(targets.shape), str(predictions.shape))
            else:
                assert predictions.shape[2:] == targets.shape[1:], \
                    "If predictions have shape (N, K, d1, ..., dm) then targets must either have the same shape, or (N, d1, ..., dm).\n"+\
                    "Target shape: %s, Prediction shape: %s\n"%(str(targets.shape), str(predictions.shape))
    else:
        assert predictions.shape == targets.shape, \
            "Target shape must be equal to the prediction shape.\n"+\
            "Target shape: %s, Prediction shape: %s\n"%(str(targets.shape), str(predictions.shape))




def _calculate_epoch_loss_and_metrics(
    cumulative_epoch_loss:float, num_batches:int, verbose:int, epoch:int, 
    hist_loss:dict, hist_metric:dict, display_metrics:bool, cumulative_metric:float, metric_denominator:int):
    # Calculate training epoch loss
    epoch_loss = cumulative_epoch_loss / num_batches
    if verbose==3 and epoch==0: print("Epoch loss (training): %.5f"%epoch_loss)
    if hist_loss is not None: hist_loss.append(epoch_loss)
    # Calculate training epoch metric (accuracy or r2-score)
    if display_metrics:
        epoch_metric = cumulative_metric / metric_denominator
        if verbose==3 and epoch==0: print("Epoch metric: %.5f"%epoch_metric)
        if hist_metric is not None: hist_metric.append(epoch_metric)
    return epoch_loss, epoch_metric, hist_loss, hist_metric



def save_pytorch_model(model:torch.nn.Module, saveto:str, dataloader, script_before_save:bool=True, verbose:int=1):
    try:
        if verbose > 0: print("Saving model...")
        if script_before_save:
            example,_ = next(iter(dataloader))
            example = example[0,:].unsqueeze(0)
            model.cpu()
            with torch.no_grad():
                traced = torch.jit.trace(model, example)
                traced.save(saveto)
        else:
            with torch.no_grad():
                torch.save(model, saveto)
    except Exception as e:
        if verbose > 0:
            print(e)
            print("Failed to save the model.")
    if verbose > 0: print("Done Saving.")
    


def train_pytorch_model_TL(model, dataset, batch_size:int, loss_str:str, optimizer_str:str, optimizer_params:dict=None, loss_function_params:dict=None, learnrate:float=0.001, 
    learnrate_decay_gamma:float=None, epochs:int=10, validation_patience:int=10000, validation_data:float=0.1, verbose:int=1, script_before_save:bool=True, saveto:str=None, 
    num_workers=0):
    """Train a Pytorch model, given some hyperparameters.

    ### Args:
        - `model` (`torch.nn`): A torch.nn model
        - `dataset` (`torch.utils.data.Dataset`): Dataset object to be used
        - `batch_size` (int): Batch size
        - `loss_str` (str): Loss function to be used. Examples: "CrossEntropyLoss", "BCELoss", "BCEWithLogitsLoss", "MSELoss", etc.
        - `optimizer_str` (str): Optimizer to be used. Examples: "Adam", "SGD", "RMSprop", etc.
        - `optimizer_params` (dict, optional): Parameters for the optimizer.
        - `loss_function_params` (dict, optional): Parameters for the loss function.
        - `learnrate` (float, optional): Learning rate. Defaults to 0.001.
        - `learnrate_decay_gamma` (float, optional): Learning rate exponential decay rate. Defaults to None.
        - `epochs` (int, optional): Number of epochs. Defaults to 10.
        - `validation_patience` (int, optional): Number of epochs to wait before stopping training. Defaults to 10000.
        - `validation_data` (float, optional): Fraction of the dataset to be used for validation. Defaults to 0.1.
        - `verbose` (int, optional): Logging the progress. Defaults to 1. 0 prints nothing, 2 prints everything.
        - `script_before_save` (bool, optional): Use TorchScript for serializing the model. Defaults to True.
        - `saveto` (str, optional): Save PyTorch model in path. Defaults to None.
        - `num_workers` (int, optional): Number of workers for the dataloader. Defaults to 0.
        
    ### Returns:
        - `model`: Trained PyTorch-compatible model
        - `history`: PyTorch model history dictionary, containing the following keys:
            - `training_loss`: List containing training loss values of epochs.
            - `validation_loss`: List containing validation loss values of epochs.
            - `learning_rate`: List containing learning rate values of epochs.
            - `training_metrics`: List containing training metric values of epochs.
            - `validation_metrics`: List containing validation metric values of epochs.
    """
    
    import torch.optim as optim
    from torch.utils.data import DataLoader, random_split, Dataset, TensorDataset
    SEED = 42
    from timeit import default_timer as timer
    from tqdm import tqdm

    
    # Initialize necessary lists
    hist_training_loss = []
    hist_validation_loss = []
    hist_learning_rate = []
    hist_trn_metric = []
    hist_val_metric = []
    
    # Empty CUDA cache
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    
    # Check if validation data is provided or not, and calculate number of training and validation data
    if isinstance(dataset, (list, tuple)):
        assert len(dataset)==2, "If dataset is a tuple, it must have only two elements, the training dataset and the validation dataset."
        trainset, valset = dataset
        num_val_data = int(len(valset))
        num_train_data = int(len(trainset))
        num_all_data = num_train_data + num_val_data
    else:
        num_all_data = len(dataset)
        num_val_data = int(validation_data*num_all_data)
        num_train_data = num_all_data - num_val_data
        (trainset, valset) = random_split(dataset, (num_train_data, num_val_data), generator=torch.Generator().manual_seed(SEED))

    if verbose > 0:
        print("Total number of data points:      %d"%num_all_data)
        print("Number of training data points:   %d"%num_train_data)
        print("Number of validation data points: %d"%num_val_data)
    
    # Generate training and validation dataloaders    
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    validloader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    
    if verbose > 0:
        print("Number of training batches:    %d"%len(trainloader))
        print("Number of validation batches:  %d"%len(validloader))
        print("Batch size:                    %d"%batch_size)
        for x,y in trainloader:
            print("Shape of training input from the dataloader:  ", x.shape)
            print("Shape of training output from the dataloader: ", y.shape)
            break
        for x,y in validloader:
            print("Shape of validation input from the dataloader:  ", x.shape)
            print("Shape of validation output from the dataloader: ", y.shape)
            break
    
    # Select the device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if verbose > 0: print("Selected device: ", device)
    model.to(device)
    
    # Instantiate the loss function
    loss_func = getattr(nn, loss_str)
    criterion = loss_func(**loss_function_params) if loss_function_params else loss_func()
        
    # Instantiate the optimizer
    optimizer_func = getattr(optim, optimizer_str)
    optimizer = optimizer_func(model.parameters(), lr=learnrate, **optimizer_params) if optimizer_params else optimizer_func(model.parameters(), lr=learnrate)
    
    # Defining learning rate scheduling
    if learnrate_decay_gamma:
        if verbose > 0: print("The learning rate has an exponential decay rate of %.5f."%learnrate_decay_gamma)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=learnrate_decay_gamma)
        lr_sch = True
    else:
        lr_sch = False
    
    # Find out if we will display any metric along with the loss.
    display_metrics = True
    classification = True
    regression = False
    if loss_str in ["BCELoss", "BCEWithLogitsLoss", "CrossEntropyLoss", "NLLLoss", "PoissonNLLLoss", "GaussianNLLLoss"]:
        classification = True
        regression = False
        trn_metric_name = "Acc"
        val_metric_name = "Val Acc"
    elif loss_str in ["MSELoss", "L1Loss", "L2Loss", "HuberLoss", "SmoothL1Loss"]:
        classification = False
        regression = True
        trn_metric_name = "R2"
        val_metric_name = "Val R2"
    else:
        classification = False
        regression = False
        display_metrics = False
    if verbose > 0:
        if classification: print("Classification problem detected. We will look at accuracies.")
        elif regression: print("Regression problem detected. We will look at R2 scores.")
        else: print("We have detected neither classification nor regression problem. No metric will be displayed other than loss.")
    
                    
    # Calculating number of training and validation batches
    num_training_batches = len(trainloader)
    num_validation_batches = len(validloader)
    
    # Preparing progress bar
    progress_bar_size = 40
    ch = "â–ˆ"
    intvl = num_training_batches/progress_bar_size;
    valtol = validation_patience if validation_patience else 100000000
    minvalerr = 10000000000.0
    badvalcount = 0
    
    # Commencing training loop
    tStart = timer()
    loop = tqdm(range(epochs), desc='Training Progress', ncols=100) if verbose==1 else range(epochs)
    for epoch in loop:
        
        # Initialize per-epoch variables
        tEpochStart = timer()
        epoch_loss_training = 0.0
        epoch_loss_validation = 0.0
        newnum = 0
        oldnum = 0
        trn_metric = 0.0
        val_metric = 0.0
        num_train_logits = 0
        num_val_logits = 0
    
        if verbose>=2 and epoch > 0: print("Epoch %3d/%3d ["%(epoch+1, epochs), end="")
        if verbose==3 and epoch ==0: print("First epoch ...")
        
        ##########################################################################
        # Training
        if verbose==3 and epoch==0: print("\nTraining phase ...")
        model.train()
        for i, data in enumerate(trainloader):
            # Fetch data
            seqs, targets = data[0].to(device), data[1].to(device)
            # Forward propagation
            predictions = model(seqs)
            # for module in model.modules():
            #     if isinstance(module,(nn.BatchNorm3d,nn.BatchNorm1d)):
            #         print("---------------------------------------------------\n")
            #         print("---------------------------------------------------\n")
            #         print(module.running_mean)
            #         print("---------------------------------------------------\n")
            #         print("---------------------------------------------------\n")
            #         break
            # Test shapes
            _test_shapes(predictions, targets, classification)
            # Loss calculation and accumulation
            loss = criterion(predictions, targets)
            epoch_loss_training += loss.item()
            # Backpropagation and optimizer update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # Metrics calculation
            if display_metrics:
                with torch.no_grad():
                    trn_metric, num_train_logits = _update_metrics_for_batch(
                        predictions, targets, loss_str, classification, regression, verbose, i, epoch, trn_metric, num_train_logits)
                    
            # Visualization of progressbar within the batch
            if verbose>=2 and epoch > 0:
                newnum = int(i/intvl)
                if newnum > oldnum:
                    print((newnum-oldnum)*ch, end="")
                    oldnum = newnum 
        
        # Update learning rate if necessary
        if lr_sch: scheduler.step()
        
        # Calculate epoch loss and metrics
        epoch_loss_training, trn_metric, hist_training_loss, hist_trn_metric = _calculate_epoch_loss_and_metrics(epoch_loss_training, num_training_batches, verbose, epoch, 
            hist_training_loss, hist_trn_metric, display_metrics, trn_metric, (num_train_logits if classification else num_training_batches))
            
        if verbose>=2 and epoch > 0: print("] ", end="")
        
        # ##########################################################################
        # # Validation
        # if verbose==3 and epoch==0: print("\nValidation phase ...")
        # # model.eval()
        # with torch.no_grad():
        #     for i, data in enumerate(validloader):
        #         seqs, targets = data[0].to(device), data[1].to(device)
        #         predictions = model(seqs)
        #         loss = criterion(predictions, targets)
        #         epoch_loss_validation += loss.item()
        #         # Do prediction for metrics
        #         if display_metrics:
        #             val_metric, num_val_logits = _update_metrics_for_batch(
        #                 predictions, targets, loss_str, classification, regression, verbose, i, epoch, val_metric, num_val_logits)
        # # Calculate epoch loss and metrics
        # epoch_loss_validation, val_metric, hist_validation_loss, hist_val_metric = _calculate_epoch_loss_and_metrics(epoch_loss_validation, num_validation_batches, verbose, epoch, 
        #     hist_validation_loss, hist_val_metric, display_metrics, val_metric, (num_val_logits if classification else num_validation_batches))
        
        # Log the learning rate, if there is any scheduling.
        # if lr_sch: hist_learning_rate.append(scheduler.get_last_lr()[0])
        # else: hist_learning_rate.append(learnrate)
        
        ##########################################################################
        # Post Processing Training Loop            
        tEpochEnd = timer()
        if verbose>=2:
            if display_metrics:
                print("Loss: %5.4f |%s: %5.4f " % (
                    epoch_loss_training, trn_metric_name, trn_metric))
            else:
                print("Loss: %5.4f" % (epoch_loss_training))
        
        # # Checking for early stopping
        # if epoch_loss_validation < minvalerr:
        #     minvalerr = epoch_loss_validation
        #     badvalcount = 0
        # else:
        #     badvalcount += 1
        #     if badvalcount > valtol:
        #         if verbose > 0:
        #             print("Validation loss not improved for more than %d epochs."%badvalcount)
        #             print("Early stopping criterion with validation loss has been reached. " + 
        #                 "Stopping training at %d epochs..."%epoch)
    #     #         break
    # # End for loop
    # # model.eval()
    # ##########################################################################
    # # Epilogue
    # tFinish = timer()
    # if verbose > 0:        
    #     print('Finished Training.')
    #     print("Training process took %.2f seconds."%(tFinish-tStart))
    # # if saveto:
    # #    save_pytorch_model(model, saveto, trainloader, script_before_save, verbose)
    # # Clear CUDA cache    
    # if torch.cuda.is_available(): torch.cuda.empty_cache()
    # # Generate output dictionaries
    # history = {
    #     'training_loss':hist_training_loss}
    # if display_metrics:
    #     history["training_metrics"] = hist_trn_metric
    # if verbose > 0: print("Done training.")
    
    #return history

### TL Training 

In [None]:
train_pytorch_model_TL(model, [train_set_TL,val_set_TL], batch_size=hparams["batch_size"], loss_str=hparams["loss_function"], optimizer_str=hparams["optimizer"], 
    optimizer_params=None, loss_function_params=None, learnrate=hparams["learning_rate"], 
    learnrate_decay_gamma=hparams["learning_rate_decay_gamma"], epochs=1, validation_patience=10000, validation_data=0.1, verbose=1, script_before_save=False, saveto=None, 
    num_workers=0)

### TL Testing

In [None]:
print("Evaluating the model on the test set ...")
# results = evaluate_pytorch_model(model, test_set, loss_str='crossentropy', loss_function_params=None,
# batch_size=1024, device_str="cuda", verbose=True, num_workers=0)
results = model.evaluate_model(test_set_TL, verbose=True)
print("test set accuracy after TL: ")
print(results["metrics"])

### Taking out the BN parameters after TL

In [None]:
bn_paras_mean_TL = []
bn_paras_var_TL = []
for layer in model.modules():
   if isinstance(layer,(torch.nn.modules.batchnorm.BatchNorm3d,torch.nn.modules.batchnorm.BatchNorm1d)): 
       bn_paras_mean_TL.append(layer.running_mean.cpu().numpy())
       bn_paras_var_TL.append(layer.running_var.cpu().numpy())

In [None]:
print(bn_paras_mean == bn_paras_mean_TL)
print(bn_paras_var == bn_paras_var_TL)

In [None]:
idx = 0
print(bn_paras_mean[idx])
print(bn_paras_mean_TL[idx])

In [None]:
idx = 1
print(bn_paras_mean[idx])
print(bn_paras_mean_TL[idx])