First, clone the repository

In [None]:
!git clone https://github.com/vairodp/AstroNet.git

In [None]:
%cd AstroNet

Then install the missing libraries that our code requires:

In [None]:
!pip install tensorflow_addons
!pip install tensorflow-datasets==4.3.0
!pip install imgaug==0.4.0

Make sure you're using a GPU in order to get fast train and inference.

In [None]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  print(
      '\n\nThis error most likely means that this notebook is not '
      'configured to use a GPU.  Change this in Notebook Settings via the '
      'command palette (cmd/ctrl-shift-P) or the Edit menu.\n\n')
  raise SystemError('GPU device not found')

Finally, run the training code

In [None]:
%cd src

In [None]:
from unet import SourceSegmentation
from datasets.convo_ska import ConvoSKA
from callbacks.display_callback import DisplayCallback
from configs.train_config import ITER_PER_EPOCH, NUM_EPOCHS


# Set use_class_weights=False to run the model without weights for the 4 classes
# Set tiny=True to run the smaller version of this model
unet = SourceSegmentation((128,128,1), use_class_weights=True, tiny=False)
unet.model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.003, clipvalue=1.0)
checkpoint_filepath = '../checkpoints/unet-best.h5'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='../log')
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=10)

unet.compile(optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
dataset_train = ConvoSKA(mode='train').get_dataset()
val_data = ConvoSKA(mode='validation').get_dataset()
display_callback = DisplayCallback(val_data)


unet.fit(dataset_train, epochs=NUM_EPOCHS, validation_data=val_data, 
        callbacks=[model_checkpoint_callback, display_callback, tensorboard_callback, early_stop_callback], steps_per_epoch=ITER_PER_EPOCH)
