# Active Learning Experiment Notebook

##### TODO
if using autooracle, want 
- visualization of what the oracle is seeing/saying

in general want
- training curves (make sure to maintain a history from initial training)

convergence criterion? 

### Import Statements

In [None]:
#Python Library imports
import random
import torch 
import numpy as np
import glob
import os
import datetime 


#Backend py file imports
from dataloader import get_DataLoader
from disc_model import disc_model

from auto_oracle import query_oracle_automatic
# from manual_oracle import query_oracle

from experiment import save_active_learning_results, remove_bad_oracle_results
from experiment import update_dir_with_oracle_info, redirect_saved_oracle_filepaths_to_thresheld_directory, save_files_for_nnunet

from nnunet_model import convert_2d_image_to_nifti, plan_and_preprocess
import nnunet_model
import unet_model

%matplotlib inline

### Seed all Random Generators

In [None]:
random_seed_number = 44

In [None]:
torch.manual_seed(random_seed_number)
torch.cuda.manual_seed(random_seed_number)
np.random.seed(random_seed_number)
random.seed(random_seed_number)

torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True

### Run ID Setup

This section will set up a run_id to identify each run and a save folder to save data to. 

In [None]:
#run_id has format of "yy_mm_dd", and iter_num is the current run on the day (0,1,2,etc)
run_id = datetime.date.today().strftime("%y_%m_%d")
iter_num = input("which iteration is this of the day? ")

# Where do you want to save all outputs?
output_dir = "/usr/xtmp/jly16/mammoproj/nnunet_integration_tmp/AllOracleRuns"
run_dir = os.path.join(output_dir, "Run_" + run_id, "Iter" + str(iter_num))

In [None]:
#users_name tells us who is working on the notebook (vaibhav/alina/julia)
users_name = input("what is your name: ")
print(f"Your name is: {users_name}.")

## Active Learning Stage

### Set Image Directories

This section sets up directories for the active learning. We will need a directory of images for training the discriminator, a directory that the oracle will query from, and a directory that has the ground truth segmentations.

All files within the directories should be .npy and have shape (2, r, c) where the first channel is the original image, and the second channel is the binarized segmentation. 

In [None]:
discriminator_training_dir = "/usr/xtmp/vs196/mammoproj/Data/final_dataset/train/" 
oracle_query_dir = "/usr/xtmp/vs196/mammoproj/Data/final_dataset/train/" 
ground_truth_dir = "/usr/xtmp/vs196/mammoproj/Data/final_dataset/train/" 

### Initial Discriminator Training

In this section, we initialize the discriminator by training a VGG11 network to discriminate between "good" segmentations (labeled `1`) and "bad" segmentations (labeled `0`). The data for training comes from the `discriminator_training_dir` specified [above](#Set-Image-Directories), but mismatches the (`image`, `segmentation`) pairs for half of each batch. The mismatched pairs get a label of `0`, the not mismatched pairs get a label of `1`.  

Parameters that can be tuned: 
- `batch_size`
- `init_disc_epochs` (number of epochs trained for initializing the discriminator)

Mismatch method can be found in the `disc_model.initialize_model` method. This method may also warrant tweaking.

In [None]:
# Generates dataloader to churn out batches of the images from discriminator_training_dir. 
# Takes in batch_size and num_workers
batch_size = 32 # TUNABLE PARAMETER
dataloader = get_DataLoader(discriminator_training_dir, batch_size, 2)

In [None]:
# instantiate, load, and initialize discriminator model by training N epochs (see disc_model.initialize_model for details)
discriminator = disc_model()
discriminator.load_model(dataloader) 

init_disc_epochs = 10 # TUNABLE PARAMETER
discriminator.initialize_model(batch_size = batch_size, epochs=init_disc_epochs) # initial training


#### Take a look at initial discriminator training performance

In [None]:
discriminator.plot_loss() ## TODO: should this be by epoch instead?... why was it by batch initially? 

In [None]:
discriminator.plot_distribution(discriminator_training_dir)

In [None]:
discriminator.show_disc(discriminator_training_dir)

### Generating initial patient scores

Here, we get an score from the initialized discriminator for all (`image`, `segmentation`) pairs in the `oracle_query_dir`. The score indicates how good/bad the pair (low scores mean the segmentation does not match the image well, high scores mean the segmentation matches the image well). These scores will be used to choose which images to show the oracle. 

Note: it needs to use the same `batch_size` as above

In [None]:
# create a holder for all patient scores (this will be a list of dictionaries)
all_patient_scores = [] 

In [None]:
# Gets the patient scores based on initial trained discriminator model. 
# Patient scores is how "good" the discriminator model thinks the segmentation is
# patient_scores is a dictionary of patient:score which gets appended to the all_patient_scores list

patient_dataloader = get_DataLoader(oracle_query_dir, batch_size, 2)
patient_scores = discriminator.get_scores(patient_dataloader)  

all_patient_scores.append(patient_scores)

### Oracle Querying

This section queries the oracle by selecting some (`image`, `segmentation`) pairs to show the oracle, and then the oracle provides feedback on whether the the segmentation is good (`1`), bad (`0`), or needs a new threshold (for clear over- or under-segmentation). The feedback from the oracle is the used to further train the discriminator. 

The method the system uses to select which pairs to show the oracle is specified using `query_method`. Currently, the method options are `"uniform"`, `"best"`, `"worst"`, `"percentile=0.x"` (where you specify a percentile), `"random"`, `"middle"`. These methods refer to the patient scores generated by the discriminator (eg "best" are the pairs with the highest discriminator score). 

The oracle can either be a human or the computer. When using a human oracle, we will want to use the `query_oracle` (may need some debugging). When using the computer oracle (aka auto-oracle), use `query_oracle_automatic`. Here, we use the auto-oracle. 

A major part of this research project is to see how the AL system reacts to different query methods and number of images queried. 

This process of querying and updating the discriminator can be repeated until the discriminator performs satisfactorily. 

In [None]:
#Initializes oracle results dict and thresholds dict
oracle_results = {}
oracle_results_thresholds = {}

#### chk1

In [None]:
# YOU SHOULD CHOOSE THE QUERY METHOD AND QUERY_NUMBER
#      query_method: how it chooses the images to show. (best, worst, percentile=0.x, uniform)
#      query_number: how many images to query at once

# oracle_results is a dictionary that stores image_name:result. The result is 1 if correct, 0 if bad
# oracle_results_thresholds is the threshold that produced the best segmentation 

_,_ = query_oracle_automatic(oracle_results, oracle_results_thresholds, patient_scores,
                ground_truth_dir, oracle_query_dir,
                query_method="uniform", query_number=10)


### Updating the Discriminator

Once the oracle has been queried, we want to update the discriminator by training it on the (`image`, `segmentation`) pairs and their good(`1`)/bad(`0`) labels. 

You can tune: 
- `update_disc_epochs`: how many epochs we update the discriminator with

In [None]:
#Update the discriminator with data from the oracle for N number of epochs
update_disc_epochs = 1 ## TUNABLE PARAMETER
    
discriminator.update_model(oracle_results,batch_size = batch_size, num_epochs= update_disc_epochs)
patient_scores = discriminator.get_scores(patient_dataloader)
all_patient_scores.append(patient_scores)

#### Visualize discriminator post-update

In [None]:
discriminator.plot_loss()

In [None]:
discriminator.plot_distribution(discriminator_training_dir)

In [None]:
discriminator.show_disc(discriminator_training_dir)

# Go back to CK1 heading if you want to keep querying images

### Deal with some file saving/organization

This section saves some of the results from the discriminator section of our AL. It 
1. saves the pairs identified as good (ie correct segmentations) into the `CorrectSegmentations` subfolder of our `run_dir` (their paths are specified in `saved_oracle_filepaths`).
2. creates a `OracleThresholdedImage_ff` subfolder that has all the same images from the `oracle_query_dir` with rethresholded segmentations based off `oracle_results_thresholds`.
3. `new_saved_oracle_filepaths` are the pairs that were identified as good, but with their paths in the `OracleThresholdedImage_ff` subfolder
4. IF we are using nnunet, then we need to redirect the `new_saved_oracle_filepaths` again, and reprocess to nnunet file/folder format which can be identified using an auto-generated nnunet `task_id` 

In [None]:
# Space for saving oracle results and pickling data structures
saved_oracle_filepaths = save_active_learning_results(
    run_dir, oracle_results, oracle_results_thresholds, oracle_query_dir)

# # not necessary as oracle_results is never even used again in this method.
# oracle_results = remove_bad_oracle_results(oracle_results)

# if no images are classified as correct by oracle, print and return
if len(saved_oracle_filepaths) == 0:
    print("No oracle results classified as correct.")
else:
    print(
        f"Oracle classifies {len(saved_oracle_filepaths)} images as correct.")


In [None]:
segmenter_train_dir = update_dir_with_oracle_info(run_dir, oracle_results_thresholds, oracle_query_dir)

In [None]:
new_saved_oracle_filepaths = redirect_saved_oracle_filepaths_to_thresheld_directory(
    saved_oracle_filepaths, segmenter_train_dir)

In [None]:
unet = False

In [None]:
if not unet: 
    last_task = sorted(glob.glob(os.path.join(os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data','Task*')))[-1]
    last_task = last_task.split('nnUNet_raw_data/Task')[-1][:3]
    task_id = int(last_task) + 1
    save_files_for_nnunet(task_id, run_id, new_saved_oracle_filepaths)

### Update Segmenter

In this section, we update the segmenter (which has been pre-trained off CBIS-DDSM) using the good (`image`, `segmentation`) pairs saved above. 

TUNABLE PARAMETER: 
- `segmenter_update_epochs`: how many epochs to update the segmenter with.

In [None]:
# initialize the model
if unet:
    segmenter = unet_model.unet_model()
    segmenter_train_dir = new_saved_oracle_filepaths
else:
    segmenter = nnunet_model.nnunet_model()
    segmenter_train_dir = os.path.join(os.environ['nnUNet_preprocessed'], f'Task{task_id}_{run_id}')
segmenter.load_model(segmenter_train_dir)

In [None]:
# update the segmenter using the correct image, segmentation pairs
segmenter_update_epochs = 5 ## TUNABLE PARAMETEr
segmenter.update_model(num_epochs = segmenter_update_epochs);

In [None]:
# potentially save model this iteration if we want. # to be used later 
if unet:
    model_save_path = os.path.join(run_dir, "unetmodel.pth")
else:
    model_save_path = os.path.join(run_dir, 'all', "Iter" + str(iter_num)+".model")
segmenter.save_model(model_save_path)

### Visualize Segmentations

In [None]:
if unet: 
    base_dir = "/usr/xtmp/vs196/mammoproj/Data/final_dataset/train/*"
else:
    base_dir = '/usr/xtmp/jly16/mammoproj/data/nnUNet_raw_data_base/nnUNet_raw_data/Task504_duke-mammo/imagesTr'
filepaths = [str(f) for f in np.random.choice(glob.glob(os.path.join(base_dir, '*')), 5)]

print(filepaths)
segmenter.show_segmentations(filepaths)



### Generate Predictions and Validation

In [None]:
# evaluation 1: generate new segmentations of TRAINING images and save them. (This is for the next stage of active learning)

# Dir for segmentations marked correct by the oracle. We do not want to overwrite the old segmentation, so save them here as an archive
correct_save_dir = os.path.join(run_dir, "Segmentations_C" )
# Dir for completely new set of segmentations created by the updated segmenter
save_dir = os.path.join(run_dir,"Segmentations")

if unet: 
    segmentation_folder = discriminator_training_dir
else:
    segmentation_folder = '/usr/xtmp/jly16/mammoproj/data/nnUNet_raw_data_base/nnUNet_raw_data/Task504_duke-mammo/imagesTr'

segmenter.predict(segmentation_folder, save_dir, correct_save_dir = correct_save_dir, saved_oracle_filepaths = saved_oracle_filepaths)   
# Push save_dir as the oracle image dir for the next iteration. That's where we populate with unbinarized segmentations from recently trained UNet


In [None]:
# evaluation 2: generate segmentations of VALIDATION and see how accurate our new segmenter is
if unet:
    valid_input_dir =  f"/usr/xtmp/vs196/mammoproj/Data/manualfa/manual_validation/"
    valid_output_dir = None
else:
    valid_input_dir = os.path.join(
        os.environ['nnUNet_raw_data_base'], 'nnUNet_raw_data', f"Task504_duke-mammo", 'imagesTs')
    valid_output_dir = os.path.join(run_dir, "ValSegmentations")
validation_metric = segmenter.validate(valid_input_dir, valid_output_dir)
print(f"Metric of new segmenter after active learning is: {validation_metric}.")


### Plotting Active Learning Metrics

In [None]:
# #Prints out metrics for all the patient scores from each update.
# for i in all_patient_scores:
#     print(oracle.calculate_dispersion_metric(i,oracle_results))

In [None]:
# #Plot the disperson metric
# j = []
# for i in all_patient_scores:
#     j.append(oracle.calculate_dispersion_metric(i,oracle_results))
    
# plt.plot(j)