In [1]:
import os

os.environ['KERAS_BACKEND'] = 'jax'

from pathlib import Path

import h5py as h5
import keras
import numpy as np

from hgq.config import QuantizerConfigScope
from hgq.layers import QAdd, QAveragePooling1D, QEinsumDenseBatchnorm, QGlobalAveragePooling1D
from hgq.utils import trace_minmax
from hgq.utils.sugar import BetaScheduler, Dataset, FreeEBOPs, ParetoFront, PBar, PieceWiseSchedule

In [2]:
train_path_root = Path('/tmp/train')
test_path_root = Path('/tmp/val')
train_paths = list(train_path_root.glob('**/*.h5'))
test_paths = list(test_path_root.glob('**/*.h5'))

In [3]:
N = 64
n = 3

In [4]:
_X_train = []
_X_test = []
_y_train = []
_y_test = []

for p in train_paths:
    with h5.File(p) as f:
        x = np.array(f['jetConstituentList'][:, :64, :], dtype=np.float16)[:, :, [5, 8, 11]]
        _X_train.append(x)
        y = np.argmax(f['jets'][:, -6:-1], axis=-1, keepdims=True)
        _y_train.append(y)

for p in test_paths:
    with h5.File(p) as f:
        x = np.array(f['jetConstituentList'][:, :64, :], dtype=np.float16)[:, :, [5, 8, 11]]
        _X_test.append(x)
        y = np.argmax(f['jets'][:, -6:-1], axis=-1, keepdims=True)
        _y_test.append(y)

X_train = np.concatenate(_X_train)
X_test = np.concatenate(_X_test)
y_train = np.concatenate(_y_train)
y_test = np.concatenate(_y_test)

del _X_train
del _X_test
del _y_train
del _y_test

In [5]:
X_train.shape

(620000, 64, 3)

In [6]:
with QuantizerConfigScope(place=('weight', 'bias'), b0=4, i0=2, k0=1), QuantizerConfigScope(place='datalane', f0=3):
    inp = keras.layers.Input((N, n))

    x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, 64), bias_axes='C', activation='relu')(inp)
    s = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, 64), bias_axes='C')(x)
    qs = QAveragePooling1D(pool_size=N)(x)
    d = QEinsumDenseBatchnorm('bnc,cC->bnC', (1, 64), bias_axes='C')(qs)
    x = keras.layers.ReLU()(QAdd()([s, d]))

    x = QEinsumDenseBatchnorm('bnc,cC->bnC', (N, 64), bias_axes='C', activation='relu')(x)
    x = keras.layers.Flatten()(QGlobalAveragePooling1D()(x))
    x = QEinsumDenseBatchnorm(
        'bc,cC->bC',
        64,
        bias_axes='C',
        activation='relu',
    )(x)
    x = QEinsumDenseBatchnorm(
        'bc,cC->bC',
        32,
        bias_axes='C',
        activation='relu',
    )(x)
    x = QEinsumDenseBatchnorm(
        'bc,cC->bC',
        16,
        bias_axes='C',
        activation='relu',
    )(x)
    out = QEinsumDenseBatchnorm(
        'bc,cC->bC',
        5,
        bias_axes='C',
    )(x)

model = keras.Model(inp, out)
opt = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = ['accuracy']
model.compile(optimizer=opt, loss=loss, metrics=metrics)

In [7]:
save_path = Path('/tmp/results')
save_path.mkdir(parents=True, exist_ok=True)

In [8]:
N_train = int(0.8 * len(X_train))
order = np.random.permutation(len(X_train))
X_train = X_train[order]
y_train = y_train[order]
X_val = X_train[N_train:]
y_val = y_train[N_train:]
X_train = X_train[:N_train]
y_train = y_train[:N_train]

_std, _bias = np.std(X_train.astype(np.float32), axis=(0, 1)), np.mean(X_train.astype(np.float32), axis=(0, 1))
X_train = (X_train - _bias) / _std
X_val = (X_val - _bias) / _std
X_test = (X_test - _bias) / _std

In [9]:
train_data = Dataset(X_train, y_train, batch_size=2480, device='gpu:0')
val_data = Dataset(X_val, y_val, batch_size=2480, device='gpu:0')
test_data = Dataset(X_test, y_test, batch_size=2600, device='gpu:0')

In [10]:
pbar = PBar(metric='loss: {loss:.2f}/{val_loss:.2f} - acc: {accuracy:.2%}/{val_accuracy:.2%} - beta: {beta:.2e}')

terminate_on_nan = keras.callbacks.TerminateOnNaN()

save = ParetoFront(
    path=save_path / 'ckpts',
    fname_format='epoch={epoch}-acc={accuracy:.2%}-val_acc={val_accuracy:.2%}-EBOPs={ebops}.keras',
    metrics=['val_accuracy', 'ebops'],
    enable_if=lambda x: x['val_accuracy'] > 0.5,
    sides=[1, -1],
)

ebops = FreeEBOPs()
beta_sched = BetaScheduler(PieceWiseSchedule([[0, 1e-7, 'linear'], [1000, 5.0e-7, 'constant']]))

callbacks = [beta_sched, ebops, save, pbar, terminate_on_nan]

In [11]:
model.fit(train_data, epochs=200, validation_data=val_data, callbacks=callbacks, verbose=0)

  0%|          | 0/200 [00:00<?, ?epoch/s]

loss: 0.66/0.56 - acc: 80.66%/80.30% - beta: 1.80e-07 - EBOPs: 601,402: 100%|██████████| 200/200 [03:59<00:00,  1.20s/epoch]   


<keras.src.callbacks.history.History at 0x7fe198c23380>

In [12]:
model.load_weights(save.paths[0])

In [13]:
trace_minmax(model, train_data)
trace_minmax(model, val_data, reset=False)

In [14]:
model.evaluate(test_data)

[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 17ms/step - accuracy: 0.8065 - loss: 0.5566


[0.5568761825561523, 0.806607723236084]

In [15]:
from da4ml.codegen import VerilogModel
from da4ml.converter.hgq2.parser import trace_model
from da4ml.trace import HWConfig, comb_trace
from hls4ml.converters import convert_from_keras_model

In [16]:
inp, out = trace_model(model, solver_options={'hard_dc': 2}, hwconf=HWConfig(1, -1, -1))
solution = comb_trace(inp, out)
solution.save_binary('/tmp/emulator.bin')
verilog_model = VerilogModel(
    solution,
    prj_name='jet_classifier_large',
    path='/tmp/verilog_test',
    part_name='xcvu13p-flga2577-2-e',
    clock_period=2,
    clock_uncertainty=0.0,
    latency_cutoff=2,
)
verilog_model.write()

In [17]:
verilog_model

Top Module: jet_classifier_large
192 (1407 bits) -> 5 (86 bits)
28 stages @ max_delay=2
Estimated cost: 319680 LUTs, 253441 FFs
Emulator is **not compiled**

In [18]:
verilog_model._compile()

In [19]:
r_keras = model.predict(X_test, batch_size=26000, verbose=0)

In [20]:
r_verilog = verilog_model.predict(X_test)

In [21]:
np.all(r_keras == r_verilog)

np.True_

In [23]:
hls_config = {'Model': {'Strategy': 'distributed_arithmetic', 'Precision': 'fixed<-1,0>', 'ReuseFactor': 1}}
hls4ml_model = convert_from_keras_model(model, '/tmp/hls4ml_test', hls_config=hls_config)

In [24]:
hls4ml_model.compile()

In [25]:
r_hls4ml = hls4ml_model.predict(np.ascontiguousarray(X_test))

In [26]:
np.all(r_keras == r_hls4ml)

np.True_