In [1]:
# !pip install tensorflow==2.2.0 tensorflow-addons==0.10.0 tensorflow_io==0.14.0 matplotlib Pillow tensorflow-probability==0.9.0 tensorflow-datasets==3.0.0
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [3]:
import tensorflow as tf

tf.__version__

'2.7.0'

In [4]:
import tensorflow.experimental.numpy as tnp

In [5]:
import datasets
ds = datasets.get_dataset('cifar10', tfds_data_dir='tensorflow_datasets')
img_sz = ds._img_size
n_train = ds.num_train_examples
ds = ds.train_input_fn({'batch_size': 256})

ds_iter = iter(ds)

from datasets import data_preprocess
data = next(ds_iter)
x = data_preprocess(data['image'])
labels = data['label']



In [6]:
import tensorflow.compat.v2 as tf
import tensorflow_addons as tfa
import nn


def nonlinearity(x):
    if FLAGS.act == 'lrelu':
        return tf.nn.leaky_relu(x)
    if FLAGS.act == 'swish':
        return tf.nn.swish(x)

    raise NotImplementedError


class normalize(tf.keras.layers.Layer):
    def __init__(self, name, *args, **kwargs):
        super(normalize, self).__init__(name=name, *args, **kwargs)
        if FLAGS.normalize:
            if FLAGS.normalize == 'group_norm':
                self.norm = tfa.layers.GroupNormalization(groups=32, epsilon=1e-06)
            elif FLAGS.normalize == 'batch_norm':
                self.norm = tf.keras.layers.BatchNormalization()
            elif FLAGS.normalize == 'instance_norm':
                self.norm = tfa.layers.InstanceNormalization()

    def call(self, inputs, **kwargs):
        if FLAGS.normalize:
            inputs = self.norm(inputs, training=True)
        return inputs


class downsample(tf.keras.layers.Layer):
    def __init__(self, name, with_conv, *args, **kwargs):
        super(downsample, self).__init__(name=name, *args, **kwargs)
        self.with_conv = with_conv

    def build(self, input_shape):
        B, H, W, C = input_shape
        if self.with_conv:
            self.conv2d = nn.conv2d(name='conv', num_units=C, filter_size=3, stride=2, spec_norm=FLAGS.spec_norm)
        # print('{}: x={}'.format(self.name, input_shape))

    def call(self, inputs, **kwargs):
        B, H, W, C = inputs.shape

        if self.with_conv:
            x = self.conv2d(inputs)
        else:
            x = tf.nn.avg_pool(inputs, 2, 2, 'SAME')
        assert x.shape == [B, H // 2, W // 2, C]

        return x


class resnet_block(tf.keras.layers.Layer):
    def __init__(self, *, name, out_ch=None):
        super(resnet_block, self).__init__(name=name)
        self.out_ch = out_ch
        self.conv_shortcut = FLAGS.res_conv_shortcut
        self.spec_norm = FLAGS.spec_norm
        self.use_scale = FLAGS.res_use_scale

    def build(self, input_shape):
        B, H, W, C = input_shape
        if self.out_ch is None:
            self.out_ch = C
        self.normalize_1 = normalize('norm1')
        self.normalize_2 = normalize('norm2')

        self.dense = nn.dense(name='temb_proj', num_units=self.out_ch, spec_norm=self.spec_norm)
        self.conv2d_1 = nn.conv2d(name='conv1', num_units=self.out_ch, spec_norm=self.spec_norm)

        self.conv2d_2 = nn.conv2d(
            name='conv2', num_units=self.out_ch, init_scale=0., spec_norm=self.spec_norm, use_scale=self.use_scale
        )
        if self.conv_shortcut:
            self.conv2d_shortcut = nn.conv2d(name='conv_shortcut', num_units=self.out_ch, spec_norm=self.spec_norm)
        else:
            self.nin_shortcut = nn.nin(name='nin_shortcut', num_units=self.out_ch, spec_norm=self.spec_norm)
        # print('{}: x={}'.format(self.name, input_shape))

    def call(self, inputs, temb=None, dropout=0.):
        B, H, W, C = inputs.shape
        x = inputs
        h = inputs

        h = nonlinearity(self.normalize_1(h))
        h = self.conv2d_1(h)

        if temb is not None:

            # add in timestep embedding
            temp_o = self.dense(nonlinearity(temb))
            # print(h.shape, temp_o.shape)
            h += temp_o[:, None, None, :]

        h = nonlinearity(self.normalize_2(h))
        h = tf.nn.dropout(h, rate=dropout)
        h = self.conv2d_2(h)

        if C != self.out_ch:
            if self.conv_shortcut:
                x = self.conv2d_shortcut(x)
            else:
                x = self.nin_shortcut(x)

        assert x.shape == h.shape
        return x + h


class attn_block(tf.keras.layers.Layer):
    def __init__(self, name):
        super(attn_block, self).__init__(name=name)

    def build(self, input_shape):
        B, H, W, C = input_shape
        self.normalize = normalize(name='norm')
        self.nin_q = nn.nin(name='q', num_units=C)
        self.nin_k = nn.nin(name='k', num_units=C)
        self.nin_v = nn.nin(name='v', num_units=C)

        self.nin_proj_out = nn.nin(name='proj_out', num_units=C, init_scale=0.)
        # print('{}: x={}'.format(self.name, input_shape))

    def call(self, inputs):
        x = inputs
        B, H, W, C = x.shape

        h = self.normalize(x)
        q = self.nin_q(h)
        k = self.nin_k(h)
        v = self.nin_v(h)

        w = tf.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5))
        w = tf.reshape(w, [B, H, W, H * W])
        w = tf.nn.softmax(w, -1)
        w = tf.reshape(w, [B, H, W, H, W])

        h = tf.einsum('bhwHW,bHWc->bhwc', w, v)
        h = self.nin_proj_out(h)

        assert h.shape == x.shape

        return x + h


class net_res_temb2(tf.keras.layers.Layer):
    def __init__(self, *, name, ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
                 attn_resolutions, num_classes=10):
        super(net_res_temb2, self).__init__(name=name)
        self.ch, self.ch_mult = ch, ch_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.num_resolutions = len(self.ch_mult)
        self.resamp_with_conv = FLAGS.resamp_with_conv
        self.use_attention = FLAGS.use_attention
        self.spec_norm = FLAGS.spec_norm
        self.num_classes = num_classes

    def build(self, input_shape):
        # timestep embedding
        self.temb_dense_0 = nn.dense(name='temb/dense0', num_units=self.ch * 4, spec_norm=self.spec_norm)
        self.temb_dense_1 = nn.dense(name='temb/dense1', num_units=self.ch * 4, spec_norm=self.spec_norm)
        self.temb_dense_2 = nn.dense(name='temb/dense2', num_units=self.ch * self.ch_mult[-1], spec_norm=False)

        self.linear = nn.dense(name='classifier', num_units=self.num_classes, spec_norm=False)

        S = input_shape[-3]
        self.res_levels = []
        self.attn_s = dict()
        self.downsample_s = []

        # downsample
        self.conv2d_in = nn.conv2d(name='conv_in', num_units=self.ch, spec_norm=self.spec_norm)
        for i_level in range(self.num_resolutions):
            res_s = []
            if self.use_attention and S in self.attn_resolutions:
                self.attn_s[str(S)] = []
            for i_block in range(self.num_res_blocks):
                res_s.append(
                    resnet_block(
                        name='level_{}_block_{}'.format(i_level, i_block), out_ch=self.ch * self.ch_mult[i_level]
                    )
                )
                if self.use_attention and S in self.attn_resolutions:
                    self.attn_s[str(S)].append(attn_block(name='down_{}_attn_{}'.format(i_level, i_block)))
            self.res_levels.append(res_s)

            if i_level != self.num_resolutions - 1:
                self.downsample_s.append(downsample(name='downsample_{}'.format(i_level), with_conv=self.resamp_with_conv))
                S = S // 2

        # end
        self.normalize_out = normalize(name='norm_out')
        self.fc_out = nn.dense(name='dense_out', num_units=1, spec_norm=False)

    def call(self, inputs, t, dropout):
        x = inputs
        B, S, _, _ = x.shape
        assert x.dtype == tf.float32 and x.shape[2] == S
        if isinstance(t, int) or len(t.shape) == 0:
            t = tf.ones([B], dtype=tf.int32) * t

        # Timestep embedding
        temb = nn.get_timestep_embedding(t, self.ch)

        temb = self.temb_dense_0(temb)

        temb = self.temb_dense_1(nonlinearity(temb))

        assert temb.shape == [B, self.ch * 4]

        # downsample
        h = self.conv2d_in(x)
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.res_levels[i_level][i_block](h, temb=temb, dropout=dropout)

                if self.use_attention:
                    if h.shape[1] in self.attn_resolutions:
                        h = self.attn_s[str(h.shape[1])][i_block](h)

            if i_level != self.num_resolutions - 1:
                h = self.downsample_s[i_level](h)

        # end
        if FLAGS.final_act == 'relu':
            h = tf.nn.relu(h)
        elif FLAGS.final_act == 'swish':
            h = tf.nn.swish(h)
        elif FLAGS.final_act == 'lrelu':
            tf.nn.leaky_relu(x)
        else:
            raise NotImplementedError
        h = tf.reduce_sum(h, [1, 2])
        temb_final = self.temb_dense_2(nonlinearity(temb))
        feature = h * temb_final
        logits = self.linear(feature)
        h = tf.reduce_sum(feature, axis=1)

        return h, logits

 The versions of TensorFlow you are currently using is 2.7.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [7]:
import argparse
FLAGS = argparse.Namespace()
FLAGS.act = 'lrelu'
FLAGS.normalize = None
FLAGS.res_conv_shortcut = True
FLAGS.spec_norm = True
FLAGS.res_use_scale = True
FLAGS.resamp_with_conv = False
FLAGS.use_attention = False
FLAGS.final_act = 'relu'

In [8]:
ch_mult = (1, 2, 2, 2)
net = net_res_temb2(name='net', ch=128, ch_mult=ch_mult, num_res_blocks=2, attn_resolutions=(16,))

In [9]:
x = tf.random.uniform([64, 32, 32, 3], minval=-.5, maxval=.5)
t = tf.random.uniform(shape=[64], maxval=6, dtype=tf.int32)
out = net(x, t, 0)[0]
out.shape







TensorShape([64])

In [10]:
import math
DEFAULT_DTYPE = tf.float32

In [7]:
def get_timestep_embedding(timesteps, embedding_dim: int):
    """
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32

    half_dim = embedding_dim // 2

    emb = math.log(10000) / (half_dim - 1)
    emb = tf.exp(tf.range(half_dim, dtype=DEFAULT_DTYPE) * -emb)
    print(timesteps.shape, emb.shape)
    # emb = tf.range(num_embeddings, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :]
    emb = tf.cast(timesteps, dtype=DEFAULT_DTYPE)[:, None] * emb[None, :]

    emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1)
    print(emb.shape)
    if embedding_dim % 2 == 1:  # zero pad
        # emb = tf.concat([emb, tf.zeros([num_embeddings, 1])], axis=1)
        emb = tf.pad(emb, [[0, 0], [0, 1]])
    assert emb.shape == [timesteps.shape[0], embedding_dim]
    return emb

t = tf.random.uniform(shape=[64], maxval=6, dtype=tf.int32)
temb = get_timestep_embedding(t, 127)
temb.shape


(64,) (63,)
(64, 126)


TensorShape([64, 127])

# Recovery Likelihood

In [12]:
import numpy as np
import tensorflow.compat.v2 as tf


def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    betas = np.append(betas, 1.)
    assert betas.shape == (num_diffusion_timesteps + 1,)


    return betas


def get_sigma_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    """
    Get the noise level schedule
    :param beta_start: begin noise level
    :param beta_end: end noise level
    :param num_diffusion_timesteps: number of timesteps
    :return:
    -- sigmas: sigma_{t+1}, scaling parameter of epsilon_{t+1}
    -- a_s: sqrt(1 - sigma_{t+1}^2), scaling parameter of x_t
    """
    betas = np.linspace(beta_start, beta_end, 1000, dtype=np.float64)
    betas = np.append(betas, 1.)
    assert isinstance(betas, np.ndarray)
    betas = betas.astype(np.float64)
    assert (betas > 0).all() and (betas <= 1).all()
    sqrt_alphas = np.sqrt(1. - betas)
    idx = tf.cast(np.concatenate([np.arange(num_diffusion_timesteps) * (1000 // ((num_diffusion_timesteps - 1) * 2)), [999]]), dtype=tf.int32)
    a_s = np.concatenate(
        [[np.prod(sqrt_alphas[: idx[0] + 1])],
         np.asarray([np.prod(sqrt_alphas[idx[i - 1] + 1: idx[i] + 1]) for i in np.arange(1, len(idx))])])
    sigmas = np.sqrt(1 - a_s ** 2)

    return sigmas, a_s


import torch
from torch.distributions.multivariate_normal import MultivariateNormal


class RecoveryLikelihood(tf.keras.Model):
    def __init__(self, hps):
        super(RecoveryLikelihood, self).__init__()
        self.hps = hps
        self.num_timesteps = FLAGS.num_diffusion_timesteps

        self.sigmas, self.a_s = get_sigma_schedule(beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=self.num_timesteps)
        self.a_s_cum = np.cumprod(self.a_s)
        self.sigmas_cum = np.sqrt(1 - self.a_s_cum ** 2)
        self.a_s_prev = self.a_s.copy()
        self.a_s_prev[-1] = 1
        self.is_recovery = np.ones(self.num_timesteps + 1, dtype=np.float32)
        self.is_recovery[-1] = 0
        centers = torch.load('../%s_one_mean.pt' % hps.problem)
        covs = torch.load('../%s_one_cov.pt' % hps.problem)
        size = [3, 32, 32]
        self.dist = MultivariateNormal(centers, covariance_matrix=covs + 1e-4 * torch.eye(int(np.prod(size))))

        if self.hps.img_sz == 32:
            ch_mult = (1, 2, 2, 2)
        elif self.hps.img_sz == 128:
            ch_mult = (1, 2, 2, 2, 4, 4)
        elif self.hps.img_sz == 64:
            ch_mult = (1, 2, 2, 2, 4)
        elif self.hps.img_sz == 256:
            ch_mult = (1, 1, 2, 2, 2, 4, 4,)
        else:
            raise NotImplementedError

        self.net = net_res_temb2(name='net', ch=128, ch_mult=ch_mult, num_res_blocks=FLAGS.num_res_blocks, attn_resolutions=(16,))

    def init(self, x_shape):
        """
        Initialization function to activate model weights.
        :param x_shape: input date shape
        """
        x = tf.random.uniform(x_shape, minval=-.5, maxval=.5)
        self.net(x, 0, dropout=0.)

    @staticmethod
    def _extract(a, t, x_shape):
        """
        Extract some coefficients at specified timesteps,
        then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
        """
        if isinstance(t, int) or len(t.shape) == 0:
            t = tf.ones(x_shape[0], dtype=tf.int32) * t
        bs, = t.shape
        assert x_shape[0] == bs
        out = tf.gather(tf.convert_to_tensor(a, dtype=tf.float32), t)
        assert out.shape == [bs]
        return tf.reshape(out, [bs] + ((len(x_shape) - 1) * [1]))

    def q_sample(self, x_start, t, *, noise=None):
        """
        Diffuse the data (t == 0 means diffused for 1 step)
        """
        if noise is None:
            noise = tf.random.normal(shape=x_start.shape)
            # noise = self.dist.sample((x_start.shape[0],)).view((-1, 3, x_start.shape[2], x_start.shape[2])).permute((0, 2, 3, 1)).numpy()
        assert noise.shape == x_start.shape
        x_t = self._extract(self.a_s_cum, t, x_start.shape) * x_start +
        self._extract(self.sigmas_cum, t, x_start.shape) * noise

    return x_t

def q_sample_pairs(self, x_start, t):
    """
    Generate a pair of disturbed images for training
    :param x_start: x_0
    :param t: time step t
    :return: x_t, x_{t+1}
    """
    noise = tf.random.normal(shape=x_start.shape)
    # noise = self.dist.sample((x_start.shape[0],)).view((-1, 3, x_start.shape[2], x_start.shape[2])).permute((0, 2, 3, 1)).numpy()
    x_t = self.q_sample(x_start, t)
    x_t_plus_one = self._extract(self.a_s, t+1, x_start.shape) * x_t +
    self._extract(self.sigmas, t+1, x_start.shape) * noise

return x_t, x_t_plus_one, noise

def q_sample_progressive(self, x_0):
    """
    Generate a full sequence of disturbed images
    """
    x_preds = []
    for t in range(self.num_timesteps + 1):
        t_now = tf.ones([x_0.shape[0]], dtype=tf.int32) * t
        x = self.q_sample(x_0, t_now)
        x_preds.append(x)
    x_preds = tf.stack(x_preds, axis=0)

    return x_preds

# === Training loss ===
def training_losses(self, x_pos, x_neg, t, *, dropout=0.):
    """
    Training loss calculation
    """
    a_s = self._extract(self.a_s_prev, t + 1, x_pos.shape)
    y_pos = a_s * x_pos
    y_neg = a_s * x_neg
    pos_f = self.net(y_pos, t, dropout=dropout)[0]
    neg_f = self.net(y_neg, t, dropout=dropout)[0]
    loss = - (pos_f - neg_f)

    loss_scale = 1.0 / (tf.cast(tf.gather(self.sigmas, t + 1), tf.float32) / self.sigmas[1])
    loss = loss_scale * loss

    loss_ts = tf.math.unsorted_segment_mean(tf.abs(loss), t, self.num_timesteps)
    f_ts = tf.math.unsorted_segment_mean(tf.abs(pos_f), t, self.num_timesteps)

    return tf.nn.compute_average_loss(loss, global_batch_size=self.hps.n_batch_train), loss_ts, f_ts

def log_prob(self, y, t, tilde_x, b0, sigma, is_recovery, *, dropout):
    return self.net(y, t, dropout=dropout)[0] / tf.reshape(b0, [-1]) - tf.reduce_sum((y - tilde_x) ** 2 / 2 / sigma ** 2 * is_recovery, axis=[1, 2, 3])

def grad_f(self, y, t, tilde_x, b0, sigma, is_recovery, *, dropout):
    with tf.GradientTape() as tape:
        tape.watch(y)
        log_p_y = self.log_prob(y, t, tilde_x, b0, sigma, is_recovery, dropout=dropout)
    grad_y = tape.gradient(log_p_y, y)
    return grad_y, log_p_y

# === Sampling ===
def p_sample_langevin(self, tilde_x, t, *, dropout):
    """
    Langevin sampling function
    """
    sigma = self._extract(self.sigmas, t + 1, tilde_x.shape)
    sigma_cum = self._extract(self.sigmas_cum, t, tilde_x.shape)
    is_recovery = self._extract(self.is_recovery, t + 1, tilde_x.shape)
    a_s = self._extract(self.a_s_prev, t + 1, tilde_x.shape)

    c_t_square = sigma_cum / self.sigmas_cum[0]
    step_size_square = c_t_square * self.hps.mcmc_step_size_b_square * sigma ** 2

    y = tf.identity(tilde_x)
    is_accepted_summary = tf.zeros(y.shape[0], dtype=tf.float32)
    grad_y, log_p_y = self.grad_f(y, t, tilde_x, step_size_square, sigma, is_recovery, dropout=dropout)

    for _ in tf.range(tf.convert_to_tensor(self.hps.mcmc_num_steps)):
        noise = tf.random.normal(y.shape)
        # noise = self.dist.sample((y.shape[0], )).view(y.shape.as_list()).numpy()
        y_new = y + 0.5 * step_size_square * grad_y + tf.sqrt(step_size_square) * noise * FLAGS.noise_scale

        grad_y_new, log_p_y_new = self.grad_f(y_new, t, tilde_x, step_size_square, sigma, is_recovery, dropout=dropout)
        y, grad_y, log_p_y = y_new, grad_y_new, log_p_y_new

    is_accepted_summary = is_accepted_summary / tf.convert_to_tensor(self.hps.mcmc_num_steps, dtype=tf.float32)
    is_accepted_summary = tf.reduce_mean(is_accepted_summary)

    x = y / a_s

    disp = tf.math.unsorted_segment_mean(
        tf.norm(tf.reshape(x, [x.shape[0], -1]) - tf.reshape(tilde_x, [tilde_x.shape[0], -1]), axis=1),
        t, self.num_timesteps)

    return x, disp, is_accepted_summary

@tf.function
def p_sample_progressive(self, noise):
    """
    Sample a sequence of images with the sequence of noise levels
    """
    num = noise.shape[0]
    x_neg_t = noise
    x_neg = tf.zeros([self.hps.num_diffusion_timesteps, num, self.hps.img_sz, self.hps.img_sz, 3], dtype=tf.float32)
    x_neg = tf.concat([x_neg, tf.expand_dims(noise, axis=0)], axis=0)
    is_accepted_summary = tf.constant(0.)

    for t in tf.range(self.hps.num_diffusion_timesteps - 1, -1, -1):
        x_neg_t, _, is_accepted = self.p_sample_langevin(x_neg_t, t, dropout=0.)
        is_accepted_summary = is_accepted_summary + is_accepted
        x_neg_t = tf.reshape(x_neg_t, [num, self.hps.img_sz, self.hps.img_sz, 3])
        insert_mask = tf.equal(t, tf.range(self.hps.num_diffusion_timesteps + 1, dtype=tf.int32))
        insert_mask = tf.reshape(tf.cast(insert_mask, dtype=tf.float32), [-1, *([1] * len(noise.shape))])
        x_neg = insert_mask * tf.expand_dims(x_neg_t, axis=0) + (1. - insert_mask) * x_neg
    is_accepted_summary = is_accepted_summary / tf.convert_to_tensor(self.hps.num_diffusion_timesteps, dtype=tf.float32)
    return x_neg, is_accepted_summary

def p_sample_progressive_inner(self, noise):
    """
    Sample a sequence of images with the sequence of noise levels, without tf.function decoration
    """
    num = noise.shape[0]
    x_neg_t = noise
    x_neg = tf.zeros([self.hps.num_diffusion_timesteps, num, self.hps.img_sz, self.hps.img_sz, 3], dtype=tf.float32)
    x_neg = tf.concat([x_neg, tf.expand_dims(noise, axis=0)], axis=0)
    is_accepted_summary = tf.constant(0.)

    for t in tf.range(self.hps.num_diffusion_timesteps - 1, -1, -1):
        x_neg_t, _, is_accepted = self.p_sample_langevin(x_neg_t, t, dropout=0.)
        is_accepted_summary = is_accepted_summary + is_accepted
        x_neg_t = tf.reshape(x_neg_t, [num, self.hps.img_sz, self.hps.img_sz, 3])
        insert_mask = tf.equal(t, tf.range(self.hps.num_diffusion_timesteps + 1, dtype=tf.int32))
        insert_mask = tf.reshape(tf.cast(insert_mask, dtype=tf.float32), [-1, *([1] * len(noise.shape))])
        x_neg = insert_mask * tf.expand_dims(x_neg_t, axis=0) + (1. - insert_mask) * x_neg
    is_accepted_summary = is_accepted_summary / tf.convert_to_tensor(self.hps.num_diffusion_timesteps, dtype=tf.float32)
    return x_neg, is_accepted_summary

@tf.function
def distribute_p_sample_progressive(self, noise, strategy):
    """
    Multi-device distributed version of p_sample_progressive
    """
    samples, is_accepted = strategy.run(self.p_sample_progressive_inner, args=(noise,))
    samples = tf.concat(samples.values, axis=1)
    is_accepted = strategy.reduce(tf.distribute.ReduceOp.MEAN, is_accepted, axis=None)

    return samples, is_accepted

In [13]:
FLAGS.jobid = 0
FLAGS.logdir = ''
FLAGS.eager = False
FLAGS.ckpt_load = None
FLAGS.device = 0
FLAGS.tpu = False
FLAGS.tpu_name = None
FLAGS.tpu_zone = None
FLAGS.rnd_seed = 1
FLAGS.problem = 'cifar10'
FLAGS.n_batch_train = 64
FLAGS.lr = 0.0001
FLAGS.beta_1 = 0.9
FLAGS.n_iters = 1000000
FLAGS.grad_clip = False
FLAGS.warmup = 1000
FLAGS.n_batch_per_iter = 1
FLAGS.cosine_decay = False
FLAGS.opt = 'adam'
FLAGS.eval = False
FLAGS.include_xpred_freq = 1
FLAGS.eval_fid = False
FLAGS.fid_n_samples = 64
FLAGS.fid_n_iters = 40000
FLAGS.fid_n_batch = 64
FLAGS.num_res_blocks = 2
FLAGS.num_diffusion_timesteps = 6
FLAGS.randflip = True
FLAGS.dropout = 0.0
FLAGS.normalize = None
FLAGS.use_attention = False
FLAGS.resamp_with_conv = False
FLAGS.spec_norm = True
FLAGS.res_conv_shortcut = True
FLAGS.res_use_scale = True
FLAGS.ma_decay = 0.999
FLAGS.noise_scale = 1.0
FLAGS.mcmc_num_steps = 30
FLAGS.mcmc_step_size_b_square = 0.0002


FLAGS.num_diffusion_timesteps = 6
FLAGS.img_sz = 32
FLAGS.num_res_blocks = 8
FLAGS.n_batch_train = 256
FLAGS.noise_scale = 1.0

diffusion = RecoveryLikelihood(FLAGS)
diffusion.init(x.shape)


In [10]:
hps = FLAGS
B = x.shape[0]
t = tf.random.uniform(shape=[B], maxval=diffusion.num_timesteps, dtype=tf.int32)
x_pos, x_neg, noise = diffusion.q_sample_pairs(x, t)

NameError: name 'diffusion' is not defined

In [19]:
x_neg, disp, is_accepted = diffusion.p_sample_langevin(x_neg, t, dropout=hps.dropout)

KeyboardInterrupt: 

In [17]:
loss, loss_ts, f_ts = diffusion.training_losses(x_pos, x_neg, t, dropout=hps.dropout)

In [21]:
print(t.numpy())
print(diffusion.sigmas)

[3 0 4 1 0 0 2 5 1 3 4 0 5 2 2 1 4 3 4 1 5 0 2 0 2 3 5 3 0 0 0 5 1 5 3 3 0
 0 5 5 0 1 1 0 3 2 4 2 2 0 4 0 3 0 2 1 3 1 0 3 3 5 4 1]
[0.01       0.32368022 0.51649529 0.63221107 0.71324214 0.77336903
 0.99974058]


In [22]:
loss_scale = 1.0 / (tf.cast(tf.gather(diffusion.sigmas, t + 1), tf.float32) / diffusion.sigmas[1])
print(loss_scale)


tf.Tensor(
[0.45381534 1.         0.4185327  0.6266857  1.         1.
 0.51198125 0.3237642  0.6266857  0.45381534 0.4185327  1.
 0.3237642  0.51198125 0.51198125 0.6266857  0.4185327  0.45381534
 0.4185327  0.6266857  0.3237642  1.         0.51198125 1.
 0.51198125 0.45381534 0.3237642  0.45381534 1.         1.
 1.         0.3237642  0.6266857  0.3237642  0.45381534 0.45381534
 1.         1.         0.3237642  0.3237642  1.         0.6266857
 0.6266857  1.         0.45381534 0.51198125 0.4185327  0.51198125
 0.51198125 1.         0.4185327  1.         0.45381534 1.
 0.51198125 0.6266857  0.45381534 0.6266857  1.         0.45381534
 0.45381534 0.3237642  0.4185327  0.6266857 ], shape=(64,), dtype=float32)


<tf.Tensor: shape=(64,), dtype=float64, numpy=
array([0.66666667,       -inf, 0.75      , 0.        ,       -inf,
             -inf, 0.5       , 0.8       , 0.        , 0.66666667,
       0.75      ,       -inf, 0.8       , 0.5       , 0.5       ,
       0.        , 0.75      , 0.66666667, 0.75      , 0.        ,
       0.8       ,       -inf, 0.5       ,       -inf, 0.5       ,
       0.66666667, 0.8       , 0.66666667,       -inf,       -inf,
             -inf, 0.8       , 0.        , 0.8       , 0.66666667,
       0.66666667,       -inf,       -inf, 0.8       , 0.8       ,
             -inf, 0.        , 0.        ,       -inf, 0.66666667,
       0.5       , 0.75      , 0.5       , 0.5       ,       -inf,
       0.75      ,       -inf, 0.66666667,       -inf, 0.5       ,
       0.        , 0.66666667, 0.        ,       -inf, 0.66666667,
       0.66666667, 0.8       , 0.75      , 0.        ])>