<a href="https://colab.research.google.com/github/weathon/3d2smile/blob/main/Untitled84.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# https://www.tensorflow.org/tutorials/text/image_captioning

In [2]:
import tensorflow as tf
import numpy as np
from PIL import Image
import pylab

In [93]:
CHANNELS = [64, 64, 128, 256, 512]
RES_NUMBER = 4


class ResNetBlock(tf.keras.layers.Layer):
  def __init__(self, channel1, channel2, name):
    super().__init__(name=name)
    self.Conv2D = tf.keras.Sequential([
        tf.keras.layers.Conv2D(channel1, (3,3), padding="same", strides = 2),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Conv2D(channel2, (3,3), padding="same"),
        tf.keras.layers.BatchNormalization()
    ])
    self.project =  tf.keras.layers.Conv2D(channel2, (1,1), name="image_projection", strides = 2)
    self.relu = tf.keras.layers.ReLU()

  def call(self, images):
    x = self.Conv2D(images)
    images = self.project(images)
    return self.relu(x+images)

class ImageEncoder(tf.keras.Model):
  def __init__(self):
    super().__init__()
    self.conv2D_7 = tf.keras.layers.Conv2D(CHANNELS[0], (7, 7), padding="same", strides = 2)
    self.pooling = tf.keras.layers.MaxPooling2D(pool_size=(3,3), strides=2)
    self.ResBlocks = []
    for i in range(RES_NUMBER):
      self.ResBlocks.append(ResNetBlock(CHANNELS[i+1], CHANNELS[i+1]*2, name=f"resblock_{i}"))
    self.flatten = tf.keras.layers.Reshape((49,1024))
    self.attention = tf.keras.layers.Attention()

  def call(self, images):
    images = self.conv2D_7(images)
    images = self.pooling(images)
    for block in self.ResBlocks:
      images = block(images)
    images = self.flatten(images)
    return self.attention([images, images, images])

  # def model(self):
  #     x = tf.keras.layers.Input(shape=(400, 400, 3))
  #     return tf.keras.Model(inputs=[x], outputs=self.call(x))
  # https://stackoverflow.com/questions/55235212/model-summary-cant-print-output-shape-while-using-subclass-model

In [None]:
print(np.array([
        [
         [1,2,3] ,[3,4,5]
        ],
        [
         [5,6,8] ,[7,8,50]
        ],
    ]).shape)
tf.keras.layers.Reshape((4,3))(
[    [
        [[
         [1,2,3] ,[3,4,5]
        ]],
        [[
         [5,6,8] ,[7,8,50]
        ]],
    ]]
)


In [None]:
encoder = ImageEncoder()
input = tf.keras.layers.Input(shape=(400, 400, 3))
model =  tf.keras.Model(inputs=[input], outputs=encoder.call(input))
model.summary(expand_nested=1)
tf.keras.utils.plot_model(model, expand_nested=1)

In [None]:
block = ResNetBlock(8, 8, name="block9")
input = tf.keras.layers.Input(shape=(400, 400, 3))
model =  tf.keras.Model(inputs=[input], outputs=block.call(input))
model.summary(expand_nested=1)
tf.keras.utils.plot_model(model, expand_nested=1)