Skip to content

Commit

Permalink
Merge pull request #233 from philipperemy/plot_tcn_model
Browse files Browse the repository at this point in the history
plot TCN model - fix
  • Loading branch information
philipperemy committed May 11, 2022
2 parents 8801791 + 4b35518 commit 05cfa3d
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 30 deletions.
41 changes: 21 additions & 20 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Keras TCN CI

on: [push, pull_request]
on: [ push, pull_request ]

jobs:
build:
Expand All @@ -9,24 +9,25 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.8]
python-version: [ 3.8 ]

steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install tox flake8
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity 10 --max-line-length 127 --statistics
- name: Test with tox
run: |
tox
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -y graphviz
python -m pip install --upgrade pip
pip install tox flake8
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity 10 --max-line-length 127 --statistics
- name: Test with tox
run: |
tox
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
keract
matplotlib
pydot
25 changes: 25 additions & 0 deletions tasks/plot_tcn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from tcn import TCN
import tensorflow as tf

timesteps = 32
input_dim = 5
input_shape = (timesteps, input_dim)
forecast_horizon = 3
num_features = 4

inputs = tf.keras.layers.Input(shape=input_shape, name='input')
tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='LeakyReLU')(inputs)
outputs = tf.keras.layers.Dense(forecast_horizon * num_features, activation='linear')(tcn_out)
outputs = tf.reshape(outputs, shape=(-1, forecast_horizon, num_features), name='ouput')
model = tf.keras.Model(inputs=inputs, outputs=outputs)

tf.keras.utils.plot_model(
model,
to_file='TCN_model.png',
show_shapes=True,
show_dtype=True,
show_layer_names=True,
rankdir='TB',
dpi=200,
layer_range=None,
)
13 changes: 7 additions & 6 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import inspect
from typing import List

# pylint: disable=E0611,E0401
from tensorflow.keras import backend as K, Model, Input, optimizers
# pylint: disable=E0611
# pylint: disable=E0611,E0401
from tensorflow.keras import layers
# pylint: disable=E0611
# pylint: disable=E0611,E0401
from tensorflow.keras.layers import Activation, SpatialDropout1D, Lambda
# pylint: disable=E0611
# pylint: disable=E0611,E0401
from tensorflow.keras.layers import Layer, Conv1D, Dense, BatchNormalization, LayerNormalization


Expand Down Expand Up @@ -229,7 +230,7 @@ def __init__(self,
self.nb_stacks = nb_stacks
self.kernel_size = kernel_size
self.nb_filters = nb_filters
self.activation = activation
self.activation_name = activation
self.padding = padding
self.kernel_initializer = kernel_initializer
self.use_batch_norm = use_batch_norm
Expand Down Expand Up @@ -280,7 +281,7 @@ def build(self, input_shape):
nb_filters=res_block_filters,
kernel_size=self.kernel_size,
padding=self.padding,
activation=self.activation,
activation=self.activation_name,
dropout_rate=self.dropout_rate,
use_batch_norm=self.use_batch_norm,
use_layer_norm=self.use_layer_norm,
Expand Down Expand Up @@ -362,7 +363,7 @@ def get_config(self):
config['use_skip_connections'] = self.use_skip_connections
config['dropout_rate'] = self.dropout_rate
config['return_sequences'] = self.return_sequences
config['activation'] = self.activation
config['activation'] = self.activation_name
config['use_batch_norm'] = self.use_batch_norm
config['use_layer_norm'] = self.use_layer_norm
config['use_weight_norm'] = self.use_weight_norm
Expand Down
14 changes: 10 additions & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
envlist = {py3}-tensorflow-{2.6.2,2.7.0,2.8.0,2.9.0}

[testenv]
deps = -rrequirements.txt
deps = pytest
pylint
flake8
-rrequirements.txt
tensorflow-2.6.2: tensorflow==2.6.2
tensorflow-2.7.0: tensorflow==2.7.0
tensorflow-2.8.0: tensorflow==2.8.0
tensorflow-2.9.0: tensorflow==2.9.0rc2
changedir = tasks/
commands = python tcn_call_test.py
commands = pylint --disable=R,C,W,E1136 ../tcn
flake8 ../tcn --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 ../tcn --count --exclude=michel,tests --max-line-length 127 --statistics
python tcn_call_test.py
python save_reload_sequential_model.py
python sequential.py
python multi_length_sequences.py
python plot_tcn_model.py
passenv = *
install_command = pip install {packages}

install_command = pip install {packages}

0 comments on commit 05cfa3d

Please sign in to comment.