In [None]:

import os
import sys
import math
import json
import shutil
import random
import numpy as np
from copy import copy
from collections import defaultdict
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

import torch
torch.multiprocessing.set_sharing_strategy("file_system")
torch.backends.cudnn.deterministic = True

%load_ext autoreload
%autoreload 2

## 1. Training

We will use a `Trainer` class to perform the training and validation experiments. It requires a configuration file where all training and experiment parameters are defined in hierarchical structure. The configuration is usually a `.yaml` file which is converted to hierarchical configurations using [Hydra](https://hydra.cc/). 

#### Training configuration
The default configurations is specified in `config/base.yaml`. To override the default values using class indexing, for instance, to change the `batch_size`, use `train.batch_size=32`.

In [None]:
from omegaconf import OmegaConf

cfg = OmegaConf.create(OmegaConf.load("./configs/base.yaml"))
print(OmegaConf.to_yaml(cfg))

#### Initialization

The initialization of the Trainer class will create a model for training based on the network specified in the `.yaml` file, define loss function prepare an output directory `<cfg.exp.log_dir>/<exp_name>`, to save model checkpoints, and logs training and prediction outputs. If you don't want to log anything, you can set the parameter `train.debug` of the hydra config to `False`. 


In [None]:
from trainer import Trainer

trainer = Trainer(cfg, exp_name='exp1_training')
print(trainer.network)
print("Training parameters:", trainer.num_params)

#### Prepare data for training and evaluation

Next, we obtain the list of shots ids with available labels from the `data.label_dir`.

In [None]:
from dataset.datasets import split_data
from utils.misc import get_files_in_dir

data = [f.split('.')[0] for f in get_files_in_dir(cfg.data.label_dir, file_end='.csv')]

train_shots, test_shots = split_data(data, train_split=cfg.data.train_split)
print(f"{data=}\n{train_shots=}\n{test_shots=}")

#### Train function
Next we call the `train` function of the `Trainer` with training and test shots ID. The `train` function initialize dataloders, create model checkpointer for saving model with best metric and schedulers for adjusting learning rate. Note that, the training will monitor the metric specified in `train.monitor` to save the model checkpoints depending upon the monitor mode e.g. 'min' for 'loss' and 'max' for 'accuracy'.  The trainer utilize EarlyStopping method to stop the training if  metric specified in the `train.early_stopping_metric` doesn't improve for `train.early_stop_patience` steps.  

The trainer will create a model for training based on the network specified in the config file, train the network and save the network states in a output directory `<cfg.exp.log_dir>/<exp_name>`,  where in addition to model checkpoints, logs and any prediction outputs will be saved. If you don't want to log anything, you can set the parameter `train.debug` of the hydra config to `False`.

By default, the model is trained to predict the ELM types classification only. If ELM detection is required (identifying where are the ELM peaks), you need to set the `net.detection=True`, which will train a separate detection head for detections. 

In [None]:
trainer.train(train_shots, test_sets=test_shots)

#### Print metrics

In [None]:
# Plot the test results
print("Classification results:\n", trainer.cls_metric_logger.tabulate_metrics())
if cfg.net.detection:
    print("Detection results:\n", trainer.det_metric_logger.tabulate_metrics())

#### Plot Confusion Matrix

In [None]:
import matplotlib
%matplotlib inline

trainer.cls_metric_logger.plot_confusion_matrix()
if cfg.net.detection:
    trainer.det_metric_logger.plot_confusion_matrix()

#### Plot training history

In [None]:
trainer.plot_history()

#### Plot ELM Detections

In [None]:
from dataset.datasets import ELMDataset
import matplotlib
%matplotlib inline
shot_id = []

h = trainer.evaluate(['30462'], phase='eval')
trainer.plot_preds(phase='eval')

#### Visualize the progress in Tensoboard

One a new terminal, launch tensoboard with log dir set to `./logs`.

```
$ tensorboard --logdir ./logs/ --bind_all
```

## 2. K-Fold Validation

We will use the same steps as in the previous secsion, but this time we will rotate the validation set for K times and run the training in each set. 

In [None]:
from omegaconf import OmegaConf
from utils.misc import get_files_in_dir
from dataset.datasets import split_data
from trainer import Trainer
%load_ext autoreload
%autoreload 2

cfg = OmegaConf.create(OmegaConf.load("./configs/base.yaml"))

n_folds = 5

data = [f.split('.')[0] for f in get_files_in_dir(cfg.data.label_dir, file_end='.csv')]
kfolds_data = split_data(data, n_folds=n_folds)

kfold_cls_results = []
kfold_det_results = []

for i in range(n_folds):
    
    # Create test samples for this fold
    train_shots, test_shots = kfolds_data[i]
    print(f"Fold={i}/{n_folds} \n{train_shots=}\n{test_shots=}")

    # create a trainer and train on each fold 
    kfold_trainer = Trainer(cfg, exp_name=f"exp1_{n_folds}folds_training/fold{i+1}")
    kfold_trainer.train(train_shots, test_sets=test_shots)

    kfold_cls_results.append(kfold_trainer.cls_metric_logger.results)
    if cfg.net.detection:
        kfold_det_results.append(kfold_trainer.det_metric_logger.results)

### Visualize K-Fold Results

In [None]:
import pandas as pd

for k, data in enumerate(kfolds_data):
    print(f"Fold {k+1}: train shots={data[0]} test shots={data[1]}")

kfold_cls_results_df = pd.DataFrame(kfold_cls_results)
print("K-Fold Results: \n", kfold_cls_results_df)
sum_metrics = kfold_cls_results_df[["tp", "fp", "tn", "fn"]].sum()
mean_metrics = kfold_cls_results_df[["accuracy", "precision", "recall", "f1"]].mean()
avg_results = pd.DataFrame([sum_metrics.tolist() + mean_metrics.tolist()], 
                           columns=sum_metrics.index.tolist() + mean_metrics.index.tolist()
                          )
print("Average K-Fold Results: \n", avg_results)

In [None]:
from utils.vis_utils import plot_bar_metrics
import matplotlib
%matplotlib inline

plot_bar_metrics(kfold_cls_results, xticks_label='Fold', figsize=(8, 6))

In [None]:
plot_bar_metrics(kfold_det_results, metric_type='Detection', xticks_label='Fold', figsize=(8, 6))

## 3. Active Learning with Random Sampling
For active learning, we will use the same config `base.yaml`. We will incrementally add the training shots as we go through the AL iterations. Due to limited available labels, we will keep the same test shots for evaluation at each iteration. We can expect increasing accuracy over AL iterationss and higher accuracy than the normal and fold 5 of the K-fold training.  

In [None]:
import random
import numpy as np
from omegaconf import OmegaConf

from utils.misc import get_files_in_dir
from dataset.datasets import split_data
from trainer import Trainer

%load_ext autoreload
%autoreload 2
    
cfg = OmegaConf.create(OmegaConf.load("./configs/base.yaml"))

data = [f.split('.')[0] for f in get_files_in_dir(cfg.data.label_dir, file_end='.csv')]
train_shots, test_shots = split_data(data, train_split=cfg.data.train_split)
print(f"{data=}\n{train_shots=}\n{test_shots=}")

al_trainer = Trainer(cfg, exp_name='active_learning_with_random_sampling')

rng = np.random.default_rng(cfg.rng.seed)

# Active learning Options
INITIAL_LABELS = 5
N_CYCLES = 5
QUERY_BATCH_SIZE = 2

# Split labeled and unlabeled data
labeled_shots = []
unlabeled_shots = train_shots

al_cls_results = []
al_det_results = []
for i in range(N_CYCLES):
    
    print(f"\nActive Learning Cycle {i + 1}")
    
    # Random sample from unlabelled indices
    selected_shots = rng.choice(unlabeled_shots, 
                                size=QUERY_BATCH_SIZE, 
                                replace=False,
                                ).tolist()

    labeled_shots.extend(selected_shots)
    
    unlabeled_shots = [v for v in unlabeled_shots if v not in selected_shots]
    
    print(f"{labeled_shots=}\n{unlabeled_shots=}")

    al_trainer.train(train_sets=labeled_shots,
                     test_sets=test_shots,
                    )
    al_trainer.save_states(ckpt_name=f"model_states_cycle{i}")
    
    al_cls_results.append(al_trainer.cls_metric_logger.results)
    if cfg.net.detection:
        al_det_results.append(al_trainer.det_metric_logger.results)

    if not len(unlabeled_shots)>0:
        break
        

#### Visualize results of Active Learning with Random Sampling

In [None]:
import matplotlib
%matplotlib inline
from utils.vis_utils import plot_bar_metrics
# Visualize classification results
plot_bar_metrics(al_cls_results, metric_type='classification', xticks_label='Iteration')


In [None]:
# Visualize detection results
if len(al_det_results)>0:
    plot_bar_metrics(al_det_results, metric_type='Detection', xticks_label='Iteration')


## 4.Active learning with Uncertainity Sampling

In [None]:
import random
import numpy as np
from copy import copy
from omegaconf import OmegaConf
from scipy.stats import entropy

import torch
from torch.utils.data import DataLoader

from utils.misc import get_files_in_dir
from dataset.datasets import split_data, collate_fn
from dataset.datasets import ELMDataset
from trainer import Trainer

%load_ext autoreload
%autoreload 2

INITIAL_LABELS = 5
QUERY_BATCH_SIZE = 5
N_CYCLES = 5

def compute_entropy(model, dataset, device):
    model.eval()
    
    dataloader = DataLoader(
        dataset, 
        batch_size=1,
        collate_fn=collate_fn,
        )

    cls_uncertainties = []
    det_uncertainities = []

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            cls_preds, elm_preds = model(batch.dalpha)
            cls_probs = torch.softmax(cls_preds, dim=1).cpu().numpy()
            cls_ent = entropy(cls_probs, axis=1)
            cls_uncertainties.append(cls_ent.mean())
            if elm_preds is not None:
                det_probs = torch.softmax(elm_preds, dim=1).cpu().numpy()
                det_ent = entropy(det_probs, axis=1)
                det_uncertainities.append(det_ent.mean())

    
    return cls_uncertainties, det_uncertainities

def al_uncertainity_iterations(cfg, train_shots, test_shots, initial_shots, exp_name='al_uncertainity_sampling'):
    
    # Create the trainer
    trainer = Trainer(cfg, exp_name=exp_name)

    labeled_shots = []
    unlabeled_shots = train_shots
    selected_shots = initial_shots

    cls_results = []
    det_results = []
    for i in range(N_CYCLES):
        
        print(f"\nActive Learning Cycle {i + 1}")
        
        labeled_shots.extend(selected_shots)
        
        unlabeled_shots = [s for s in unlabeled_shots if s not in selected_shots]
    
        print(f"{labeled_shots=}\n{unlabeled_shots=}")
        
        trainer.train(train_shots, test_sets=test_shots,)
            
        # create dataset for unlabelled indices
        unlabeled_dataset = ELMDataset(cfg.data, 
                                       label_files=unlabeled_shots, 
                                       mode='test',
                                      )
    
        cls_uncertainty_scores, det_uncertainity_scores = compute_entropy(
            model=trainer.network, 
            dataset=unlabeled_dataset, 
            device=trainer.device,
            )
    
        # Select most uncertain samples
        query_indices = np.argsort(cls_uncertainty_scores)[-QUERY_BATCH_SIZE:]
    
        # Add the uncertain samples for training
        selected_shots = [unlabeled_shots[idx] for idx in query_indices]
        
        cls_results.append(trainer.cls_metric_logger.results)
        if cfg.net.detection:
            det_results.append(trainer.det_metric_logger.results)
    
        if not len(unlabeled_shots)>0:
            break

    return trainer, cls_results, det_results 

cfg = OmegaConf.create(OmegaConf.load("./configs/base.yaml"))

# Load train/test data
shots = [f.split('.')[0] for f in get_files_in_dir(cfg.data.label_dir, file_end='.csv')]
train_shots, test_shots = split_data(shots, train_split=cfg.data.train_split, seed=cfg.rng.seed)
print(f"{shots=}\n{train_shots=}\n{test_shots=}")

# Initialize labeled and unlabeled samples
# rng = np.random.default_rng(cfg.rng.seed)

# selected_shots = rng.choice(train_shots, 
#                             size=INITIAL_LABELS, 
#                             replace=False,
#                             ).tolist()

initial_shots = ['30418', '30424', '30441', '30449', '30457']

ent_trainer, al_ent_cls_results, al_ent_det_results = al_uncertainity_iterations(
    cfg, 
    train_shots, 
    test_shots, 
    initial_shots,
)

In [None]:
# Visualize classification resultsimport matplotlib 
%matplotlib inline

from utils.vis_utils import plot_bar_metrics

plot_bar_metrics(al_ent_cls_results, metric_type='classification', xticks_label='Iteration')

In [None]:
# Visualize detection results
if len(al_ent_det_results)>0:
    plot_bar_metrics(al_ent_det_results, metric_type='Detection', xticks_label='Iteration')

## 5. Active Learning with K-Fold Validation

In [None]:
from omegaconf import OmegaConf
from utils.misc import get_files_in_dir
from dataset.datasets import split_data
from trainer import Trainer
%load_ext autoreload
%autoreload 2

cfg = OmegaConf.create(OmegaConf.load("./configs/base.yaml"))

INITIAL_LABELS = 5
QUERY_BATCH_SIZE = 5
N_CYCLES = 5

n_folds = 5

data = [f.split('.')[0] for f in get_files_in_dir(cfg.data.label_dir, file_end='.csv')]
kfolds_data = split_data(data, n_folds=n_folds)

kfold_al_cls_results = []
kfold_al_det_results = []

rng = np.random.default_rng(cfg.rng.seed)

for i in range(n_folds):
    
    # Create test samples for this fold
    train_shots, 
    = kfolds_data[i]
    print(f"Fold={i}/{n_folds} \n{train_shots=}\n{test_shots=}")

    initial_shots = rng.choice(train_shots, 
                                size=INITIAL_LABELS, 
                                replace=False,
                                ).tolist()

    _, _kfold_al_cls_results, _kfold_al_det_results = al_uncertainity_iterations(
        cfg, 
        train_shots, 
        test_shots, 
        initial_shots,
    )

    kfold_al_cls_results.append(_kfold_al_cls_results)
    kfold_al_det_results.append(_kfold_al_det_results)

In [None]:
for i, fold_results in enumerate(kfold_al_cls_results):
    plot_bar_metrics(fold_results, metric_type=f'Classification (Fold-{i+1})', xticks_label='Iteration')