# ORACLE PTN Template
This notebook serves as a template for ORACLE PTN experiments  
It can be run on its own by setting STANDALONE to True (do a find for "STANDALONE" to see where)  
But it is intended to be executed as part of a *papermill.py script. See any of the   
experimentes with a papermill script to get started with that workflow.  

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

    
import os, json, sys, time, random
import numpy as np
import torch
from torch.optim import Adam
from  easydict import EasyDict
import matplotlib.pyplot as plt

from steves_models.steves_ptn import Steves_Prototypical_Network

from steves_utils.lazy_iterable_wrapper import Lazy_Iterable_Wrapper
from steves_utils.iterable_aggregator import Iterable_Aggregator
from steves_utils.ptn_train_eval_test_jig import  PTN_Train_Eval_Test_Jig
from steves_utils.torch_sequential_builder import build_sequential
from steves_utils.torch_utils import get_dataset_metrics, ptn_confusion_by_domain_over_dataloader
from steves_utils.utils_v2 import (per_domain_accuracy_from_confusion, get_datasets_base_path)
from steves_utils.PTN.utils import independent_accuracy_assesment

from steves_utils.simple_datasets.ORACLE.episodic_dataset_accessor import get_episodic_dataloaders
from steves_utils.ORACLE.utils_v2 import (
    ALL_DISTANCES_FEET,
    ALL_SERIAL_NUMBERS,
    ALL_RUNS,
    serial_number_to_id
)

from steves_utils.ptn_do_report import (
    get_jig_diagram,
    get_results_table,
    get_parameters_table,
    get_domain_accuracies,
)

# Allowed Parameters
These are allowed parameters, not defaults
Each of these values need to be present in the injected parameters (the notebook will raise an exception if they are not present)

Papermill uses the cell tag "parameters" to inject the real parameters below this cell.
Enable tags to see what I mean

In [None]:
allowed_parameters = {}
allowed_parameters["experiment_name"] = "MANUAL ORACLE PTN"
allowed_parameters["lr"] = 0.001
allowed_parameters["device"] = "cuda"

allowed_parameters["seed"] = 1337
allowed_parameters["desired_classes_source"] = ALL_SERIAL_NUMBERS
allowed_parameters["desired_classes_target"] = ALL_SERIAL_NUMBERS

allowed_parameters["source_domains"] = [38,]
allowed_parameters["target_domains"] = [20,44,
    2,
    8,
    14,
    26,
    32,
    50,
    56,
    62
]

allowed_parameters["num_examples_per_class_per_domain_source"]=100
allowed_parameters["num_examples_per_class_per_domain_target"]=100

allowed_parameters["n_shot"] = 3
allowed_parameters["n_way"]  = len(allowed_parameters["desired_classes_source"])
allowed_parameters["n_query"]  = 2
allowed_parameters["train_k_factor"] = 1
allowed_parameters["val_k_factor"] = 2
allowed_parameters["test_k_factor"] = 2


allowed_parameters["n_epoch"] = 3

allowed_parameters["patience"] = 10
allowed_parameters["criteria_for_best"] = "target"
allowed_parameters["normalize_source"] = False
allowed_parameters["normalize_target"] = False


allowed_parameters["x_net"] =     [
    {"class": "nnReshape", "kargs": {"shape":[-1, 1, 2, 256]}},
    {"class": "Conv2d", "kargs": { "in_channels":1, "out_channels":256, "kernel_size":(1,7), "bias":False, "padding":(0,3), },},
    {"class": "ReLU", "kargs": {"inplace": True}},
    {"class": "BatchNorm2d", "kargs": {"num_features":256}},

    {"class": "Conv2d", "kargs": { "in_channels":256, "out_channels":80, "kernel_size":(2,7), "bias":True, "padding":(0,3), },},
    {"class": "ReLU", "kargs": {"inplace": True}},
    {"class": "BatchNorm2d", "kargs": {"num_features":80}},
    {"class": "Flatten", "kargs": {}},

    {"class": "Linear", "kargs": {"in_features": 80*256, "out_features": 256}}, # 80 units per IQ pair
    {"class": "ReLU", "kargs": {"inplace": True}},
    {"class": "BatchNorm1d", "kargs": {"num_features":256}},

    {"class": "Linear", "kargs": {"in_features": 256, "out_features": 256}},
]

# Parameters relevant to results
# These parameters will basically never need to change
allowed_parameters["NUM_LOGS_PER_EPOCH"] = 10
allowed_parameters["BEST_MODEL_PATH"] = "./best_model.pth"

In [None]:
# Set this to True if you want to run this template directly
STANDALONE = False
if STANDALONE:
    if not 'parameters' in locals() and not 'parameters' in globals():
        print("parameters not injected, running with allowed_parameters!")
        parameters = allowed_parameters

if not 'parameters' in locals() and not 'parameters' in globals():
    raise Exception("Parameter injection failed")

#Use an easy dict for all the parameters
p = EasyDict(parameters)

allowed_keys =set(allowed_parameters.keys())
supplied_keys = set(p.keys())



if  supplied_keys != allowed_keys:
    print("Parameters are incorrect")
    if len(supplied_keys - allowed_keys)>0: print("Shouldn't have:", str(supplied_keys - allowed_keys))
    if len(allowed_keys - supplied_keys)>0: print("Need to have:", str(allowed_keys - supplied_keys))
    raise RuntimeError("Parameters are incorrect")



In [None]:
###################################
# Set the RNGs and make it all deterministic
###################################
np.random.seed(p.seed)
random.seed(p.seed)
torch.manual_seed(p.seed)

torch.use_deterministic_algorithms(True) 

In [None]:
# TODO
# Required since we're pulling in 3rd party code
torch.set_default_dtype(torch.float64)

In [None]:
###################################
# Build the network(s)
# Note: It's critical to do this AFTER setting the RNG
###################################
x_net = build_sequential(p.x_net)

In [None]:
start_time_secs = time.time()

In [None]:
###################################
# Build the dataset
###################################
source_original_train, source_original_val, source_original_test = get_episodic_dataloaders(
    serial_numbers=p.desired_classes_source,
    distances=p.source_domains,
    num_examples_per_distance_per_serial=p.num_examples_per_class_per_domain_source,
    iterator_seed=p.seed,
    n_shot=p.n_shot,
    n_way=p.n_way,
    n_query=p.n_query,
    train_val_test_k_factors=(p.train_k_factor,p.val_k_factor,p.test_k_factor),
    normalize_type=p.normalize_source,
#         pickle_path=os.path.join(get_datasets_base_path(), "oracle.frame_indexed.stratified_ds.2022A.pkl"),
)

target_original_train, target_original_val, target_original_test = get_episodic_dataloaders(
    serial_numbers=p.desired_classes_target,
    distances=p.target_domains,
    num_examples_per_distance_per_serial=p.num_examples_per_class_per_domain_target,
    iterator_seed=p.seed,
    n_shot=p.n_shot,
    n_way=p.n_way,
    n_query=p.n_query,
    train_val_test_k_factors=(p.train_k_factor,p.val_k_factor,p.test_k_factor),
    normalize_type=p.normalize_target,
#         pickle_path=os.path.join(get_datasets_base_path(), "oracle.frame_indexed.stratified_ds.2022A.pkl"),
)


# For CNN We only use X and Y. And we only train on the source.
# Properly form the data using a transform lambda and Lazy_Iterable_Wrapper. Finally wrap them in a dataloader

transform_lambda = lambda ex: ex[1] # Original is (<domain>, <episode>) so we strip down to episode only

source_processed_train = Lazy_Iterable_Wrapper(source_original_train, transform_lambda)
source_processed_val   = Lazy_Iterable_Wrapper(source_original_val, transform_lambda)
source_processed_test  = Lazy_Iterable_Wrapper(source_original_test, transform_lambda)

target_processed_train = Lazy_Iterable_Wrapper(target_original_train, transform_lambda)
target_processed_val   = Lazy_Iterable_Wrapper(target_original_val, transform_lambda)
target_processed_test  = Lazy_Iterable_Wrapper(target_original_test, transform_lambda)

datasets = EasyDict({
    "source": {
        "original": {"train":source_original_train, "val":source_original_val, "test":source_original_test},
        "processed": {"train":source_processed_train, "val":source_processed_val, "test":source_processed_test}
    },
    "target": {
        "original": {"train":target_original_train, "val":target_original_val, "test":target_original_test},
        "processed": {"train":target_processed_train, "val":target_processed_val, "test":target_processed_test}
    },
})

In [None]:
###################################
# Build the model
###################################
model = Steves_Prototypical_Network(x_net, x_shape=(2,256))
optimizer = Adam(params=model.parameters(), lr=p.lr)

In [None]:
###################################
# train
###################################
jig = PTN_Train_Eval_Test_Jig(model, p.BEST_MODEL_PATH, p.device)

jig.train(
    train_iterable=datasets.source.processed.train,
    source_val_iterable=datasets.source.processed.val,
    target_val_iterable=datasets.target.processed.val,
    num_epochs=p.n_epoch,
    num_logs_per_epoch=p.NUM_LOGS_PER_EPOCH,
    patience=p.patience,
    optimizer=optimizer,
    criteria_for_best=p.criteria_for_best,
)

In [None]:
total_experiment_time_secs = time.time() - start_time_secs

In [None]:
###################################
# Evaluate the model
###################################
source_test_label_accuracy, source_test_label_loss = jig.test(datasets.source.processed.test)
target_test_label_accuracy, target_test_label_loss = jig.test(datasets.target.processed.test)

source_val_label_accuracy, source_val_label_loss = jig.test(datasets.source.processed.val)
target_val_label_accuracy, target_val_label_loss = jig.test(datasets.target.processed.val)

history = jig.get_history()

total_epochs_trained = len(history["epoch_indices"])

val_dl = Iterable_Aggregator((datasets.source.original.val,datasets.target.original.val))

confusion = ptn_confusion_by_domain_over_dataloader(model, p.device, val_dl)
per_domain_accuracy = per_domain_accuracy_from_confusion(confusion)

# Add a key to per_domain_accuracy for if it was a source domain
for domain, accuracy in per_domain_accuracy.items():
    per_domain_accuracy[domain] = {
        "accuracy": accuracy,
        "source?": domain in p.source_domains
    }

# Do an independent accuracy assesment JUST TO BE SURE!
_source_test_label_accuracy = independent_accuracy_assesment(model, datasets.source.processed.test)
_target_test_label_accuracy = independent_accuracy_assesment(model, datasets.target.processed.test)
_source_val_label_accuracy = independent_accuracy_assesment(model, datasets.source.processed.val)
_target_val_label_accuracy = independent_accuracy_assesment(model, datasets.target.processed.val)

assert(_source_test_label_accuracy == source_test_label_accuracy)
assert(_target_test_label_accuracy == target_test_label_accuracy)
assert(_source_val_label_accuracy == source_val_label_accuracy)
assert(_target_val_label_accuracy == target_val_label_accuracy)

experiment = {
    "experiment_name": p.experiment_name,
    "parameters": dict(p),
    "results": {
        "source_test_label_accuracy": source_test_label_accuracy,
        "source_test_label_loss": source_test_label_loss,
        "target_test_label_accuracy": target_test_label_accuracy,
        "target_test_label_loss": target_test_label_loss,
        "source_val_label_accuracy": source_val_label_accuracy,
        "source_val_label_loss": source_val_label_loss,
        "target_val_label_accuracy": target_val_label_accuracy,
        "target_val_label_loss": target_val_label_loss,
        "total_epochs_trained": total_epochs_trained,
        "total_experiment_time_secs": total_experiment_time_secs,
        "confusion": confusion,
        "per_domain_accuracy": per_domain_accuracy,
    },
    "history": history,
    "dataset_metrics": get_dataset_metrics(datasets, "ptn"),
}

In [None]:
###################################
# Write out the results
###################################
def write_results(p:EasyDict, experiment)->None:
    with open(p.EXPERIMENT_JSON_PATH, "w") as f:
        json.dump(experiment, f, indent=2)

In [None]:
ax = get_jig_diagram(experiment)
plt.show()

In [None]:
get_results_table(experiment)

In [None]:
get_parameters_table(experiment)


In [None]:
get_domain_accuracies(experiment)

In [None]:
print("Source Test Label Accuracy:", experiment["results"]["source_test_label_accuracy"], "Target Test Label Accuracy:", experiment["results"]["target_test_label_accuracy"])
print("Source Val Label Accuracy:", experiment["results"]["source_val_label_accuracy"], "Target Val Label Accuracy:", experiment["results"]["target_val_label_accuracy"])

In [None]:
json.dumps(experiment)