# Example Usage

In [3]:
from run import run_cross_val
from models import efficientnet
import pandas as pd
from utils import gen_config
import pprint

ROOT = "./data"
RANDOM_SEED = 0

The `config` dictionary is a parameter that is passed to the `run` or `run_cross_val` functions in order to customize the training routines. It contains various parameters that control the behavior of the training process. By modifying the values in the `config` dictionary, you can adjust the training settings such as the number of epochs, learning rate, batch size, and other hyperparameters. 

In [2]:
# gen_config returns a config with specific settings found
# by bayesian hyperparameter optimization for the efficientnet_b0 model
config = gen_config("efficient_net")
config['epochs'] = 10
pprint.pprint(config)

{'batch_size': 32,
 'brightness': 0.055,
 'contrast': 0.003,
 'crop_size': (224, 224),
 'epochs': 10,
 'h_flip': 0.5,
 'hue': 0.1,
 'image_size': (256, 224),
 'label_smoothing': 0.04,
 'lr': 0.00085,
 'rotation': 10,
 'saturation': 0.075,
 'scale_lower': 0.8,
 'weight_decay': 0.008}


The `run_cross_val` function is a custom function that performs cross-validation for a given model. It takes the following parameters:

- `model_constructor`: A function that returns a model to be trained and evaluated.
- `root`: The root directory of the dataset.
- `config`: The configuration dictionary that contains various training settings.
- `checkpoint_path`: The directory to save the checkpoints.
- `random_seed`: The random seed for reproducibility.

The function performs stratified k-fold cross-validation, where the dataset is split into k subsets (folds). It trains the model on k-1 folds and evaluates it on the remaining fold. This process is repeated k times, with each fold serving as the validation set once.

The function returns two dictionaries: `best_val_results` and `test_results`. 
- `best_val_results` contains the evaluation results on the validation set for each fold using the best checkpoints.
- `test_results` contains the evaluation results on the test set for each fold using the best checkpoints.

In [17]:
# note that we pass a function that returns a model and not a specific instance of said model
best_val_results, test_results = run_cross_val(efficientnet, ROOT, config, False, "./checkpoints/temp", random_seed=RANDOM_SEED)

Random seed set as 0
### Training ###
Batch: 0 - Loss: 1.42 --- Accuracy: 21.88
Batch: 1 - Loss: 1.38 --- Accuracy: 34.38
Batch: 2 - Loss: 1.37 --- Accuracy: 28.12
Batch: 3 - Loss: 1.40 --- Accuracy: 25.00
Batch: 4 - Loss: 1.41 --- Accuracy: 25.00
Batch: 5 - Loss: 1.39 --- Accuracy: 31.25
Batch: 6 - Loss: 1.39 --- Accuracy: 31.25
Batch: 7 - Loss: 1.39 --- Accuracy: 21.88
Batch: 8 - Loss: 1.32 --- Accuracy: 37.50
Batch: 9 - Loss: 1.34 --- Accuracy: 28.12
Batch: 10 - Loss: 1.27 --- Accuracy: 53.12
Batch: 11 - Loss: 1.30 --- Accuracy: 43.75
Batch: 12 - Loss: 1.32 --- Accuracy: 43.75
Batch: 13 - Loss: 1.28 --- Accuracy: 53.12
Batch: 14 - Loss: 1.25 --- Accuracy: 50.00
Batch: 15 - Loss: 1.19 --- Accuracy: 53.12
Batch: 16 - Loss: 1.23 --- Accuracy: 43.75
Batch: 17 - Loss: 1.05 --- Accuracy: 75.00

Epoch 0: [------------------------------]

Train_loss: 1.3175 --- Train_acc: 38.8889

### Validation and Checkpointing ###
Validation_loss: 1.2302 --- Validation_accuracy: 0.4122 --- Validation_mea

In [26]:
def summary(results_dict):
    # Convert test_results dictionary to DataFrame
    df = pd.DataFrame(results_dict)

    # Compute the mean of each column
    mean_values = df.mean()
    std_values = df.std()
    mean_std_df = pd.DataFrame({'Mean': mean_values, 'Std': std_values})
    print(mean_std_df)

print("Summary of Validation Results from Best Checkpoints")
summary(best_val_results)
print("\nSummary of Test Results from Best Checkpoints")
summary(test_results)

# Note that the loss is always the same whether 2, 3, or 4 classes
# as it is always computed using 4 classes.
# This should be fixed in the future

Summary of Validation Results from Best Checkpoints
                            Mean       Std
2_class_loss            0.224053  0.077559
2_class_accuracy        0.970243  0.013147
2_class_mean_precision  0.972010  0.011473
2_class_mean_recall     0.970229  0.013138
3_class_loss            0.224053  0.077559
3_class_accuracy        0.951315  0.018076
3_class_mean_precision  0.921985  0.031233
3_class_mean_recall     0.953737  0.011988
4_class_loss            0.224053  0.077559
4_class_accuracy        0.931026  0.023998
4_class_mean_precision  0.912112  0.030413
4_class_mean_recall     0.922715  0.021221

Summary of Test Results from Best Checkpoints
                            Mean       Std
2_class_loss            0.204134  0.031172
2_class_accuracy        0.978462  0.008427
2_class_mean_precision  0.978783  0.008307
2_class_mean_recall     0.978693  0.008375
3_class_loss            0.204134  0.031172
3_class_accuracy        0.950769  0.011666
3_class_mean_precision  0.912057  0.01517