In [None]:
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, BatchNormalization, Dropout
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU
from tensorflow import keras

In [None]:
dataset = keras.utils.image_dataset_from_directory(
    'straw/',
    color_mode='rgb',
    batch_size=100,
    image_size=(48, 48),
    shuffle=True,
)

In [None]:
images_list = []

for images, labels in dataset:
    images_list.append(images)

images = np.concatenate(images_list, axis=0)


In [None]:
BUFFER_SIZE = images.shape[0]
BATCH_SIZE = 100

In [None]:
BUFFER_SIZE = BUFFER_SIZE // BATCH_SIZE * BATCH_SIZE
images = images[:BUFFER_SIZE]

In [None]:
images.shape

In [None]:
images = images / 255

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
hidden_dim = 2

In [None]:
generator = tf.keras.Sequential([
    
  Dense(6 * 6 * 512, activation='relu', input_shape=(hidden_dim,)), #6х6
  BatchNormalization(),
  LeakyReLU(),
    
  Reshape((6, 6, 512)),
    
  Conv2DTranspose(256, (5, 5), strides=(1, 1), padding='same', activation='relu'),
  BatchNormalization(),
  LeakyReLU(),

  Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', activation='relu'), #12x12
  BatchNormalization(),
  LeakyReLU(),

  Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', activation='relu'), #24x24
  BatchNormalization(),
  LeakyReLU(),

  Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', activation='sigmoid'), #48x48
    
])

In [None]:
discriminator = tf.keras.Sequential([

Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[48, 48, 3]),
LeakyReLU(),
Dropout(0.3),

Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
LeakyReLU(),
Dropout(0.3),

Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
LeakyReLU(),
Dropout(0.3),
    
Flatten(),
Dense(1)

])

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(fake_output):
  loss = cross_entropy(tf.ones_like(fake_output), fake_output)
  return loss

In [None]:
def discriminator_loss(real_output, fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  total_loss = real_loss + fake_loss
  return total_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
@tf.function
def train_step(images):
  noise = tf.random.normal([BATCH_SIZE, hidden_dim])

  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True) # generate img

    real_output = discriminator(images, training=True) # gradient of real in disc
    fake_output = discriminator(generated_images, training=True) # gradient of generated in disc

    gen_loss = generator_loss(fake_output) # loss of generation 
    disc_loss = discriminator_loss(real_output, fake_output) # loss of discriminator

  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) # calculate gradients 
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) 

  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) # apply optimizer 
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

  return gen_loss, disc_loss

In [None]:
noise = tf.random.normal([BATCH_SIZE, hidden_dim])
generated_images = generator(noise, training=True)
plt.imshow(generated_images[10])

In [None]:
import time
def train(dataset, epochs):
  history = []
  MAX_PRINT_LABEL = 10
  th = BUFFER_SIZE / (BATCH_SIZE * MAX_PRINT_LABEL)

  for epoch in range(1, epochs + 1):
    print(f'{epoch}/{EPOCHS}: ', end='')

    start = time.time()
    n = 0

    gen_loss_epoch = 0
    for image_batch in dataset:
      gen_loss, disc_loss = train_step(image_batch)
      gen_loss_epoch += K.mean(gen_loss)
      if (n % th == 0): print('=', end='')
      n += 1

    history += [gen_loss_epoch / n]
    print(': ' + str(history[-1]))
    print('Время эпохи {} составляет {} секунд'.format(epoch, time.time() - start))

  return history

In [None]:
EPOCHS = 3000
history = train(train_dataset, EPOCHS)


In [None]:
n = 2
total = 2 * n + 1

plt.figure(figsize=(total, total))

num = 1
for i in range(-n, n + 1):
  for j in range(-n, n + 1):
    ax = plt.subplot(total, total, num)
    num += 1
    img = generator.predict(np.expand_dims([0.5 * i / n, 0.5 * j / n], axis=0))
    plt.imshow(img[0])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.show()

In [None]:
generator.save('50x50_strawnerry_GAN_3000_EPOCHS_generator2.keras')

In [None]:
discriminator.save('50x50_strawnerry_GAN_3000_EPOCHS_discriminator2.keras')