In [None]:
import sys
import time
import typing as tt

In [None]:
import numpy as np
import torch
from torch.utils.tensorboard.writer import SummaryWriter

In [None]:
import ptan

In [None]:
class RewardTracker:
    def __init__(self, writer: SummaryWriter):
        self.writer = writer

    def __enter__(self):
        self.ts = time.time()
        self.ts_frame = 0
        self.total_rewards = []
        return self

    def __exit__(self, *args):
        self.writer.close()

    def reward(self, reward: float, frame: int,
               epsilon: tt.Optional[float] = None):
        self.total_rewards.append(reward)
        speed = (frame - self.ts_frame) / (time.time() - self.ts)
        self.ts_frame = frame
        self.ts = time.time()
        mean_reward = np.mean(self.total_rewards[-100:])
        epsilon_str = ""
        if epsilon is not None:
            epsilon_str = f", eps {epsilon:.2f}"
        print(f"{frame}: done {len(self.total_rewards)} games, "
              f"mean reward {mean_reward:.3f}, "
              f"speed {speed:.2f} f/s{epsilon_str}")
        sys.stdout.flush()
        if epsilon is not None:
            self.writer.add_scalar("epsilon", epsilon, frame)
        self.writer.add_scalar("speed", speed, frame)
        self.writer.add_scalar("reward_100", mean_reward, frame)
        self.writer.add_scalar("reward", reward, frame)
        return mean_reward if len(self.total_rewards) > 30 else None

In [None]:
def unpack_batch(batch, net, last_val_gamma, device="cpu", states_preprocessor=ptan.agent.default_states_preprocessor):
    """
    Convert batch into training tensors
    :param batch:
    :param net:
    :return: states variable, actions tensor, reference values variable
    """
    states = []
    actions = []
    rewards = []
    not_done_idx = []
    last_states = []
    for idx, exp in enumerate(batch):
        states.append(exp.state)
        actions.append(int(exp.action))
        rewards.append(exp.reward)
        if exp.last_state is not None:
            not_done_idx.append(idx)
            last_states.append(exp.last_state)
    states_v = states_preprocessor(states)
    if torch.is_tensor(states_v):
        states_v = states_v.to(device)
    actions_t = torch.LongTensor(actions).to(device)

    # handle rewards
    rewards_np = np.array(rewards, dtype=np.float32)
    if not_done_idx:
        last_states_v = states_preprocessor(last_states)
        if torch.is_tensor(last_states_v):
            last_states_v = last_states_v.to(device)
        last_vals_v = net(last_states_v)[1]
        last_vals_np = last_vals_v.data.cpu().numpy()[:, 0]
        rewards_np[not_done_idx] += last_val_gamma * last_vals_np

    ref_vals_v = torch.FloatTensor(rewards_np).to(device)
    return states_v, actions_t, ref_vals_v