In [None]:
# python based
import tensorflow as tf

from pathlib import Path
import time
import shutil
import random
import os
import pandas as pd
import numpy as np
from tensorflow.keras.optimizers import Adam, Nadam
import matplotlib.pyplot as plt

# custom 
from loss import *
from models import *
from dataloaders import utils
from dataloaders import OptimizedDataGenerator as DG

In [None]:
print("Num CPU:", os.cpu_count())
print(utils.check_GPU())

In [None]:
# dataset_path = '/depot/cms/users/dkondra/smart-pixels/dataset8/unflipped-positive'
dataset_path = '/depot/cms/users/das214/dataset8/unflipped'
data_directory_path = os.path.join(dataset_path, 'recon3D/')
labels_directory_path = os.path.join(dataset_path, 'labels/')

data_files_path_list = [os.path.join(data_directory_path, f) for f in os.listdir(data_directory_path)]
labels_files_path_list = [os.path.join(labels_directory_path, f) for f in os.listdir(labels_directory_path)]

data_files_path_list = np.sort(data_files_path_list)
labels_files_path_list = np.sort(labels_files_path_list)

print(data_directory_path)
print(labels_directory_path)
print(len(data_files_path_list))
print(len(labels_files_path_list))

In [None]:
output_directory = Path("./").resolve()

batch_size = 5000
val_batch_size = 5000
train_file_size = 142
val_file_size = 6

# batch_size = 500
# val_batch_size = 500
# train_file_size = 20 
# val_file_size = 6 

In [None]:
os.makedirs(output_directory, exist_ok=True)
print(output_directory)

In [None]:
# create tf records directory (random)
stamp = '%08x' % random.randrange(16**8)
stamp = 1
tfrecords_dir_train = Path(output_directory, f"tfrecords_train_{stamp}").resolve()
tfrecords_dir_validation = Path(output_directory, f"tfrecords_validation_{stamp}").resolve()

# Path where the TFRecord files will be saved (deterministic)
tfrecords_dir_train = "/depot/cms/users/das214/tfrecords_20t_train_d8"
tfrecords_dir_validation = "/depot/cms/users/das214/tfrecords_20t_val_d8"

# clean up tf records
# utils.safe_remove_directory(tfrecords_dir_train)
# utils.safe_remove_directory(tfrecords_dir_validation)

In [None]:
# validation generator

# Caution: If you want to load older TFRecord files dont run like this instead use `load_from_tfrecords_dir`
#       Or else if there exist and data at `tfrecords_dir` will be removed.

start_time = time.time()
validation_generator = DG.OptimizedDataGenerator(
    data_directory_path = data_directory_path,
    labels_directory_path = labels_directory_path,
    is_directory_recursive = False,
    file_type = "parquet",
    data_format = "3D",
    batch_size = val_batch_size,
    file_count = val_file_size,
    to_standardize= True,
    include_y_local= False, 
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (2,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, 
    files_from_end=True,

    tfrecords_dir = tfrecords_dir_validation,
    use_time_stamps = [0, 19], #-1
    max_workers = 2 # Don't make this too large (will use up all RAM)
)

print("--- Validation generator %s seconds ---" % (time.time() - start_time))

In [None]:
# training generator

# Caution: If you want to load older TFRecord files dont run like this instead use `load_from_tfrecords_dir`
#       Or else if there exist and data at `tfrecords_dir` will be removed.


start_time = time.time()
training_generator = DG.OptimizedDataGenerator(
    data_directory_path = data_directory_path,
    labels_directory_path = labels_directory_path,
    is_directory_recursive = False,
    file_type = "parquet",
    data_format = "3D",
    batch_size = batch_size,
    file_count = train_file_size,
    to_standardize= True,
    include_y_local= False,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (2,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, # True 

    tfrecords_dir = tfrecords_dir_train,
    use_time_stamps = [0, 19], #-1
    max_workers = 2 # Don't make this too large (will use up all RAM)
)
print("--- Training generator %s seconds ---" % (time.time() - start_time))

In [None]:
# This cell can be commented out entirely
# This cell shows the implementation of how to load TFRecord files if they are already initialized earlier
# Letting the user load from older files saving time (from preprocessing and saving)

training_generator = DG.OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle = True,
    seed = 13,
    quantize = True
)

validation_generator = DG.OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_validation, 
    shuffle = True,
    seed = 13,
    quantize = True
)

In [None]:
input_shape = (13, 21, 2)
model = CreateModel(input_shape, n_filters=5, pool_size=3)
model.summary()

In [None]:
decay_steps = 90*training_generator.__len__()
alpha = 0.01
initial_learning_rate = 1e-3
warmup_target = 1e-1
warmup_steps = 10*training_generator.__len__()

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=initial_learning_rate,
    decay_steps=decay_steps,
    alpha=alpha,
    warmup_target = warmup_target,
    warmup_steps = warmup_steps
)


model.compile(
    optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3),
    loss=custom_loss
)

# model.compile(
#     optimizer=tf.keras.optimizers.Nadam(learning_rate=lr_schedule),
#     loss=custom_loss
# )

In [None]:
fingerprint = '%08x' % random.randrange(16**8)
os.makedirs("trained_models", exist_ok=True)
base_dir = f'./trained_models/model-{fingerprint}-checkpoints'
os.makedirs(base_dir, exist_ok=True)  
checkpoint_filepath = base_dir + '/weights.{epoch:02d}-t{loss:.2f}-v{val_loss:.2f}.hdf5'

In [None]:
print(fingerprint)

In [None]:
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint, Callback

early_stopping_patience = 50

class CustomModelCheckpoint(ModelCheckpoint):
    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        checkpoints = [f for f in os.listdir(base_dir) if f.startswith('weights')]
        if len(checkpoints) > 1:
            checkpoints.sort()
            for checkpoint in checkpoints[:-1]:
                os.remove(os.path.join(base_dir, checkpoint))

es = EarlyStopping(patience=early_stopping_patience, restore_best_weights=True)

mcp = CustomModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=True,
    save_freq='epoch',
    verbose=1
)

csv_logger = CSVLogger(f'{base_dir}/training_log.csv', append=True)

In [None]:
model.fit(
    x=training_generator,
    validation_data=validation_generator,
    callbacks=[es, mcp, csv_logger],
    epochs=1000,
    shuffle=False,
    verbose=1
)

In [None]:
# # clean up tf records
# utils.safe_remove_directory(tfrecords_dir_train)
# utils.safe_remove_directory(tfrecords_dir_validation)