In [1]:
import torch
from copy import deepcopy
from omegaconf import DictConfig
from pathlib import Path
from tqdm.notebook import tqdm

# Convert Segmentation Models

In [10]:
def replace_substring_in_dict(d, old_substring, new_substring):
    for key, value in d.items():
#         print(key, value)
        if (isinstance(value, dict) or isinstance(value, DictConfig)):
            replace_substring_in_dict(value, old_substring, new_substring)
        elif isinstance(value, str):
#             print(d[key])
            d[key] = value.replace(old_substring, new_substring)
        else:
#             print(type(value))
            pass

def convert_segmentation_run_dir(segmenter_run_dir):
    # 1) If Needed, Rename old 'last.ckpt' to 'last_bioblue.ckpt'
    if not (segmenter_run_dir/'models'/'last_bioblue.ckpt').exists():
        print('Should move last to last_bioblue')
        filename = segmenter_run_dir/'models'/'last.ckpt'
        filename.rename(segmenter_run_dir/'models'/'last_bioblue.ckpt')
    
    # 2) Change the content of bioblue ckpt to fit sunscc module instead
    bioblue_version = torch.load(segmenter_run_dir/'models'/'last_bioblue.ckpt')   
    sunscc_version = deepcopy(bioblue_version)
    replace_substring_in_dict(sunscc_version, 'bioblue', "sunscc")

    # 3) Save the sunscc checkpoint to 'last.ckpt'
    torch.save(sunscc_version, segmenter_run_dir/'models'/'last.ckpt')

In [13]:
bioblue_run_dirs=[
    Path('../outputs/2023-01-22/01-18-26_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run0_SUNSCC'),
    Path('../outputs/2023-01-22/01-18-26_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run2_SUNSCC'),
    Path('../outputs/2023-01-22/01-18-26_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run3_SUNSCC'),
    Path('../outputs/2023-01-22/01-18-26_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run4_SUNSCC'),
    Path('../outputs/2023-01-22/05-55-04_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run5_SUNSCC'),
    Path('../outputs/2023-01-22/05-55-04_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run6_SUNSCC'),
    Path('../outputs/2023-01-22/06-30-11_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run7_SUNSCC'),
    Path('../outputs/2023-01-22/10-23-14_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run9_SUNSCC'),
    Path('../outputs/2023-01-22/13-15-38_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run8_SUNSCC'),
    Path('../outputs/2023-01-22/10-24-14_2013-15_UNet_T425_T375_T325_StepLR_epoch_1_run10_SUNSCC'),
]

for run_dir in tqdm(bioblue_run_dirs):
    convert_segmentation_run_dir(run_dir)

  0%|          | 0/10 [00:00<?, ?it/s]

Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue
Should move last to last_bioblue


# Convert Classification Models

In [4]:
def replace_substring_in_dict(d, old_substring, new_substring):
    for key, value in d.items():
#         print(key, value)
        if (isinstance(value, dict) or isinstance(value, DictConfig)):
            replace_substring_in_dict(value, old_substring, new_substring)
        elif isinstance(value, str):
#             print(d[key])
            d[key] = value.replace(old_substring, new_substring)
        else:
#             print(type(value))
            pass

def convert_segmentation_run_dir(classifier_run_dir):
    ckpt_to_rename=[
        "ENCODER_MLP1_MLP2_MLP3.ckpt",
        "ENCODER_MLP1_MLP2.ckpt",
        "ENCODER_MLP1.ckpt",
    ]

    for ckpt in ckpt_to_rename:
        bioblue_ckpt = ckpt.replace('.ckpt', '_bioblue.ckpt')

        # 1) If Needed, Rename old 'last.ckpt' to 'last_bioblue.ckpt'
        if not (classifier_run_dir/'models'/ bioblue_ckpt).exists():
            print(f'Should move {ckpt} to {bioblue_ckpt}')
            filename = classifier_run_dir/'models'/ ckpt
            filename.rename(classifier_run_dir/'models'/ bioblue_ckpt)
        
        # 2) Change the content of bioblue ckpt to fit sunscc module instead
        bioblue_version = torch.load(classifier_run_dir/'models'/bioblue_ckpt)   
        sunscc_version = deepcopy(bioblue_version)
        replace_substring_in_dict(sunscc_version, 'bioblue', "sunscc")

        # 3) Save the sunscc checkpoint to 'last.ckpt'
        torch.save(sunscc_version, classifier_run_dir/'models'/ckpt)

In [5]:
bioblue_run_dirs=[
    Path('../../outputs/rebuttal/SUNSCC_AllRevisedFiltered_Rebuttal_WithHideNoAug_0.0_class1_100epochs_run21'),
]


for run_dir in tqdm(bioblue_run_dirs):
    convert_segmentation_run_dir(run_dir)

  0%|          | 0/1 [00:00<?, ?it/s]

Should move  to last_bioblue
Should move  to last_bioblue
