In [1]:

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

import mccd
from astropy.io import fits

%pylab inline

print(tf.__version__)






runstats and/or skimage could not be imported because not installed
Populating the interactive namespace from numpy and matplotlib
2.4.4


In [2]:

tf.config.list_physical_devices('GPU')


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [20]:

from tensorflow.keras.layers import Layer, Conv2D, LeakyReLU, PReLU, UpSampling2D, MaxPooling2D, Activation
from tensorflow.keras.models import Model
from tensorflow_addons.layers import SpectralNormalization

class Conv(Layer):
    def __init__(self, n_filters, kernel_size=3, non_linearity='relu', spectral_normalization=False, power_iterations=5, **kwargs):
        super().__init__(**kwargs)
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.non_linearity = non_linearity
        self.spectral_normalization = spectral_normalization
        self.power_iterations = power_iterations
        if self.spectral_normalization:
            self.conv = SpectralNormalization(
                Conv2D(
                    filters=self.n_filters,
                    kernel_size=self.kernel_size,
                    padding='same',
                    activation=None,
                ),
                power_iterations=self.power_iterations,
            )
        else:
            self.conv = Conv2D(
                filters=self.n_filters,
                kernel_size=self.kernel_size,
                padding='same',
                activation=None,
            )
        if self.non_linearity == 'lrelu':
            self.act = LeakyReLU(0.1)
        elif self.non_linearity == 'prelu':
            self.act = PReLU(shared_axes=[1, 2])
        else:
            self.act = Activation(self.non_linearity)

    def call(self, inputs):
        outputs = self.conv(inputs)
        outputs = self.act(outputs)
        return outputs

class ConvBlock(Layer):
    def __init__(self, n_filters, kernel_size=3, non_linearity='relu', n_non_lins=2, spectral_normalization=False, power_iterations=5, **kwargs):
        super().__init__(**kwargs)
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.non_linearity = non_linearity
        self.n_non_lins = n_non_lins
        self.spectral_normalization = spectral_normalization
        self.power_iterations = power_iterations
        self.convs = [
            Conv(
                n_filters=self.n_filters,
                kernel_size=self.kernel_size,
                non_linearity=self.non_linearity,
                spectral_normalization=self.spectral_normalization,
                power_iterations=self.power_iterations,
            ) for _ in range(self.n_non_lins)
        ]

    def call(self, inputs):
        outputs = inputs
        for conv in self.convs:
            outputs = conv(outputs)
        return outputs

class UpConv(Layer):
    def __init__(self, n_filters, kernel_size=3, spectral_normalization=False, power_iterations=5, **kwargs):
        super().__init__(**kwargs)
        self.n_filters = n_filters
        self.kernel_size = kernel_size
        self.spectral_normalization = spectral_normalization
        self.power_iterations = power_iterations
        if self.spectral_normalization:
            self.conv = SpectralNormalization(
                Conv2D(
                    filters=self.n_filters,
                    kernel_size=self.kernel_size,
                    padding='same',
                    activation=None,
                ),
                power_iterations=self.power_iterations,
            )
        else:
            self.conv = Conv2D(
                filters=self.n_filters,
                kernel_size=self.kernel_size,
                padding='same',
                activation=None,
            )
        self.up = UpSampling2D(size=(2, 2))

    def call(self, inputs):
        outputs = self.up(inputs)
        outputs = self.conv(outputs)
        return outputs


class Unet(Model):
    def __init__(
            self,
            n_output_channels=1,
            kernel_size=3,
            layers_n_channels=[64, 128, 256, 512, 1024],
            layers_n_non_lins=2,
            non_linearity='relu',
            spectral_normalization=False,
            power_iterations=5,
            **kwargs,
        ):
        super().__init__(**kwargs)
        self.n_output_channels = n_output_channels
        self.kernel_size = kernel_size
        self.layers_n_channels = layers_n_channels
        self.n_layers = len(self.layers_n_channels)
        self.spectral_normalization = spectral_normalization
        self.layers_n_non_lins = layers_n_non_lins
        self.non_linearity = non_linearity
        self.power_iterations = power_iterations
        self.down_convs = [
            ConvBlock(
                n_filters=n_channels,
                kernel_size=self.kernel_size,
                non_linearity=self.non_linearity,
                n_non_lins=self.layers_n_non_lins,
                spectral_normalization=self.spectral_normalization,
                power_iterations=self.power_iterations,
            ) for n_channels in self.layers_n_channels[:-1]
        ]
        self.down = MaxPooling2D(pool_size=(2, 2), padding='same')
        self.bottom_conv = ConvBlock(
            n_filters=self.layers_n_channels[-1],
            kernel_size=self.kernel_size,
            non_linearity=self.non_linearity,
            n_non_lins=self.layers_n_non_lins,
            spectral_normalization=self.spectral_normalization,
            power_iterations=self.power_iterations,
        )
        self.up_convs = [
            ConvBlock(
                n_filters=n_channels,
                kernel_size=self.kernel_size,
                non_linearity=self.non_linearity,
                n_non_lins=self.layers_n_non_lins,
                spectral_normalization=self.spectral_normalization,
                power_iterations=self.power_iterations,
            ) for n_channels in self.layers_n_channels[:-1]
        ]
        self.ups = [
            UpConv(
                n_filters=n_channels,
                kernel_size=self.kernel_size,
                spectral_normalization=self.spectral_normalization,
                power_iterations=self.power_iterations,
            ) for n_channels in self.layers_n_channels[:-1]
        ]
        if self.spectral_normalization:            
            self.final_conv = SpectralNormalization(
                Conv2D(
                    filters=self.n_output_channels,
                    kernel_size=1,
                    padding='same',
                    activation=None,
                ),
                power_iterations=self.power_iterations,
            )    
        else:
            self.final_conv = Conv2D(
                filters=self.n_output_channels,
                kernel_size=1,
                padding='same',
                activation=None,
            )
        
    def pad(self, image):
        r"""Convert images to 64x64x1 shaped tensors to feed the model, using zero-padding."""
        pad = tf.constant([[0,0], [6,7],[6,7], [0,0]])
        return tf.pad(image, pad, "CONSTANT")    
        
    def crop(self, image):
        r"""Crop back the image to its original size and convert it to np.array"""
        return tf.image.crop_to_bounding_box(image, 6, 6, 51, 51)

    def call(self, inputs):
        scales = []
        outputs = self.pad(inputs)
        for conv in self.down_convs:
            outputs = conv(outputs)
            scales.append(outputs)
            outputs = self.down(outputs)
        outputs = self.bottom_conv(outputs)
        for scale, conv, up in zip(scales[::-1], self.up_convs[::-1], self.ups[::-1]):
            outputs = up(outputs)
            outputs = tf.concat([outputs, scale], axis=-1)
            outputs = conv(outputs)
        outputs = self.final_conv(outputs)
        outputs = self.crop(outputs)
        return outputs



In [21]:
from mccd.denoising.evaluate import keras_psnr, center_keras_psnr
from mccd.denoising.preprocessing import eigenPSF_data_gen

In [22]:
!ls -lah /n05data/tliaudat/new_deepmccd/training_realistic_sims/output_mccd/eigenPSF_datasets/

total 12G
drwxrwxr-x 2 tliaudat tliaudat  105 Feb 22 14:27 .
drwxrwxr-x 9 tliaudat tliaudat  251 Feb 22 13:46 ..
-rw-rw-r-- 1 tliaudat tliaudat 5.9G Feb 22 14:27 all_eigenpsfs.fits
-rw-rw-r-- 1 tliaudat tliaudat 447M Feb 22 14:27 global_eigenpsfs.fits
-rw-rw-r-- 1 tliaudat tliaudat 5.5G Feb 22 14:27 local_eigenpsfs.fits


In [26]:

args = {
    'run_id_name': 'spec_norm_unet',
    'dataset_path': '/n05data/tliaudat/new_deepmccd/training_realistic_sims/output_mccd/eigenPSF_datasets/local_eigenpsfs.fits',
    'base_save_path': '/n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/',
    'batch_size': 32,
    'data_train_ratio': 0.8,
    'n_epochs': 100,
    'lr_param': 1e-3,
    'use_lr_scheduler': True,
    'layers_n_channel': 64,
    'layers_levels': 5,
    'kernel_size': 3,
    'n_shuffle': 50,
    'spectral_normalization': True,
    'power_iterations': 1,
}

# Paths
run_id_name = args['run_id_name']
eigenpsf_dataset_path = args['dataset_path']
base_save_path = args['base_save_path']
checkpoint_path = base_save_path + 'cp_' + run_id_name + '.h5'

# Save parameters
# np.save(base_save_path + 'params_' + run_id_name + '.npy', args, allow_pickle=True)



# Training parameters
batch_size = args['batch_size'] # 32
n_epochs = args['n_epochs'] # 500
lr_param =  args['lr_param'] # 1e-3



# Unet parameters

# # Save output prints to logfile
# old_stdout = sys.stdout
# log_file = open(base_save_path + run_id_name + '_output.log','w')
# sys.stdout = log_file
# print('Starting the log file.')

print(tf.test.gpu_device_name())

print('Load data..')
img = fits.open(eigenpsf_dataset_path)
img = img[1].data['VIGNETS_NOISELESS']

np.random.shuffle(img)

size_train = np.floor(len(img) * args['data_train_ratio'])
training, test = img[:int(size_train),:,:], img[int(size_train):,:,:]

print('Prepare datasets..')
training = eigenPSF_data_gen(
    data=training,
    snr_range= [1e-3, 100],
    img_shape=(51, 51),
    batch_size=batch_size,
    n_shuffle=args['n_shuffle'],
    noise_estimator=False,
    enhance_noise=True,
)

test = eigenPSF_data_gen(
    data=test,
    snr_range= [1e-3, 100],
    img_shape=(51, 51),
    batch_size=1,
    noise_estimator=False,
    enhance_noise=True,
)




/device:GPU:0
Load data..
Prepare datasets..


In [27]:

steps = int(size_train/batch_size)

# Increasing the filter number with a factor of 2
layers_n_channels = [args['layers_n_channel'] * (2**it) for it in range(args['layers_levels'])]
print('layers_n_channels: ', layers_n_channels)

model = Unet(
    n_output_channels=1,
    kernel_size=args['kernel_size'],
    layers_n_channels=layers_n_channels,
    spectral_normalization=args['spectral_normalization'],
    power_iterations=args['power_iterations'],
)    

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    monitor='mse',
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode='min',
    save_freq='epoch',
    options=None
)

def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lr_cback = tf.keras.callbacks.LearningRateScheduler(l_rate_schedule)

if args['use_lr_scheduler']:
    models_callbacks = [cp_callback, lr_cback]
else:
    models_callbacks = [cp_callback]

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_param),
    loss='mse',
    metrics=['mse', keras_psnr, center_keras_psnr],
)



layers_n_channels:  [64, 128, 256, 512, 1024]


In [28]:
print('Start model training and timing..')
start_train = time.time()
history = model.fit(
    training,
    validation_data=test,
    steps_per_epoch=steps,
    epochs=n_epochs,
    validation_steps=1,
    callbacks=models_callbacks,
    shuffle=False,
    verbose=1,
)
print('Model training ended..')
end_train = time.time()
print('Train elapsed time: %f'%(end_train-start_train))

# Save history file
try:
    np.save(base_save_path + run_id_name + '_history_file.npy', history.history, allow_pickle=True)
except:
    pass


Start model training and timing..
Epoch 1/100

Epoch 00001: mse improved from inf to 0.00051, saving model to /n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/cp_spec_norm_unet.h5
Epoch 2/100

Epoch 00002: mse improved from 0.00051 to 0.00021, saving model to /n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/cp_spec_norm_unet.h5
Epoch 3/100

Epoch 00003: mse improved from 0.00021 to 0.00017, saving model to /n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/cp_spec_norm_unet.h5
Epoch 4/100

Epoch 00004: mse improved from 0.00017 to 0.00017, saving model to /n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/cp_spec_norm_unet.h5
Epoch 5/100

Epoch 00005: mse improved from 0.00017 to 0.00015, saving model to /n05data/tliaudat/new_deepmccd/sandbox/testing_spectral_norm/cp_spec_norm_unet.h5
Epoch 6/100

Epoch 00006: mse did not improve from 0.00015
Epoch 7/100

Epoch 00007: mse improved from 0.00015 to 0.00014, saving model to /n05data/tliaudat/new

KeyboardInterrupt: 

In [None]:


"""
UNET 32
With power_iterations=10
- One epoch is ~380s

With power_iterations=5
- One epoch is  ~309s

With power_iterations=1
- One epoch is  ~242s

UNET 64
With power_iterations=1
- One epoch is  ~522s
"""
