# Active Learning Experiment Notebook

### Import Statements

In [None]:
#Python Library imports
import numpy as np
import torch
import torchvision
from time import time
import random

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from torch.autograd import Variable
import os
import glob
import cv2
from tqdm import tqdm

from matplotlib.pyplot import imsave, imread
import matplotlib.pyplot as plt
import sys
import matplotlib.gridspec as gridspec

import copy
import pickle

#Backend py file imports
from floodfill import *
from dataloader import *
from model import *
from oracle import *
from unet import *
import ternausnet.models

### Run ID Setup

#run_id has format of "mm_dd_count", where count is the current run on the day (a,b,etc)
run_id = "07_19_a" 

In [None]:
#users_name tells us who is working on the notebook (vaibhav/alina)
users_name = input("what is your name: ")
print(f"Your name is: {users_name}.")

## Active Learning Stage

### Initialization

In [None]:
im_dir = "" #im_dir is the directory where oracle pulls images

In [None]:
dataloader = get_DataLoader(im_dir,32,2) #Generates dataloader from im_dir. Takes in batch_size and num_workers

#Initializes oracle results dict and thresholds dict
oracle_results = {}
oracle_results_thresholds = {}

### Initial Training

In [None]:
model,loss_tracker,criterion,optimizer = initialize_and_train_model(dataloader, epochs=5) #default batch_size and epochs
plt.plot(loss_tracker) #plot graph

In [None]:
#Gets the patient scores based on initial trained model. Patient scores is how "good" the model thinks the segmentation is
all_patient_scores = []

patient_scores = get_patient_scores(model,dataloader)  #patient_scores is a dictionary of patient->score
all_patient_scores.append(patient_scores)

### Oracle Querying

In [None]:
#Queries the oracle. Relevant arguments: query_method and query_number
#Query methods: best, worst, percentile=0.x, uniform.
#Put 1 if correct, 0 if impossible, new threshold if new threshold will help.

new_oracle_results, new_oracle_results_thresholds = query_oracle(oracle_results,oracle_results_thresholds,
                                                                 patient_scores,im_dir,query_method="best",
                                                                 query_number=13)
oracle_results, oracle_results_thresholds = new_oracle_results, new_oracle_results_thresholds

### Updating

In [None]:
#Update patient scores and add to all_patient_scores array
patient_scores = get_patient_scores(model,dataloader)
all_patient_scores.append(patient_scores)

In [None]:
#Update the active learning classifier with learned data
#TODO: track model loss somehow along with patient_scores (KEEP TODO AND ADDRESS)
for i in range(3):
    model = model_update(model,dataloader,oracle_results,criterion,optimizer,num_epochs=1)

    patient_scores = get_patient_scores(model,dataloader)
    all_patient_scores.append(patient_scores)

**Go Back to Oracle Querying Heading if you want to keep querying images**

### Plotting Active Learning Metrics

In [None]:
#Prints out metrics for all the patient scores from each update.
for i in all_patient_scores:
    print(calculate_dispersion_metric(i,oracle_results))

In [None]:
#Plot the disperson metric
j = []
for i in all_patient_scores:
    j.append(calculate_dispersion_metric(i,oracle_results))
    
plt.plot(j)

In [None]:
print("Length of patient scores: " + str(len(patient_scores)))
print("Length of oracle results: " + str(len(oracle_results)))

scores = []
for key in patient_scores.keys():
    scores.append(patient_scores[key])
plt.plot(scores)

ones = 0
for i in oracle_results.keys():
    if oracle_results[i]==1:
        ones+=1
print("Number of ones in oracle results: ", ones)