In [1]:
import argparse
import yaml
from models import dist_util, logger
from models.image_datasets import load_data
from models.resample import create_named_schedule_sampler
from models.script_util import (
    model_and_diffusion_defaults,
    encoder_defaults,
    create_model_and_diffusion,
    create_encoder,
    select_config,
)
from models.train_util import TrainLoop

In [2]:
with open('configs/cifar10.yaml', 'r') as f:
    try:
        config = yaml.safe_load(f)
    except yaml.YAMLError as exc:
        print(exc)
print(config)

{'image_size': 32, 'num_channels': 128, 'num_res_blocks': 3, 'learn_sigma': True, 'dropout': 0.3, 'diffusion_steps': 4000, 'noise_schedule': 'cosine', 'lr': 0.0001, 'dataset': 'cifar10', 'data_dir': '../data/cifar10', 'log_dir': './log', 'schedule_sampler': 'uniform', 'weight_decay': 0.0, 'lr_anneal_steps': 10000, 'batch_size': 64, 'microbatch': -1, 'ema_rate': 0.9999, 'log_interval': 10, 'save_interval': 1000, 'resume_checkpoint': '', 'use_fp16': False, 'fp16_scale_growth': '1e-3', 'num_heads': 4, 'num_heads_upsample': -1, 'attention_resolutions': '16,8', 'sigma_small': False, 'class_cond': True, 'num_classes': 10, 'use_label_embed': False, 'latent_dim': 512, 'use_latent': True, 'out_channels': 512, 'timestep_respacing': '', 'use_kl': False, 'predict_xstart': False, 'rescale_timesteps': True, 'rescale_learned_sigmas': True, 'use_checkpoint': False, 'use_scale_shift_norm': True, 'use_time_embed': False, 'model_path': './log/model010000.pt', 'encoder_path': './log/encoder010000.pt', 'nu

In [3]:
dist_util.setup_dist()

In [4]:
model, diffusion = create_model_and_diffusion(
        **select_config(config, model_and_diffusion_defaults().keys())
    )
model.to(dist_util.dev())

encoder = create_encoder(**select_config(config, encoder_defaults().keys()))
encoder.to(dist_util.dev())

EncoderModel(
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (emb_layers): Identity()
        (out_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.3, inplace=False)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (skip_connection): Identity()
      )
    )
    (2): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), 

In [5]:
schedule_sampler = create_named_schedule_sampler(config['schedule_sampler'], diffusion)
data = load_data(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    class_cond=config['class_cond'],
)

In [6]:
batch, cond = next(data)
batch = batch.to(dist_util.dev())
cond['y'] = cond['y'].to(dist_util.dev())

In [7]:
z = encoder(batch, y=cond['y'])

In [8]:
z.shape

torch.Size([64, 512])

In [9]:
batch.shape

torch.Size([64, 3, 32, 32])

In [10]:
import torch
import numpy as np
timesteps = torch.tensor(np.arange(64))

In [11]:
timesteps.dtype

torch.int64

In [12]:
model(batch,timesteps.cuda(), y=cond['y'], latent=z)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          

In [15]:
model.load_state_dict(
        dist_util.load_state_dict(config['model_path'], map_location="cpu")
    )

<All keys matched successfully>

In [20]:
params = [name for name, _ in encoder.named_parameters()]
len(params)

168

In [18]:
len(dist_util.load_state_dict(config['encoder_path'], map_location="cpu"))

168

In [13]:
encoder.load_state_dict(
        dist_util.load_state_dict(config['encoder_path'], map_location="cpu")
)

<All keys matched successfully>

In [14]:
encoder(batch)

tensor([[-0.5540, -0.2016,  0.0689,  ...,  0.6870,  0.3185,  0.3228],
        [-0.6892, -0.2513,  0.2256,  ...,  0.5964,  0.3277,  0.4934],
        [-0.6982, -0.3601,  0.0855,  ...,  0.4440,  0.4496,  0.5317],
        ...,
        [-0.5263, -0.2650,  0.1995,  ...,  0.6355,  0.3570,  0.3893],
        [-0.5152, -0.3647,  0.0553,  ...,  0.5611,  0.4886,  0.3204],
        [-0.4804, -0.0814, -0.0747,  ...,  0.7765,  0.2212,  0.2401]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)

: 