# Initialization

## Import

### Import Modules

In [None]:
import gc
import time

from tasks.optic_disc_cup.datasets import DrishtiDataset, RimOneDataset
from tasks.optic_disc_cup.metrics import calc_disc_cup_iou
from config.config_type import AllConfig, DataConfig, DataTuneConfig, LearnConfig, WeaselConfig, ProtoSegConfig
from data.dataset_loaders import DatasetLoaderParamSimple
from learners.protoseg import ProtoSegLearner
from learners.weasel import WeaselLearner
from models.u_net import UNet

from torch import cuda

### Autoreload Import

In [None]:
%reload_ext autoreload
%autoreload 1
%aimport config.config_type
%aimport models.u_net
%aimport data.types, data.few_sparse_dataset, data.dataset_loaders
%aimport learners.learner, learners.weasel, learners.protoseg
%aimport tasks.optic_disc_cup.datasets, tasks.optic_disc_cup.metrics

In [None]:
%autoreload now

## All Config

### Short Training

In [None]:
data_config: DataConfig = {
    'num_classes': 3,
    'num_channels': 3,
    'num_workers': 0,
    'batch_size': 1,
    'resize_to': (256, 256)
}

data_tune_config: DataTuneConfig = {
    'shot_list': [5],
    'sparsity_dict': {
        'point': [10],
        'grid': [25],
        'contour': [1],
        'skeleton': [1],
        'region': [1],
        'point_old': [10],
        'grid_old': [25]
    }
}

learn_config: LearnConfig = {
    'should_resume': False,
    'use_gpu': True,
    'num_epochs': 8,
    'optimizer_lr': 1e-3,
    'optimizer_weight_decay': 5e-5,
    'optimizer_momentum': 0.9,
    'scheduler_step_size': 150,
    'scheduler_gamma': 0.2,
    'tune_freq': 4,
    'exp_name': ''
}

weasel_config: WeaselConfig = {
    'use_first_order': False,
    'update_param_step_size': 0.3,
    'tune_epochs': 6,
    'tune_test_freq': 3
}

protoseg_config: ProtoSegConfig = {
    'embedding_size': 4
}

all_config: AllConfig = {
    'data': data_config,
    'data_tune': data_tune_config,
    'learn': learn_config,
    'weasel': weasel_config,
    'protoseg': protoseg_config
}

### Long Training

In [None]:
# data_config: DataConfig = {
#     'num_classes': 3,
#     'num_channels': 3,
#     'num_workers': 0,
#     'batch_size': 1,
#     'resize_to': (256, 256)
# }
# 
# data_tune_config: DataTuneConfig = {
#     'shot_list': [10],
#     'sparsity_dict': {
#         'point': [10],
#         'grid': [25],
#         'contour': [1],
#         'skeleton': [1],
#         'region': [1],
#         'point_old': [10],
#         'grid_old': [25]
#     }
# }
# 
# learn_config: LearnConfig = {
#     'should_resume': False,
#     'use_gpu': True,
#     'num_epochs': 200,
#     'optimizer_lr': 1e-3,
#     'optimizer_weight_decay': 5e-5,
#     'optimizer_momentum': 0.9,
#     'scheduler_step_size': 150,
#     'scheduler_gamma': 0.2,
#     'tune_freq': 40,
#     'exp_name': ''
# }
# 
# weasel_config: WeaselConfig = {
#     'use_first_order': False,
#     'update_param_step_size': 0.3,
#     'tune_epochs': 40,
#     'tune_test_freq': 8
# }
# 
# protoseg_config: ProtoSegConfig = {
#     'embedding_size': 4
# }
# 
# all_config: AllConfig = {
#     'data': data_config,
#     'data_tune': data_tune_config,
#     'learn': learn_config,
#     'weasel': weasel_config,
#     'protoseg': protoseg_config
# }

# Dataset Exploration

## Additional Import

In [None]:
# import numpy as np
# from matplotlib import pyplot as plt
# 
# from data.types import SparsityValue
# 
# plt.style.use('dark_background')

## RIM-ONE

### Create Dataset

In [None]:
rim_one_sparsity_params: dict = {
    'point_dot_size': 5,
    'grid_dot_size': 4,
    'contour_radius_dist': 4,
    'contour_radius_thick': 2,
    'skeleton_radius_thick': 4,
    'region_compactness': 0.5
}

rim_one_data = RimOneDataset(
    mode='train',
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=rim_one_sparsity_params
)

### Check Sparse Masks

In [None]:
# sparsity_values: dict[str, SparsityValue] = {
#     'point': 10,
#     'grid': 20,
#     'contour': 1,
#     'skeleton': 1,
#     'region': 1,
#     'point_old': 10,
#     'grid_old': 20
# }
# image, mask, sparse_masks, image_filename = rim_one_data.get_data_with_sparse_all(0, sparsity_values)
# print(image.shape, image.max(), image.min(), image_filename)
# print(mask.shape, mask.dtype, np.unique(mask))
# 
# n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1 
# _, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows*6))
# axs = axs.flat
# axs[0].imshow(image)
# axs[1].imshow(mask)
# for i, sparsity in enumerate(sparse_masks):
#     axs[i+2].imshow(sparse_masks[sparsity])

### Check Others

In [None]:
# image_sizes = []
# for image_path, mask_path in rim_one_data.get_all_data_path():
#     image, _ = rim_one_data.read_image_mask(image_path, mask_path)
#     image_sizes.append(image.shape)
# 
# image_sizes = np.array(image_sizes)
# 
# print(np.unique(image_sizes[:,0], return_counts=True))
# print(image_sizes[:,0].min(), image_sizes[:,0].max())
# print(np.unique(image_sizes[:,1], return_counts=True))
# print(image_sizes[:,1].min(), image_sizes[:,1].max())

## DRISHTI

### Create Dataset

In [None]:
drishti_sparsity_params: dict = {
    'point_dot_size': 4,
    'grid_dot_size': 4,
    'contour_radius_dist': 4,
    'contour_radius_thick': 1,
    'skeleton_radius_thick': 3,
    'region_compactness': 0.5
}

drishti_data = DrishtiDataset(
    mode='train',
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=drishti_sparsity_params
)

### Check Sparse Masks

In [None]:
# sparsity_values: dict[str, SparsityValue] = {
#     'point': 10,
#     'grid': 25,
#     'contour': 1,
#     'skeleton': 1,
#     'region': 1,
#     'point_old': 10,
#     'grid_old': 25
# }
# image, mask, sparse_masks, image_filename = drishti_data.get_data_with_sparse_all(1, sparsity_values)
# print(image.shape, image.max(), image.min(), image_filename)
# print(mask.shape, mask.dtype, np.unique(mask))
# 
# n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1
# _, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows*6))
# axs = axs.flat
# axs[0].imshow(image)
# axs[1].imshow(mask, cmap='gray')
# for i, sparsity in enumerate(sparse_masks):
#     axs[i+2].imshow(sparse_masks[sparsity])

### Check Others

In [None]:
# image_sizes = []
# for image_path, mask_path in drishti_data.get_all_data_path():
#     image, _ = rim_one_data.read_image_mask(image_path, mask_path)
#     image_sizes.append(image.shape)
# 
# image_sizes = np.array(image_sizes)
# 
# print(np.unique(image_sizes[:,0], return_counts=True))
# print(image_sizes[:,0].min(), image_sizes[:,0].max())
# print(np.unique(image_sizes[:,1], return_counts=True))
# print(image_sizes[:,1].min(), image_sizes[:,1].max())

## Dataset Loader Params

In [None]:
rim_one_meta_loader_params: DatasetLoaderParamSimple = {
    'dataset_class': RimOneDataset,
    'dataset_kwargs': {
        'split_seed': 0,
        'split_test_size': 0.2,
        'num_shots': -1,
        'sparsity_mode': 'random',
        'sparsity_value': 'random',
        'sparsity_params': rim_one_sparsity_params
    }
}

drishti_tune_loader_params: DatasetLoaderParamSimple = {
    'dataset_class': DrishtiDataset,
    'dataset_kwargs': {
        'split_seed': 0,
        'split_test_size': 0.2,
        'sparsity_params': drishti_sparsity_params
    }
}

## Num Workers Check

In [None]:
# from torch.utils.data import DataLoader, Dataset
# 
# def check_num_workers(dataset_instance: Dataset, dataset_name: str, n_workers: int):
#     data_loader = DataLoader(dataset_instance,
#                              batch_size=3,
#                              num_workers=n_workers,
#                              shuffle=True,
#                              pin_memory=True)
#     start_time = time.time()
#     # noinspection PyUnusedLocal
#     for idx, data in enumerate(data_loader):
#         # print("{} - data {}/{}".format(dataset_name, idx+1, len(data_loader)))
#         pass
#     end_time = time.time()
#     print("{} - {} workers: {} seconds".format(dataset_name, num_workers, end_time - start_time))

In [None]:
# num_workers_range = range(0, 4, 1)
# num_workers_range = range(0, 40, 3)

In [None]:
# rim_one_dataset = RimOneDataset(
#     'train',
#     all_config['data']['num_classes'],
#     all_config['data']['resize_to'],
#     **rim_one_meta_loader_params['dataset_kwargs']
# )
# 
# for num_workers in num_workers_range:
#     check_num_workers(rim_one_dataset, 'RO', num_workers)

In [None]:
# drishti_dataset = DrishtiDataset(
#     'train',
#     all_config['data']['num_classes'],
#     all_config['data']['resize_to'],
#     **drishti_tune_loader_params['dataset_kwargs']
# )
# 
# for num_workers in num_workers_range:
#     check_num_workers(drishti_dataset, 'DR', num_workers)

# Weasel Learner

## Initialization

### Update Config

In [None]:
all_config['data']['batch_size'] = 3
# all_config['data']['batch_size'] = 14

# all_config['data']['num_workers'] = 3

# all_config['learn']['should_resume'] = True

all_config['learn']['exp_name'] = 'v3 RO-DR S WS'
# all_config['learn']['exp_name'] = 'v3 RO-DR L WS'

In [None]:
rim_one_meta_loader_params['dataset_kwargs']['num_shots'] = 50

### Create Model

In [None]:
net = UNet(all_config['data']['num_channels'], all_config['data']['num_classes'])

n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('# of parameters: ' + str(n_params))

### Create Learner

In [None]:
learner = WeaselLearner(net, all_config, [rim_one_meta_loader_params], drishti_tune_loader_params, calc_disc_cup_iou)

## Learning

In [None]:
try:
    learner.learn()
except BaseException as e:
    learner.log_error()
    raise e
finally:
    learner.remove_log_handlers()
    del net
    del learner
    gc.collect()
    cuda.empty_cache()

In [None]:
time.sleep(60)

# Protoseg Learner

## Initialization

### Update Config

In [None]:
all_config['data']['batch_size'] = 5
# all_config['data']['batch_size'] = 36

# all_config['data']['num_workers'] = 0

# all_config['learn']['should_resume'] = True

all_config['learn']['exp_name'] = 'v3 RO-DR S PS'
# all_config['learn']['exp_name'] = 'v3 RO-DR L PS'

In [None]:
rim_one_meta_loader_params['dataset_kwargs']['num_shots'] = 50

### Create Model

In [None]:
net = UNet(all_config['data']['num_channels'], all_config['protoseg']['embedding_size'])

n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('# of parameters: ' + str(n_params))

### Create Learner

In [None]:
learner = ProtoSegLearner(net, all_config, [rim_one_meta_loader_params], drishti_tune_loader_params, calc_disc_cup_iou)

## Learning

In [None]:
try:
    learner.learn()
except BaseException as e:
    learner.log_error()
    raise e
finally:
    learner.remove_log_handlers()
    del net
    del learner
    gc.collect()
    cuda.empty_cache()

# Other