In [1]:
import keras
import numpy as np
import tensorflow as tf
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt

2025-04-13 10:53:52.296068: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744505632.309947  136304 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744505632.314384  136304 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744505632.325473  136304 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744505632.325490  136304 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744505632.325491  136304 computation_placer.cc:177] computation placer alr

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
              tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [99]:
class ReplayBuffer:

    def __init__(self, max_size: int, observation_shape: tuple[int, ...], action_shape: tuple[int, ...], seed: int | None = None):
        """Stores the replay history with a maximum of `max_size` entries, removing old entries as needed.

        Parameters:
            max_size: maximal number of entries to keep
            observation_space: specification of the observation space
            action_space: specification of the action space
            seed: seed to initialize the internal random number generator for reproducibility"""

        self.current_observations = np.zeros((max_size, *observation_shape), dtype=np.float32)
        self.next_observations = np.zeros((max_size, *observation_shape), dtype=np.float32)
        self.actions = np.zeros((max_size,), dtype=np.float32)
        self.rewards = np.zeros((max_size,), dtype=np.float32)
        self.dones = np.zeros((max_size,), dtype=np.float32)
        
        self.max_size = max_size
        self.rng = np.random.default_rng(seed=seed)
        self.buffer_pointer = 0
        self.current_size = 0
        
        
    def add(self, current_observations: np.ndarray, actions: np.ndarray, rewards: np.ndarray, next_observations: np.ndarray, dones: np.ndarray) -> None:
        """Add a new entry to the buffer.

        Parameters:
            current_observations: environment state observed at the current step
            actions: action taken by the model
            rewards: reward received after taking the action
            next_observations: environment state obversed after taking the action
            dones: whether the episode has ended or not"""
        batch_size = current_observations.shape[0]
        idxs = (np.arange(batch_size) + self.buffer_pointer) % self.max_size
        
        self.current_observations[idxs] = current_observations
        self.actions[idxs] = actions
        self.rewards[idxs] = rewards
        self.next_observations[idxs] = next_observations
        self.dones[idxs] = dones

        self.buffer_pointer = (idxs[-1] + 1) % self.max_size
        self.current_size = min(self.max_size, self.current_size + batch_size)
        
    def sample(self, n_samples: int, replace: bool = True) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Randomly samples `n_samples` from the buffer.

        Parameters:
            n_samples: number of samples to select
            replace: sample with or without replacement

        Returns:
            current observations, actions, rewards, next observations, done"""
        return self[self.rng.choice(self.current_size, size=n_samples, replace=replace)]

    def clear(self) -> None:
        """Clears the buffer"""
        self.buffer_pointer = 0
        self.current_size = 0

    def __getitem__(self, index: int | np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Gets a sample at `index`

        Parameters:
            index: index of the sample to get

        Returns:
            current observation, action, reward, next observation, done"""
        return (
            self.current_observations[index],
            self.actions[index],
            self.rewards[index],
            self.next_observations[index],
            self.dones[index]
        )
        
    def __len__(self) -> int:
        """Returns the number of entries in the buffer"""
        return self.current_size

In [100]:
def get_name(prefix: str | None, suffix: str | None, separator: str = '_'):
    return prefix and prefix + separator + suffix

In [101]:
def get_model(
    observation_shape: tuple[int, ...],
    features: int,
    blocks: int, 
    activation: str | keras.layers.Activation | None = 'silu',
    dropout_rate: float = 0.,
    multiply_freq: int = 1,
    name: str | None = None,
    output_activation: str = 'linear',
) -> keras.Model:

    inputs = x = keras.layers.Input(observation_shape, name=get_name(name, 'input'))

    for i in range(blocks):
        x = keras.layers.Dense(features, name=get_name(name, f'dense_{i}'))(x)
        x = keras.layers.BatchNormalization(name=get_name(name, f'norm_{i}'))(x)
        x = keras.layers.Activation(activation, name=get_name(name, f'activation_{i}'))(x)

        features //= 2

    x = keras.layers.Dense(1, activation=output_activation, name=get_name(name, 'output'))(x)
    x = keras.layers.Reshape((), name=get_name(name, 'flatten_output'))(x)
    return keras.models.Model(inputs=inputs, outputs=x, name=name)   

In [102]:
def get_actor(
    observation_shape: tuple[int, ...],
    features: int,
    blocks: int, 
    activation: str | keras.layers.Activation | None = 'silu',
    dropout_rate: float = 0.,
    multiply_freq: int = 1,
    name: str | None = None
) -> keras.Model:
    
    return get_model(
        observation_shape,
        features,
        blocks, 
        activation,
        dropout_rate,
        multiply_freq,
        name,
        'tanh'
    )    

In [103]:
def get_critic(
    observation_shape: tuple[int, ...],
    features: int,
    blocks: int, 
    activation: str | keras.layers.Activation | None = 'silu',
    dropout_rate: float = 0.,
    multiply_freq: int = 1,
    name: str | None = None
) -> keras.Model:
    inputs_obs = keras.layers.Input(observation_shape, name=get_name(name, 'obs_input'))
    inputs_action = keras.layers.Input((1, ), name=get_name(name, 'act_input'))

    x = keras.layers.Concatenate(name=get_name(name, 'concat'))([inputs_obs, inputs_action])

    backbone = get_model(
        tuple(x.shape[1:]),
        features,
        blocks, 
        activation,
        dropout_rate,
        multiply_freq,
        get_name(name, 'backbone')
    )

    x = backbone(x)

    return keras.models.Model(inputs=[inputs_obs, inputs_action], outputs=x, name=name)     

In [105]:
class Sampler:
    def __init__(self, sigma: float, seed: int | None = None):
        """Selects a random action with probability `epsilon` otherwise selects the most probably action given by the model.

        Parameters:
            epsilon: the probability to select a random action
            seed: seed to initialize the internal random number generator for reproducibility"""
        self.sigma = sigma
        self.rng = np.random.default_rng(seed=seed)

    def __call__(self, actions: np.ndarray) -> np.ndarray:
        """Select an action given the `probabilities

        Parameters:
            probabilities: probabilities for each action

        Returns:
            index of the selected action"""
        batch, *_ = actions.shape

        return np.clip(actions + self.rng.normal(scale=self.sigma, size=batch), -1, 1)

In [106]:
def play_game(
    model: keras.Model,
    buffer: ReplayBuffer,
    env: Environment,
    steps: int,
    sampler: Sampler,
    observations: np.ndarray | None = None,
    one_episode: bool = False
) -> np.ndarray:
    """Plays the environment `env` using model `model` for a total of `steps` steps.

    Arguments:
        model: model to use
        buffer: buffer to store results to
        env: environment to use
        steps: total number of steps to record
        sampler: sampler to use
        observations: observation to start from
        one_episode: exist as soon as one of the environments finishes

    Returns:
        the last observations
    """
    if observations is None:
        observations = env.reset()

    for _ in range(steps // env.n_env):
        res = model(observations, training=False).numpy()
        actions = sampler(res)
        new_observations, rewards, terminateds = env.step(actions) # new_observations, rewards, terminateds, truncated, info

        dones = terminateds

        buffer.add(
            current_observations=observations,
            actions=actions,
            rewards=rewards,
            next_observations=new_observations,
            dones=dones,
        )
        
        if one_episode and np.any(dones):
            observations = None
            break
        
        observations = new_observations
    
    return observations

In [107]:
class QDataset(keras.utils.PyDataset):

    def __init__(
        self,
        steps_per_epoch: int,
        batch_size: int,
        buffer: ReplayBuffer,
    ):
        super().__init__()
        self.buffer = buffer
        self.batch_size = batch_size
        self.steps_per_epoch = steps_per_epoch

        self._answer = np.zeros(batch_size, dtype=np.float32)

    def __len__(self):
        return self.steps_per_epoch

    def __getitem__(self, idx: int) -> np.ndarray:
        return self.buffer.sample(self.batch_size), self._answer

In [113]:
buffer_size = 2 ** 18     # 262_144
steps_per_play = 2 ** 17  # 131_072 ~ 1.25 полных игр 
steps_per_epoch = 2 ** 10 # 1024
batch_size = 512

end_sigma = 0.1
epochs = 128
decay_epochs = epochs // 2
update_frequency = 512
play_frequency = 128

gamma = 0.99
rho = 0.9

In [114]:
sigma = 1 / 3
sampler = Sampler(sigma)

In [115]:
sigma_decay = keras.optimizers.schedules.PolynomialDecay(sigma, decay_epochs, end_learning_rate=end_sigma)

In [116]:
actor_model_current = get_actor((4, ), 256, 5, name='fishing_actor_current')
actor_model_target = get_actor((4, ), 256, 5, name='fishing_actor_target')
actor_model_target.set_weights(actor_model_current.get_weights())
actor_model_target.trainable = False

critic_model_current = get_critic((4, ), 256, 5, name='fishing_critic_current')
critic_model_target = get_critic((4, ), 256, 5, name='fishing_critic_target')
critic_model_target.set_weights(critic_model_current.get_weights())
critic_model_target.trainable = False

In [117]:
actor_model_current.summary()

In [118]:
critic_model_current.summary()

In [119]:
optimizer_critic = keras.optimizers.AdamW(learning_rate=1e-4)
optimizer_critic.build(critic_model_current.trainable_weights)

optimizer_actor = keras.optimizers.AdamW(learning_rate=1e-4)
optimizer_actor.build(actor_model_current.trainable_weights)
#loss = ...

In [120]:
train_env = Environment(n_env=batch_size)
val_env = Environment(n_env=batch_size)

In [121]:
train_buffer = ReplayBuffer(
    max_size=buffer_size,
    observation_shape=(4, ),
    action_shape=()
)

val_buffer = ReplayBuffer(
    max_size=batch_size * val_env.max_iter,
    observation_shape=(4, ),
    action_shape=()
)

In [122]:
dataset = QDataset(
    steps_per_epoch=steps_per_epoch,
    batch_size=batch_size,
    buffer=train_buffer,
)

In [123]:
def update_weights(model: keras.Model, target_model: keras.Model):
    target_model.set_weights([rho * target_w + (1 - rho) * w for w, target_w in zip(model.get_weights(), target_model.get_weights())])

In [129]:
pbar = tqdm.trange(epochs)
last_observation = None

for epoch in pbar:
    
    sampler.sigma = sigma_decay(epoch).numpy()
    
    last_observation = play_game(
        actor_model_current,
        train_buffer,
        train_env,
        steps_per_play,
        sampler, 
        last_observation
    )

    for _ in tqdm.trange(len(dataset), leave=False):
        (current_observations, actions, rewards, next_observations, dones), _ = dataset[_]
        
        next_observations = tf.constant(next_observations)
        current_observations = tf.constant(current_observations)
        actions = tf.constant(actions)
        
        q_target = rewards + gamma * (1 - dones) * critic_model_target([next_observations, actor_model_target(next_observations, training=False)], training=False)

        with tf.GradientTape() as tape:
            q_current = critic_model_current([current_observations, actions], training=True)
            loss = keras.ops.mean(keras.ops.square(q_current - q_target))

        gradients = tape.gradient(loss, critic_model_current.trainable_weights)
        optimizer_critic.apply(gradients)

        with tf.GradientTape() as tape:
            loss = -keras.ops.mean(critic_model_current([current_observations, actor_model_current(current_observations, training=True)], training=False))
            
        gradients = tape.gradient(loss, actor_model_current.trainable_weights)
        optimizer_actor.apply(gradients)

        update_weights(critic_model_current, critic_model_target)
        update_weights(actor_model_current, actor_model_target)

    val_buffer.clear()

    play_game(
        actor_model_current,
        val_buffer,
        val_env,
        int(1e10),
        Sampler(0),
        None,
        one_episode=True
    )
    
    score = val_buffer.rewards[:len(val_buffer)].sum() / val_env.n_env

    pbar.set_description(f'{score}')

  0%|          | 0/128 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [130]:
env = Environment(1)

In [131]:
o = env.reset()

In [134]:
o

array([[ 0.03625419,  1.03625419,  0.53625419, -0.01125997]])

In [149]:
a = actor_model_current(o)
a

<tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.09180834], dtype=float32)>

In [150]:
o, r, d = env.step(a.numpy())

In [151]:
o, r, d

(array([[0.10723707, 1.10723707, 0.09820086, 0.0547027 ]]),
 array([0.73081578]),
 array([False]))