In [1]:
# import gym
# import jax.numpy as jnp
# env = gym.make('CartPole-v1')

# states = []
# actions = []
# next_states = []

# for _ in range(1000):
#     obs, _ = env.reset()
#     done = False
#     while not done:
#         action = env.action_space.sample()
#         actions.append(action)
#         new_obs, _, terminated, truncated, _ = env.step(action)
#         done = terminated or truncated
#         next_states.append(new_obs)
#         states.append(obs)
#         obs = new_obs



In [2]:
# from chat GPT
import gym
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax

# from gymnax code
def save_pkl_object(obj, filename):
    """Helper to store pickle objects."""
    import pickle
    from pathlib import Path

    output_file = Path(filename)
    output_file.parent.mkdir(exist_ok=True, parents=True)

    with open(filename, "wb") as output:
        # Overwrites any existing file.
        pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL)

    print(f"Stored data at {filename}.")


def load_pkl_object(filename: str):
    """Helper to reload pickle objects."""
    import pickle

    with open(filename, "rb") as input:
        obj = pickle.load(input)
    print(f"Loaded data from {filename}.")
    return obj

class TransitionModel(nn.Module):
    @nn.compact
    def __call__(self, inputs):
        if len(inputs.shape) == 1:
            # inputs /= jnp.array([10, 1, 1, 1, 1])
            # inputs /= jnp.array([10, 1, 1, 1, 1])
            x_0 = nn.Dense(4)(inputs[: -1])
            x_1 = nn.Dense(4)(inputs[: -1])
            x_1 *= inputs[-1:] # if zero == 0, if one == 1
            x_0 *= (inputs[-1:] - 1) ** 2 # if zero == 1, if one == 0
            pred = x_0 + x_1
        else:
            # inputs /= jnp.array([10, 1, 1, 1, 1])
            # inputs /= jnp.array([10, 1, 1, 1, 1])
            x_0 = nn.Dense(4)(inputs[:, : -1])
            x_1 = nn.Dense(4)(inputs[:, : -1])
            x_1 *= inputs[:, -1:] # if zero == 0, if one == 1
            x_0 *= (inputs[:, -1:] - 1) ** 2 # if zero == 1, if one == 0
            pred = x_0 + x_1
        # inputs /= jnp.array([10, 1, 1, 1, 1])
        # x = nn.relu(nn.Dense(256)(inputs))
        # x = nn.relu(nn.Dense(256)(x))
        # x = nn.relu(nn.Dense(256)(x))
        # pred = nn.Dense(4)(x)
        return pred


# Define the optimizer (Adam with cosine learning rate schedule and weight decay)
optimizer = optax.chain(
    optax.adam(learning_rate=1e-4),
)

# Generate the input and output arrays
env = gym.make('CartPole-v1')
states = []
actions = []
next_states = []
for _ in range(1000):
    obs, _ = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        actions.append(action)
        new_obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_states.append(new_obs)
        states.append(obs)
        obs = new_obs
inputs_arr = jnp.concatenate([jnp.array(states, dtype=jnp.float32), jnp.array(actions, dtype=jnp.float32).reshape(-1, 1)], axis=1)
outputs_arr = jnp.array(next_states, dtype=jnp.float32)

network = TransitionModel()
params = network.init(jax.random.PRNGKey(0), inputs_arr[0])
# Create a train state object
train_state = train_state.TrainState.create(
    apply_fn=network.apply,
    params=params,
    tx=optimizer,
)

@jax.jit
def train_step(train_state, x, y):
# Define the loss function (mean squared error)
    def mse_loss(y_pred, y_true):
        diff = y_pred - y_true
        return jnp.mean(jnp.sum(diff**2, axis=-1))

    def loss_fn_wrapper(params):
        y_pred = train_state.apply_fn(params, x)
        return mse_loss(y_pred, y)
    grad_fn = jax.value_and_grad(loss_fn_wrapper)
    loss, grad = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grad)
    return train_state, loss

# Train the model
num_epochs = 500
batch_size = 32
num_batches = inputs_arr.shape[0] // batch_size

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in range(num_batches):
        batch_start = batch * batch_size
        batch_end = (batch + 1) * batch_size
        batch_inputs = inputs_arr[batch_start:batch_end] # / jnp.array([10, 1, 1, 1, 1])
        batch_outputs = outputs_arr[batch_start:batch_end] # / jnp.array([10, 1, 1, 1])
        train_state, batch_loss = train_step(train_state, batch_inputs, batch_outputs)
        epoch_loss += batch_loss
    epoch_loss /= num_batches
    print(f"Epoch {epoch+1} - Training Loss: {epoch_loss}")

save_pkl_object({"params": train_state.params}, "forward_model.pkl")


  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


Epoch 1 - Training Loss: 2.1585497856140137
Epoch 2 - Training Loss: 1.830747127532959
Epoch 3 - Training Loss: 1.550766944885254
Epoch 4 - Training Loss: 1.3095893859863281
Epoch 5 - Training Loss: 1.1013367176055908
Epoch 6 - Training Loss: 0.9216362237930298
Epoch 7 - Training Loss: 0.7670014500617981
Epoch 8 - Training Loss: 0.6344825029373169
Epoch 9 - Training Loss: 0.5215258598327637
Epoch 10 - Training Loss: 0.42594948410987854
Epoch 11 - Training Loss: 0.34591734409332275
Epoch 12 - Training Loss: 0.27981871366500854
Epoch 13 - Training Loss: 0.22610893845558167
Epoch 14 - Training Loss: 0.18318219482898712
Epoch 15 - Training Loss: 0.14934857189655304
Epoch 16 - Training Loss: 0.1229403167963028
Epoch 17 - Training Loss: 0.10246969014406204
Epoch 18 - Training Loss: 0.08672597259283066
Epoch 19 - Training Loss: 0.07476284354925156
Epoch 20 - Training Loss: 0.06582226604223251
Epoch 21 - Training Loss: 0.05924617126584053
Epoch 22 - Training Loss: 0.05442427843809128
Epoch 23 

KeyboardInterrupt: 

In [None]:
from flax.training import train_state

# Define the optimizer (Adam with cosine learning rate schedule and weight decay)
optimizer = optax.chain(
    optax.adam(learning_rate=1e-4),
)

# Generate the input and output arrays
env = gym.make('CartPole-v1')
states = []
actions = []
next_states = []
for _ in range(1000):
    obs, _ = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        actions.append(action)
        new_obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        next_states.append(new_obs)
        states.append(obs)
        obs = new_obs
inputs_arr = jnp.concatenate([jnp.array(next_states, dtype=jnp.float32), jnp.array(actions, dtype=jnp.float32).reshape(-1, 1)], axis=1)
outputs_arr = jnp.array(states, dtype=jnp.float32)

network = TransitionModel()
params = network.init(jax.random.PRNGKey(0), inputs_arr[0])
# Create a train state object
train_state = train_state.TrainState.create(
    apply_fn=network.apply,
    params=params,
    tx=optimizer,
)

@jax.jit
def train_step(train_state, x, y):
# Define the loss function (mean squared error)
    def mse_loss(y_pred, y_true):
        diff = y_pred - y_true
        return jnp.mean(jnp.sum(diff**2, axis=-1))

    def loss_fn_wrapper(params):
        y_pred = train_state.apply_fn(params, x)
        return mse_loss(y_pred, y)
    grad_fn = jax.value_and_grad(loss_fn_wrapper)
    loss, grad = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grad)
    return train_state, loss

# Train the model
num_epochs = 500
batch_size = 32
num_batches = inputs_arr.shape[0] // batch_size

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for batch in range(num_batches):
        batch_start = batch * batch_size
        batch_end = (batch + 1) * batch_size
        batch_inputs = inputs_arr[batch_start:batch_end] # / jnp.array([10, 1, 1, 1, 1])
        batch_outputs = outputs_arr[batch_start:batch_end] # / jnp.array([10, 1, 1, 1])
        train_state, batch_loss = train_step(train_state, batch_inputs, batch_outputs)
        epoch_loss += batch_loss
    epoch_loss /= num_batches
    print(f"Epoch {epoch+1} - Training Loss: {epoch_loss}")

save_pkl_object({"params": train_state.params}, "backward_model.pkl")

Epoch 1 - Training Loss: 2.2498042583465576
Epoch 2 - Training Loss: 1.894377589225769
Epoch 3 - Training Loss: 1.5968819856643677
Epoch 4 - Training Loss: 1.3440979719161987
Epoch 5 - Training Loss: 1.1280978918075562
Epoch 6 - Training Loss: 0.9433769583702087
Epoch 7 - Training Loss: 0.7855345010757446
Epoch 8 - Training Loss: 0.6508222222328186
Epoch 9 - Training Loss: 0.536070704460144
Epoch 10 - Training Loss: 0.43867701292037964
Stored data at backward_model.pkl.
