In [None]:
!pip install keras-visualizer

In [None]:
!pip install astroNN

In [None]:
from keras_visualizer import visualizer

In [None]:
import numpy as np
from tensorflow.keras import utils
from tensorflow.keras import layers
from astroNN.datasets import load_galaxy10
import matplotlib.pyplot as plt
import tensorflow as tf
import pathlib
import numpy as np
from sklearn.model_selection import train_test_split
from astroNN.datasets.galaxy10 import galaxy10cls_lookup
from PIL import Image as im

In [None]:
#Load Data
dataset = load_galaxy10()

In [None]:
images, labels = dataset

In [None]:
print(images[0].shape)
first_image = im.fromarray(images[0])
plt.axis("off")
plt.imshow(first_image)

In [None]:
#Remove low count classes
del_idx = []

for idx, label in enumerate(labels):
  if (label in [3,5,6,8,9]): del_idx.append(idx)

labels = np.delete(labels, del_idx)
images = np.delete(images, del_idx, 0)
labels = np.where(labels == 4, 3, labels)
labels = np.where(labels == 7, 4, labels)

In [None]:
labels = utils.to_categorical(labels, 5)
labels = labels.astype(np.float32)
images = images.astype(np.float32)

In [None]:
print(labels.shape)

In [None]:
# Datasets
train_idx, val_idx = train_test_split(np.arange(labels.shape[0]), test_size=0.2)
train_ds, train_labels, val_ds, val_labels = images[train_idx], labels[train_idx], images[val_idx], labels[val_idx]
val_img = val_ds
train_ds = tf.data.Dataset.from_tensor_slices((train_ds, train_labels)).batch(128)
val_ds = tf.data.Dataset.from_tensor_slices((val_ds, val_labels)).batch(128)

In [None]:
# Lookup class labels
class_names = []
for i in [0,1,2,4,7]:
  class_names.append(galaxy10cls_lookup(i))
print(class_names)
num_classes = 5

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
# Data augmentation
data_augmentation = tf.keras.Sequential(
  [
    tf.keras.layers.experimental.preprocessing.RandomFlip(mode="horizontal"),
    tf.keras.layers.experimental.preprocessing.RandomContrast(factor=0.1),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.3, fill_mode='nearest', interpolation='bilinear', seed=None, fill_value=0.0)
  ]
)

In [None]:
def create_model():
  model = tf.keras.Sequential([
    data_augmentation,
    tf.keras.Input(shape=(69, 69, 3)),
    layers.Conv2D(32, kernel_size=7, strides=2, padding="same"),
    layers.MaxPool2D(),
    layers.LeakyReLU(alpha=0.4),
    layers.Conv2D(32, kernel_size=7, strides=2, padding="same"),
    layers.MaxPool2D(),
    layers.LeakyReLU(alpha=0.4),
    layers.Conv2D(64, kernel_size=7, strides=2, padding="same"),
    layers.MaxPool2D(),
    layers.LeakyReLU(alpha=0.4),
    layers.Conv2D(128, kernel_size=7, strides=2, padding="same"),
    layers.LeakyReLU(alpha=0.4),
    layers.Dropout(0.2),
    layers.Flatten(),
    layers.Dense(64, activation="relu"),
    layers.Dense(32, activation="relu"),
    layers.Dense(5, activation="sigmoid"),
  ])
  model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.00001),
    loss = 'categorical_crossentropy',
    metrics=['accuracy'],
    
  )
  return model
model = create_model()
#model.summary()

In [None]:
epochs = 400
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

In [None]:
y_vloss = history.history['val_loss']
y_loss = history.history['loss']
y_acc = history.history['accuracy']
y_vacc = history.history['val_accuracy']

In [None]:
print(max(y_vacc))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.plot(np.arange(len(y_vloss)), y_vloss, marker='.', c='red')
ax1.plot(np.arange(len(y_loss)), y_loss, marker='.', c='blue')
ax1.grid()
plt.setp(ax1, xlabel='epoch', ylabel='loss')

ax2.plot(np.arange(len(y_vacc)), y_vacc, marker='.', c='red')
ax2.plot(np.arange(len(y_acc)), y_acc, marker='.', c='blue')
ax2.grid()
plt.setp(ax2, xlabel='epoch', ylabel='accuracy')

plt.show()


In [None]:
preds = model.predict(val_ds)
cs = np.argmax(preds, axis=1)
print(cs.shape)
print(val_img.shape)

In [None]:
for idx, i in enumerate(cs):
  if(i!= 4): continue
  img = val_img[idx]
  plt.imshow(img)