First, clone the repository

In [1]:
!git clone https://github.com/vairodp/AstroNet.git

Cloning into 'AstroNet'...
remote: Enumerating objects: 837, done.[K
remote: Counting objects: 100% (837/837), done.[K
remote: Compressing objects: 100% (545/545), done.[K
remote: Total 837 (delta 471), reused 604 (delta 261), pack-reused 0[K
Receiving objects: 100% (837/837), 35.74 MiB | 19.62 MiB/s, done.
Resolving deltas: 100% (471/471), done.


In [2]:
%cd AstroNet

/content/AstroNet


Then install the missing libraries that our code requires:

In [3]:
!pip install tensorflow_addons
!pip install tensorflow-datasets==4.3.0
!pip install imgaug==0.4.0

Collecting tensorflow_addons
  Downloading tensorflow_addons-0.14.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[?25l[K     |▎                               | 10 kB 25.0 MB/s eta 0:00:01[K     |▋                               | 20 kB 28.7 MB/s eta 0:00:01[K     |▉                               | 30 kB 25.0 MB/s eta 0:00:01[K     |█▏                              | 40 kB 18.9 MB/s eta 0:00:01[K     |█▌                              | 51 kB 14.8 MB/s eta 0:00:01[K     |█▊                              | 61 kB 13.6 MB/s eta 0:00:01[K     |██                              | 71 kB 12.3 MB/s eta 0:00:01[K     |██▍                             | 81 kB 13.5 MB/s eta 0:00:01[K     |██▋                             | 92 kB 11.8 MB/s eta 0:00:01[K     |███                             | 102 kB 11.9 MB/s eta 0:00:01[K     |███▎                            | 112 kB 11.9 MB/s eta 0:00:01[K     |███▌                            | 122 kB 11.9 MB/s eta 0:00:01[K

Make sure you're using a GPU in order to get fast train and inference.

In [4]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [5]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  print(
      '\n\nThis error most likely means that this notebook is not '
      'configured to use a GPU.  Change this in Notebook Settings via the '
      'command palette (cmd/ctrl-shift-P) or the Edit menu.\n\n')
  raise SystemError('GPU device not found')

Finally, run the training code

In [7]:
%cd src

/content/AstroNet/src


In [8]:
from unet import SourceSegmentation
from datasets.convo_ska import ConvoSKA
from callbacks.display_callback import DisplayCallback
from configs.train_config import ITER_PER_EPOCH, NUM_EPOCHS


# Set use_class_weights=False to run the model without weights for the 4 classes
# Set tiny=True to run the smaller version of this model
unet = SourceSegmentation((128,128,1), use_class_weights=True, tiny=False)
unet.model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.003, clipvalue=1.0)
checkpoint_filepath = '../checkpoints/unet-best.h5'

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    mode='min',
    save_best_only=True)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='../log')
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=10)

unet.compile(optimizer=optimizer, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
dataset_train = ConvoSKA(mode='train').get_dataset()
val_data = ConvoSKA(mode='validation').get_dataset()
display_callback = DisplayCallback(val_data)


unet.fit(dataset_train, epochs=NUM_EPOCHS, validation_data=val_data, 
        callbacks=[model_checkpoint_callback, display_callback, tensorboard_callback, early_stop_callback], steps_per_epoch=ITER_PER_EPOCH)


Model: "U-Net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 128, 128, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 64) 576         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 64) 256         conv2d[0][0]                     
__________________________________________________________________________________________________
tf.nn.relu (TFOpLambda)         (None, 128, 128, 64) 0           batch_normalization[0][0]        
______________________________________________________________________________________________





KeyboardInterrupt: ignored