In [1]:
# test cdm encoder
import torch
import torch.nn as nn
from models.unet import *
from models.nn import (
    SiLU,
    conv_nd,
    linear,
    avg_pool_nd,
    zero_module,
    normalization,
    timestep_embedding,
    checkpoint,
)

In [2]:
class EncBlock(nn.Module):
    def __init__(self, 
                in_channels, 
                out_channels,
                emb_channels=None,
                dropout=0.1,
                conv_resample=True,
                use_scale_shift_norm=False,
                dims=2,
                use_checkpoint=False,):
        super().__init__()
        self.resblock = ResBlock(
                in_channels,
                emb_channels,
                dropout,
                out_channels=out_channels,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            )
        self.conv = TimestepEmbedSequential(Downsample(out_channels, conv_resample, dims=dims))

    def forward(self, x, emb=None):
        x = self.resblock(x, emb)
        x = self.conv(x, emb)
        return x

In [3]:
import yaml
with open('configs/cifar10.yaml', 'r') as f:
    try:
        config = yaml.safe_load(f)
    except yaml.YAMLError as exc:
        print(exc)
print(config)

{'num_classes': 10, 'dataset': 'cifar10', 'data_dir': '../data/cifar_train', 'log_dir': './log/cdm_log', 'model_path': './log/jscc/model100000.pt', 'encoder_path': './log/unet/encoder100000.pt', 'jscc_encoder_path': './log/jscc/encoder100000.pt', 'dropout': 0.3, 'lr': 0.0001, 'weight_decay': 0.0, 'lr_anneal_steps': 100000, 'batch_size': 64, 'microbatch': -1, 'ema_rate': 0.9999, 'log_interval': 100, 'save_interval': 50000, 'resume_checkpoint': '', 'use_fp16': False, 'fp16_scale_growth': '1e-3', 'image_size': 32, 'num_channels': 128, 'num_res_blocks': 3, 'num_heads': 4, 'num_heads_upsample': -1, 'attention_resolutions': '16,8', 'learn_sigma': True, 'diffusion_steps': 4000, 'noise_schedule': 'cosine', 'schedule_sampler': 'uniform', 'sigma_small': False, 'class_cond': True, 'timestep_respacing': '', 'use_kl': False, 'predict_xstart': False, 'rescale_timesteps': True, 'rescale_learned_sigmas': True, 'use_checkpoint': False, 'use_scale_shift_norm': True, 'num_samples': 100, 'clip_denoised': 

In [4]:
config.keys()

dict_keys(['num_classes', 'dataset', 'data_dir', 'log_dir', 'model_path', 'encoder_path', 'jscc_encoder_path', 'dropout', 'lr', 'weight_decay', 'lr_anneal_steps', 'batch_size', 'microbatch', 'ema_rate', 'log_interval', 'save_interval', 'resume_checkpoint', 'use_fp16', 'fp16_scale_growth', 'image_size', 'num_channels', 'num_res_blocks', 'num_heads', 'num_heads_upsample', 'attention_resolutions', 'learn_sigma', 'diffusion_steps', 'noise_schedule', 'schedule_sampler', 'sigma_small', 'class_cond', 'timestep_respacing', 'use_kl', 'predict_xstart', 'rescale_timesteps', 'rescale_learned_sigmas', 'use_checkpoint', 'use_scale_shift_norm', 'num_samples', 'clip_denoised', 'use_ddim', 'encoder_type', 'use_label_embed', 'use_time_embed', 'use_latent', 'latent_dim', 'out_channels', 'hidden_dims', 'c_out', 'noise'])

In [5]:
# load the data
from models.image_datasets import load_data
data = load_data(
    data_dir=config['data_dir'],
    batch_size=config['batch_size'],
    image_size=config['image_size'],
    class_cond=config['class_cond'],
)

In [6]:
class encoder(nn.Module):
    def __init__(
        self,
        in_channels,
        model_channels,
        out_channels,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_scale_shift_norm=False,
        use_time_embed=False,
        use_label_embed=False
    ):
        super().__init__()

        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.use_time_embed = use_time_embed
        self.use_label_embed = use_label_embed

        if use_time_embed:
            time_embed_dim = model_channels * 4
            self.time_embed = nn.Sequential(
                linear(model_channels, time_embed_dim),
                SiLU(),
                linear(time_embed_dim, time_embed_dim),
            )
        else:
            time_embed_dim = None

        assert (use_label_embed) == (
            self.num_classes is not None
        ), "must specify num_classes if and only if the model is class-conditional"

        if self.use_label_embed:
            label_embed_dim = model_channels * 4
            self.label_emb = nn.Embedding(num_classes, label_embed_dim)
        else:
            label_embed_dim = None

        embed_dim = time_embed_dim or label_embed_dim
        ch = model_channels
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
            )

        for level, mult in enumerate(channel_mult):
            layers = [
                EncBlock(in_channels=ch,
                    out_channels=mult * model_channels,
                    emb_channels=embed_dim,
                    dropout=dropout,
                    conv_resample=True,
                    use_scale_shift_norm=False,
                    dims=2,
                    use_checkpoint=False,)
            ]
            ch = mult * model_channels
            self.input_blocks.append(
                TimestepEmbedSequential(*layers)
            )
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            nn.AdaptiveAvgPool2d((1,1)),
            conv_nd(dims, ch, out_channels, 1),
            nn.Flatten(),
        )
        # self.input_blocks.append(TimestepEmbedSequential(
        #             conv_nd(dims, ch, out_channels, 3, padding=1)
        #         ))
    def forward(self, x, timesteps=None, y=None):
        if self.use_time_embed:
            emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        else:
            emb = None
        # label embedding
        if self.use_label_embed:
            assert y.shape == (x.shape[0],), "labels are not provided or are not the right shape"
            # y: [64]; label_emb: [64, 512]
            if emb is None:
                emb = self.label_emb(y)
            else:
                emb = emb + self.label_emb(y)
        for block in self.input_blocks:
            x = block(x, emb)
        return self.out(x)

In [8]:
Encoder = encoder(in_channels=3,
        model_channels=128,
        out_channels=512,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_scale_shift_norm=False,
        use_time_embed=False,
        use_label_embed=False)

In [9]:
batch, cond = next(data)
temp = Encoder(batch)
print(temp.shape)

torch.Size([64, 512])


In [12]:
Encoder

encoder(
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): EncBlock(
        (resblock): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (emb_layers): Identity()
          (out_layers): Sequential(
            (0): GroupNorm32(32, 128, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0, inplace=False)
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
        (conv): TimestepEmbedSequential(
          (0): Downsample(
            (op): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          )
        )
      )
    )
  