In [5]:
import numpy as np
import pandas as pd 
import tensorflow as tf

from tensorflow.keras import backend as K
from tensorflow.keras import Model, Input, optimizers
from tensorflow.keras.layers import Dense

In [6]:
from layers.spatial_block import SpatialBlock
from layers.temporal_block import ResidualBlock

### Dataset

In [7]:
in_dim = 50
out_dim = 3
seq_len = 30

data = np.random.rand(800, seq_len, in_dim)
labels = np.random.rand(800, seq_len, out_dim)

### WattNet

In [9]:
def create_model(in_dim: int, out_dim: int, w_dim: int = 128, emb_dim: int = 8,
                 dilation_depth: int = 4, dropout_prob: float = 0.2, n_stacks: int = 2):
    """
    Creates a WattNet keras model. 
    
    Parameters
    ----------
    w_dim: int
        Spatial compression dimension carried out by a 2-layer MLP 
    emb_dim: int
        Embedding dimension of scalar values for each of the `w_dim` left after compression
    dilation_depth: int
        Number of temporal-spatial blocks. Dilation for temporal dilated convolution is doubled each time
    n_repeat: int
        Number of repeats of #`dilation_depth` of temporal-spatial layers
    """
                        
    # Gated Block
    residual_blocks = []
    dilations = [2 ** i for i in range(1, dilation_depth+1)]
    
    # Pre MLP
    x_in = Input(shape=(seq_len, in_dim,), name='input_features')
    
    x_out = x_in
        
    
    x_out = Dense(units=128, activation='relu', name=f'preMLP_{0}')(x_out)
    x_out = Dense(units=128, activation='relu', name=f'preMLP_{1}')(x_out)
    x_out = Dense(units=128, activation='relu', name=f'preMLP_{2}')(x_out)
    x_out = Dense(units=w_dim, activation='relu', name=f'preMLP_{3}')(x_out)
            
    for s in range(n_stacks):
        for di, d in enumerate(dilations):            

            # Gated Block 
            x_out = ResidualBlock(dilation=d, filters=w_dim, kernel_size=2, groups=w_dim, 
                                  name=f'ResidualBlock_{s}_{d}')(x_out)            

            # Attention - Slice across temporal dimension
            x_out = SpatialBlock(stack=s, dilation=d, w_dim=w_dim, 
                                 name=f'SpatialBlock_{s}_{d}')(x_out)
            
    # Post MLP
    #x_out = Flatten()(x_out)
    
    x_out = Dense(units=512, activation='relu', name=f'postMLP_{0}')(x_out)
    x_out = Dense(units=out_dim, activation='softmax', name=f'postMLP_{1}')(x_out)
    
    model = Model(inputs=[x_in], outputs=[x_out], name='WattNet')
    model.compile(optimizer='Adam', loss='categorical_crossentropy')

    return model            

In [10]:
model = create_model(in_dim=in_dim, out_dim=out_dim, w_dim=90, emb_dim=1,
                     dilation_depth=2, dropout_prob=0.2, n_stacks=4)

model.summary()

Train & Evaluate

In [None]:
model.fit(x=data, y=labels)