In [1]:
import torch
import data as Data
import model as Model
import argparse
import logging
import core.logger as Logger
import core.metrics as Metrics
from tensorboardX import SummaryWriter
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
dataset_opt = {
    "name": "CelebaHQ",
    "mode": "LRHR",
    "dataroot": "data/datasets/femurPNGs/femur_LH_64_256",
    "datatype": "img", #lmdb or img, path of img files
    "l_resolution": 64,
    "r_resolution": 256,
    "data_len": -1, # data length in validation 
    "chtype": "L",
    "use_shuffle": True
}

val_set = Data.create_dataset(dataset_opt, 'val')
val_loader = Data.create_dataloader(
        val_set, dataset_opt, 'val')

In [3]:
model_opt = {
    "gpu_ids": [0],
    "phase": "val",
    "distributed": False,
    "model": {
        "which_model_G": "sr3",
        "finetune_norm": False,
        "unet": {
            "in_channel": 2,
            "out_channel": 1,
            "inner_channel": 64,
            "channel_multiplier": [
                1,
                2,
                4,
                8,
                8
            ],
            "attn_res": [
                32
            ],
            "res_blocks": 2,
            "dropout": 0.1
        },
        "beta_schedule": {
            "train": {
                "schedule": "cosine",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            },
            "val": {
                "schedule": "cosine",
                "n_timestep": 2000,
                "linear_start": 1e-6,
                "linear_end": 1e-2
            }
        },
        "diffusion": {
            "image_size": 256,
            "channels": 1,
            "conditional": True
        }
    },
    "train": {
        "n_iter": 1000000,
        "val_freq": 1e4,
        "save_checkpoint_freq": 1e4,
        "print_freq": 200,
        "optimizer": {
            "type": "adam",
            "lr": 1e-4
        },
        "ema_scheduler": {
            "step_start_ema": 5000,
            "update_ema_every": 1,
            "ema_decay": 0.9999
        }
    },
    "path": {
        "log": "logs",
        "tb_logger": "tb_logger",
        "results": "results",
        "checkpoint": "checkpoint",
        "resume_state": "experiments/sr_ffhq_211013_112422/checkpoint/I380000_E130"
#         "resume_state": "experiments/sr_ffhq_211010_120508/checkpoint/I90000_E31"
#         "resume_state": "experiments/sr_ffhq_211010_120508/checkpoint/I160000_E55"
    }
}

diffusion = Model.create_model(model_opt)

In [4]:
diffusion.set_new_noise_schedule(model_opt['model']['beta_schedule'][model_opt['phase']], schedule_phase=model_opt['phase'])

In [5]:
def test_image(val_data, model, result_path, scalefactor):
    diffusion=model
    diffusion.feed_data(val_data)
    diffusion.test(continous=False)
    visuals = diffusion.get_current_visuals()
    sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
    hr_img = Metrics.tensor2img(visuals['HR'])  # uint8
    lr_img = Metrics.tensor2img(visuals['LR'])  # uint8
    fake_img = Metrics.tensor2img(visuals['INF'])  # uint8
    lr_img = np.repeat(np.repeat(lr_img, scalefactor, axis=0), scalefactor, axis=1)
    out_img = np.concatenate((hr_img,lr_img,fake_img,sr_img),axis=1)
    Metrics.save_img(out_img, '{}.png'.format(result_path))

#     # generation
#     Metrics.save_img(
#         hr_img, '{}_hr.png'.format(result_path))
#     Metrics.save_img(
#         sr_img, '{}_sr.png'.format(result_path))
#     Metrics.save_img(
#         lr_img, '{}_lr.png'.format(result_path))
#     Metrics.save_img(
#         fake_img, '{}_inf.png'.format(result_path))

In [6]:
def test_image_continuous(val_data, model, result_path, scalefactor):
    diffusion=model
    diffusion.feed_data(val_data)
    diffusion.test(continous=True)
    visuals = diffusion.get_current_visuals()
    sr_img = Metrics.tensor2img(visuals['SR'])  # uint8
    srflat = np.concatenate([sr_img[:,:,i] for i in range(len(sr_img[0,0,:]))],axis=1)
    Metrics.save_img(srflat, '{}.png'.format(result_path))

In [14]:
idx = [500]#[1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]
scalefactor = int(dataset_opt['r_resolution']/dataset_opt['l_resolution'])

for i,  val_data in enumerate(val_loader):
    if i in idx:
        test_image_continuous(val_data, diffusion, "misc/continuousim_E45_{}_r4".format(i), scalefactor)
    elif i > max(idx): 
        break
    else:
        continue

sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.61it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.60it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.61it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.61it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.61it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.62it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.60it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.60it/s]
sampling loop time step: 100%|██████████| 2000/2000 [02:38<00:00, 12.61it/s]


In [9]:
idx = [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000]
scalefactor = int(dataset_opt['r_resolution']/dataset_opt['l_resolution'])

for i,  val_data in enumerate(val_loader):
    if i in idx:
        test_image(val_data, diffusion, "misc/testim_E45_{}_cosinesched_r1".format(i), scalefactor)
    elif i > max(idx): 
        break
    else:
        continue
#     val_data = np.transpose(np.tile(np.squeeze(np.array(val_data['HR'])),(3,1,1)),[1,2,0])
#     val_data = -1*val_data
#     plt.imshow(val_data)
    

sampling loop time step: 100%|██████████| 2000/2000 [01:08<00:00, 29.10it/s]
sampling loop time step: 100%|██████████| 2000/2000 [01:08<00:00, 29.00it/s]
sampling loop time step: 100%|██████████| 2000/2000 [01:11<00:00, 28.14it/s]
sampling loop time step: 100%|██████████| 2000/2000 [01:11<00:00, 27.90it/s]
sampling loop time step: 100%|██████████| 2000/2000 [01:12<00:00, 27.71it/s]
sampling loop time step:   4%|▍         | 86/2000 [00:02<01:06, 28.88it/s]


KeyboardInterrupt: 