In [None]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.optimizers import Adam
import tensorflow.keras
from keras.models import Sequential, Model
from keras.layers import *
from keras.utils import Sequence
from keras.layers import Conv2D, MaxPooling2D
from qkeras import *

from keras.utils import Sequence
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

import os
import random
from datetime import datetime
import time

pi = 3.14159265359

maxval=1e9
minval=1e-9

In [None]:
from dataloaders.OptimizedDataGeneratorNew import OptimizedDataGenerator
from loss import *
from models import *

In [None]:
seed = 10
tf.random.set_seed(seed)
random.seed(seed)

In [None]:
dataset_base_dir = "/data/dajiang/smart-pixels/dataset_3sr/shuffled/dataset_3sr_50x12P5_parquets/contained/recon3D"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "20t")

dataset_train_dir = os.path.join(dataset_base_dir, "train")
dataset_validation_dir = os.path.join(dataset_base_dir, "validation")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

In [None]:
print(f'Number of training files: {len(os.listdir(dataset_train_dir))}')
print(f'Number of validation files: {len(os.listdir(dataset_validation_dir))}')

In [None]:
batch_size = 5000
val_batch_size = 5000
train_file_size = len(os.listdir(dataset_train_dir))
val_file_size = len(os.listdir(dataset_validation_dir))

In [None]:
start_time = time.time()
validation_generator = OptimizedDataGenerator(
    dataset_base_dir = dataset_validation_dir,
    file_type = "parquet",
    data_format = "3D",
    batch_size = val_batch_size,
    file_count = val_file_size,
    to_standardize= True,
    include_y_local= False,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (20,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, 
    files_from_end=True,

    tfrecords_dir = tfrecords_dir_val,
    use_time_stamps = -1,
    max_workers = 2
)

print("--- Validation generator %s seconds ---" % (time.time() - start_time))


In [None]:
# training generator
start_time = time.time()
training_generator = OptimizedDataGenerator(
    dataset_base_dir = dataset_train_dir,
    file_type = "parquet",
    data_format = "3D",
    batch_size = batch_size,
    file_count = train_file_size,
    to_standardize= True,
    include_y_local= False,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (20,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, # True 

    tfrecords_dir = tfrecords_dir_train,
    use_time_stamps = -1,
    max_workers = 2
)
print("--- Training generator %s seconds ---" % (time.time() - start_time))

In [None]:
training_generator = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle = True,
    seed = seed,
    quantize = True
)

validation_generator = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_val,
    shuffle = True,
    seed = seed,
    quantize = True
)


In [None]:
model=CreateModel((13,21,20),n_filters=5,pool_size=3)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=custom_loss
)

model.summary()

In [None]:
# training
pitch = '50x12P5'
fingerprint = '%08x' % random.randrange(16**8)
base_dir = '/home/dajiang/smart-pixels-ml/weights/weights_7pitches/dataset_3sr_contained_weights/'
weights_dir = base_dir + 'weights-{}-bs{}-{}-20t-checkpoints'.format(pitch, batch_size, fingerprint)

# create output directories
if os.path.isdir(base_dir):
    os.mkdir(weights_dir)
else:
    os.mkdir(base_dir)
    os.mkdir(weights_dir)
    
checkpoint_filepath = weights_dir + '/weights.{epoch:02d}-t{loss:.2f}-v{val_loss:.2f}.hdf5'
mcp = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=False,
)

print('Model fingerprint: {}'.format(fingerprint))

In [None]:
history = model.fit(x=training_generator,
                    validation_data=validation_generator,
                    callbacks=[mcp],
                    epochs=1000,
                    shuffle=False, # shuffling now occurs within the data-loader
                    verbose=1)