In [None]:
"""
hat.archs.hat_archs.HAT이 아닌 hat.archs.hat_model.HATModel을 학습에서 사용. 
"""

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from basicsr.data.transforms import augment, paired_random_crop
import os
os.chdir('HAT_official/hat')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.getcwd()

'/workspace/competition/2022.09_SR/HAT_official/hat'

In [6]:
import cv2
import numpy as np
import glob as glob
import wandb
import albumentations as A
from archs.hat_arch import HAT
import random
from tqdm import tqdm

In [None]:
def load_img_to_tensor(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
    return img

class SRDataset(Dataset):
    def __init__(self,
                 lr_path,
                 hr_path,
                 train_ratio=0.8,
                 mode='train',
                 seed=2022,
                 transform=A.Compose([
                     A.HorizontalFlip(p=0.5),
                     A.Rotate(limit=180, p=0.9),
                    ], additional_targets={'image2': 'image'}),
                 augmentation_prop=0.5):
        
        random.seed(2022)
        self.lr_path = lr_path
        self.hr_path = hr_path
        
        img_list = os.listdir(lr_path)
        random.shuffle(img_list)
        if mode=='train':
            self.img_list = img_list[:round(len(img_list)*train_ratio)]
        elif mode=='valid':
            self.img_list = img_list[round(len(img_list)*train_ratio):]
        else:
            raise f'invalid mode. {mode}'
            
        self.transform = transform
        self.augmentation_prop = augmentation_prop
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        lr = cv2.imread(os.path.join(self.lr_path, self.img_list[idx]), cv2.IMREAD_COLOR).astype(np.float32) / 255.
        hr = cv2.imread(os.path.join(self.hr_path, self.img_list[idx]), cv2.IMREAD_COLOR).astype(np.float32) / 255.
        
        p = random.random()
        if p < self.augmentation_prop and self.mode=='train':
            gt_size = self.opt['gt_size']
            # random crop
            hr, lr = paired_random_crop(hr, lr, gt_size=480, scale=4)
            # flip, rotation
            hr, lr = augment([hr, lr], hflip=True, rotation=True)
            
            # augmentations = self.transform(image=lr, image2=hr)
            # lr, hr = augmentations['image'], augmentations['image2']
            
        lr = torch.from_numpy(np.transpose(lr[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
        hr = torch.from_numpy(np.transpose(hr[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
        
        return lr, hr 
    
lr_path = 'HAT_official/datasets/data/train/120_480/lr'
hr_path = 'HAT_official/datasets/data/train/120_480/hr'
train_set = SRDataset(lr_path=lr_path, hr_path=hr_path, train_ratio=0.9, mode='train')
valid_set = SRDataset(lr_path=lr_path, hr_path=hr_path, train_ratio=0.9, mode='valid')
print(len(train_set), len(valid_set))

batch_size=1
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)

In [None]:
model = HAT(upscale=4,
            in_chans=3,
            img_size=64,
            window_size=16,
            compress_ratio=3,
            squeeze_factor=30,
            conv_scale=0.01,
            overlap_ratio=0.5,
            img_range=1.,
            depths=(6,6,6,6,6,6,6,6,6,6,6,6),
            embed_dim=180,
            num_heads=(6,6,6,6,6,6,6,6,6,6,6,6),
            mlp_ratio=2,
            upsampler='pixelshuffle',
            resi_connection='1conv')
pretrained_path = 'HAT_official/experiments/pretrained_models/HAT-L_SRx4_ImageNet-pretrain.pth'
pretrained_states = torch.load(pretrained_path)
model.load_state_dict(pretrained_states['params_ema'])

In [None]:
loss_fn = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.99), weight_decay=0.001)

In [None]:
class PSNR:
    """Peak Signal to Noise Ratio
    img1 and img2 have range [0, 255]"""

    def __init__(self):
        self.name = "PSNR"

    @staticmethod
    def __call__(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        return 20 * torch.log10(255.0 / torch.sqrt(mse))
metric=PSNR()

In [None]:
window_size=64
scale=4

In [None]:
len(valid_loader)

In [None]:
# Creates a GradScaler once at the beginning of training.
scaler = torch.cuda.amp.GradScaler()

In [None]:
device = 'cuda'
model = model.to(device)
model.train()
epochs = 10
total_iter = 10000
i = 0
while i < total_iter:
    for lr, hr in train_loader:
        lr, hr = lr.to(device), hr.to(device)
        
        # padding
        mod_pad_h, mod_pad_w = 0, 0
        _, _, h, w = lr.size()
        if h % window_size != 0:
            mod_pad_h = window_size - h % window_size
        if w % window_size != 0:
            mod_pad_w = window_size - w % window_size
        lr = F.pad(lr, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
        
        with torch.cuda.amp.autocast():  
            sr = model(lr)
        
            # padding한 부분 삭제
            _, _, h, w = sr.size()
            sr = sr[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]

            loss = loss_fn(sr, hr)
        
        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        scaler.scale(loss).backward()
        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(scaler)
        # Updates the scale for next iteration.
        scaler.update()
        scaler.zero_grad(set_to_none=True) # set_to_none=True here can modestly improve performance
        
        prog_bar.set_description(round(float(loss), 4))
        
        i += 1
        if i%1==0:
            print(i, round(loss.data.float(), 4))
        
        # validation
        if i%2==0:
            model.eval()
            running_loss = 0
            running_psnr = 0
            for lr, hr in valid_loader:
                with torch.no_grad():
                    lr, hr = lr.to(device), hr.to(device)
                    
                    # pad
                    mod_pad_h, mod_pad_w = 0, 0
                    _, _, h, w = lr.size()
                    if h % window_size != 0:
                        mod_pad_h = window_size - h % window_size
                    if w % window_size != 0:
                        mod_pad_w = window_size - w % window_size
                    lr = F.pad(lr, (0, mod_pad_w, 0, mod_pad_h), 'reflect')

                    sr = model(lr)
                    # padding한 부분 삭제
                    _, _, h, w = sr.size()
                    sr = sr[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
                    
                    running_loss += loss_fn(sr, hr) * sr.shape[0]
                    for bi in range(sr.shape[0]):
                        running_psnr = metric(sr[0], hr[0])
            print(i, f'valid | loss {round(float(running_loss), 4)} | psnr {round(float(running_psnr, 4))}')
            torch.save({'params_ema': model.state_dict()}, f'experiments/1007/HAT-L_120-960_{i}.pth')
            model.train()
                    
        if i==total_iter:
            break

In [7]:
model = HAT(upscale=4,
            in_chans=3,
            img_size=64,
            window_size=16,
            compress_ratio=3,
            squeeze_factor=30,
            conv_scale=0.01,
            overlap_ratio=0.5,
            img_range=1.,
            depths=(6,6,6,6,6,6,6,6,6,6,6,6),
            embed_dim=180,
            num_heads=(6,6,6,6,6,6,6,6,6,6,6,6),
            mlp_ratio=2,
            upsampler='pixelshuffle',
            resi_connection='1conv')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [11]:
'feed_data' in dir(model)

False

In [None]:
def parse_options(root_path, is_train=True):
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument(
        '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
    args = parser.parse_args()

    # parse yml to dict
    opt = yaml_load(args.opt)

    # distributed settings
    if args.launcher == 'none':
        opt['dist'] = False
        print('Disable distributed.', flush=True)
    else:
        opt['dist'] = True
        if args.launcher == 'slurm' and 'dist_params' in opt:
            init_dist(args.launcher, **opt['dist_params'])
        else:
            init_dist(args.launcher)
    opt['rank'], opt['world_size'] = get_dist_info()

    # random seed
    seed = opt.get('manual_seed')
    if seed is None:
        seed = random.randint(1, 10000)
        opt['manual_seed'] = seed
    set_random_seed(seed + opt['rank'])

    # force to update yml options
    if args.force_yml is not None:
        for entry in args.force_yml:
            # now do not support creating new keys
            keys, value = entry.split('=')
            keys, value = keys.strip(), value.strip()
            value = _postprocess_yml_value(value)
            eval_str = 'opt'
            for key in keys.split(':'):
                eval_str += f'["{key}"]'
            eval_str += '=value'
            # using exec function
            exec(eval_str)

    opt['auto_resume'] = args.auto_resume
    opt['is_train'] = is_train

    # debug setting
    if args.debug and not opt['name'].startswith('debug'):
        opt['name'] = 'debug_' + opt['name']

    if opt['num_gpu'] == 'auto':
        opt['num_gpu'] = torch.cuda.device_count()

    # datasets
    for phase, dataset in opt['datasets'].items():
        # for multiple datasets, e.g., val_1, val_2; test_1, test_2
        phase = phase.split('_')[0]
        dataset['phase'] = phase
        if 'scale' in opt:
            dataset['scale'] = opt['scale']
        if dataset.get('dataroot_gt') is not None:
            dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
        if dataset.get('dataroot_lq') is not None:
            dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])

    # paths
    for key, val in opt['path'].items():
        if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
            opt['path'][key] = osp.expanduser(val)

    if is_train:
        experiments_root = osp.join(root_path, 'experiments', opt['name'])
        opt['path']['experiments_root'] = experiments_root
        opt['path']['models'] = osp.join(experiments_root, 'models')
        opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
        opt['path']['log'] = experiments_root
        opt['path']['visualization'] = osp.join(experiments_root, 'visualization')

        # change some options for debug mode
        if 'debug' in opt['name']:
            if 'val' in opt:
                opt['val']['val_freq'] = 8
            opt['logger']['print_freq'] = 1
            opt['logger']['save_checkpoint_freq'] = 8
    else:  # test
        results_root = osp.join(root_path, 'results', opt['name'])
        opt['path']['results_root'] = results_root
        opt['path']['log'] = results_root
        opt['path']['visualization'] = osp.join(results_root, 'visualization')

    return opt, args

In [26]:
from basicsr.utils.registry import MODEL_REGISTRY
MODEL_REGISTRY.__dict__

{'_name': 'model',
 '_obj_map': {'SRModel': basicsr.models.sr_model.SRModel,
  'VideoBaseModel': basicsr.models.video_base_model.VideoBaseModel,
  'EDVRModel': basicsr.models.edvr_model.EDVRModel,
  'SwinIRModel': basicsr.models.swinir_model.SwinIRModel,
  'HiFaceGANModel': basicsr.models.hifacegan_model.HiFaceGANModel,
  'SRGANModel': basicsr.models.srgan_model.SRGANModel,
  'StyleGAN2Model': basicsr.models.stylegan2_model.StyleGAN2Model,
  'VideoGANModel': basicsr.models.video_gan_model.VideoGANModel,
  'VideoRecurrentModel': basicsr.models.video_recurrent_model.VideoRecurrentModel,
  'ESRGANModel': basicsr.models.esrgan_model.ESRGANModel,
  'VideoRecurrentGANModel': basicsr.models.video_recurrent_gan_model.VideoRecurrentGANModel}}

In [28]:
os.getcwd()

'/workspace/competition/2022.09_SR/HAT_official/hat'

In [29]:
from basicsr.models import build_model
from basicsr.utils.options import parse_options
from models import hat_model
import yaml
from collections import OrderedDict

def ordered_yaml():
    """Support OrderedDict for yaml.
    Returns:
        tuple: yaml Loader and Dumper.
    """
    try:
        from yaml import CDumper as Dumper
        from yaml import CLoader as Loader
    except ImportError:
        from yaml import Dumper, Loader

    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG

    def dict_representer(dumper, data):
        return dumper.represent_dict(data.items())

    def dict_constructor(loader, node):
        return OrderedDict(loader.construct_pairs(node))

    Dumper.add_representer(OrderedDict, dict_representer)
    Loader.add_constructor(_mapping_tag, dict_constructor)
    return Loader, Dumper

def yaml_load(f):
    """Load yaml file or string.
    Args:
        f (str): File path or a python string.
    Returns:
        dict: Loaded dict.
    """
    if os.path.isfile(f):
        with open(f, 'r') as f:
            return yaml.load(f, Loader=ordered_yaml()[0])
    else:
        return yaml.load(f, Loader=ordered_yaml()[0])
    
root_path = '../options/train/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain.yml'
opt = yaml_load(root_path)
# opt, args = parse_options(root_path, is_train=True)

In [33]:
opt['is_train'] = True
opt['dist'] = False
opt['path']['pretrain_network_g'] = '../experiments/pretrained_models/HAT-L_SRx4_ImageNet-pretrain.pth'
model = build_model(opt)

In [34]:
'feed_data' in dir(model)

True

In [15]:
model.feed_data()

AttributeError: 'HAT' object has no attribute 'feed_data'