# Training a super-resolution model

The dataset comes with several pretrained super-resolution models we used as a benchmark:

- HighResNet
- SRCNN Multi-Frame
- SRCNN Single-Image

We trained the models on a [p3.2xlarge](https://aws.amazon.com/ec2/instance-types/p3/) instance,
and the training usually takes about 45 min - 1.5 hr on a single GPU instance, using 8 low-resolution revisits and the entire dataset.

The splits we used are available in the `stratified_train_val_test_split.csv` file.  
These splits are stratified to ensure equal representation of all LCCS/IPCC/SMOD classes within each split.  
To run on a smaller subset, you can manually specify the number of AOIs to be used in each split using the `--train_split`, `--val_split`, `--test_split` arguments.

To train the network, or reproduce this benchmark, you can run the following commands:


In [None]:
%load_ext autoreload
%autoreload 2
from src.train import *
import sys

In [None]:
default_train_command = [
    # Batch size, gpus, limits
    "python",
    "--batch_size", "16",
    "--gpus", "-1", # if you want to use cpu, comment out this line
    "--max_steps", "50000",
    "--precision", "16",
    "--num_workers", "8", # set your own number of workers(cpu cores)

    # Model/Hyperparameters
    "--model", "srcnn",
    "--w_mse", "0.3",
    "--w_mae", "0.4",
    "--w_ssim", "0.3",
    "--hidden_channels", "128",
    "--shift_px", "2",
    "--shift_mode", "lanczos",
    "--shift_step", "0.5",
    "--residual_layers", "1",
    "--learning_rate", "1e-4",
    
    # Data
    "--dataset", "JIF",
    "--root", "dataset", # dataset directory name
    "--revisits", "1", # single frame SR
    "--input_size", "160", "160",
    "--output_size", "500", "500",
    "--chip_size", "80", "80",
    "--chip_stride", "80", "80",
    "--lr_bands_to_use", "true_color",
    "--use_single_frame_sr", "True",
    "--normalize_lr", "True",
    "--randomly_rotate_and_flip_images", "True",
    "--shuffle", "True",
    "--subset_train", "0.2",
    #"--radiometry_depth", "12",

    # Training, validation, test splits
    "--list_of_aois", "dataset/train_val_test.csv",
    
    # WandB 기본적으로 비활성화
    # 필요한 경우 아래 라인의 주석을 해제하여 WandB 활성화
    #"--use_wandb"
]

def run_training_command(training_command, running_on_windows=True):
    sys.argv = training_command
    if running_on_windows:
        sys.argv += ["--num_workers", "0"]
    cli_main()

**NOTE**: Keep in mind the training was done on an instance with 1xV100 and 64 GB of RAM.  
The batch size might be too large for your local computer.  

If CUDA runs out of memory, consider decreasing it above in the `default_training_command`.  
You can also decrease the number of revisits to any number from 1 to 8.

If CUDA runs out of shared memory, you can increase it on Linux by running:  
`sudo mount -o remount,size={YOUR_RAM_SIZE, e.g. 64G} /dev/shm`

If running on Windows, set the `running_on_windows` flag in the `run_train_command` function to True.

### Training a single model

In [None]:
run_training_command(default_train_command, running_on_windows=True)

### Reproducing the benchmark

In [None]:
benchmark_random_seeds = [431608443, 122938034, 315114726]
benchmark_data_seed = 386564310

# HighResNet triple replicates
highresnet_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

# SRCNN MultiFrame triple replicates
# Change model to SRCNN

default_train_command[10] = 'srcnn'
srcnn_multiframe_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

# SRCNN Single Image triple replicates
# Change number of revisits to 1
default_train_command[34] = '1'
srcnn_single_image_replicates = [
    default_train_command 
    + ["--data_split_seed", str(benchmark_data_seed)]
    + ["--seed", str(seed)] 
    for seed in benchmark_random_seeds
]

In [None]:
for replicates in [highresnet_replicates, srcnn_multiframe_replicates, srcnn_single_image_replicates]:
    for replicate_training_command in replicates:
        run_training_command(replicate_training_command, running_on_windows=True)