In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from tqdm.notebook import tqdm
import numpy as np
from matplotlib import pyplot as plt
import modelutils_v2 as modelutils
import copy

The input (at least the state) into the network is 11650 elements long. The action vector is npeople long. The critic network maps the state and action vectors to a scalar. The actor network maps a state to an action. The action output will be the theta of a bernoulli trial to determine if an individual will be tested.

In [None]:
npeople = 100
inputlen = 11650+npeople #we'll need to change the 11650 is the underlying model changes

In [None]:
class Critic(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d1 = Dense(inputlen+npeople, activation='relu')
    self.d2 = Dense(12000,activation='relu')
    self.d3 = Dense(5000,activation='relu')
    self.d4 = Dense(1000,activation='relu')
    self.d5 = Dense(500)
    self.d6 = Dense(100)
    self.dout = Dense(1)

  def call(self, x):
    x = self.d1(x)
    x = self.d2(x)
    x = self.d3(x)
    x = self.d4(x)
    x = self.d5(x)
    x = self.d6(x)
    return self.dout(x)

class Actor(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.d1 = Dense(inputlen, activation='relu')
    self.d2 = Dense(12000,activation='relu')
    self.d2 = Dense(12000,activation='relu')
    self.d3 = Dense(5000,activation='relu')
    self.d4 = Dense(1000,activation='relu')
    self.d5 = Dense(500)
    self.d6 = Dense(250)
    self.dout = Dense(npeople)

  def call(self, x):
    x = self.d1(x)
    x = self.d2(x)
    x = self.d3(x)
    x = self.d4(x)
    x = self.d5(x)
    x = self.d6(x)
    return self.dout(x)

# Create an instance of the model
critic_target = Critic()
critic_raw = Critic()

actor_target = Actor()
actor_raw = Actor()

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    
    # training=True is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    
    predictions = model(images, training=True)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

In [None]:
@tf.function
def test_step(images, labels):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss(t_loss)
  test_accuracy(labels, predictions)

In [None]:
episodes = 5

for episode_idx in tqdm(range(episodes)):
  # Reset the metrics at the start of the next epoch
#   train_loss.reset_states()
#   train_accuracy.reset_states()
#   test_loss.reset_states()
#   test_accuracy.reset_states()

  for images, labels in train_ds:
    train_step(images, labels)

  for test_images, test_labels in test_ds:
    test_step(test_images, test_labels)

  template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
  print(template.format(epoch + 1,
                        train_loss.result(),
                        train_accuracy.result() * 100,
                        test_loss.result(),
                        test_accuracy.result() * 100))