# Imports

In [None]:
%env CUDA_VISIBLE_DEVICES=1

In [None]:
import torch
import re
import os
from collections import OrderedDict

In [None]:
DEVICE = torch.device('cuda')

# Load models and rename layers

In [None]:
from tqdm.auto import tqdm

In [None]:
%run ../models/checkpoint/__init__.py
%run ../utils/files.py

In [None]:
def test_rename_layers(layers):
    new_layers = []
    for layer in layers:
        layer2 = rename_layer(layer)
        if layer2 not in layers:
            new_layers.append((layer, layer2))
        if layer2 != layer:
            print(f'{layer:<42} {layer2 if layer2 != layer else "SAME"}')
    return new_layers

In [None]:
def load_checkpoint(run_id):
    folder = get_checkpoint_folder(run_id, save_mode=False, assert_exists=True)

    for filepath in os.listdir(folder):
        if 'checkpoint' in filepath:
            filepath = os.path.join(folder, filepath)
            return torch.load(filepath, map_location=DEVICE)
    raise Exception(f'No checkpoint found in {folder}')

In [None]:
def rename_checkpoints(run_ids, fixed=True, dry=True):
    for run_id in tqdm(run_ids):
        folder = get_checkpoint_folder(run_id, save_mode=False, assert_exists=True)

        for filepath in os.listdir(folder):
            if not 'checkpoint' in filepath:
                continue
            filepath = os.path.join(folder, filepath)

            # Load checkpoint
            checkpoint = torch.load(filepath, map_location=DEVICE)

            # Rename layers
            state_dict = checkpoint['model']
            new_state_dict = OrderedDict()
            for key, value in state_dict.items():
                new_key = rename_layer(key)
                new_state_dict[new_key] = value

            # Override layers
            checkpoint['model'] = new_state_dict
            
            # Get new dest
            if fixed:
                path_dest = filepath.replace('/models/', '/models/fixed/')
            else:
                path_dest = filepath

            # Save to new dest
            if dry:
                if filepath == path_dest:
                    print(f'Would override {filepath}')
                else:
                    print(f'Would save from {filepath} to {path_dest}')
            else:
                os.makedirs(os.path.dirname(path_dest), exist_ok=True)
                torch.save(checkpoint, path_dest)

## RG

In [None]:
def rename_layer(name):
#     name = name.replace('attention_layer', 'attention')
#     name = name.replace('lstm_cell', 'word_lstm')
#     name = name.replace('W_vocab', 'word_fc')
#     name = name.replace('embeddings_table', 'word_embeddings')
#     name = name.replace('features_fc.2', 'features_fc')
    
#     name = name.replace('stop_control.0', 'stop_control')
#     name = name.replace('last_fc.1', 'last_fc')
    name = name.replace('classifier.1', 'classifier')

    return name

In [None]:
run_ids = [
    # RunId('0430_231758', True, 'rg'),
    # RunId('0428_200424', True, 'rg'),
    RunId(name, False, 'rg')
    for name in os.listdir(_get_parent_folder('models', False, 'rg'))
    if re.search(r'^\d{4}_\d{6}', name)
]
len(run_ids)

In [None]:
checkpoint = load_checkpoint(run_ids[0])
test_rename_layers(checkpoint['model'].keys())

In [None]:
rename_checkpoints(run_ids, fixed=False, dry=False)

## CLS-SEG

In [None]:
def rename_layer(name):
    name = name.replace('classifier.1', 'classifier')
    return name

In [None]:
run_ids = [
    # RunId('0430_231758', True, 'rg'),
    # RunId('0428_200424', True, 'rg'),
    RunId(name, False, 'cls-seg')
    for name in os.listdir(_get_parent_folder('models', False, 'cls-seg'))
    if name != 'debug' and name != 'fixed'
]
# run_ids = [r for r in run_ids if 'densenet-121-cls-seg' in r.name]
len(run_ids)

In [None]:
# checkpoint = load_checkpoint(run_ids[0])
len(checkpoint['model'].keys())

In [None]:
[c for c in checkpoint['model'].keys() if 'classifier' in c]

In [None]:
test_rename_layers(checkpoint['model'].keys())

In [None]:
rename_checkpoints(run_ids, dry=False)