In [1]:
from res.sequential_tasks import TemporalOrderExp6aSequence as QRSU

In [2]:
# Create a data generator
example_generator = QRSU.get_predefined_generator(
    difficulty_level=QRSU.DifficultyLevel.EASY,
    batch_size=32,
)

In [3]:
example_batch = example_generator[1]

In [5]:
example_batch[0].shape

(32, 9, 8)

In [6]:
example_batch[1].shape

(32, 4)

In [7]:
example_batch = example_generator[1]
print(f'The return type is a {type(example_batch)} with length {len(example_batch)}.')
print(f'The first item in the tuple is the batch of sequences with shape {example_batch[0].shape}.')
print(f'The first element in the batch of sequences is:\n {example_batch[0][0, :, :]}')
print(f'The second item in the tuple is the corresponding batch of class labels with shape {example_batch[1].shape}.')
print(f'The first element in the batch of class labels is:\n {example_batch[1][0, :]}')

The return type is a <class 'tuple'> with length 2.
The first item in the tuple is the batch of sequences with shape (32, 9, 8).
The first element in the batch of sequences is:
 [[0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0]
 [1 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 1]]
The second item in the tuple is the corresponding batch of class labels with shape (32, 4).
The first element in the batch of class labels is:
 [0. 1. 0. 0.]


In [17]:
# First raw element
print('First raw element of 1st batch:\n {}'.format(example_batch[0][0]))
# Decoding the first sequence
sequence_decoded = example_generator.decode_x(example_batch[0][0])
print(f'The sequence is: {sequence_decoded}')

# Decoding the class label of the first sequence
class_label_decoded = example_generator.decode_y(example_batch[1][0])
print(f'The class label is: {class_label_decoded}')

First raw element of 1st batch:
 [[0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0]
 [1 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 0 0 0 1 0 0 0]
 [0 1 0 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 1 0 0 0 0 0]
 [0 0 0 0 0 0 0 1]]
The sequence is: BXccYaaE
The class label is: R


In [18]:
# Second raw element
print('First raw element of 1st batch:\n {}'.format(example_batch[0][1]))
# Decoding the first sequence
sequence_decoded = example_generator.decode_x(example_batch[0][1])
print(f'The sequence is: {sequence_decoded}')

# Decoding the class label of the first sequence
class_label_decoded = example_generator.decode_y(example_batch[1][1])
print(f'The class label is: {class_label_decoded}')

First raw element of 1st batch:
 [[0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0]
 [0 0 0 0 0 1 0 0]
 [1 0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0 0]
 [1 0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 1]]
The sequence is: BdXdXbE
The class label is: Q


### 2. Defining the Model

In [19]:
import tensorflow as tf
class SimpleRNN(tf.keras.Model):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        self.rnn = tf.keras.layers.SimpleRNN(units = hidden_size, activation='relu',
                                            input_dim = input_size)
        self.linear = tf.keras.layers.Dense(units = output_size, input_dim = hidden_size)
        
    def call(self, x):
        h = self.rnn(x)
        x = self.linear(h)
        return x
    
class SimpleLSTM(tf.keras.Model):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        
        self.lstm = tf.keras.layers.LSTM(units = hidden_size,
                                            input_dim = input_size)
        self.linear = tf.keras.layers.Dense(units = output_size, input_dim = hidden_size)
        
    def call(self, x):
        h = self.lstm(x)
        x = self.linear(h)
        return x
    
    def get_states_across_time(self, x):
        h_c = None
        h_list, c_list = list(), list()
        for t in range(x.shape[1]):
            h_c = self.lstm(x[:, [t], :], h_c)
            h_list.append(h_c)
            c_list.append(h_c)
        h = tf.concat(h_list)
        c = tf.concat(c_list)
        return h, c

In [20]:
def train(model, train_data_gen, criterion, optimizer):
    num_correct = 0
    
    for batch_idx in range(len(train_data_gen)):
        data, target = train_data_gen[batch_idx]
        
        with tf.GradientTape() as tape:
            output = model(data)
            loss = criterion(output, target)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        y_pred = output.argmax(axis=1)
        num_correct += (y_pred == target).sum()
        
    return num_correct, loss
            

In [21]:
def test(model, test_data_gen, criterion):
    num_correct = 0
    
    for batch_idx in range(len(test_data_gen)):
        data, target = test_data_gen[batch_idx]
        output = model(data)
        target = target.argmax(axis=1)
        loss = criterion(output, target)
        y_pred = output.argmax(axis=1)
        num_correct += (y_pred == target).sum()
    return num_correct, loss

In [22]:
import matplotlib.pyplot as plt
from res.plot_lib import set_default, plot_state, print_colourbar

In [23]:
set_default()