In [None]:
!pip install watermark

In [None]:

#@title import packages
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors
tfk = tf.keras
tfkl = tf.keras.layers

# warningを非表示にする
tf.autograph.set_verbosity(0)

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

sns.reset_defaults()
sns.set_context(context='talk', font_scale=1.0)
cmap = plt.get_cmap("tab10")

%config InlineBackend.figure_format = 'retina'
%matplotlib inline


In [None]:

#@title distribution of pixels
def trancated_logistic(x, mu, s):
  if 0 < x < 255:
    return tf.sigmoid((x + 0.5 - mu) / s) - tf.sigmoid((x - 0.5 - mu) / s)
  elif x == 0:
    return tf.sigmoid((x + 0.5 - mu) / s)
  elif x == 255:
    return 1 - tf.sigmoid((x - 0.5 - mu) / s)
  else:
    return 0


trancated_logistic = np.vectorize(trancated_logistic)

xx = np.arange(0, 256)

fig, axes = plt.subplots(2, 1, figsize=(8, 8))

mu_list, s_list = [200, 120, 30], [12, 15, 10]
ax = axes[0]
for mu, s in zip(mu_list, s_list):
    ax.bar(xx, trancated_logistic(xx, mu, s), label=f'={mu}, s={s}', alpha=0.7)
ax.legend()
ax.set_xlabel('pixel')
ax.set_ylabel('density')
ax.set_title('component distributions')

pi_list = [0.2, 0.5, 0.3]
d = 0
for mu, s, pi in zip(mu_list, s_list, pi_list):
    d += pi * trancated_logistic(xx, mu, s)
ax = axes[1]
ax.bar(xx, d, color='gray')
ax.set_xlabel('pixel')
ax.set_ylabel('density')
ax.set_title(f'mixture distribution (={pi_list})')

plt.tight_layout()

In [None]:

tf.random.set_seed(42)

# tensorflow_datasetsからMNISTのデータを読み込み
data = tfds.load('mnist')
train_data, test_data = data['train'], data['test']

def image_preprocess(x):
  x['image'] = tf.cast(x['image'], tf.float32)
  return ((x['image'], x['label']),)

batch_size = 16
train_it = train_data.map(image_preprocess).batch(batch_size).shuffle(1000)

## Model Definition

In [None]:
image_shape = (28,28,1)
#define PixelCNN
label_shape=()
dist = tfd.PixelCNN(
    image_shape=image_shape, #(height,width,channel)
    conditional_shape = label_shape, #shape of conditional input
    num_resnet=1, #num of layers of resnet
    num_hierarchies = 2, #num of blocsk
    num_filters = 32, #num of filters
    num_logistic_mix = 5, #num of mixture distributions
    dropout_p=0.3 #dropout rate
)

#input
image_input = tfkl.Input(shape=image_shape)
label_input = tfkl.Input(shape=label_shape)

#log likelihood
log_prob = dist.log_prob(image_input,conditional_input=label_input)

#define model
class_cond_model = tfk.Model(
    inputs=[image_input,label_input],outputs=log_prob)
class_cond_model.add_loss(-tf.reduce_mean(log_prob))

#compilation of model
class_cond_model.compile(
    optimizer=tfk.optimizers.Adam(),metrics=[]
)
#train
class_cond_model.fit(train_it,epochs=10,verbose=True)

#plot result
n_sample=4
samples=dist.sample((n_sample,3),conditional_input=[1,2,3])
fig,axes = plt.subplots(n_sample,3,figsize=(12,10))
for i in range(n_sample):
  for j in range(3):
    ax = axes[i][j]
    ax.imshow(samples[i,j,...,0],cmap="gray")
    ax.set_title(f"sample of digit {j+1}")