# Main script

The Code is created based on the method described in the following paper

[1] "Deep Optimization Prior for THz Model Parameter Estimation", T.M. Wong, H. Bauermeister, M. Kahl, P. Haring Bolivar, M. Moeller, A. Kolb,
Winter Conference on Applications of Computer Vision (WACV) 2022.

If you use this code in your scientific publication, please cite the mentioned paper.
The code and the algorithm are for non-comercial use only.

For other details, please visit website https://github.com/tak-wong/Deep-Optimization-Prior

In [None]:
from MoAE import *

In [None]:
def get_dataset_filename(dataset_name):
    dataset_filename = ''
    if (dataset_name.lower() == 'metalpcb'):
        dataset_filename = 'MetalPCB_91x446x446.mat'
        
    if (dataset_name.startswith('MetalPCB_AWGN')):
        dataset_filename = "MetalPCB_AWGN/{}_91x446x446.mat".format(dataset_name)
        
    if (dataset_name.startswith('MetalPCB_ShotNoise')):
        dataset_filename = "MetalPCB_ShotNoise/{}_91x446x446.mat".format(dataset_name)
        
    if (dataset_name.startswith('SynthUSAF_AWGN')):
        dataset_filename = "SynthUSAF_AWGN/{}_91x446x446.mat".format(dataset_name)
        
    if (dataset_name.startswith('SynthUSAF_ShotNoise')):
        dataset_filename = "SynthUSAF_ShotNoise/{}_91x446x446.mat".format(dataset_name)
        
    if (dataset_name.startswith('SynthObj_AWGN')):
        dataset_filename = "SynthObj_AWGN/{}_91x446x446.mat".format(dataset_name)
        
    if (dataset_name.startswith('SynthObj_ShotNoise')):
        dataset_filename = "SynthObj_ShotNoise/{}_91x446x446.mat".format(dataset_name)
        
    return dataset_filename

# Example 1: MetalPCB

In [None]:
if __name__ == '__main__':
    seed = 0
    lr = 0.01
    epochs = 1200
    dataset_name = 'metalpcb'
    dataset_filename = get_dataset_filename(dataset_name)
    dataset_path = './dataset'
    dest_path = './result'
    verbose = True
    debug = True
    
    hp = hyperparameter_unet_thz(use_seed = seed, learning_rate = lr, epochs = epochs)
    optimizer = autoencoder_unet_thz(dataset_name, dataset_filename, dataset_path, dest_path, hp, verbose)
    
    if (debug):
        optimizer.RUNS = 1
        optimizer.INTERVAL_PLOT_LOSS = 100
        optimizer.INTERVAL_SAVE_LOSS = 100
        optimizer.INTERVAL_PLOT_LR = 100
        optimizer.INTERVAL_SAVE_LR = 100
        optimizer.INTERVAL_PLOT_PARAMETERS = 100
        optimizer.INTERVAL_SAVE_PARAMETERS = 100
        optimizer.INTERVAL_PLOT_LOSSMAP = 100
        optimizer.INTERVAL_SAVE_LOSSMAP = 100
        optimizer.INTERVAL_PLOT_PIXEL = 100
        optimizer.INTERVAL_SAVE_PIXEL = 100
    optimizer.train()

In [None]:
    seed = 0
    lr = 0.01
    epochs = 1200
    dataset_name = 'MetalPCB_AWGN_n20db'
    dataset_filename = get_dataset_filename(dataset_name)
    dataset_path = './dataset'
    dest_path = './result'
    verbose = True
    debug = False
    
    hp = hyperparameter_nonet1st_thz(use_seed = seed, learning_rate = lr, epochs = epochs)
    optimizer = autoencoder_nonet1st_thz(dataset_name, dataset_filename, dataset_path, dest_path, hp, verbose)

    optimizer.train()

# Example 2: SynthUSAF+ShotNoise

In [None]:
    lr = 0.01
    epochs = 1200
    dataset_name = 'SynthUSAF_ShotNoise_p10db'
    dataset_filename = get_dataset_filename(dataset_name)
    dataset_path = './dataset'
    dest_path = './result'
    verbose = True
    debug = False
    
    hp = hyperparameter_nonet2nd_thz(use_seed = seed, learning_rate = lr, epochs = epochs)
    optimizer = autoencoder_nonet2nd_thz(dataset_name, dataset_filename, dataset_path, dest_path, hp, verbose)
    optimizer.train()

# Example 3: SynthObj+AWGN

In [None]:
    lr = 0.01
    epochs = 1200
    dataset_name = 'SynthObj_AWGN_p0db'
    dataset_filename = get_dataset_filename(dataset_name)
    dataset_path = './dataset'
    dest_path = './result'
    verbose = True
    debug = False
    
    hp = hyperparameter_ppae_thz(use_seed = seed, learning_rate = lr, epochs = epochs)
    optimizer = autoencoder_ppae_thz(dataset_name, dataset_filename, dataset_path, dest_path, hp, verbose)
    optimizer.train()