This repository defines tensorflow templates.
By using a parameter file to set the following items, you can train and evaluate in various combinations. You can also easily use other settings by implementing them in your own settings classes.
- data (and data preprocess)
- model (and model layer)
- optimizer method
- loss function
- metrics
- callbacks
We use Sphinx to create documentation for the implementation sources.
Please clone this repository and check it locally.
For more details about the build, please see here.
cd tensorflow-template/docs
make html -e SPHINXOPTS='-a -E -D language="en"' # "en" or "ja"
git clone https://github.com/r-dev95/tensorflow-template.git
You need to install uv
.
If you don't have a python development environment yet, see here.
cd tensorflow-template/
uv sync
cd src
uv run python dataset.py --result dataset --data mnist
uv run python train.py --param param/tutorial/param_train.yaml
uv run python eval.py --param param/tutorial/param_eval.yaml
This section describes how to use parameter files (.yaml
).
The parameter file is used in the following source code. The following source code can use some command line arguments, but they can be overwritten in the parameter file, so it is assumed that all parameters are set in the parameter file.
- train.py
- eval.py
Some of the settings can not be set using parameter file. In particular, detailed settings for tensorflow (keras) must be implemented by referring to the official tensorflow (keras) website.
Main parameters that are also implemented as command line arguments are set with zero indentation.
- The main parameters include
param
, but this is not set as it only works as a command line argument.
train.py
and eval.py
common settings example:
# log handler (idx=0: stream handler, idx=1: file handler)
# (True: set handler, False: not set handler)
# type: list[bool, bool]
handler: [True, True]
# log level (idx=0: stream handler, idx=1: file handler)
# (DEBUG: 10, INFO: 20, WARNING: 30, ERROR: 40, CRITICAL: 50)
# type: list[int, int]
level: [10, 10]
# flag (eager mode: true, graph mode: false)
# type: boolean
eager: false
# random seed
# type: int
seed: 0
# directory path (data save)
# type: str
result: result
only train.py
settings example:
# directory path (training data)
# type: str
train: data/mnist/train
# directory path (validation data)
# type: str | None
valid: data/mnist/test
# batch size (training data)
# type: int
train_batch: 32
# batch size (validation data)
# type: int | None
valid_batch: 1000
# shuffle size
# type: int | None
shuffle: null
# Number of epochs
# type: int
epochs: 2
only eval.py
settings example:
# directory path (evaluation data)
# type: str
eval: data/mnist/test
# batch size (training data)
# type: int
batch: 1000
For currently available data
, see the variable func
's key of the SetupData
class here.
data
settings example:
data:
kind: mnist
For currently available data preprocess
, see the variable func
's key of the Processor
class here.
-
The
kind
ofdata preprocess
is set as a list. -
If you set
catencode
tokind
, setcatencode
setting as shown in the following example. The same applies to the subsequent parameters.
data preprocess
settings example:
process:
kind: [catencode, rescale]
catencode:
num_tokens: &num_classes 10
output_mode: one_hot
sparse: false
rescale:
scale: 0.00392156862745098
offset: 0
For currently available model
, see the variable func
's key of the SetupModel
class here.
model
settings example:
model:
kind: simple
For currently available model layer
, see the variable func
's key of the SetupLayer
class here.
-
The
kind
ofmodel layer
is set as a list. -
The value of
kind
can have "_" + alphanumeric characters at the end.
model layer
settings example:
layer:
kind: [flatten, dense_1, relu, dense_2]
flatten:
data_format: channels_last
DENSE: &DENSE
units: null
activation: null
use_bias: true
kernel_initializer: glorot_uniform
bias_initializer: zeros
kernel_regularizer: null
bias_regularizer: null
activity_regularizer: null
kernel_constraint: null
bias_constraint: null
lora_rank: null
dense_1:
<<: *DENSE
units: 100
dense_2:
<<: *DENSE
units: *num_classes
relu:
max_value: null
negative_slope: 0
threshold: 0
For currently available optimizer method
, see the variable func
's key of the SetupOpt
class here.
- The
optimizer method
parameter is only used intrain.py
.
optimizer method
settings example:
opt:
kind: adam
adam:
learning_rate: 0.001
beta_1: 0.9
beta_2: 0.999
epsilon: 0.0000001
amsgrad: false
weight_decay: null
clipnorm: null
clipvalue: null
global_clipnorm: null
use_ema: false
ema_momentum: 0.99
ema_overwrite_frequency: null
loss_scale_factor: null
gradient_accumulation_steps: null
name: adam
For currently available loss function
, see the variable func
's key of the SetupLoss
class here.
loss function
settings example:
loss:
kind: cce
cce:
from_logits: true
label_smoothing: 0
axis: -1
reduction: sum_over_batch_size
name: categorical_crossentropy
# dtype: null
For currently available metrics
, see the variable func
's key of the SetupMetrics
class here.
- The
kind
ofmetrics
is set as a list.
metrics
settings example:
metrics:
kind: [mse, cacc]
mse:
name: mean_squared_error
# dtype: null
cacc:
name: categorical_accuracy
# dtype: null
For currently available callbacks
, see the variable func
's key of the SetupCallbacks
class here.
-
The
callbacks
parameter is only used intrain.py
. -
The
kind
ofcallbacks
is set as a list.
callbacks
settings example:
cb:
kind: [mcp, csv]
mcp:
# filepath: null # The "filepath" is fixed in the code.
monitor: val_loss
verbose: 0
save_best_only: false
save_weights_only: true
mode: auto
save_freq: epoch
initial_value_threshold: null
csv:
# filename: null # The "filename" is fixed in the code.
separator: ","
append: false
This repository is licensed under the Apache License 2.0.