# Model Select

In [9]:
from data import create_dataset, create_dataloader

import os
import torch
from omegaconf import OmegaConf
from timm import create_model
from models import MemSeg

import ipywidgets as widgets
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

cfg = OmegaConf.load('./configs.yaml')

# ====================================
# Select Model
# ====================================

def load_model(target_name):
    global model
    global testset 
    
    testset = create_dataset(
        datadir                = cfg['DATASET']['datadir'],
        target                 = target_name, 
        is_train               = False,
        resize                 = cfg['DATASET']['resize'],
        texture_source_dir     = cfg['DATASET']['texture_source_dir'],
        structure_grid_size    = cfg['DATASET']['structure_grid_size'],
        transparency_range     = cfg['DATASET']['transparency_range'],
        perlin_scale           = cfg['DATASET']['perlin_scale'], 
        min_perlin_scale       = cfg['DATASET']['min_perlin_scale'], 
        perlin_noise_threshold = cfg['DATASET']['perlin_noise_threshold']
    )
    
    memory_bank = torch.load(f'./saved_model/MemSeg-{target_name}/memory_bank.pt')
    memory_bank.device = 'cpu'
    for k in memory_bank.memory_information.keys():
        memory_bank.memory_information[k] = memory_bank.memory_information[k].cpu()

    feature_extractor = feature_extractor = create_model(
        cfg['MODEL']['feature_extractor_name'], 
        pretrained    = True, 
        features_only = True
    )
    model = MemSeg(
        memory_bank       = memory_bank,
        feature_extractor = feature_extractor
    )

    model.load_state_dict(torch.load(f'./saved_model/MemSeg-{target_name}/best_model.pt'))
    model.eval()

# ====================================
# Visualization
# ====================================


def result_plot(idx):
    input_i, mask_i, target_i = testset[idx]

    output_i = model(input_i.unsqueeze(0)).detach()
    output_i = torch.nn.functional.softmax(output_i, dim=1)

    def minmax_scaling(img):
        return (((img - img.min()) / (img.max() - img.min())) * 255).to(torch.uint8)

    fig, ax = plt.subplots(1,4, figsize=(15,10))
    
    ax[0].imshow(minmax_scaling(input_i.permute(1,2,0)))
    ax[0].set_title('Input: {}'.format('Normal' if target_i == 0 else 'Abnormal'))
    ax[1].imshow(mask_i, cmap='gray')
    ax[1].set_title('Ground Truth')
    ax[2].imshow(output_i[0][1], cmap='gray')
    ax[2].set_title('Predicted Mask')
    ax[3].imshow(minmax_scaling(input_i.permute(1,2,0)), alpha=1)
    ax[3].imshow(output_i[0][1], cmap='gray', alpha=0.5)
    ax[3].set_title(f'Input X Predicted Mask')
    
    plt.show()
    




# ====================================
# widgets
# ====================================


target_list = widgets.Dropdown(
    options=[m.split('-')[1] for m in os.listdir('./saved_model')],
    value='capsule',
    description='Model:',
    disabled=False,
)
button = widgets.Button(description="Model Change")
output = widgets.Output()


@output.capture()
def on_button_clicked(b):
    clear_output(wait=True)
    load_model(target_name=target_list.value)
    
    # vizualization
    file_list = widgets.Dropdown(
        options=[(file_path, i) for i, file_path in enumerate(testset.file_list)],
        value=0,
        description='image:',
    )

    
    widgets.interact(result_plot, idx=file_list)

button.on_click(on_button_clicked)


display(widgets.HBox([target_list, button]), output)



HBox(children=(Dropdown(description='Model:', index=11, options=('leather', 'pill', 'carpet', 'hazelnut', 'til…

Output()