In [26]:
# evaluate PSNR, SSIM, and LPIPS on CIFAR10
from models.script_util import create_reverse_process, create_sampler
import argparse
import yaml
from models import dist_util, logger
from models.image_datasets import load_data
from models.resample import create_named_schedule_sampler
from models.script_util import (
    model_and_diffusion_defaults,
    encoder_defaults,
    create_model_and_diffusion,
    create_encoder,
    select_config,
    create_diffusion
)
from models.train_util import TrainLoop

In [27]:
with open('configs/cifar10.yaml', 'r') as f:
    try:
        config = yaml.safe_load(f)
    except yaml.YAMLError as exc:
        print(exc)
print(config)

{'num_classes': 10, 'dataset': 'cifar10', 'data_dir': '../data/cifar_train', 'log_dir': './jscc_log', 'model_path': './jscc_log/model100000.pt', 'encoder_path': './log/encoder010000.pt', 'jscc_encoder_path': './jscc_log/encoder100000.pt', 'dropout': 0.3, 'lr': 0.0001, 'weight_decay': 0.0, 'lr_anneal_steps': 100000, 'batch_size': 64, 'microbatch': -1, 'ema_rate': 0.9999, 'log_interval': 10, 'save_interval': 50000, 'resume_checkpoint': '', 'use_fp16': False, 'fp16_scale_growth': '1e-3', 'image_size': 32, 'num_channels': 128, 'num_res_blocks': 3, 'num_heads': 4, 'num_heads_upsample': -1, 'attention_resolutions': '16,8', 'learn_sigma': True, 'diffusion_steps': 4000, 'noise_schedule': 'cosine', 'schedule_sampler': 'uniform', 'sigma_small': False, 'class_cond': True, 'timestep_respacing': '', 'use_kl': False, 'predict_xstart': False, 'rescale_timesteps': True, 'rescale_learned_sigmas': True, 'use_checkpoint': False, 'use_scale_shift_norm': True, 'num_samples': 100, 'clip_denoised': True, 'us

In [28]:
# distributed training
dist_util.setup_dist()

model, diffusion = create_model_and_diffusion(
        **select_config(config, model_and_diffusion_defaults().keys())
    )
model.to(dist_util.dev())
model.load_state_dict(
        dist_util.load_state_dict(config['model_path'], map_location="cpu")
    )

if config['encoder_type'] == 'unet':
    encoder = create_encoder(**select_config(config, encoder_defaults().keys()))
    encoder.to(dist_util.dev())
    encoder.load_state_dict(
        dist_util.load_state_dict(config['encoder_path'], map_location="cpu")
    )
elif config['encoder_type'] == 'jscc':
    from models.autoencoder import JSCC_encoder
    encoder = JSCC_encoder(hidden_dims=config['hidden_dims'])
    encoder.to(dist_util.dev())
    encoder.load_state_dict(
        dist_util.load_state_dict(config['jscc_encoder_path'], map_location="cpu")
    )

schedule_sampler = create_named_schedule_sampler(config['schedule_sampler'], diffusion)

In [29]:
# load the data
data = load_data(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    class_cond=config['class_cond'],
)

In [30]:
reverse_func = create_reverse_process(config, T=250)
sample_fn = create_sampler(config, T=20)

In [31]:
from piq import FID, ssim, psnr, MSID, LPIPS
metric_fid = FID()
metric_ssim = ssim
metric_psnr = psnr

res_fid = 0
res_ssim = 0
res_psnr = 0

In [32]:
batch, cond = next(data)
batch = batch.to(dist_util.dev())
cond['y'] = cond['y'].to(dist_util.dev())

model_kwargs = {}
model_kwargs["y"] = cond['y']
z = encoder(batch)
model_kwargs["latent"] = z

In [33]:
out = reverse_func(model, batch, model_kwargs=model_kwargs)
# generated samples
sample = sample_fn(
    model,
    (config['batch_size'], 3, config['image_size'], config['image_size']),
    noise=out['sample'],
    clip_denoised=config['clip_denoised'],
    model_kwargs=model_kwargs,
)

In [34]:
model_kwargs = {}
for i, (batch, cond) in enumerate(data):
    batch = batch.to(dist_util.dev())
    cond['y'] = cond['y'].to(dist_util.dev())
    
    model_kwargs["y"] = cond['y']
    z = encoder(batch)
    model_kwargs["latent"] = z
    # final noising output
    out = reverse_func(model, batch, model_kwargs=model_kwargs)
    # generated samples
    sample = sample_fn(
        model,
        (config['batch_size'], 3, config['image_size'], config['image_size']),
        noise=out['sample'],
        clip_denoised=config['clip_denoised'],
        model_kwargs=model_kwargs,
    )
    # calculate metrics
    res_fid += metric_fid(sample.view(-1,config['image_size']**2), batch.view(-1,config['image_size']**2)).item()
    res_ssim += metric_ssim((sample+1)/2, (batch+1)/2).item()
    res_psnr += metric_psnr((sample+1)/2, (batch+1)/2).item()
    if i % 10 == 0:
        print(f'batch {i} done')

KeyboardInterrupt: 