## Setup

In [None]:
! pip install ipython numpy pillow pandas six tensorflow tensorflow-hub

In [None]:
# @title Imports and utility functions
import os

import IPython
from IPython.display import display
import numpy as np
import PIL.Image
import pandas as pd
import six

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

import tensorflow_hub as hub

In [2]:
def imgrid(imarray, cols=8, pad=1):
  pad = int(pad)
  assert pad >= 0
  cols = int(cols)
  assert cols >= 1
  N, H, W, C = imarray.shape
  rows = int(np.ceil(N / float(cols)))
  batch_pad = rows * cols - N
  assert batch_pad >= 0
  post_pad = [batch_pad, pad, pad, 0]
  pad_arg = [[0, p] for p in post_pad]
  imarray = np.pad(imarray, pad_arg, 'constant')
  H += pad
  W += pad
  grid = (imarray
          .reshape(rows, cols, H, W, C)
          .transpose(0, 2, 1, 3, 4)
          .reshape(rows*H, cols*W, C))
  return grid[:-pad, :-pad]


def imshow(a, format='png', jpeg_fallback=True):
  a = np.asarray(a, dtype=np.uint8)
  if six.PY3:
    str_file = six.BytesIO()
  else:
    str_file = six.StringIO()
  PIL.Image.fromarray(a).save(str_file, format)
  png_data = str_file.getvalue()
  try:
    disp = display(IPython.display.Image(png_data))
  except IOError:
    if jpeg_fallback and format != 'jpeg':
      print ('Warning: image was too large to display in format "{}"; '
             'trying jpeg instead.').format(format)
      return imshow(a, format='jpeg')
    else:
      raise
  return disp


In [3]:
class Generator(object):

  def __init__(self, module_spec):
    self._module_spec = module_spec
    self._sess = None
    self._graph = tf.Graph()
    self._load_model()

  @property
  def z_dim(self):
    return self._z.shape[-1].value

  @property
  def conditional(self):
    return self._labels is not None

  def _load_model(self):
    with self._graph.as_default():
        self._generator = hub.load(self._module_spec)
        
        # Retrieve the structured input signature of the model
        input_signature = self._generator.signatures["default"].structured_input_signature[1]
        
        # Define input tensors based on expected shapes and types
        inputs = {
            key: tf.zeros(shape=spec.shape, dtype=spec.dtype) 
            for key, spec in input_signature.items()
        }
        
        # Generate a sample output by calling the model with these inputs
        self._samples = self._generator.signatures["default"](**inputs)
        
        print("Inputs:", inputs)
        print("Generated output:", self._samples)
        
        # Set self._z and self._labels if present in input keys
        self._z = inputs.get("z")
        self._labels = inputs.get("labels", None)

  def _init_session(self):
    if self._sess is None:
      self._sess = tf.Session(graph=self._graph)
      self._sess.run(tf.global_variables_initializer())

  def get_noise(self, num_samples, seed=None):
    if np.isscalar(seed):
      np.random.seed(seed)
      return np.random.normal(size=[num_samples, self.z_dim])
    z = np.empty(shape=(len(seed), self.z_dim), dtype=np.float32)
    for i, s in enumerate(seed):
      np.random.seed(s)
      z[i] = np.random.normal(size=[self.z_dim])
    return z

  def get_samples(self, z, labels=None):
    with self._graph.as_default():
      self._init_session()

      feed_dict = {self._z: z}
      if self.conditional:
        assert labels is not None
        assert labels.shape[0] == z.shape[0]
        feed_dict[self._labels] = labels

      # Thực hiện suy luận và in ra kết quả trả về từ mô hình
        samples = self._sess.run(self._samples, feed_dict=feed_dict)
        print("Samples dictionary:", samples)  # Xem các khóa trong samples
        
        # Giả sử đầu ra mong muốn nằm trong khóa 'generated'
        generated_samples = samples["generated"]  # Thay 'generated' bằng tên khóa thực tế
        return np.uint8(np.clip(256 * generated_samples, 0, 255))

      # samples = self._sess.run(self._samples, feed_dict=feed_dict)
      # return np.uint8(np.clip(256 * samples, 0, 255))

## Select a model

In [None]:
# @title Select a model { run: "auto" }

model_name = "S3GAN 128x128 20% labels (FID 6.9, IS 98.1)"  # @param ["S3GAN 256x256 10% labels (FID 8.8, IS 130.7)", "S3GAN 128x128 2.5% labels (FID 12.6, IS 48.7)", "S3GAN 128x128 5% labels (FID 8.4, IS 74.0)", "S3GAN 128x128 10% labels (FID 7.6, IS 90.3)", "S3GAN 128x128 20% labels (FID 6.9, IS 98.1)"]
models = {
    "S3GAN 256x256 10% labels": "https://tfhub.dev/google/compare_gan/s3gan_10_256x256/1",
    "S3GAN 128x128 2.5% labels": "https://tfhub.dev/google/compare_gan/s3gan_2_5_128x128/1",
    "S3GAN 128x128 5% labels": "https://tfhub.dev/google/compare_gan/s3gan_5_128x128/1",
    "S3GAN 128x128 10% labels": "https://tfhub.dev/google/compare_gan/s3gan_10_128x128/1",
    "S3GAN 128x128 20% labels": "https://tfhub.dev/google/compare_gan/s3gan_20_128x128/1",
}

module_spec = models[model_name.split(" (")[0]]
print("Module spec:", module_spec)

tf.reset_default_graph()
# tf.compat.v1.reset_default_graph()
print("Loading model...")
sampler = Generator(module_spec)
print("Model loaded.")

## Sample

In [None]:
# @title Sampling { run: "auto" }

num_rows = 2  # Số hàng
num_cols = 3  # Số cột
noise_seed = 23  # Hạt giống cho noise
label_str = "980) volcano"  

num_samples = 64  # Tổng số mẫu (2 * 3 = 6)
z = sampler.get_noise(num_samples, seed=noise_seed)

label = int(label_str.split(')')[0])
if label == -1:
    # Nếu label là -1, chọn ngẫu nhiên label từ số lớp
    labels = np.random.randint(0, num_classes, size=(num_samples))
else:
    # Tạo mảng label với tất cả các giá trị giống nhau
    labels = np.asarray([label] * num_samples)

# Gọi hàm get_samples với z và labels
samples = sampler.get_samples(z, labels)

# Hiển thị hình ảnh
imshow(imgrid(samples[:6], cols=num_cols))
# Hiển thị chỉ 6 hình
# imshow(imgrid(samples[:6], cols=num_cols))  # Chỉ hiển thị 6 hình


In [None]:
# @title Interpolation { run: "auto" }

num_samples = 1 
num_interps = 6  
noise_seed_A = 11 
noise_seed_B = 0
label_str = "1) goldfish, Carassius auratus"  


def interpolate(A, B, num_interps):
  alphas = np.linspace(0, 1, num_interps)
  if A.shape != B.shape:
    raise ValueError('A and B must have the same shape to interpolate.')
  return np.array([((1-a)*A + a*B)/np.sqrt(a**2 + (1-a)**2) for a in alphas])


def interpolate_and_shape(A, B, num_interps):
  interps = interpolate(A, B, num_interps)
  return (interps.transpose(1, 0, *range(2, len(interps.shape)))
                 .reshape(num_samples * num_interps, -1))

label = int(label_str.split(')')[0])
labels = np.asarray([label] * num_samples * num_interps)


z_A = sampler.get_noise(num_samples, seed=noise_seed_A)
z_B = sampler.get_noise(num_samples, seed=noise_seed_B)
z = interpolate_and_shape(z_A, z_B, num_interps)

target_shape = (64, z.shape[1])  # Kích thước mà mô hình mong đợi (64, 120)
padded_z = np.zeros(target_shape, dtype=z.dtype)
padded_z[:z.shape[0]] = z  # Điền dữ liệu thực tế đến `z.shape[0]`

padded_labels = np.zeros(64, dtype=labels.dtype)
padded_labels[:labels.shape[0]] = labels

samples = sampler.get_samples(padded_z, padded_labels)
imshow(imgrid(samples[:6], cols=num_interps))
# samples = sampler.get_samples(z, labels)
# imshow(imgrid(samples, cols=num_interps))

## Discriminator

In [None]:
class Discriminator(object):

  def __init__(self, module_spec):
    self._module_spec = module_spec
    self._sess = None
    self._graph = tf.Graph()
    self._load_model()

  @property
  def conditional(self):
    return "labels" in self._inputs

  @property
  def image_shape(self):
    return self._inputs["images"].shape.as_list()[1:]

  def _load_model(self):
    with self._graph.as_default():
      self._discriminator = hub.load(self._module_spec,tags={"disc", "bsNone"})
      signature = self._discriminator.signatures["default"]
      # Lấy thông tin đầu vào của mô hình
      input_info = self._discriminator.signatures['default'].structured_input_signature[1]
      self._inputs = {k: tf.placeholder(v.dtype, v.shape.as_list(), k) for k, v in input_info.items()}

      self._outputs = signature(**self._inputs)
      print("Inputs:", self._inputs)
      print("Outputs:", self._outputs)

  def _init_session(self):
    if self._sess is None:
      self._sess = tf.Session(graph=self._graph)
      self._sess.run(tf.global_variables_initializer())

  def predict(self, images, labels=None):
    with self._graph.as_default():
      self._init_session()
      feed_dict = {self._inputs["images"]: images}
      if "labels" in self._inputs:
        assert labels is not None
        assert labels.shape[0] == images.shape[0]
        feed_dict[self._inputs["labels"]] = labels
      return self._sess.run(self._outputs, feed_dict=feed_dict)

In [None]:
disc = Discriminator(module_spec)

batch_size = 4
num_classes = 1000
images = np.random.random(size=[batch_size] + disc.image_shape)
labels = np.random.randint(0, num_classes, size=(batch_size))

disc.predict(images, labels=labels)
# lấy hình ảnh trong hàm Generator
# samples = sampler.get_samples(z, labels)
# predictions = disc.predict(samples[:4], labels=labels[:4])  
# print(predictions)


Credit: https://github.com/tensorflow/hub/blob/master/examples/colab/s3gan_generation_with_tf_hub.ipynb