In [None]:
%cd ..

In [None]:
import argparse
import datetime
import logging
import math
import copy
import random
import time
import torch
import os
from os import path as osp
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['TORCH_HOME'] = '/sun/home_torch'

from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
                           init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
import warnings
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
def parse_options(root_path, is_train=True):
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
    # parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args(['--launcher','pytorch','-opt','options/VQGAN_512_ds32_nearest_stage1_harmer.yml'])
    opt = parse(args.opt, root_path, is_train=is_train)

    # distributed settings
    if args.launcher == 'none':
        opt['dist'] = False
        print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        # if args.launcher == 'slurm' and 'dist_params' in opt:
        #     init_dist(args.launcher, **opt['dist_params'])
        # else:
        #     init_dist(args.launcher)

    # opt['rank'], opt['world_size'] = get_dist_info()
    opt['rank'] = 0
    opt['world_size'] = 1

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    return opt


def init_loggers(opt):
    log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
    logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
    # logger.info(get_env_info())
    logger.info(dict2str(opt))

    # initialize wandb logger before tensorboard logger to allow proper sync:
    if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
        assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
        init_wandb_logger(opt)
    tb_logger = None
    if opt['logger'].get('use_tb_logger'):
        tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
    return logger, tb_logger

In [None]:
root_path = osp.abspath('./')
root_path

In [None]:
opt = parse_options(root_path, is_train=True)

torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True

# load resume states if necessary
if opt['path'].get('resume_state'):
    device_id = torch.cuda.current_device()
    resume_state = torch.load(
        opt['path']['resume_state'],
        map_location=lambda storage, loc: storage.cuda(device_id))
# load pretrained network
elif opt['path'].get('pretrain_network_g'):
    device_id = torch.cuda.current_device()
    pretrained_state = torch.load(
        opt['path']['pretrain_network_g'],
        map_location=lambda storage, loc: storage.cuda(device_id)
    )[f"{opt['path']['param_key_g']}"]
    resume_state = None
else:
    resume_state = None

# mkdir for experiments and logger
if resume_state is None:
    make_exp_dirs(opt)
    if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
        mkdir_and_rename(osp.join('tb_logger', opt['name']))

# initialize loggers
logger, tb_logger = init_loggers(opt)

In [None]:
model = build_model(opt)