In [1]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
from tensorflow.keras.optimizers import Adam
import keras
from keras.models import Sequential, Model
from keras.layers import *
from keras.utils import Sequence
from keras.layers import Conv2D, MaxPooling2D
from qkeras import *

from keras.utils import Sequence
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping

import os
import random
from datetime import datetime
import time

import matplotlib.pyplot as plt

pi = 3.14159265359

maxval=1e9
minval=1e-9

2025-06-03 02:01:35.095549: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-03 02:01:35.095621: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-03 02:01:35.096717: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-03 02:01:35.104291: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# os.chdir('SmartPix/data_generator')
os.chdir('SmartPix/Datagenerator_debug')
!pwd

/home/das214/SmartPix/Datagenerator_debug


In [3]:
#from dataprep import *
# from OptimizedDataGeneratorNew import OptimizedDataGenerator
from loss import *
from models import *

In [None]:
# OptimizedDataGeneratorNew.py
import os
import gc
import math
import glob
import random
import logging
import datetime
import numpy as np
import pandas as pd

from typing import Union, List, Tuple
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed

from tqdm import tqdm
import tensorflow as tf
from qkeras import quantized_bits

import utils


# custom quantizer

# @tf.function
def QKeras_data_prep_quantizer(data, bits=4, int_bits=0, alpha=1):
    """
    Applies QKeras quantization.
    Args:
        data (tf.Tensor): Input data (tf.Tensor).
        bits (int): Number of bits for quantization.
        int_bits (int): Number of integer bits.
        alpha (float): (don't change)
    Returns::
        tf.Tensor: Quantized data (tf.Tensor).
    """
    quantizer = quantized_bits(bits, int_bits, alpha=alpha)
    return quantizer(data)

class OptimizedDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, 
            dataset_base_dir: str = "./",
            batch_size: int = 32,
            file_count = None,
            labels_list: Union[List,str] = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
            to_standardize: bool = False,
            input_shape: Tuple = (13,21),
            transpose = None,
            files_from_end = False,
            shuffle=False,

            # Added in Optimized datagenerators 
            load_from_tfrecords_dir: str = None,
            tfrecords_dir: str = None,
            use_time_stamps = -1,
            seed: int = None,
            quantize: bool = False,
            max_workers: int = 1,
            label_scale_pctl: float = 99,
            norm_pos_pctl: float = 99.7,
            norm_neg_pctl: float = 99.7,
            **kwargs,
            ):
        super().__init__() 

        self.shuffle = shuffle
        if shuffle:
            self.seed = seed if seed is not None else 13
            self.rng = np.random.default_rng(seed = self.seed)
        
        # If data is already prepared load -> load that data and use
        if load_from_tfrecords_dir is not None:
            self.file_offsets = [None]
            if not os.path.isdir(load_from_tfrecords_dir):
                raise ValueError(f"Directory {load_from_tfrecords_dir} does not exist.")
            else:
                self.tfrecords_dir = load_from_tfrecords_dir
        else:
            n_time, height, width = input_shape
            
            if use_time_stamps == -1:
                use_time_stamps = list(np.arange(0,20))
            assert len(use_time_stamps) == n_time, f"Expected {n_time} time steps, got {len(use_time_stamps)}"
    
            len_xy = height * width
            col_indices = [
                np.arange(t * len_xy, (t + 1) * len_xy).astype(str)
                for t in use_time_stamps
            ]
            self.recon_cols = np.concatenate(col_indices).tolist()
    
            self.max_workers = max_workers
            self.label_scale_pctl = label_scale_pctl
            self.norm_pos_pctl = norm_pos_pctl
            self.norm_neg_pctl = norm_neg_pctl

            
            self.files = sorted(glob.glob(os.path.join(dataset_base_dir, "part.*.parquet"), recursive=False))
    
            if file_count != None:
                if not files_from_end:
                    self.files = self.files[:file_count]
                else:
                    self.files = self.files[-file_count:]
    
            self.file_offsets = [0]
            self.dataset_mean = None
            self.dataset_std = None
            self.norm_factor_pos = None  
            self.norm_factor_neg = None
            self.labels_scale = None

            self.batch_size = batch_size
            self.labels_list = labels_list
            self.input_shape = input_shape
            self.transpose = transpose
            self.to_standardize = to_standardize

            self.batch_metadata = []  

            self.process_file_parallel()
    
            self.current_file_index = None
            self.current_dataframes = None
    
            if tfrecords_dir is None:
                raise ValueError(f"tfrecords_dir is None")
            utils.safe_remove_directory(tfrecords_dir)
                
            self.tfrecords_dir = tfrecords_dir    
            os.makedirs(self.tfrecords_dir, exist_ok=True)
            self.save_batches_sequentially()
            del self.current_dataframes 
            
        self.tfrecord_filenames = np.sort(np.array(tf.io.gfile.glob(os.path.join(self.tfrecords_dir, "*.tfrecord"))))
        self.quantize = quantize
        self.epoch_count = 0
        self.on_epoch_end()


    def process_file_parallel(self):
        file_infos = [(afile, 
                    self.recon_cols, self.labels_list, 
                    self.label_scale_pctl, self.norm_pos_pctl, self.norm_neg_pctl) 
                    for afile in self.files
                    ]
        results = []
        with ProcessPoolExecutor(self.max_workers) as executor:
            futures = [executor.submit(self._process_file_single, file_info) for file_info in file_infos]
            for future in tqdm(as_completed(futures), total=len(file_infos), desc="Processing Files..."):
                results.append(future.result())

        for amean, avariance, amin, amax, num_rows, labels_scale, pos_scale, neg_scale in results:
            self.file_offsets.append(self.file_offsets[-1] + num_rows)

            if self.dataset_mean is None:
                self.dataset_max = amax
                self.dataset_min = amin
                self.dataset_mean = amean
                self.dataset_std = avariance
            else:
                self.dataset_max = max(self.dataset_max, amax)
                self.dataset_min = min(self.dataset_min, amin)
                self.dataset_mean += amean
                self.dataset_std += avariance
            
            if self.labels_scale is None:
                self.labels_scale = labels_scale
            else:
                self.labels_scale = np.maximum(self.labels_scale, labels_scale)

            self.norm_factor_pos = (pos_scale if self.norm_factor_pos is None
                                    else max(self.norm_factor_pos, pos_scale))
            self.norm_factor_neg = (neg_scale if self.norm_factor_neg is None
                                    else max(self.norm_factor_neg, neg_scale))

        self.dataset_mean = self.dataset_mean / len(self.files)
        self.dataset_std = np.sqrt(self.dataset_std / len(self.files)) 
            
        self.file_offsets = np.array(self.file_offsets)


    @staticmethod
    def _process_file_single(file_info):
        afile, recon_cols, labels_list, label_scale_pctl, norm_pos_pctl, norm_neg_pctl = file_info

        df = pd.read_parquet(afile, columns=recon_cols + labels_list).reset_index(drop=True)
        x = df[recon_cols].values

        nonzeros = abs(x) > 0
        x[nonzeros] = np.sign(x[nonzeros]) * np.log1p(abs(x[nonzeros])) / math.log(2)
        amean, avariance = np.mean(x[nonzeros], keepdims=True), np.var(x[nonzeros], keepdims=True) + 1e-10
        centered = np.zeros_like(x)
        centered[nonzeros] = (x[nonzeros] - amean) / np.sqrt(avariance)
        amin, amax = np.min(centered), np.max(centered)

        pos_vals = np.abs(centered[centered  > 0])
        neg_vals = np.abs(centered[centered  < 0])

        pos_scale = (np.percentile(pos_vals, norm_pos_pctl)
                    if pos_vals.size else 1.0)
        neg_scale = (np.percentile(neg_vals, norm_neg_pctl)
                    if neg_vals.size else 1.0)

        len_adf = len(df)

        labels_values = df[labels_list].values
        labels_scale = np.percentile(np.abs(labels_values), label_scale_pctl, axis=0)

        del df
        gc.collect()
        
        return amean, avariance, amin, amax, len_adf, labels_scale, pos_scale, neg_scale

    def standardize(self, x):
        """
        Applies the normalization configuration in-place to a batch of inputs.
        `x` is changed in-place since the function is mainly used internally
        to standardize images and feed them to your network.
        Args:
            x: Batch of inputs to be normalized.
        Returns:
            The inputs, normalized. 
        """
        out = (x - self.dataset_mean)/self.dataset_std
        out[out > 0] = out[out > 0]/self.norm_factor_pos
        out[out < 0] = out[out < 0]/self.norm_factor_neg
        out = np.clip(out, self.dataset_min, self.dataset_max)
        return out

    def save_batches_sequentially(self):
        num_batches = self.__len__()
        errors_found = []
        for i in tqdm(range(num_batches), desc="Saving batches as TFRecords"):
            result = self.save_single_batch(i)
            if "Error" in result:
                print(result)
                errors_found.append(result)
        
        if errors_found:
            logging.warning(f"Encountered {len(errors_found)} errors during sequential saving of TFRecords.")
        else:
            logging.info("All batches saved successfully in sequential mode.")


    def save_single_batch(self, batch_index):
        """
        Serializes and saves a single batch to a TFRecord file.
        Args:
            batch_index (int): Index of the batch to save.
        Returns:
            str: Path to the saved TFRecord file or an error message.
        """
        
        try:
            filename = f"batch_{batch_index}.tfrecord"
            TFRfile_path = os.path.join(self.tfrecords_dir, filename)
            X, y = self.prepare_batch_data(batch_index)
            serialized_example = self.serialize_example(X, y)
            with tf.io.TFRecordWriter(TFRfile_path) as writer:
                writer.write(serialized_example)
            return TFRfile_path
        except Exception as e:
            return f"Error saving batch {batch_index}: {e}" 
        
 
    def prepare_batch_data(self, batch_index):
        start_evt = batch_index * self.batch_size  
        total_rows = self.file_offsets[-1]
        end_evt = min(start_evt + self.batch_size, total_rows)

        X_chunks = []
        y_chunks = []

        file_idx = np.searchsorted(self.file_offsets, start_evt, side="right") - 1

        evt_cursor = start_evt

        while (evt_cursor < end_evt):
            file_start = self.file_offsets[file_idx]
            rel_start = evt_cursor - file_start

            if file_idx + 1 < len(self.files):
                file_end = self.file_offsets[file_idx + 1]
                rel_end = min(end_evt, file_end) - file_start
            else:
                rel_end = end_evt - file_start

            need_rows = rel_end - rel_start

            if file_idx != self.current_file_index:
                parquet_file = self.files[file_idx]
                df = (pd.read_parquet(parquet_file,
                                    columns=self.recon_cols + self.labels_list)
                        .dropna(subset=self.recon_cols)
                        .reset_index(drop=True))
                if self.shuffle:
                    df = df.sample(frac=1, random_state=self.seed).reset_index(drop=True)
                recon_df  = df[self.recon_cols]
                labels_df = df[self.labels_list]

                recon_values = recon_df.values
                nonzeros = abs(recon_values) > 0
                recon_values[nonzeros] = np.sign(recon_values[nonzeros]) * np.log1p(abs(recon_values[nonzeros])) / np.log(2)
                if self.to_standardize:
                    recon_values[nonzeros] = self.standardize(recon_values[nonzeros])
                recon_values = recon_values.reshape((-1, *self.input_shape))
                if self.transpose is not None:
                    recon_values = recon_values.transpose(self.transpose)
                self.current_dataframes = (
                    recon_values, 
                    labels_df.values,
                )
                self.current_file_index = file_idx
                del df
                gc.collect()

            recon_df, labels_df = self.current_dataframes
            X_chunk = recon_df[rel_start:rel_end]
            y_chunk = labels_df[rel_start:rel_end] / self.labels_scale

            if batch_index == len(self.batch_metadata):
                self.batch_metadata.append({
                    "batch_idx"        : batch_index,
                    "target_batch_size": int(self.batch_size),
                    "actual_batch_size": 0,
                    "shuffled"         : bool(self.shuffle),
                    "shuffle_seed"     : (int(self.seed)
                                        if self.shuffle else None),
                    "segments"         : []
                })
            seg_rows = int(rel_end - rel_start)
            meta = self.batch_metadata[batch_index]
            meta["actual_batch_size"] += seg_rows
            meta["segments"].append({
                "file_idx"  : int(file_idx),
                "file_name" : os.path.basename(self.files[file_idx]),
                "row_start" : int(rel_start),
                "row_end"   : int(rel_end - 1)
            })


            X_chunks.append(X_chunk)
            y_chunks.append(y_chunk)

            evt_cursor += need_rows
            file_idx += 1

        X = np.concatenate(X_chunks, axis=0)
        y = np.concatenate(y_chunks, axis=0)

        return X, y
   

    def serialize_example(self, X, y):
        """
        Serializes a single example (featuresand labels) to TFRecord format. 
        
        Args:
        - X: Training data
        - y: labelled data
        
        Returns:
        - string (serialized TFRecord example).
        """
        # X and y are float32 (maybe we can reduce this)
        X = tf.cast(X, tf.float32)
        y = tf.cast(y, tf.float32)

        feature = {
            'X': self._bytes_feature(tf.io.serialize_tensor(X)),
            'y': self._bytes_feature(tf.io.serialize_tensor(y)),
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    @staticmethod
    def _bytes_feature(value):
        """
        Converts a string/byte value into a Tf feature of bytes_list
        
        Args: 
        - string/byte value
        
        Returns:
        - tf.train.Feature object as a bytes_list containing the input value.
        """
        if isinstance(value, type(tf.constant(0))): # check if Tf tensor
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    def __getitem__(self, batch_index):
        """
        Load the batch from a pre-saved TFRecord file instead of processing raw data.
        Each file contains exactly one batch.
        quantization is done here: Helpful for pretraining without the quantization and the later training with quantized data.
        shuffling is also done here.
        TODO: prefetching (un-done)
        """
        tfrecord_path = self.tfrecord_filenames[batch_index]
        raw_dataset = tf.data.TFRecordDataset(tfrecord_path)
        parsed_dataset = raw_dataset.map(self._parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)

        # Get the first (and only) batch from the dataset
        try:
            X_batch, y_batch = next(iter(parsed_dataset))
        except StopIteration:
            raise ValueError(f"No data found in TFRecord file: {tfrecord_path}")

        X_batch = tf.reshape(X_batch, [-1, *X_batch.shape[1:]])
        y_batch = tf.reshape(y_batch, [-1, *y_batch.shape[1:]])

        if self.quantize:
            X_batch = QKeras_data_prep_quantizer(X_batch, bits=4, int_bits=0, alpha=1)

        if self.shuffle:
            indices = tf.range(start=0, limit=tf.shape(X_batch)[0], dtype=tf.int32)
            shuffled_indices = tf.random.shuffle(indices, seed=self.seed)
            X_batch = tf.gather(X_batch, shuffled_indices)
            y_batch = tf.gather(y_batch, shuffled_indices)

        del raw_dataset, parsed_dataset
        return X_batch, y_batch
            
    @staticmethod
    def _parse_tfrecord_fn(example):
        """
        Parses a single TFRecord example.
        
        Returns:
        - X: as a float32 tensor.
        - y: as a float32 tensor.
        """
        feature_description = {
            'X': tf.io.FixedLenFeature([], tf.string),
            'y': tf.io.FixedLenFeature([], tf.string),
        }
        example = tf.io.parse_single_example(example, feature_description)
        X = tf.io.parse_tensor(example['X'], out_type=tf.float32)
        y = tf.io.parse_tensor(example['y'], out_type=tf.float32)
        return X, y

    def __len__(self):
        """
        Phase-aware length:
            during initial TFRecord creation: math on file_offsets
            after creation in same process: len(batch_metadata)
            when loading existing TFRecords: len(tfrecord_filenames)
        """
        # already have metadata?  Fastest answer.
        if self.batch_metadata:
            return len(self.batch_metadata)

        # still building batches, so compute from source rows.
        if len(self.file_offsets) > 1:         # have real offsets
            total_rows = self.file_offsets[-1]
            return math.ceil(total_rows / self.batch_size)

        # running in "load" mode.
        self.tfrecord_filenames = np.sort(
            np.array(tf.io.gfile.glob(
                os.path.join(self.tfrecords_dir, "*.tfrecord"))))
        return len(self.tfrecord_filenames)

    def on_epoch_end(self):
        '''
        This shuffles the file ordering so that it shuffles the ordering in which the TFRecord
        are loaded during the training for each epochs.
        '''
        gc.collect()
        self.epoch_count += 1
        # Log quantization status once
        if self.epoch_count == 1:
            logging.warning(f"Quantization is {self.quantize} in data generator. This may affect model performance.")

        if self.shuffle:
            self.rng.shuffle(self.tfrecord_filenames)
            self.seed += 1 # So that after each epoch the batch is shuffled with a different seed (deterministic)

In [5]:
dataset_base_dir = "/depot/cms/users/das214/datasets/dataset_2s/dataset_2s_50x12P5_parquets/"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "2t")

dataset_base_dir = os.path.join(dataset_base_dir, "parquets")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

batch_size = 5000
val_batch_size = 5000
train_file_size = 75
val_file_size = 25

In [9]:
start_time = time.time()
validation_generator = OptimizedDataGenerator(
    dataset_base_dir = dataset_base_dir,
    file_type = "parquet",
    data_format = "3D",
    batch_size = val_batch_size,
    file_count = val_file_size,
    to_standardize= True,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (2,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, 
    files_from_end=True,

    tfrecords_dir = tfrecords_dir_val,
    use_time_stamps = [0,19],
    max_workers = 2
)

print("--- Validation generator %s seconds ---" % (time.time() - start_time))

# training generator
start_time = time.time()
training_generator = OptimizedDataGenerator(
    dataset_base_dir = dataset_base_dir,
    file_type = "parquet",
    data_format = "3D",
    batch_size = batch_size,
    file_count = train_file_size,
    to_standardize= True,
    labels_list = ['x-midplane','y-midplane','cotAlpha','cotBeta'],
    input_shape = (2,13,21), # (20,13,21),
    transpose = (0,2,3,1),
    shuffle = False, # True 

    tfrecords_dir = tfrecords_dir_train,
    use_time_stamps = [0,19],
    max_workers = 2
)
print("--- Training generator %s seconds ---" % (time.time() - start_time))

Processing Files...:   0%|          | 0/25 [00:00<?, ?it/s]

Processing Files...: 100%|██████████| 25/25 [00:09<00:00,  2.75it/s]


Directory /depot/cms/users/das214/datasets/dataset_2s/dataset_2s_50x12P5_parquets/TFR_files/2t/TFR_val is removed...


Saving batches as TFRecords:   0%|          | 0/102 [00:00<?, ?it/s]2025-06-03 02:02:13.780650: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 3234 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB MIG 1g.5gb, pci bus id: 0000:81:00.0, compute capability: 8.0
Saving batches as TFRecords: 100%|██████████| 102/102 [00:20<00:00,  5.05it/s]


--- Validation generator 29.70891261100769 seconds ---


Processing Files...: 100%|██████████| 75/75 [00:26<00:00,  2.84it/s]


Directory /depot/cms/users/das214/datasets/dataset_2s/dataset_2s_50x12P5_parquets/TFR_files/2t/TFR_train is removed...


Saving batches as TFRecords: 100%|██████████| 306/306 [00:58<00:00,  5.21it/s]


--- Training generator 85.89488315582275 seconds ---


In [10]:
training_generator.file_offsets

array([      0,   20374,   40748,   61122,   81496,  101870,  122244,
        142618,  162992,  183366,  203740,  224114,  244488,  264862,
        285236,  305610,  325984,  346358,  366732,  387106,  407480,
        427854,  448228,  468602,  488976,  509350,  529724,  550098,
        570472,  590846,  611220,  631594,  651968,  672342,  692716,
        713090,  733464,  753838,  774212,  794586,  814960,  835334,
        855708,  876082,  896456,  916830,  937204,  957578,  977952,
        998326, 1018700, 1039074, 1059448, 1079822, 1100196, 1120570,
       1140944, 1161318, 1181692, 1202066, 1222440, 1242814, 1263188,
       1283562, 1303936, 1324310, 1344684, 1365058, 1385432, 1405806,
       1426180, 1446554, 1466928, 1487302, 1507676, 1528050])

In [11]:
training_generator.norm_factor_pos

1.4030861830865065

In [12]:
training_generator.norm_factor_neg

2.4879350274661034

In [25]:

def get_best_batch_size(file_offsets, target_bs=5000):
    """
    Find the best batch size that minimizes the residual when dividing the total number of rows.
    Args:
        file_offsets (np.ndarray): Array of file offsets.
        target_bs (int): Target batch size.
        tol (float): Tolerance for batch size deviation.
    Returns:
        int: Best batch size.
    """
    last_offset = file_offsets[-1]
    d_bs = int(0.5 * target_bs)
    batch_sizes = np.arange(target_bs - d_bs, target_bs + d_bs + 1)

    residuals = last_offset % batch_sizes
    min_res   = residuals.min()

    # All bs giving the minimal residual
    candidates = batch_sizes[residuals == min_res]

    # Prefer the one closest to the target
    idx = np.argmin(np.abs(candidates - target_bs))
    return int(candidates[idx]), min_res

def _build_batch_plan(file_offsets, batch_size, tol = 0.75):
    """
    Pre-compute (row_start, row_end) for every batch.
    If the last batch < 0.5xbatch_size, merge the last two
    and split them evenly, so both new batches are within
    0.5x...1.0xbatch_size.
    """
    total = file_offsets[-1]
    b      = batch_size
    plan   = []
    start  = 0
    while start < total:
        end = min(start + b, total)
        plan.append((start, end))
        start = end

    # Re-balance if the tail is too short
    if len(plan) >= 2:
        last_len = plan[-1][1] - plan[-1][0]
        if last_len < tol * b:
            sec_start = plan[-2][0]
            comb_len  = plan[-1][1] - sec_start
            half      = math.ceil(comb_len / 2)
            plan[-2]  = (sec_start, sec_start + half)
            plan[-1]  = (sec_start + half, sec_start + comb_len)
    return plan


In [30]:

best_bs, min_residual = get_best_batch_size(training_generator.file_offsets, target_bs=5_000) 
print(f"Best batch size: {best_bs}, Min residual: {min_residual}")

plan = _build_batch_plan(training_generator.file_offsets, batch_size=batch_size, tol = 0.9)
print(f"Batch plan: {plan[-5:]}")


Best batch size: 5010, Min residual: 0
Batch plan: [(1505000, 1510000), (1510000, 1515000), (1515000, 1520000), (1520000, 1524025), (1524025, 1528050)]


In [105]:
training_generator.batch_metadata

[{'batch_idx': 0,
  'target_batch_size': 5000,
  'actual_batch_size': 5000,
  'shuffled': False,
  'shuffle_seed': None,
  'segments': [{'file_idx': 0,
    'file_name': 'part.0.parquet',
    'row_start': 0,
    'row_end': 4999}]},
 {'batch_idx': 1,
  'target_batch_size': 5000,
  'actual_batch_size': 5000,
  'shuffled': False,
  'shuffle_seed': None,
  'segments': [{'file_idx': 0,
    'file_name': 'part.0.parquet',
    'row_start': 5000,
    'row_end': 9999}]},
 {'batch_idx': 2,
  'target_batch_size': 5000,
  'actual_batch_size': 5000,
  'shuffled': False,
  'shuffle_seed': None,
  'segments': [{'file_idx': 0,
    'file_name': 'part.0.parquet',
    'row_start': 10000,
    'row_end': 14999}]},
 {'batch_idx': 3,
  'target_batch_size': 5000,
  'actual_batch_size': 5000,
  'shuffled': False,
  'shuffle_seed': None,
  'segments': [{'file_idx': 0,
    'file_name': 'part.0.parquet',
    'row_start': 15000,
    'row_end': 19999}]},
 {'batch_idx': 4,
  'target_batch_size': 5000,
  'actual_batch_

In [106]:
for bm in training_generator.batch_metadata:
    print(bm['actual_batch_size'])

5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000
5000


In [107]:
(5000+3050)/2

4025.0

In [108]:
def build_vec_index(gen):
    lbls, batches, events = [], [], []
    for b in tqdm(range(len(gen)), desc = "Building LUT"):
        _, yb = gen[b]
        # arr = np.round(yb.numpy() * SCALE).astype("int64")
        arr = yb.numpy()
        lbls.append(arr)
        batches.append(np.full(len(arr), b, "int32"))
        events.append(np.arange(len(arr), dtype="int32"))
    lbls   = np.vstack(lbls)               # shape (N, 4)
    batches = np.concatenate(batches)
    events  = np.concatenate(events)
    return lbls, batches, events           # all NumPy

lbls,bs,es    = build_vec_index(training_generator)


Building LUT: 100%|██████████| 306/306 [00:19<00:00, 15.32it/s]


In [109]:
model=CreateModel((13,21,2),n_filters=5,pool_size=3)
model.compile(
    optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3),
    loss=custom_loss
)

model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 13, 21, 2)]       0         
                                                                 
 q_separable_conv2d_1 (QSep  (None, 11, 19, 5)         33        
 arableConv2D)                                                   
                                                                 
 q_activation_5 (QActivatio  (None, 11, 19, 5)         0         
 n)                                                              
                                                                 
 q_conv2d_1 (QConv2D)        (None, 11, 19, 5)         30        
                                                                 
 q_activation_6 (QActivatio  (None, 11, 19, 5)         0         
 n)                                                              
                                                           

In [110]:
from datetime import datetime

fingerprint = '%08x' % random.randrange(16**8)
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
os.makedirs("trained_models", exist_ok=True)
base_dir = f'./trained_models/model-{fingerprint}-checkpoints'
os.makedirs(base_dir, exist_ok=True)  
checkpoint_filepath = base_dir + '/weights.{epoch:02d}-t{loss:.2f}-v{val_loss:.2f}.hdf5'

In [111]:
print(fingerprint)

02f54bce


In [112]:
from tensorflow.keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint, Callback

early_stopping_patience = 50

class CustomModelCheckpoint(ModelCheckpoint):
    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        checkpoints = [f for f in os.listdir(base_dir) if f.startswith('weights')]
        if len(checkpoints) > 1:
            checkpoints.sort()
            for checkpoint in checkpoints[:-1]:
                os.remove(os.path.join(base_dir, checkpoint))

es = EarlyStopping(patience=early_stopping_patience, restore_best_weights=True)

mcp = CustomModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=True,
    save_freq='epoch',
    verbose=1
)

csv_logger = CSVLogger(f'{base_dir}/training_log.csv', append=True)

In [113]:
history = model.fit(
        x=training_generator,
        validation_data=validation_generator,
        callbacks=[es, mcp, csv_logger],
        epochs=1000,
        shuffle=False,
        verbose=1
    )

Epoch 1/1000


Epoch 1: val_loss improved from inf to 3209.36328, saving model to ./trained_models/model-02f54bce-checkpoints/weights.01-t16400.39-v3209.36.hdf5
Epoch 2/1000
Epoch 2: val_loss improved from 3209.36328 to -2254.90234, saving model to ./trained_models/model-02f54bce-checkpoints/weights.02-t431.89-v-2254.90.hdf5
Epoch 3/1000
Epoch 3: val_loss improved from -2254.90234 to -4375.14502, saving model to ./trained_models/model-02f54bce-checkpoints/weights.03-t-2700.99-v-4375.15.hdf5
Epoch 4/1000
Epoch 4: val_loss improved from -4375.14502 to -5775.51367, saving model to ./trained_models/model-02f54bce-checkpoints/weights.04-t-5376.68-v-5775.51.hdf5
Epoch 5/1000
Epoch 5: val_loss improved from -5775.51367 to -7591.24951, saving model to ./trained_models/model-02f54bce-checkpoints/weights.05-t-7001.22-v-7591.25.hdf5
Epoch 6/1000
Epoch 6: val_loss improved from -7591.24951 to -8344.97852, saving model to ./trained_models/model-02f54bce-checkpoints/weights.06-t-7784.75-v-8344.98.hdf5
Epoch 7/1000

KeyboardInterrupt: 