In [6]:
# NN model to precict true state given observations
# Input: obervation: observations in some time step of the game
#        legal_masks: possible child states
# output size: get_all_states() see imperfect information games notebook
# I am not sure about the loss function for now. So, this notebook is in progress. Not fully completed


import collections
import functools
import os
from typing import Sequence

import numpy as np
import tensorflow.compat.v1 as tf

def cascade(x, fns):
    for fn in fns:
        x = fn(x)
    return x

tfkl = tf.keras.layers

class TrainInput(collections.namedtuple("TrainInput", "observation legals_mask state")):
    # Inputs for training the Model.

    @staticmethod
    def stack(train_inputs):
        observation, legals_mask, state = zip(*train_inputs)
        return TrainInput(
                        np.array(observation, dtype=np.float32),
                        np.array(legals_mask, dtype=np.bool),
                        np.array(state))


class Losses(collections.namedtuple("Losses", "state l2")):
# Losses from a training step.

    @property
    def total(self):
        return self.state + self.l2

    def __str__(self):
        return ("Losses(total: {:.3f}, state: {:.3f},"
                    "l2: {:.3f})").format(self.total, self.state, self.l2)

    def __add__(self, other):
        return Losses(self.state + other.state,
                      self.l2 + other.l2)

    def __truediv__(self, n):
        return Losses(self.state / n, self.l2 / n)

class Model(object):

    def __init__(self, session, saver, path):
        # Init a model. Use build_model, from_checkpoint or from_graph instead."""
        self._session = session
        self._saver = saver
        self._path = path
        
        def get_var(name):
            return self._session.graph.get_tensor_by_name(name + ":0")

        self._input = get_var("input")
        self._legals_mask = get_var("legals_mask")
        self._training = get_var("training")
        self._state_softmax = get_var("state_softmax")
        self._state_loss = get_var("state_loss")
        self._l2_reg_loss = get_var("l2_reg_loss")
        self._state_targets = get_var("state_targets")
        self._train = self._session.graph.get_operation_by_name("train")

    @classmethod
    def build_model(cls, model_type, input_shape, output_size, nn_width, nn_depth,
                  weight_decay, learning_rate, path):
        
        # Build a model with the specified params.

        g = tf.Graph() 
        with g.as_default():
            cls._define_graph(model_type, input_shape, output_size, nn_width,
                              nn_depth, weight_decay, learning_rate)
            init = tf.variables_initializer(tf.global_variables(),
                                      name="init_all_vars_op")
            with tf.device("/cpu:0"):  # Saver only works on CPU.
                saver = tf.train.Saver(
                    max_to_keep=10000, sharded=False, name="saver")
        session = tf.Session(graph=g)
        session.__enter__()
        session.run(init)
        return cls(session, saver, path)

    @classmethod
    def from_checkpoint(cls, checkpoint, path=None):
        #Load a model from a checkpoint."""
        model = cls.from_graph(checkpoint, path)
        model.load_checkpoint(checkpoint)
        return model

    @staticmethod
    def _define_graph(model_type, input_shape, output_size,
                    nn_width, nn_depth, weight_decay, learning_rate):
        
        # Define the model graph.
        # Inference inputs
        input_size = int(np.prod(input_shape))
        observations = tf.placeholder(tf.float32, [None, input_size], name="input")
        legals_mask = tf.placeholder(tf.bool, [None, output_size],
                                     name="legals_mask")
        training = tf.placeholder(tf.bool, name="training")

        bn_updates = []

        torso = observations  # Ignore the input shape, treat it as a flat array.
        for i in range(nn_depth):
            torso = cascade(torso, [
                tfkl.Dense(nn_width, name=f"torso_{i}_dense"),
                tfkl.Activation("relu"),
            ])

        # The state head
        state_head = cascade(torso, [
              tfkl.Dense(nn_width, name="state_dense"),
              tfkl.Activation("relu"),
          ])
    
        state_logits = tfkl.Dense(output_size, name="state")(state_head)
        state_logits = tf.where(legals_mask, state_logits,
                                 -1e32 * tf.ones_like(state_logits))
        state_softmax = tf.identity(tfkl.Softmax()(state_logits),
                                     name="state_softmax")
        state_targets = tf.placeholder(
            shape=[None, output_size], dtype=tf.float32, name="state_targets")
        state_loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=state_logits, labels=state_targets),
            name="state_loss")

        l2_reg_loss = tf.add_n([
            weight_decay * tf.nn.l2_loss(var)
            for var in tf.trainable_variables()
            if "/bias:" not in var.name
        ], name="l2_reg_loss")

        total_loss = state_loss + l2_reg_loss
        
        optimizer = tf.train.AdamOptimizer(learning_rate)
        
        with tf.control_dependencies(bn_updates):
            train = optimizer.minimize(total_loss, name="train")

    @property
    def num_trainable_variables(self):
        return sum(np.prod(v.shape) for v in tf.trainable_variables())

    def print_trainable_variables(self):
        for v in tf.trainable_variables():
            print("{}: {}".format(v.name, v.shape))

    def inference(self, observation, legals_mask):
        return self._session.run(
                            self._state_softmax,
                            feed_dict={self._input: np.array(observation, dtype=np.float32),
                               self._legals_mask: np.array(legals_mask, dtype=np.bool),
                               self._training: False})

    def update(self, train_inputs: Sequence[TrainInput]):
        # Runs a training step.
        batch = TrainInput.stack(train_inputs)

        # Run a training step and get the losses.
        _, state_loss, l2_reg_loss = self._session.run(
                [self._train, self._state_loss, self._l2_reg_loss],
                feed_dict={self._input: batch.observation,
                   self._legals_mask: batch.legals_mask,
                   self._state_targets: batch.state,
                   self._training: True})

        return Losses(state_loss, l2_reg_loss)

    def save_checkpoint(self, step):
        return self._saver.save(
            self._session,
            os.path.join(self._path, "checkpoint"),
            global_step=step)

    def load_checkpoint(self, path):
        return self._saver.restore(self._session, path)