# Training AID dataset by SRCNN

- reference code: https://github.com/worldstrat/worldstrat.git

In [1]:
%load_ext autoreload
%autoreload 2
from src.train import *
import sys

In [None]:
default_train_command = [
    # Batch size, gpus, limits
    "python",
    "--batch_size", "16", 
    "--gpus", "-1",
    "--max_steps", "50000",
    "--precision", "16",
    "--num_workers", 0,

    # Model/Hyperparameters
    "--model", "SRCNN", # must be uppercase
    "--w_mse", "0.3",
    "--w_mae", "0.4",
    "--w_ssim", "0.3",
    "--hidden_channels", "64",
    "--residual_layers", "4",
    "--padding_mode", "reflect",
    "--sr_kernel_size", "1",
    "--use_dropout", "False",
    "--use_batchnorm", "False",
    "--learning_rate", "1e-4",

    # Data
    "--root", "AID-dataset/",
    "--zoom_factor", "4", # 모델에 영향
    "--output_size", "600", "600",
    "--chip_size", "600", "600",
    "--chip_stride", "600", "600",
    "--randomly_rotate_and_flip_images", True,
    "--shuffle", True,
    "--subset_train", 1.0,

    "--use_wandb", True,
    "--benchmark", True,
    "--upload_checkpoint", False,
]

def run_training_command(training_command, running_on_windows=True):
    sys.argv = training_command
    if running_on_windows:
        sys.argv += ["--num_workers", "0"] # 윈도우에서는 멀티프로세싱 비효율(than linux)
    cli_main()

### Training a single model

In [None]:
run_training_command(default_train_command, running_on_windows=False)

Global seed set to 1337
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

### Reproducing the benchmark

In [None]:
benchmark_random_seeds = [431608443, 122938034, 315114726]
benchmark_data_seed = 386564310

# HighResNet triple replicates
highresnet_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

# SRCNN MultiFrame triple replicates
# Change model to SRCNN

default_train_command[10] = 'srcnn'
srcnn_multiframe_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

# SRCNN Single Image triple replicates
# Change number of revisits to 1
default_train_command[34] = '1'
srcnn_single_image_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

In [None]:
for replicates in [highresnet_replicates, srcnn_multiframe_replicates, srcnn_single_image_replicates]:
    for replicate_training_command in replicates:
        run_training_command(replicate_training_command, running_on_windows=True)