In [None]:
import tensorflow as tf
from keras import layers, models

def create_value_model():
    model = models.Sequential([
        layers.Input(shape=(121,)),
        layers.Dense(128, activation='relu'),
        layers.Dense(64, activation='relu'),
        layers.Dense(1)  # Predict single value
    ])
    return model


In [None]:
def train_value_model(env_class, num_players=3, episodes=1000, search_depth=2):
    model = create_value_model()
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
    batch_size = 64
    replay_buffer = []

    for episode in range(episodes):
        env = env_class()
        game_memory = []

        while not env.is_terminal():
            player_id = env.get_current_player()
            obs = env.get_observation(player_id)

            _, move = maxn_alpha_beta_nn(
                env,
                depth=0,
                max_depth=search_depth,
                num_players=num_players,
                value_net=model,
                alpha=np.full(num_players, -np.inf),
                beta=np.full(num_players, np.inf)
            )

            game_memory.append((obs, player_id))
            env.apply_move(player_id, move)

        final_scores = env.get_final_scores()
        for obs, pid in game_memory:
            replay_buffer.append((obs, final_scores[pid]))

        if len(replay_buffer) > 5000:
            replay_buffer = replay_buffer[-5000:]

        if len(replay_buffer) >= batch_size:
            import random
            batch = random.sample(replay_buffer, batch_size)
            batch_obs = tf.convert_to_tensor([b[0] for b in batch], dtype=tf.float32)
            batch_targets = tf.convert_to_tensor([b[1] for b in batch], dtype=tf.float32)

            with tf.GradientTape() as tape:
                preds = tf.squeeze(model(batch_obs), axis=1)
                loss = tf.reduce_mean(tf.square(preds - batch_targets))

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            if episode % 50 == 0:
                print(f"Episode {episode} | Training loss: {loss.numpy():.4f}")

    return model


In [None]:
def run_inference_game(env_class, value_net, search_depth=2, num_players=3, render=False):
    env = env_class()
    move_count = 0

    while not env.is_terminal():
        player_id = env.get_current_player()
        obs = env.get_observation(player_id)

        _, move = maxn_alpha_beta_nn(
            env,
            depth=0,
            max_depth=search_depth,
            num_players=num_players,
            value_net=value_net,
            alpha=np.full(num_players, -np.inf),
            beta=np.full(num_players, np.inf)
        )

        if render:
            print(f"\nPlayer {player_id} plays {move}")
            env.render()  # Optional: only if your env supports it

        env.apply_move(player_id, move)
        move_count += 1

    final_scores = env.get_final_scores()
    print("\n🎉 Game Over")
    for pid, score in enumerate(final_scores):
        print(f"Player {pid}: Score = {score}")
    print(f"Total Moves: {move_count}")
    return final_scores


In [None]:
trained_model = train_value_model(env_class=YourChineseCheckersEnv)

run_inference_game(env_class=YourChineseCheckersEnv, value_net=trained_model, search_depth=2, render=True)
