<a href="https://colab.research.google.com/github/tmontaj/scripter/blob/main/Notebooks/wave2letter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import tensorflow as tf
import numpy as np

In [4]:
class FirstBlock(tf.keras.layers.Layer):
  def __init__(self, filters = 250, kernel_size = 48, strides = 2, **kwargs):
    super().__init__(**kwargs)
    self.conv = tf.keras.layers.Conv1D(filters, kernel_size, strides, padding='same')
    self.batch_norm = tf.keras.layers.BatchNormalization()
    self.relu = tf.keras.layers.ReLU()

  def call(self, input_):
    conv = self.conv(input_)
    batch_norm = self.batch_norm(conv)
    relu = self.relu(batch_norm)
    return relu

In [5]:
class MidBlock(tf.keras.layers.Layer):
  def __init__(self, filters = 250, kernel_size = 7, **kwargs):
    super().__init__(**kwargs)
    self.conv = tf.keras.layers.Conv1D(filters, kernel_size, padding='same')
    self.batch_norm = tf.keras.layers.BatchNormalization()
    self.relu = tf.keras.layers.ReLU()

  def call(self, input_):  
    conv = self.conv(input_)
    batch_norm = self.batch_norm(conv)
    relu = self.relu(batch_norm)
    return relu

In [7]:
class LastBlock(tf.keras.layers.Layer):
  def __init__(self, output_size = 40, **kwargs):
    super().__init__(**kwargs)
    self.conv1 = tf.keras.layers.Conv1D(filters = 2000, kernel_size = 32, padding='same')
    self.batch_norm = tf.keras.layers.BatchNormalization()
    self.relu = tf.keras.layers.ReLU()
    self.conv2 = tf.keras.layers.Conv1D(filters = 2000, kernel_size = 1, padding='same')
    self.conv3 = tf.keras.layers.Conv1D(filters = output_size, kernel_size = 1, padding='same')

  def call(self, input_):  
    conv1 = self.conv1(input_)
    batch_norm1 = self.batch_norm(conv1)
    relu1 = self.relu(batch_norm1)

    conv2 = self.conv1(relu1)
    batch_norm2 = self.batch_norm(conv2)
    relu2 = self.relu(batch_norm2)

    conv3 = self.conv1(relu2)
    batch_norm3 = self.batch_norm(conv3)
    relu3 = self.relu(batch_norm3)
    return relu3

In [8]:
class Wav2Let(tf.keras.Model):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    first_block = FirstBlock()
    mid_block1 = MidBlock()
    mid_block2 = MidBlock()
    mid_block3 = MidBlock()
    mid_block4 = MidBlock()
    mid_block5 = MidBlock()
    mid_block6 = MidBlock()
    mid_block7 = MidBlock()
    last_block = LastBlock()
        
  def call(self, input_):
    first_block = self.first_block(input_)
    mid_block1 = self.mid_block1(first_block)
    mid_block2 = self.mid_block2(mid_block1)
    mid_block3 = self.mid_block3(mid_block2)
    mid_block4 = self.mid_block4(mid_block3)
    mid_block5 = self.mid_block5(mid_block4)
    mid_block6 = self.mid_block6(mid_block5)
    mid_block7 = self.mid_block7(mid_block6)
    last_block = self.last_block(mid_block7)
    return last_block 


In [9]:
model = Wav2Let()