In [None]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.optimizers import Adam
import keras
from keras.models import Sequential, Model
from keras.layers import *
from keras.utils import Sequence
from keras.layers import Conv2D, MaxPooling2D
from qkeras import *

from keras.utils import Sequence
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

import os
import json
import random

pi = 3.14159265359

maxval=1e9
minval=1e-9

In [None]:
#from dataprep import *
from dataloaders.OptimizedDataGenerator import OptimizedDataGenerator
from loss import *
from models_bnorm import *

# Scaling Lists for Different Pixel Pitches:
* 100x25x100 um:  [150.0, 37.5, 10.0, 1.22]
* 50x25x100 um:   [75.0, 37.5, 10.0, 1.22]
* 50x20x100 um:   [75.0, 30.0, 10.0, 1.22]
* 50x15x100 um:   [75.0, 22.5, 10.0, 1.22]
* 50x12.5x100 um: [75.0, 18.75, 10.0, 1.22]
* 50x10x100 um:   [75.0, 15.0, 10.0, 1.22]

In [None]:
# You can define a JSON configuration file locally
# {
#    "data_base_dir": "/data/dajiang/smartPixels",
#    "tfrecords_base_dir" : "/data/dajiang/smartPixels",
#    "model_base_dir": "/home/dajiang/smart-pixels-ml/weights"
# }
config_file_path = 'config.json'

# If the file does not exist, the notebook uses default values for those entries
data_base_dir = "/data/dajiang/smartPixels/dataset_2s"
tfrecords_base_dir = "/data/dajiang/smartPixels/tfrecords"
model_base_dir = "/home/dajiang/smart-pixels-ml/weights"

if os.path.exists(config_file_path):
    with open(config_file_path, 'r') as file:
        data = json.load(file)
        data_base_dir = data.get('data_base_dir')
        tfrecords_base_dir = data.get('tfrecords_base_dir')
        model_base_dir = data.get('model_base_dir')
    print(f"Use config info from file: {data_base_dir}, {tfrecords_base_dir}, {model_base_dir}")
else:
    print(f"File does not exist. Use default config info: {data_base_dir}, {tfrecords_base_dir}, {model_base_dir}")

In [None]:
%%time

batch_size = 1000
val_batch_size = 1000
train_file_size = 50
val_file_size = 10

tfrecords_dir_train = f"{tfrecords_base_dir}/tfrecords_20t_train_50x12P5_bnorm_timeslices2"
tfrecords_dir_val = f"{tfrecords_base_dir}/tfrecords_20t_val_50x12P5_bnorm_timeslices2"

load_from_tfrecords_enabled = True

training_generator = OptimizedDataGenerator(
    data_directory_path = f"{data_base_dir}/dataset_2s_50x12P5_parquets/unflipped/recon3D/",
    labels_directory_path = f"{data_base_dir}/dataset_2s_50x12P5_parquets/unflipped/labels/",
    is_directory_recursive = False,
    file_type = "parquet",
    data_format = "3D",
    batch_size = batch_size,
    file_count = train_file_size,
    to_standardize= True,
    include_y_local= False,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    scaling_list = [75.0, 18.75, 10.0, 1.22],
    input_shape = (2,13,21),
    transpose = (0,2,3,1),
    files_from_end=True,
    shuffle= True,

    load_from_tfrecords_dir = tfrecords_dir_train if load_from_tfrecords_enabled else None,
    tfrecords_dir = tfrecords_dir_train,
    use_time_stamps = -1, #-1
    max_workers = 1, # Don't make this too large (will use up all RAM)
    seed = 10,
    quantize = True # Quantization ON
)

validation_generator = OptimizedDataGenerator(
    data_directory_path = f"{data_base_dir}/dataset_2s_50x12P5_parquets/unflipped/recon3D/",
    labels_directory_path = f"{data_base_dir}/dataset_2s_50x12P5_parquets/unflipped/labels/",
    is_directory_recursive = False,
    file_type = "parquet",
    data_format = "3D",
    batch_size = val_batch_size,
    file_count = val_file_size,
    to_standardize= True,
    include_y_local= False,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    scaling_list = [75.0, 18.75, 10.0, 1.22],
    input_shape = (2,13,21),
    transpose = (0,2,3,1),
    files_from_end=True,
    shuffle= True,

    load_from_tfrecords_dir = tfrecords_dir_val if load_from_tfrecords_enabled else None,
    tfrecords_dir = tfrecords_dir_val,
    use_time_stamps = -1, #-1
    max_workers = 1, # Don't make this too large (will use up all RAM)
    seed = 10,
    quantize = True # Quantization ON
)

In [None]:
# compiles model
model=CreateModel((13,21,2),n_filters=5,pool_size=3)
model.summary()

In [None]:
model.compile(optimizer=Adam(learning_rate=0.001), loss=custom_loss)

In [None]:
%%time
#TODO: Use a flag to skip model training and load the pre-trained model.
# TRAINING_ENABLED = True

# training
es = EarlyStopping(
    patience=50,
    restore_best_weights=True
)

base_dir = f"{model_base_dir}/weights_7pitches/weights-50x12P5x100_timeslices2-checkpoints"
os.makedirs(base_dir, exist_ok=True)
checkpoint_filepath = base_dir + '/weights.{epoch:02d}-t{loss:.2f}-v{val_loss:.2f}.hdf5'
mcp = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=False,
)

class ScalePrintingCallback(keras.callbacks.Callback):    
    def on_epoch_end(self, epoch, logs=None):
        scale_layer = self.model.layers[-1]
        print(
            f"scaling layer ({epoch}):", 
            scale_layer.scale, 
            tf.math.softplus(scale_layer.scale)
        )

print_scale = ScalePrintingCallback()

history = model.fit(x=training_generator,
                    validation_data=validation_generator,
                    callbacks=[mcp],
                    epochs=100,
                    shuffle=False, # shuffling now occurs within the data-loader
                    verbose=1)

In [None]:
model.save(f"{model_base_dir}/weights_7pitches/best_model.hdf5")
model.save(f"{model_base_dir}/weights_7pitches/best_model.keras")
model.save_weights(f"{model_base_dir}/weights_7pitches/best_model_weights.hdf5")
model.save_weights(f"{model_base_dir}/weights_7pitches/best_model_weights.keras")