In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

In [None]:
data, info = tfds.load('eurosat',with_info=True, split='train' )

print(f'\nFeatures: {info.features}')
print(f"Loaded examples: {len(data)}")

In [None]:
data.element_spec

In [None]:
def pre_process(data_element):
    image = data_element['image']
    image = tf.cast(image, tf.float32)/255
    shape = tf.shape(image)
    noisy = tf.random.normal(shape=tf.shape(image), stddev=0.15, dtype=tf.float32)
    noisy_image = tf.clip_by_value(image + noisy, clip_value_min=0, clip_value_max=1)
    
    return noisy_image, image


In [None]:
train_data = data.take(25000)
train_data = train_data.map(pre_process)
train_data = train_data.cache()
train_data = train_data.shuffle(buffer_size=2500)
train_data = train_data.batch(128)
train_data = train_data.prefetch(tf.data.AUTOTUNE)
print(train_data.element_spec)

test_data = data.skip(25000).take(2000)
test_data = test_data.map(pre_process)
test_data = test_data.cache()
test_data = test_data.batch(32)
test_data = test_data.prefetch(tf.data.AUTOTUNE)
print(test_data.element_spec)

In [None]:
noisy_batch, orig_batch = next(iter(test_data))

fig, axs = plt.subplots(8, 4, figsize=(8, 8))

for ax, noisy, orig in zip(axs.flat, noisy_batch, orig_batch):
  combined = tf.concat([noisy, orig], axis=1)
  ax.imshow(combined)
  ax.axis("off")

fig.set_tight_layout(True)

In [None]:
# creating encoder
encoder = tf.keras.Sequential()
encoder.add(tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'))
encoder.add(tf.keras.layers.MaxPooling2D())
encoder.add(tf.keras.layers.Dropout(0.2))

encoder.add(tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'))
encoder.add(tf.keras.layers.MaxPooling2D())
encoder.add(tf.keras.layers.Dropout(0.2))

encoder.add(tf.keras.layers.Conv2D(filters=128, kernel_size=1, padding='same', activation='relu', kernel_initializer='he_normal'))
encoder.add(tf.keras.layers.MaxPooling2D())
encoder.add(tf.keras.layers.Dropout(0.2))

In [None]:
# creating decoder
decoder = tf.keras.Sequential()
decoder.add(tf.keras.layers.UpSampling2D())
decoder.add(tf.keras.layers.Conv2D(filters=128, kernel_size=1, padding='same', activation='relu', kernel_initializer='he_normal'))
decoder.add(tf.keras.layers.Dropout(0.2))

decoder.add(tf.keras.layers.UpSampling2D())
decoder.add(tf.keras.layers.Conv2D(filters=64,kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'))
decoder.add(tf.keras.layers.Dropout(0.2))

decoder.add(tf.keras.layers.UpSampling2D())
decoder.add(tf.keras.layers.Conv2D(filters=32 ,kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal'))
decoder.add(tf.keras.layers.Dropout(0.2))

In [None]:
# creating model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(64, 64, 3)))
model.add(encoder)
model.add(tf.keras.layers.Dense(units=32))
model.add(decoder)
model.add(tf.keras.layers.Conv2D(filters=3, kernel_size=1, padding='same'))


In [None]:
model.summary(expand_nested=True)

In [None]:
# compiling the model
model.compile(optimizer = tf.keras.optimizers.Adam(), 
              loss = tf.keras.losses.MeanAbsoluteError(), 
              metrics = tf.keras.metrics.MeanAbsolutePercentageError())

In [None]:
# fitting the model
history = model.fit(train_data, epochs=35, validation_data=test_data)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=[10,5])
plt.plot(history.history['loss'], 'black', linewidth=2.0)
plt.plot(history.history['val_loss'], 'blue', linewidth=2.0)
plt.legend(['Training Loss', 'Validation Loss'], fontsize=14)
plt.xlabel('Epochs', fontsize=10)
plt.ylabel('Loss', fontsize=10)
plt.title('Loss Curves', fontsize=12)

plt.figure(figsize=[10,5])
plt.plot(history.history['categorical_accuracy'], 'black', linewidth=2.0)  
plt.plot(history.history['val_categorical_accuracy'], 'blue', linewidth=2.0)  
plt.legend(['Training Accuracy', 'Validation Accuracy'], fontsize=14)
plt.xlabel('Epochs', fontsize=10)
plt.ylabel('Accuracy', fontsize=10)
plt.title('Accuracy Curves', fontsize=12)