# Initialization


## Import


### Import Modules


In [None]:
import gc
import time

import numpy as np
from matplotlib import pyplot as plt

from config.config_type import (
    AllConfig,
    DataConfig,
    DataTuneConfig,
    LearnConfig,
    SaveConfig,
    WeaselConfig,
)
from data.get_meta_datasets import get_meta_datasets
from data.get_tune_loaders import get_tune_loaders
from learners.protoseg import ProtoSegLearner
from learners.weasel import WeaselLearner
from models.u_net import UNet
from tasks.optic_disc_cup.datasets import DrishtiDataset, RimOneDataset
from tasks.optic_disc_cup.metrics import calc_disc_cup_iou

from torch import cuda

plt.style.use("dark_background")

### Autoreload Import


In [None]:
%reload_ext autoreload
%autoreload 1
%aimport config.config_type
%aimport models.u_net
%aimport data.few_sparse_dataset, data.get_meta_datasets, data.get_tune_loaders
%aimport learners.learner, learners.weasel, learners.protoseg
%aimport tasks.optic_disc_cup.datasets, tasks.optic_disc_cup.metrics

In [None]:
%autoreload now

## All Config


### Short Training


In [None]:
data_config: DataConfig = {
    "num_classes": 3,
    "num_channels": 3,
    "num_workers": 0,
    "batch_size": 1,
    "resize_to": (256, 256),
}

data_tune_config: DataTuneConfig = {
    "list_shots": [5],
    "list_sparsity_point": [50],
    "list_sparsity_grid": [10],
    "list_sparsity_contour": [1],
    "list_sparsity_skeleton": [1],
    "list_sparsity_region": [1],
}

learn_config: LearnConfig = {
    "should_resume": False,
    "use_gpu": True,
    "num_epochs": 8,
    "optimizer_lr": 1e-3,
    "optimizer_weight_decay": 5e-5,
    "optimizer_momentum": 0.9,
    "scheduler_step_size": 150,
    "scheduler_gamma": 0.2,
    "tune_freq": 4,
    "meta_used_datasets": 1,
    "meta_iterations": 5,
}

save_config: SaveConfig = {
    "ckpt_path": "./ckpt/",
    "output_path": "./outputs/",
    "exp_name": "",
}

weasel_config: WeaselConfig = {
    "use_first_order": False,
    "update_param_step_size": 0.3,
    "tune_epochs": 6,
    "tune_test_freq": 3,
}

all_config: AllConfig = {
    "data": data_config,
    "data_tune": data_tune_config,
    "learn": learn_config,
    "save": save_config,
    "weasel": weasel_config,
}

### Long Training


In [None]:
# data_config: DataConfig = {
#     'num_classes': 3,
#     'num_channels': 3,
#     'num_workers': 0,
#     'batch_size': 1,
#     'resize_to': (256, 256)
# }
#
# data_tune_config: DataTuneConfig = {
#     'list_shots': [20],
#     'list_sparsity_point': [50],
#     'list_sparsity_grid': [10],
#     'list_sparsity_contour': [1],
#     'list_sparsity_skeleton': [1],
#     'list_sparsity_region': [1]
# }
#
# learn_config: LearnConfig = {
#     'should_resume': False,
#     'use_gpu': True,
#     'num_epochs': 200,
#     'optimizer_lr': 1e-3,
#     'optimizer_weight_decay': 5e-5,
#     'optimizer_momentum': 0.9,
#     'scheduler_step_size': 150,
#     'scheduler_gamma': 0.2,
#     'tune_freq': 40,
#     'meta_used_datasets': 1,
#     'meta_iterations': 5
# }
#
# save_config: SaveConfig = {
#     'ckpt_path': './ckpt/',
#     'output_path': './outputs/',
#     'exp_name': ''
# }
#
# weasel_config: WeaselConfig = {
#     'use_first_order': False,
#     'update_param_step_size': 0.3,
#     'tune_epochs': 40,
#     'tune_test_freq': 8
# }
#
# all_config: AllConfig = {
#     'data': data_config,
#     'data_tune': data_tune_config,
#     'learn': learn_config,
#     'save': save_config,
#     'weasel': weasel_config
# }

# Dataset Exploration


## RIM-ONE


### Create Dataset


In [None]:
rim_one_sparsity_params: dict = {
    "contour_radius_dist": 4,
    "contour_radius_thick": 2,
    "skeleton_radius_thick": 4,
    "region_compactness": 0.5,
}

rim_one_data = RimOneDataset(
    mode="train",
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=rim_one_sparsity_params,
)

### Check Sparse Masks


In [None]:
image, mask, sparse_masks, image_filename = rim_one_data.get_data_with_sparse_all(
    0, 50, 10, 1, 1, 1
)
print(image.shape, image.max(), image.min(), image_filename)
print(mask.shape, mask.dtype, np.unique(mask))

n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1
_, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows * 6))
axs = axs.flat
axs[0].imshow(image)
axs[1].imshow(mask)
for i, sm in enumerate(sparse_masks):
    axs[i + 2].imshow(sm)

### Check Others


In [None]:
# image_sizes = []
# for image_path, mask_path in rim_one_data.get_all_data_path():
#     image, _ = rim_one_data.read_image_mask(image_path, mask_path)
#     image_sizes.append(image.shape)

# image_sizes = np.array(image_sizes)

# print(np.unique(image_sizes[:,0], return_counts=True))
# print(image_sizes[:,0].min(), image_sizes[:,0].max())
# print(np.unique(image_sizes[:,1], return_counts=True))
# print(image_sizes[:,1].min(), image_sizes[:,1].max())

## DRISHTI


### Create Dataset


In [None]:
drishti_sparsity_params: dict = {
    "contour_radius_dist": 4,
    "contour_radius_thick": 1,
    "skeleton_radius_thick": 3,
    "region_compactness": 0.5,
}

drishti_data = DrishtiDataset(
    mode="train",
    num_classes=3,
    num_shots=5,
    resize_to=(256, 256),
    split_seed=0,
    sparsity_params=drishti_sparsity_params,
)

### Check Sparse Masks


In [None]:
image, mask, sparse_masks, image_filename = drishti_data.get_data_with_sparse_all(
    1, 50, 20, 1, 1, 1
)
print(image.shape, image.max(), image.min(), image_filename)
print(mask.shape, mask.dtype, np.unique(mask))

n_rows = int(np.ceil(len(sparse_masks) / 2)) + 1
_, axs = plt.subplots(n_rows, 2, figsize=(12, n_rows * 6))
axs = axs.flat
axs[0].imshow(image)
axs[1].imshow(mask, cmap="gray")
for i, sm in enumerate(sparse_masks):
    axs[i + 2].imshow(sm)

### Check Others


In [None]:
# image_sizes = []
# for image_path, mask_path in drishti_data.get_all_data_path():
#     image, _ = rim_one_data.read_image_mask(image_path, mask_path)
#     image_sizes.append(image.shape)

# image_sizes = np.array(image_sizes)

# print(np.unique(image_sizes[:,0], return_counts=True))
# print(image_sizes[:,0].min(), image_sizes[:,0].max())
# print(np.unique(image_sizes[:,1], return_counts=True))
# print(image_sizes[:,1].min(), image_sizes[:,1].max())

# Weasel Learner


## Initialization


### Update Config


In [None]:
all_config["data"]["batch_size"] = 3
all_config["save"]["exp_name"] = "weasel_short_rimone_to_drishti"
# all_config['save']['exp_name'] = 'weasel_long_rimone_to_drishti'

### Create Model


In [None]:
net = UNet(all_config["data"]["num_channels"], all_config["data"]["num_classes"])

n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("# of parameters: " + str(n_params))

### Prepare Data


In [None]:
meta_set = get_meta_datasets(
    [
        {
            "dataset_class": RimOneDataset,
            "num_classes": all_config["data"]["num_classes"],
            "resize_to": all_config["data"]["resize_to"],
            "kwargs": {
                "split_seed": 0,
                "split_test_size": 0.8,
                "sparsity_mode": "random",
                "sparsity_value": "random",
                "sparsity_params": rim_one_sparsity_params,
            },
        }
    ]
)

tune_loader = get_tune_loaders(
    dataset_class=DrishtiDataset,
    dataset_kwargs={
        "split_seed": 0,
        "split_test_size": 0.8,
        "sparsity_mode": "random",
        "sparsity_value": "random",
        "sparsity_params": drishti_sparsity_params,
    },
    num_classes=all_config["data"]["num_classes"],
    resize_to=all_config["data"]["resize_to"],
    shots=all_config["data_tune"]["list_shots"],
    point=all_config["data_tune"]["list_sparsity_point"],
    grid=all_config["data_tune"]["list_sparsity_grid"],
    contour=all_config["data_tune"]["list_sparsity_contour"],
    skeleton=all_config["data_tune"]["list_sparsity_skeleton"],
    region=all_config["data_tune"]["list_sparsity_region"],
    batch_size=all_config["data"]["batch_size"],
    num_workers=all_config["data"]["num_workers"],
)

### Create Learner


In [None]:
learner = WeaselLearner(net, all_config, meta_set, tune_loader, calc_disc_cup_iou)

## Learning


In [None]:
learner.learn()

In [None]:
net = None
learner = None

gc.collect()
cuda.empty_cache()

time.sleep(60)

# Protoseg Learner


## Initialization


### Update Config


In [None]:
all_config["data"]["batch_size"] = 5
# all_config['learn']['should_resume'] = True
all_config["save"]["exp_name"] = "protoseg_short_rimone_to_drishti"
# all_config['save']['exp_name'] = 'protoseg_long_rimone_to_drishti'

### Create Model


In [None]:
net = UNet(all_config["data"]["num_channels"], all_config["data"]["num_classes"])

n_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("# of parameters: " + str(n_params))

### Prepare Data


In [None]:
meta_set = get_meta_datasets(
    [
        {
            "dataset_class": RimOneDataset,
            "num_classes": all_config["data"]["num_classes"],
            "resize_to": all_config["data"]["resize_to"],
            "kwargs": {
                "split_seed": 0,
                "split_test_size": 0.8,
                "sparsity_mode": "random",
                "sparsity_value": "random",
                "sparsity_params": rim_one_sparsity_params,
            },
        }
    ]
)

tune_loader = get_tune_loaders(
    dataset_class=DrishtiDataset,
    dataset_kwargs={
        "split_seed": 0,
        "split_test_size": 0.8,
        "sparsity_mode": "random",
        "sparsity_value": "random",
        "sparsity_params": drishti_sparsity_params,
    },
    num_classes=all_config["data"]["num_classes"],
    resize_to=all_config["data"]["resize_to"],
    shots=all_config["data_tune"]["list_shots"],
    point=all_config["data_tune"]["list_sparsity_point"],
    grid=all_config["data_tune"]["list_sparsity_grid"],
    contour=all_config["data_tune"]["list_sparsity_contour"],
    skeleton=all_config["data_tune"]["list_sparsity_skeleton"],
    region=all_config["data_tune"]["list_sparsity_region"],
    batch_size=all_config["data"]["batch_size"],
    num_workers=all_config["data"]["num_workers"],
)

### Create Learner


In [None]:
learner = ProtoSegLearner(net, all_config, meta_set, tune_loader, calc_disc_cup_iou)

## Learning


In [None]:
learner.learn()

In [None]:
net = None
learner = None

gc.collect()
cuda.empty_cache()