In [1]:
import os
import sys
sys.path.append(os.path.abspath(".."))

from project_config import *
from mmengine import Config
from utils.set_flexible_cfg import *

In [3]:
ASSETS_CONFIG_DIR = ASSETS_CONFIGS_SRGAN_DIR_PATH
WORKDIR = OUT_SRGAN_DIR_PATH
CFG_NAME = "srgan"
mmagic_cfg_suffix = "srgan_resnet/srgan_x4c64b16_1xb16-1000k_div2k.py"
mmagic_cfg_path = os.path.join(MMAGIC_CONFIGS_DIR_PATH, mmagic_cfg_suffix)
os.path.exists(mmagic_cfg_path), mmagic_cfg_path

(True,
 '/home/featurize/work/AI6126project2/mmagic/configs/srgan_resnet/srgan_x4c64b16_1xb16-1000k_div2k.py')

In [4]:
cfg = Config.fromfile(mmagic_cfg_path)
data_cfg = Config.fromfile(SRC_DATA_CFG_FILE_PATH, lazy_import=False)
vis_cfg = Config.fromfile(SRC_VIS_CFG_FILE_PATH, lazy_import=False)
cfg.merge_from_dict(options=data_cfg.to_dict())
cfg.merge_from_dict(options=vis_cfg.to_dict())
# print(cfg.pretty_text)

In [None]:
cfg.train_dataloader.batch_size = 16 # original, 32

cfg.train_cfg.max_iters = 400000 # original, 500k
cfg.train_cfg.val_interval = 5000
cfg.default_hooks.logger.interval = 100
cfg.default_hooks.checkpoint.interval = cfg.train_cfg.val_interval
cfg.default_hooks.checkpoint.max_keep_ckpts = 2
cfg.default_hooks.checkpoint.save_best = 'PSNR'
cfg.default_hooks.checkpoint.rule = "greater"

# cfg.train_dataloader.num_workers = 4
# cfg.val_dataloader.num_workers = 2

In [6]:
scale = 4

# vis_mode = VisMode.TRAIN
vis_mode = VisMode.VAL
vis_type = VisBackendType.WANDB
normalization = True

scheduler_type = ParamSchedulerType.MULTISTEP
#scheduler_type = ParamSchedulerType.COSINE

cosine_n_periods = 4
cosine_eta_min_ratio = 0.001

scheduler_kwargs = dict(
    by_epoch = False,
    
    # MutlistepLR
    gamma = 0.5,
    milestones = [50000, 100000, 200000, 300000], # for edsr
    
    # CosineRestartLr
    periods = [cfg.train_cfg.max_iters//cosine_n_periods]*cosine_n_periods ,
    # eta_min = cfg.optim_wrapper.optimizer.lr*0.001,
    restart_weights = [1]*cosine_n_periods ,
)

loss_metrics = [dict(name='loss', step_metric='iter')]

In [7]:
cfg.param_scheduler = get_param_scheduler(scheduler_type, scheduler_kwargs)
visualizer, custom_vis_hook = get_visualizer_and_custom_hook(
    vis_mode, vis_type
)
cfg.visualizer = visualizer
cfg.custom_hooks.append(custom_vis_hook)

wandb_backend = next(
    (vb for vb in cfg.visualizer.vis_backends if vb.get('type') == 'WandbVisBackend'), 
    None
)

if not normalization:
    cfg.model.generator.rgb_mean = [0, 0, 0]

da_repr = get_DA_repr(cfg)

if da_repr.find('Cutblur') != -1:
    cfg.model.generator.upscale_factor = 1
    cfg.val_pipeline.insert(-1, dict(type='Resize', keys = 'img', scale=(cfg.gt_h_size, cfg.gt_w_size), keep_ratio=True, interpolation='nearest'))
    set_values, final_degradation = None, None
    for i, transform in enumerate(cfg.train_pipeline):
        if transform.type == 'SetValues':
            cfg.train_pipeline[i].dictionary.scale = 1
        if transform.type == 'FinalRandomSecondOrderDegradation':
            cfg.train_pipeline[i].params.target_size = (cfg.gt_h_size, cfg.gt_w_size)
            
for i, transform in enumerate(cfg.train_pipeline):
    if transform.type == 'PairedRandomCrop':
        cfg.train_pipeline[i].gt_patch_size = 128
        break

cfg.train_dataloader.dataset.pipeline = cfg.train_pipeline
cfg.val_dataloader.dataset.pipeline = cfg.val_pipeline


cfg.experiment_name = (
    f"{CFG_NAME}"
    # f"_b{cfg.train_dataloader.batch_size}"
    f"_iter{cfg.train_cfg.max_iters}"
    f"_ps{scheduler_type.name}"
    f"_op{cfg.optim_wrapper.constructor}"
    # f"_lr{cfg.optim_wrapper.optimizer.lr:.0e}".replace('-', 'm')
    f"_da{da_repr}"
    f"_n{normalization}"
)

if wandb_backend:
    wandb_backend.init_kwargs.name = cfg.experiment_name
    wandb_backend.define_metric_cfg.extend(loss_metrics)

cfg.vis_backends = cfg.visualizer.vis_backends

cfg.work_dir = os.path.join(WORKDIR, cfg.experiment_name)
cfg.save_dir = cfg.work_dir
cfg.default_hooks.checkpoint.out_dir = cfg.work_dir

In [8]:
from datetime import datetime
# ymd_timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
ymd_timestamp = datetime.now().strftime("%Y%m%d")
cfg_asset_path = os.path.join(
    ASSETS_CONFIG_DIR,
    f"{CFG_NAME}_{ymd_timestamp}.py"
)
cfg.dump(cfg_asset_path)
cfg_asset_path

'/home/featurize/work/AI6126project2/assets/configs/srgan_resnet/srgan_20250424.py'

In [9]:
cfg.work_dir

'/home/featurize/out/AI6126project2/srgan_resnet/srgan_iter400000_psMULTISTEP_opMultiOptimWrapperConstructor_daPRC_FlipH_FlipV_RT_Cutblur_nTrue'

In [10]:
cfg.experiment_name

'srgan_iter400000_psMULTISTEP_opMultiOptimWrapperConstructor_daPRC_FlipH_FlipV_RT_Cutblur_nTrue'