In [3]:
import os
import shutil
import pathlib
import sys
ROOT = pathlib.Path('/notebooks/code/volumes/checkpoints/weights')

In [5]:
!rm -rf image-super-resolution
!git clone https://github.com/thekevinscott/image-super-resolution.git
!mv image-super-resolution/ISR image-super-resolution/IISR
sys.path.insert(0,"/notebooks/code/notebooks/image-super-resolution")

Cloning into 'image-super-resolution'...
remote: Enumerating objects: 1578, done.[K
remote: Counting objects: 100% (295/295), done.[K
remote: Compressing objects: 100% (236/236), done.[K
remote: Total 1578 (delta 89), reused 264 (delta 59), pack-reused 1283[K
Receiving objects: 100% (1578/1578), 15.01 MiB | 32.28 MiB/s, done.
Resolving deltas: 100% (792/792), done.


In [6]:
from IISR.models import RDN, RRDN
from tqdm import tqdm
import tensorflow as tf

def get_model(model, arch_params):
    if model == 'rdn':
        return RDN(arch_params=arch_params)
    elif model == 'rrdn':
        return RRDN(arch_params=arch_params)
    raise Exception('No valid model found for ' + model)
    
def get_params(folder):
    arch, C, D, G, G0, T, x, _, _2, _3, _4, _5 = folder.split('-')

    arch_params = {
        'C': int(C[1:]),
        'D': int(D[1:]),
        'G': int(G[1:]),
        'G0':int(G0[2:]),
        'x':int(x[1:])
    }
    if arch == 'rrdn':
        arch_params['T'] = int(T[1:])
    return arch, arch_params

# def save_model(weights, output, arch, x, C, D, G, G0, T):
#     model.model.load_weights('/code/weights/' + weights)
#     model.model.save('/code/weights/' + output)

def get_weights(folder):
    weights = []
    for date_folder in os.listdir(folder):
        date_folder = folder / date_folder
        weights += [str(date_folder / f) for f in os.listdir(date_folder) if 'srgan' not in f and f.endswith('hdf5')]
    return weights

def convert_weight_files_to_model_files(root, target):
    print(f'make weights.zip file for folder {root}')
    weights = []
    errs = []

    for folder in os.listdir(root):
        arch, arch_params = get_params(folder)
        weights += [(w, arch, arch_params) for w in get_weights(root / folder)]
        
    # weights = [w for w in weights if 'rrdn-C4-D3-G32-G064-T10-x4-patchsize128-compress100-sharpen0-datadiv2k-vary_cFalse_epoch491.hdf5' in w[0]]
    weights = weights[0:]
        
    i = 0
    for weight, arch, arch_params in tqdm(weights):
        try:
            tf.keras.backend.clear_session() # needed for https://github.com/tensorflow/tfjs/issues/755#issuecomment-489665951
            model = get_model(arch, arch_params)
            model.model.load_weights(weight)
            weight_name = weight.split('/')[-3:]
            weight_name = '/'.join(weight_name).split('.')[0] + '.h5'
            target_path = target / weight_name
            os.makedirs('/'.join(str(target_path).split('/')[0:-1]), exist_ok=True)
            model.model.save(target_path)       
            i += 1
        except Exception as e:
            errs += [(weight, e)]
            
    print(f'Successfully processed {i} files')
    if len(errs) > 0:
        print(f'The following {len(errs)} weights could not be processed\n-------------------------------')
        for err, e in errs:
            print(err, e)
            
def make_weights_zip(root):
    print('Start making zip file')
    shutil.make_archive(root, 'zip', './')
    print('Ready for download')

ModuleNotFoundError: No module named 'ISR'

In [None]:
!rm -rf ./weights
!rm weights.zip

folder_name = 'weights'
try:
    os.remove(f'./{folder_name}.zip')
except:
    pass
try:
    shutil.rmtree(f'./{folder_name}')
except:
    pass
convert_weight_files_to_model_files(ROOT, pathlib.Path(f'./{folder_name}'))
# make_weights_zip(ROOT)