# DQN

Метод обучения DQN — это нейросетевая адаптация алгоритма Q-learning. Также для него разработан набор дополнений, которые становятся актуальными при переходе к обучению глубоких нейронных сетей и решению более сложных задач (то есть задач с бОльшим пространством состояний).

В этом ноутбуке будет реализован алгоритм DQN для решения среды [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/), цель которой балансировать палочкой в вертикальном положении, управляя только тележкой, к которой она прикреплена. Использована библиотека PyTorch для обучения нейронной сети, аппроксимирующей Q-функцию.

За основу взято учебное задание с курса Прикладные задачи анализа данных, майнор Интеллектуальный анализ данных, ФКН НИУ ВШЭ

![cartpole](https://gymnasium.farama.org/_images/cart_pole.gif)

![cartpole](https://www.researchgate.net/publication/362568623/figure/fig5/AS:1187029731807278@1660021350587/Screen-capture-of-the-OpenAI-Gym-CartPole-problem-with-annotations-showing-the-cart.png)

In [1]:
# If collab, setting dependencies
try:
    import google.colab
    COLAB = True
except ModuleNotFoundError:
    COLAB = False
    pass

if COLAB:
    !pip -q install "gymnasium[classic-control, atari, accept-rom-license]"
    !pip -q install piglet
    !pip -q install imageio_ffmpeg
    !pip -q install moviepy==1.0.3
    !pip -q install setuptools==59.8.0
    try:
        import skmultiflow
    except ModuleNotFoundError:
        !git clone --quiet https://github.com/ugadiarov-la-phystech-edu/scikit-multiflow.git && pip -q install ./scikit-multiflow

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.5/67.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m952.8/952.8 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.
arviz 0.21.0 requires setuptools>=60.0.0, but you have setuptools 59.8.0 which is incompatible.[0m[31m
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py bdist_wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for scikit-multiflow (set

In [2]:
import abc
import base64
import io
import math
import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import pygame
from gymnasium import spaces
from gymnasium.envs.registration import WrapperSpec
%matplotlib inline

if COLAB:
    from google.colab import files
    from google.colab.patches import cv2_imshow
    from google.colab import output

### Action Space

The action is a `ndarray` with shape `(1,)` which can take values `{0, 1}` indicating the direction of the fixed force the cart is pushed with.

- 0: Push cart to the left
- 1: Push cart to the right

**Note**: The velocity that is reduced or increased by the applied force is not fixed and it depends on the angle
the pole is pointing. The center of gravity of the pole varies the amount of energy needed to move the cart underneath it

### Observation Space

The observation is a `ndarray` with shape `(4,)` with the values corresponding to the following positions and velocities:

| Num | Observation           | Min                 | Max               |
|-----|-----------------------|---------------------|-------------------|
| 0   | Cart Position         | -4.8                | 4.8               |
| 1   | Cart Velocity         | -Inf                | Inf               |
| 2   | Pole Angle            | ~ -0.418 rad (-24°) | ~ 0.418 rad (24°) |
| 3   | Pole Angular Velocity | -Inf                | Inf               |

**Note:** While the ranges above denote the possible values for observation space of each element, it is not reflective of the allowed values of the state space in an unterminated episode. Particularly:

- The cart x-position (index 0) can be take values between `(-4.8, 4.8)`, but the episode terminates if the cart leaves the `(-2.4, 2.4)` range.
- The pole angle can be observed between  `(-.418, .418)` radians (or **±24°**), but the episode terminates
   if the pole angle is not in the range `(-.2095, .2095)` (or **±12°**)

### Rewards

Since the goal is to keep the pole upright for as long as possible, a reward of `+1` for every step taken,
including the termination step, is allotted. The threshold for rewards is 500 for v1 and 200 for v0.

### Starting State

All observations are assigned a uniformly random value in `(-0.05, 0.05)`

### Episode End

The episode ends if any one of the following occurs:

1. Termination: Pole Angle is greater than ±12°
2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
3. Truncation: Episode length is greater than 500 (200 for v0)

In [3]:
env = gym.make("CartPole-v1", max_episode_steps=1000)
env.reset()

# Info about spaces of action and states
print(f'{env.observation_space=}')
print(f'{env.action_space=}')

n_actions = env.action_space.n
state_dim = env.observation_space.shape
print(f'Action_space: {n_actions} | State_space: {state_dim}')

env.observation_space=Box([-4.8               -inf -0.41887903        -inf], [4.8               inf 0.41887903        inf], (4,), float32)
env.action_space=Discrete(2)
Action_space: 2 | State_space: (4,)


Т.к. описание состояния в задаче с маятником представляет собой не "сырые" признаки, а уже предобработанные (координаты, углы), нам не нужна для начала сложная архитектура, начнем с такой:
<img src="https://raw.githubusercontent.com/Tviskaron/mipt/master/2020/RL/figures/DQN.svg">

- Полносвязные слои (``torch.nn.Linear``)
- Простые активационные функции (``torch.nn.ReLU``)
- Сигмоиды и другие похожие функции активации могут плохо работать с ненормализованными входными данными.

- Приближается Q-функция агента, минимизируется среднеквадратичная TD-ошибка:
$$
\delta = Q_{\theta}(s, a) - [r(s, a) + \gamma \cdot max_{a'} Q_{-}(s', a')]
$$
$$
L = \frac{1}{N} \sum_i \delta_i^2,
$$
где
* $s, a, r, s'$ состояние, действие, вознаграждение и следующее состояние
* $\gamma$ дисконтирующий множитель.

$Q_{-}(s',a')$ - это та же самая функция, что и $Q_{\theta}$, которая является выходом нейронной сети, но при обучении сети через эти слои не пропускаются градиенты. В научных статьях можно обнаружить следующее обозначение для остановки градиента: $SG(\cdot)$. В PyTorch есть метод `.detach()` класса `Tensor`, который возвращает тензор с выключенными градиентами, а также контекстный менеджер `with torch.no_grad()`, который задает контекст с вычислениями, для которых не вычисляется градиент.

In [4]:
import torch
import torch.nn as nn

def create_network(input_dim, hidden_dims, output_dim):
    layers = []

    # 1st layer (input -> hidden)
    layers.append(nn.Linear(input_dim, hidden_dims[0]))
    layers.append(nn.ReLU())

    # interlayers (hidden -> hidden)
    for i in range(len(hidden_dims)-1):
        layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
        layers.append(nn.ReLU())

    # output layer (hidden -> output)
    layers.append(nn.Linear(hidden_dims[-1], output_dim))

    # assembling a network
    network = nn.Sequential(*layers)

    return network

$\epsilon$-жадный выбор действий:

In [5]:
def select_action_eps_greedy(Q, state, epsilon):
    """
    Args:
        Q: neural network predicting Q-values
        state: current state
        epsilon: probability of random action selection (0-1)
    Returns:
        selected action (int)
    """
    if not isinstance(state, torch.Tensor):
        state = torch.tensor(state, dtype=torch.float32)

    # getting Q-values for all actions
    with torch.no_grad():
        Q_s = Q(state).numpy()

    # greedy action (max Q-value)
    greedy_action = np.argmax(Q_s)

    # random action
    random_action = np.random.randint(len(Q_s))

    # ϵ-greedy selection
    action = random_action if random.random() < epsilon else greedy_action

    action = int(action)
    return action


Q = create_network(
    input_dim=np.prod(state_dim), hidden_dims=[64, 64], output_dim=n_actions
)
select_action_eps_greedy(Q, env.reset()[0].flatten(), epsilon=0.1)

1

In [6]:
def to_tensor(x, dtype=np.float32):
    if isinstance(x, torch.Tensor):
        return x
    x = np.asarray(x, dtype=dtype)
    x = torch.from_numpy(x)
    return x

def compute_td_target(
        Q, rewards, next_states, terminateds, gamma=0.99, check_shapes=True,
):
    """Computes TD-target by formula:
    target = r + gamma * max(Q(s',a')) * (1 - terminated)
    """
    # input -> tensors
    r = to_tensor(rewards)  # shape: [batch_size]
    s_next = to_tensor(next_states)  # shape: [batch_size, state_size]
    term = to_tensor(terminateds, bool)  # shape: [batch_size]

    # getting Q[s_next, .] — values of profit of all actions in next state
    with torch.no_grad():  # turning off gradients for target network
        Q_sn = Q(s_next)  # shape: [batch_size, n_actions]

    # V^*[s_next] — optimal values of profit in the next state
    V_sn = torch.max(Q_sn, dim=1)[0]

    # multiply by (1 - terminated): revards -> 0 after termination
    # counting final TD-target
    target = r + gamma * V_sn * (~term)

    # checking shapes 
    if check_shapes:
        assert Q_sn.dim() == 2, \
            "Q_sn must contain q-values for all actions [batch_size, n_actions]"
        assert V_sn.dim() == 1, \
            "V_sn must be a vector [batch_size] (max over action axis)"
        assert target.dim() == 1, \
            "target must be a vector [batch_size]"

    return target

def compute_td_loss(
        Q, states, actions, td_target, regularizer=.1, out_non_reduced_losses=False
):
    """Computes TD error (MSE) between current Q-values and TD targets.

    Args:
        Q: Neural network predicting Q-values
        states: Current states (batch)
        actions: Selected actions (batch)
        td_target: Computed TD targets (batch)
        regularizer: L1 regularization coefficient
        out_non_reduced_losses: If True, also returns individual losses

    Returns:
        Mean loss over the batch (and individual losses if requested)
    """
    # Convert inputs to tensors
    s = to_tensor(states)  # shape: [batch_size, state_size]
    a = to_tensor(actions, int).long()  # shape: [batch_size]
    td_target = to_tensor(td_target)  # shape: [batch_size]

    # Get Q(s,a) for selected actions
    Q_s = Q(s)  # shape: [batch_size, n_actions]
    Q_s_a = Q_s.gather(1, a.unsqueeze(1)).squeeze(1)  # shape: [batch_size]

    # Compute TD error (difference between current estimates and targets)
    td_error = Q_s_a - td_target  # shape: [batch_size]

    # MSE loss to minimize
    td_losses = td_error.pow(2)  # shape: [batch_size]
    loss = td_losses.mean()  # scalar

    # Add L1 regularization on Q-values
    loss += regularizer * torch.abs(Q_s_a).mean()

    if out_non_reduced_losses:
        return loss, td_losses.detach()
    return loss

In [7]:
def eval_dqn(env_name, Q):
    """Evaluates the performance of the algorithm on a single episode"""
    env = gym.make(env_name)
    s, _ = env.reset()
    done, ep_return = False, 0.

    while not done:
        # set epsilon = 0 to make an agent act greedy
        a = select_action_eps_greedy(Q, s, epsilon=0.)
        s_next, r, terminated, truncated, _ = env.step(a)
        done = terminated or truncated
        ep_return += r
        s = s_next

        if done:
            break

    return ep_return

In [8]:
from collections import deque

def linear(st, end, duration, t):
    """
    Linear interpolation of values within the range [st, end],
    using time progress t relative to the total duration.
    """

    if t >= duration:
        return end
    return st + (end - st) * (t / duration)

def run_dqn(
        env_name="CartPole-v1",
        hidden_dims=(128, 128), lr=1e-3, gamma=0.99,
        eps_st=.4, eps_end=.02, eps_dur=.25, total_max_steps=100_000,
        train_schedule=1, eval_schedule=1000, smooth_ret_window=10, success_ret=200.
):
    env = gym.make(env_name)
    eval_return_history = deque(maxlen=smooth_ret_window)

    Q = create_network(
        input_dim=env.observation_space.shape[0], hidden_dims=hidden_dims, output_dim=env.action_space.n
    )
    opt = torch.optim.Adam(Q.parameters(), lr=lr)

    s, _ = env.reset()
    done = False

    for global_step in range(1, total_max_steps + 1):
        epsilon = linear(eps_st, eps_end, eps_dur * total_max_steps, global_step)

        a = select_action_eps_greedy(Q, s, epsilon=epsilon)
        s_next, r, terminated, truncated, _ = env.step(a)
        done = terminated or truncated

        if global_step % train_schedule == 0:
            opt.zero_grad()
            td_target = compute_td_target(Q, [r], [s_next], [terminated], gamma=gamma)
            loss = compute_td_loss(Q, [s], [a], td_target)
            loss.backward()
            opt.step()

        if global_step % eval_schedule == 0:
            eval_return = eval_dqn(env_name, Q)
            eval_return_history.append(eval_return)
            avg_return = np.mean(eval_return_history)
            print(f'{global_step=} | {avg_return=:.3f} | {epsilon=:.3f}')
            if avg_return >= success_ret:
                print('Решено!')
                break

        s = s_next
        if done:
            s, _ = env.reset()
            done = False

run_dqn(eval_schedule=250)

global_step=250 | avg_return=30.000 | epsilon=0.396
global_step=500 | avg_return=20.500 | epsilon=0.392
global_step=750 | avg_return=19.667 | epsilon=0.389
global_step=1000 | avg_return=18.000 | epsilon=0.385
global_step=1250 | avg_return=17.800 | epsilon=0.381
global_step=1500 | avg_return=22.667 | epsilon=0.377
global_step=1750 | avg_return=23.000 | epsilon=0.373
global_step=2000 | avg_return=25.250 | epsilon=0.370
global_step=2250 | avg_return=29.667 | epsilon=0.366
global_step=2500 | avg_return=29.500 | epsilon=0.362
global_step=2750 | avg_return=29.200 | epsilon=0.358
global_step=3000 | avg_return=30.800 | epsilon=0.354
global_step=3250 | avg_return=36.700 | epsilon=0.351
global_step=3500 | avg_return=37.800 | epsilon=0.347
global_step=3750 | avg_return=40.400 | epsilon=0.343
global_step=4000 | avg_return=38.200 | epsilon=0.339
global_step=4250 | avg_return=37.500 | epsilon=0.335
global_step=4500 | avg_return=37.500 | epsilon=0.332
global_step=4750 | avg_return=34.900 | epsilon=0.

- `avg_return` - это средняя отдача за эпизод на истории из последних десяти эпизодов. Этот показатель низкий первые 1000 шагов и только затем возрастает и сходится на 5000-15000 шагах (в зависимости от архитектуры сети).
- Если сеть не достигает нужных результатов к концу цикла, можно увеличить число нейронов в скрытом слое или поменяйте начальный $\epsilon$.
- Переменная `epsilon` обеспечивает стремление агента исследовать среду. В данной реализации используется линейное затухание для частоты исследования.

### DQN with Experience Replay

Добавляется поддержка памяти прецедентов (Replay Buffer), которая будет из себя представлять очередь из наборов: $\{(s, a, r, s', 1_\text{terminated})\}$.

Тогда во время обучения каждый новый переход будет добавляться в память, а обучение будет целиком производиться на переходах, просэмплированных из памяти прецедентов.

In [15]:
def sample_batch(replay_buffer, n_samples):
    """
    Randomly samples n_samples transitions from the replay buffer.

    Params:
    - replay_buffer: a collection of transitions (e.g., deque or list), 
      where each transition is a tuple (state, action, reward, next_state, terminated).
    - n_samples: number of transitions to sample.

    Output:
    - states: array of states.
    - actions: array of actions.
    - rewards: array of rewards.
    - next_states: array of next states.
    - terminateds: array of episode termination flags.
    """
    rng = np.random.default_rng()
    indices = rng.choice(len(replay_buffer), size=n_samples, replace=True)
    samples = [replay_buffer[i] for i in indices]

    # unpack sample to separate lists
    states, actions, rewards, next_states, terminateds = zip(*samples)

    # lists -> np.array
    return (
        np.array(states),
        np.array(actions),
        np.array(rewards),
        np.array(next_states),
        np.array(terminateds)
    )

In [16]:
def run_dqn_rb(
        env_name="CartPole-v1",
        hidden_dims=(256, 256), lr=1e-3, gamma=0.99,
        eps_st=.4, eps_end=.02, eps_dur=.25, total_max_steps=100_000,
        train_schedule=4, replay_buffer_size=400, batch_size=32,
        eval_schedule=1000, smooth_ret_window=5, success_ret=200.
):
    env = gym.make(env_name)
    replay_buffer = deque(maxlen=replay_buffer_size)
    eval_return_history = deque(maxlen=smooth_ret_window)

    Q = create_network(
        input_dim=env.observation_space.shape[0], hidden_dims=hidden_dims, output_dim=env.action_space.n
    )
    opt = torch.optim.Adam(Q.parameters(), lr=lr)

    s = env.reset()
    if isinstance(s, tuple):
        s, _ = s
    done = False

    for global_step in range(1, total_max_steps + 1):
        epsilon = linear(eps_st, eps_end, eps_dur * total_max_steps, global_step)
        a = select_action_eps_greedy(Q, s, epsilon=epsilon)
        s_next, r, terminated, truncated, _ = env.step(a)

        replay_buffer.append((s, a, r, s_next, terminated))
        done = terminated or truncated

        if global_step % train_schedule == 0:
            train_batch = sample_batch(replay_buffer, batch_size)
            states, actions, rewards, next_states, terminateds = train_batch

            opt.zero_grad()
            td_target = compute_td_target(Q, rewards, next_states, terminateds, gamma=gamma)
            loss = compute_td_loss(Q, states, actions, td_target)
            loss.backward()
            opt.step()

        if global_step % eval_schedule == 0:
            eval_return = eval_dqn(env_name, Q)
            eval_return_history.append(eval_return)
            avg_return = np.mean(eval_return_history)
            print(f'{global_step=} | {avg_return=:.3f} | {epsilon=:.3f}')
            if avg_return >= success_ret:
                print('Решено!')
                break

        s = s_next
        if done:
            s, _ = env.reset()
            done = False

run_dqn_rb(eval_schedule=250)

global_step=250 | avg_return=8.000 | epsilon=0.396
global_step=500 | avg_return=16.500 | epsilon=0.392
global_step=750 | avg_return=18.333 | epsilon=0.389
global_step=1000 | avg_return=17.500 | epsilon=0.385
global_step=1250 | avg_return=18.000 | epsilon=0.381
global_step=1500 | avg_return=35.800 | epsilon=0.377
global_step=1750 | avg_return=32.800 | epsilon=0.373
global_step=2000 | avg_return=38.600 | epsilon=0.370
global_step=2250 | avg_return=58.400 | epsilon=0.366
global_step=2500 | avg_return=74.400 | epsilon=0.362
global_step=2750 | avg_return=143.400 | epsilon=0.358
global_step=3000 | avg_return=153.400 | epsilon=0.354
global_step=3250 | avg_return=194.600 | epsilon=0.351
global_step=3500 | avg_return=271.800 | epsilon=0.347
Решено!


## DQN with Prioritized Experience Replay

Каждому примеру, хранящемуся в памяти, добавляется значение приоритета. Приоритет будет влиять на частоту случайного выбора примеров в пакет на обучение. Удачный выбор приоритета позволит повысить эффективность обучения. Популярным вариантом является абсолютное значение TD-ошибки. Таким образом акцент при обучении Q-функции отводится примерам, на которых аппроксиматор ошибается сильнее.

Однако, нужно помнить, что это значение быстро устаревает, если его не обновлять. Но и обновлять для всей памяти каждый раз накладно. Из-за этого потребуется искать баланс между точностью оценки приоритета и скоростью работы.

Вот что я буду делать далее:

- Использовать TD-ошибку в качестве приоритета.
- Так как для пакета данных, используемых при обучении, в любом случае будет вычислена TD-ошибка, воспользуемся полученными значениями для обновления значений приоритета в памяти для каждого примера из данного пакета.
- Будем периодически сортировать память для того, чтобы новые добавляемые переходы заменяли собой те переходы, у которых наименьший приоритет (т.е. наименьшие значения ошибки). Сортировка - дорогостоящая операция, поэтому выбрана редкая периодичность.

NB: софтмакс очень чувствителен к масштабу величин и часто требует подбора температуры. Чтобы частично нивелировать эту проблему, можно использовать не `softmax(priorities)` напрямую, а воспользоваться функцией $\text{symlog} = \text{sign}(x) \cdot \log (|x| + 1)$, то есть `softmax(symlog(priorities))`, и не подбирать температуру. Идея взята из статьи DreamerV2.

In [11]:
def symlog(x):
    """
    Compute symlog values for a vector `x`.
    It's an inverse operation for symexp.
    """
    return np.sign(x) * np.log(np.abs(x) + 1)

def softmax(xs, temp=1.):
    exp_xs = np.exp((xs - xs.max()) / temp)
    return exp_xs / exp_xs.sum()

def sample_prioritized_batch(replay_buffer, n_samples):
    # getting priorities
    priorities = np.array([sample[0] for sample in replay_buffer])

    # transform priorities throught symlog
    symlog_priorities = symlog(priorities)

    # getting probabilities with softmax
    probs = softmax(symlog_priorities)

    # selecting indexes by probabilities
    indices = np.random.choice(len(replay_buffer), size=n_samples, p=probs)

    # getting elements from bufer by indexes
    sampled = [replay_buffer[idx] for idx in indices]

    states, actions, rewards, next_states, terminateds = zip(*[sample[1:] for sample in sampled])

    batch = (
        np.array(states), np.array(actions), np.array(rewards),
        np.array(next_states), np.array(terminateds)
    )
    return batch, indices

def update_batch(replay_buffer, indices, batch, new_priority):
    """Updates batches with corresponding indices
    replacing their priority values."""
    states, actions, rewards, next_states, terminateds = batch

    for i in range(len(indices)):
        new_batch = (
            new_priority[i], states[i], actions[i], rewards[i],
            next_states[i], terminateds[i]
        )
        replay_buffer[indices[i]] = new_batch

def sort_replay_buffer(replay_buffer):
    """Sorts replay buffer to move samples with
    lesser priority to the beginning ==> they will be
    replaced with the new samples sooner."""
    new_rb = deque(maxlen=replay_buffer.maxlen)
    new_rb.extend(sorted(replay_buffer, key=lambda sample: sample[0]))
    return new_rb

In [12]:
import numpy as np

if not hasattr(np, 'bool8'):
    np.bool8 = np.bool_

def run_dqn_prioritized_rb(
        env_name="CartPole-v1",
        hidden_dims=(256, 256), lr=1e-3, gamma=0.99,
        eps_st=.4, eps_end=.02, eps_dur=.25, total_max_steps=100_000,
        train_schedule=4, replay_buffer_size=400, batch_size=32,
        eval_schedule=1000, smooth_ret_window=5, success_ret=200.
):
    env = gym.make(env_name)
    replay_buffer = deque(maxlen=replay_buffer_size)
    eval_return_history = deque(maxlen=smooth_ret_window)

    Q = create_network(
        input_dim=env.observation_space.shape[0], hidden_dims=hidden_dims,
        output_dim=env.action_space.n
    )
    opt = torch.optim.Adam(Q.parameters(), lr=lr)

    s = env.reset()
    if isinstance(s, tuple):
        s, _ = s
    done = False

    for global_step in range(1, total_max_steps + 1):
        epsilon = linear(
            eps_st, eps_end, eps_dur * total_max_steps, global_step
        )
        a = select_action_eps_greedy(Q, s, epsilon=epsilon)

        result = env.step(a)
        if len(result) == 5:
            s_next, r, terminated, truncated, _ = result
        else:
            s_next, r, done, _ = result
            terminated = done
            truncated = False

        # Compute new sample loss (compute w/o gradients!)
        with torch.no_grad():
            state_tensor = torch.FloatTensor(s).unsqueeze(0)  # (1, state_dim)
            next_state_tensor = torch.FloatTensor(s_next).unsqueeze(0)

            q_values = Q(state_tensor)
            next_q_values = Q(next_state_tensor)

            # Q(s, a)
            q_val = q_values[0, a]

            # max_a' Q(s', a')
            max_next_q_val = next_q_values.max(1).values[0]

            # TD target
            td_target = r + gamma * max_next_q_val * (1.0 - float(terminated))

            # TD error (scalar)
            loss = abs(td_target.item() - q_val.item())

        replay_buffer.append((loss, s, a, r, s_next, terminated))
        done = terminated or truncated

        if global_step % train_schedule == 0:
            train_batch, indices = sample_prioritized_batch(
                replay_buffer, batch_size
            )
            (
                states, actions, rewards,
                next_states, terminateds
            ) = train_batch

            opt.zero_grad()
            td_target = compute_td_target(Q, rewards, next_states, terminateds, gamma=gamma)
            loss, td_losses = compute_td_loss(Q, states, actions, td_target, out_non_reduced_losses=True)
            loss.backward()
            opt.step()

            update_batch(
                replay_buffer, indices, train_batch, td_losses.numpy()
            )

        # with much slower scheduler periodically re-sort replay buffer
        # such that it will overwrite the least important samples
        if global_step % (10 * train_schedule) == 0:
            replay_buffer = sort_replay_buffer(replay_buffer)

        if global_step % eval_schedule == 0:
            eval_return = eval_dqn(env_name, Q)
            eval_return_history.append(eval_return)
            avg_return = np.mean(eval_return_history)
            print(f'{global_step=} | {avg_return=:.3f} | {epsilon=:.3f}')
            if avg_return >= success_ret:
                print('Решено!')
                break

        s = s_next
        if done:
            s = env.reset()
            if isinstance(s, tuple):
                s, _ = s
            done = False

run_dqn_prioritized_rb(eval_schedule=250)

global_step=250 | avg_return=16.000 | epsilon=0.396
global_step=500 | avg_return=13.000 | epsilon=0.392
global_step=750 | avg_return=14.000 | epsilon=0.389
global_step=1000 | avg_return=18.500 | epsilon=0.385
global_step=1250 | avg_return=18.600 | epsilon=0.381
global_step=1500 | avg_return=30.200 | epsilon=0.377
global_step=1750 | avg_return=32.800 | epsilon=0.373
global_step=2000 | avg_return=35.000 | epsilon=0.370
global_step=2250 | avg_return=63.400 | epsilon=0.366
global_step=2500 | avg_return=61.800 | epsilon=0.362
global_step=2750 | avg_return=88.800 | epsilon=0.358
global_step=3000 | avg_return=131.400 | epsilon=0.354
global_step=3250 | avg_return=170.400 | epsilon=0.351
global_step=3500 | avg_return=186.400 | epsilon=0.347
global_step=3750 | avg_return=186.600 | epsilon=0.343
global_step=4000 | avg_return=147.000 | epsilon=0.339
global_step=4250 | avg_return=143.200 | epsilon=0.335
global_step=4500 | avg_return=145.400 | epsilon=0.332
global_step=4750 | avg_return=147.400 | ep