In [122]:
%load_ext autoreload
%autoreload 2
from collections import defaultdict
from typing import Optional

import numpy as np
import torch
import tqdm
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import nn

from torchrl.data import BoundedTensorSpec, CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import (
    CatTensors,
    EnvBase,
    Transform,
    TransformedEnv,
    UnsqueezeTransform,
)
from torchrl.envs.transforms.transforms import _apply_to_composite
from torchrl.envs.utils import check_env_specs, step_mdp

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [123]:
import numpy as np
import random
import torch
from pytorch3d import transforms
import math
from linguamechanica.kinematics import DifferentiableOpenChainMechanism
from linguamechanica.kinematics import UrdfRobotLibrary

In [156]:
error_done_threshold = 1e-3
weights = torch.Tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda()
urdf_robot = UrdfRobotLibrary.dobot_cr5()
chain_index = 1
used_open_chain = urdf_robot.extract_open_chains(0.3)[chain_index].to(weights.device)
thetas_count = used_open_chain.screws.shape[0]
pose_count = 6
on_manifold_count = (9 * 2) + ( 2 * thetas_count)
batch_size = 1024

In [157]:
used_open_chain.screws

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0000e+00],
        [ 1.4700e-01, -5.3996e-07, -2.3803e-09, -3.6732e-06, -1.0000e+00,
          1.3644e-11]], device='cuda:0')

In [158]:
used_open_chain.initial_matrix

tensor([[-3.6732e-06,  0.0000e+00, -1.0000e+00,  0.0000e+00],
        [ 1.0000e+00, -3.6732e-06, -3.6732e-06,  0.0000e+00],
        [-3.6732e-06, -1.0000e+00,  1.3492e-11,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  1.4700e-01,  1.0000e+00]], device='cuda:0')

In [159]:
thetas = torch.Tensor([[0.0],[0.0],[1.5]]).cuda()
transformation = used_open_chain.forward_transformation(
    thetas
)
pose = transforms.se3_log_map(transformation.get_matrix())
pose

tensor([[-0.1083,  0.0695,  0.1083,  1.2092,  1.2092, -1.2092],
        [-0.1083,  0.0695,  0.1083,  1.2092,  1.2092, -1.2092],
        [-0.0703,  0.1049,  0.1096,  1.1915,  1.1915,  1.1043]],
       device='cuda:0')

In [160]:
used_open_chain = used_open_chain.to(thetas.device)

In [161]:
target_thetas = torch.Tensor([[1.0],[0.0],[0.5]]).cuda()
target_transformation = used_open_chain.forward_transformation(
    target_thetas
)
target_pose = transforms.se3_log_map(target_transformation.get_matrix())
target_pose

tensor([[-0.0773,  0.0870,  0.1150,  1.1177,  1.1177,  0.3366],
        [-0.1083,  0.0695,  0.1083,  1.2092,  1.2092, -1.2092],
        [-0.0890,  0.0761,  0.1146,  1.1232,  1.1232, -0.4472]],
       device='cuda:0')

In [162]:
def compute_error_pose(open_chain, thetas, target_pose):
    #print("compute_error_pose", thetas.shape, target_pose.shape)
    #print("thetas causing nans", thetas)
    current_transformation = open_chain.forward_transformation(thetas)
    target_transformation = transforms.se3_exp_map(target_pose)
    #print("current_transformation")
    #print(current_transformation.get_matrix())
    #print("target_transformation")
    #print(target_transformation)
    current_trans_to_target = current_transformation.compose(
        transforms.Transform3d(matrix=target_transformation).inverse()
    )
    current_trans_to_target = current_trans_to_target.to(thetas.device).get_matrix()
    #print("current_trans_to_target")
    #print(current_trans_to_target)
    error_pose = transforms.se3_log_map(current_trans_to_target)
    return error_pose

In [163]:
compute_error_pose(used_open_chain, thetas, target_pose)

tensor([[ 2.4125e-09, -1.0885e-08,  1.7733e-08,  9.1252e-01, -4.9851e-01,
         -9.1252e-01],
        [-1.4901e-08,  1.8626e-09, -5.6969e-15, -8.1887e-08,  5.9598e-08,
          8.1887e-08],
        [-1.8561e-08,  4.7160e-08, -2.9986e-08, -5.6182e-01,  8.7497e-01,
          9.1252e-01]], device='cuda:0')

In [164]:
#def force_parameters_within_bounds(thetas):
#    thetas[thetas >  math.pi] -= 2.0 * torch.pi
#    thetas[thetas < -math.pi] += 2.0 * torch.pi
#    return thetas

In [165]:
def compute_reward(thetas, target_pose, weights, error_done_threshold, open_chain):
    if len(thetas.shape) == 1:
        thetas = thetas.unsqueeze(0)
    if len(target_pose.shape) == 1:
        target_pose = target_pose.unsqueeze(0)
    open_chain = open_chain.to(thetas.device)
    #print("!!!!!!!!!!!!!!!!!!!!!")
    #print("thetas", thetas)
    #print("target_pose", target_pose.shape)
    #print("target_pose", target_pose)
    #print("!!!!!!!!!!!!!!!!!!!!!")
    error_pose = compute_error_pose(
        open_chain, thetas, target_pose
    )
    weights = weights.to(thetas.device)
    pose_error = DifferentiableOpenChainMechanism.compute_weighted_error(
        error_pose, weights
    )
    done = pose_error < error_done_threshold
    reward = - pose_error
    return reward, done

In [166]:
angles = torch.Tensor([0.1, -0.1, 0.2, -0.2])
angles_sin = angles.sin()
angles_cos = angles.cos()
torch.atan2(angles_sin, angles_cos) - angles

tensor([0., 0., 0., 0.])

In [167]:
def _step(self, tensordict):
    thetas = torch.atan2(tensordict["thetas_sin"], tensordict["thetas_cos"])
    theta_deltas = tensordict["action"]
    #print("thetas", thetas.shape)
    #print("theta_deltas", theta_deltas.shape)
    theta_deltas_sin, theta_deltas_cos = None, None
    if len(theta_deltas.shape) == 2:
        theta_deltas_sin = theta_deltas[:, 0:1]
        theta_deltas_cos = theta_deltas[:, 1:2]
    else:
        theta_deltas_sin = theta_deltas[0]
        theta_deltas_cos = theta_deltas[1]
    theta_deltas = torch.atan2(theta_deltas_sin, theta_deltas_cos)
    
    #print("theta_deltas", theta_deltas.shape)
    #print("thetas", thetas.shape)
    #max_theta_deltas = tensordict["params", "max_theta_deltas"]
    #print(thetas.shape, theta_deltas.shape)
    new_thetas = thetas + (theta_deltas * 1.000)
    #new_thetas = new_thetas.clamp(-max_theta_deltas, max_theta_deltas)
    target_pose = tensordict["target_pose"]
    #print("target_pose in _step", target_pose)
    #print("STEP", f"Target: {target_pose}", f"Thetas: {thetas}", f"New Thetas: {new_thetas}")
    #TODO: I have no idea if this is a good idea or not
    #new_thetas = force_parameters_within_bounds(new_thetas)
    #print("----------------------------")
    #print("new_thetas.shape", new_thetas.shape)
    #print("target_pose", target_pose)
    #print("weights.shape", weights.shape)
    #print("----------------------------")
    reward, done = compute_reward(new_thetas, target_pose, weights, error_done_threshold, self.open_chain)
    done = torch.zeros_like(reward, dtype=torch.bool)
    out = TensorDict(
        {
            "next": {
                "thetas_sin": new_thetas.sin(),
                "thetas_cos": new_thetas.cos(),
                "target_pose": target_pose,
                "params": tensordict["params"],
                "reward": reward,
                "done": done,
            }
        },
        tensordict.shape,
    )
    return out

In [168]:
def uniformly_sample_parameters_within_constraints(open_chain, batch_size):
    samples = []
    for sample_idx in range(batch_size):
        coordinates = []
        for i in range(len(open_chain.joint_limits)):
            # TODO: check if unconstrained works
            coordinates.append(
                random.uniform(
                    open_chain.joint_limits[i][0],
                    open_chain.joint_limits[i][1],
                )
            )
        samples.append(torch.Tensor(coordinates).unsqueeze(0))
    return torch.cat(samples, 0)


In [169]:
def generate_random_target_pose(target_thetas, open_chain):
    if len(target_thetas.shape) == 1:
        target_thetas = target_thetas.unsqueeze(0)
    open_chain = open_chain.to(target_thetas.device)
    target_transformation = open_chain.forward_transformation(
        target_thetas
    )
    target_pose = transforms.se3_log_map(
        target_transformation.get_matrix()
    )
    if target_thetas.shape[0] == 1:
        target_thetas = target_thetas.squeeze(0)
    #print("generate_random_target_pose", target_pose)
    return target_pose

In [170]:
def _reset(self, tensordict):
    if tensordict is None or tensordict.is_empty():
        # if no tensordict is passed, we generate a single set of hyperparameters
        # Otherwise, we assume that the input tensordict contains all the relevant
        # parameters to get started.
        tensordict = self.gen_params(batch_size=self.batch_size)
    batch_size = 1 if len(tensordict.shape) == 0 else tensordict.shape[0]
    thetas = uniformly_sample_parameters_within_constraints(self.open_chain, batch_size).to(device=self.device)
    if batch_size == 1:
        thetas = thetas.squeeze(0)    
    #thetas = force_parameters_within_bounds(thetas)
    #TODO: randommize this better
    target_thetas = thetas + torch.randn(thetas.shape).to(self.device)
    #target_thetas = force_parameters_within_bounds(target_thetas)
    target_pose   = generate_random_target_pose(target_thetas, self.open_chain)
    if batch_size == 1:
        target_pose = target_pose.squeeze(0)
    #print("target_pose.shape", target_pose.shape)
    out = TensorDict(
        {
            "thetas_sin": thetas.sin(),
            "thetas_cos": thetas.cos(),
            "target_pose": target_pose,
            "params": tensordict["params"],
        },
        batch_size=tensordict.shape,
    )
    return out

In [171]:
def _make_spec(self, td_params):
    # Under the hood, this will populate self.output_spec["observation"]
    self.observation_spec = CompositeSpec(
        thetas_sin=BoundedTensorSpec(
            minimum=-torch.ones(thetas_count) * torch.pi,
            maximum= torch.ones(thetas_count) * torch.pi,
            shape=(thetas_count),
            dtype=torch.float32,
        ),
        thetas_cos=BoundedTensorSpec(
            minimum=-torch.ones(thetas_count) * torch.pi,
            maximum= torch.ones(thetas_count) * torch.pi,
            shape=(thetas_count),
            dtype=torch.float32,
        ),
        #TODO: bounds are wrong. They need to be the ones in the robot constraints
        target_pose=BoundedTensorSpec(
            minimum=-torch.ones(pose_count) * 10000.0,
            maximum= torch.ones(pose_count) * 10000.0,
            shape=(pose_count),
            dtype=torch.float32,
        ),
        # we need to add the "params" to the observation specs, as we want
        # to pass it at each step during a rollout
        params=make_composite_from_td(td_params["params"]),
        shape=(),
    )
    # since the environment is stateless, we expect the previous output as input.
    # For this, EnvBase expects some state_spec to be available
    self.state_spec = self.observation_spec.clone()
    # action-spec will be automatically wrapped in input_spec when
    # `self.action_spec = spec` will be called supported
    #TODO: bounds are wrong
    self.action_spec = BoundedTensorSpec(
        minimum=-torch.ones(thetas_count * 2) * 100000.0,
        maximum=+torch.ones(thetas_count * 2) * 100000.0,
        shape=(thetas_count * 2),
        dtype=torch.float32,
    )
    self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))


def make_composite_from_td(td):
    # custom funtion to convert a tensordict in a similar spec structure
    # of unbounded values.
    composite = CompositeSpec(
        {
            key: make_composite_from_td(tensor)
            if isinstance(tensor, TensorDictBase)
            else UnboundedContinuousTensorSpec(
                dtype=tensor.dtype, device=tensor.device, shape=tensor.shape
            )
            for key, tensor in td.items()
        },
        shape=td.shape,
    )
    return composite

In [172]:
def _set_seed(self, seed: Optional[int]):
    rng = torch.manual_seed(seed)
    self.rng = rng

In [173]:
def gen_params(batch_size=None) -> TensorDictBase:
    if batch_size is None:
        batch_size = []
    td = TensorDict(
        {
            "params": TensorDict(
                {
                    "max_theta_deltas": torch.ones(thetas_count) * torch.pi,
                },
                [],
            )
        },
        [],
    )
    if batch_size:
        td = td.expand(batch_size).contiguous()
    return td

In [174]:
class InverseKinematicsEnv(EnvBase):
    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 30,
    }
    batch_locked = False

    def __init__(self, open_chain=None, td_params=None, seed=None, device="cpu"):
        super().__init__()
        if td_params is None:
            td_params = self.gen_params()
        self.open_chain = open_chain
        super().__init__(device=device, batch_size=[])
        self._make_spec(td_params)
        if seed is None:
            seed = torch.empty((), dtype=torch.int64).random_().item()
        self.set_seed(seed)

    # Helpers: _make_step and gen_params
    gen_params = staticmethod(gen_params)
    _make_spec = _make_spec

    # Mandatory methods: _step, _reset and _set_seed
    _reset = _reset
    _step = _step#staticmethod(_step)
    _set_seed = _set_seed

In [175]:
env = InverseKinematicsEnv(open_chain=used_open_chain)
check_env_specs(env)

check_env_specs succeeded!


We can have a look at our specs to have a visual representation of the environment
signature:




In [176]:
print("observation_spec:", env.observation_spec)
print("state_spec:", env.state_spec)
print("reward_spec:", env.reward_spec)

observation_spec: CompositeSpec(
    thetas_sin: BoundedTensorSpec(
        shape=torch.Size([2]),
        space=ContinuousBox(
            minimum=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True), 
            maximum=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    thetas_cos: BoundedTensorSpec(
        shape=torch.Size([2]),
        space=ContinuousBox(
            minimum=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True), 
            maximum=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    target_pose: BoundedTensorSpec(
        shape=torch.Size([6]),
        space=ContinuousBox(
            minimum=Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, contiguous=True), 
            maxim

We can execute a couple of commands too to check that the output structure
matches what is expected.



In [177]:
td = env.reset()
print("reset tensordict", td)

reset tensordict TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        params: TensorDict(
            fields={
                max_theta_deltas: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        target_pose: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False),
        thetas_cos: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        thetas_sin: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)


We can run the :func:`env.rand_step` to generate
an action randomly from the ``action_spec`` domain. A tensordict containing
the hyperparams and the current state **must** be passed since our
environment is stateless. In stateful contexts, ``env.rand_step()`` works
perfectly too.




In [178]:
td = env.rand_step(td)
print("random step tensordict", td)

random step tensordict TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                params: TensorDict(
                    fields={
                        max_theta_deltas: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([]),
                    device=cpu,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                target_pose: Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.float32, is_shared=False),
                thetas_cos: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_s

In [179]:
def get_pose_and_error_pose(thetas, target_pose, open_chain):
    open_chain = open_chain.to(thetas.device)
    error_pose = compute_error_pose(open_chain, thetas, target_pose)
    transformation = open_chain.forward_transformation(thetas)
    pose = transforms.se3_log_map(transformation.get_matrix())
    return pose, error_pose

In [212]:
'''
transformed_env = TransformedEnv(
    env,
    # Unsqueezes the observations that we will concatenate
    UnsqueezeTransform(
        unsqueeze_dim=0,
        in_keys=["thetas", "target_pose"],
        in_keys_inv=["thetas", "target_pose"],
    ),
)
'''
transformed_env = TransformedEnv(env)
'''
class OnManifodErrorTransform(Transform):
    def __init__(self, in_keys, out_keys, open_chain):
        super().__init__(in_keys, out_keys)
        self.open_chain = open_chain
        
    def _apply_transform(self, obs: torch.Tensor) -> None:
        #print("-----------------")
        #print("obs.shape", obs.shape)
        #print("obs", obs)
        thetas_sin, thetas_cos, target_pose = None, None, None
        if len(obs.shape) == 1:
            target_pose = obs[:-thetas_count*2].unsqueeze(0)
            thetas_cos = obs[obs.shape[0] - (thetas_count*2):obs.shape[0] - thetas_count].unsqueeze(0)
            thetas_sin = obs[obs.shape[0] - thetas_count:].unsqueeze(0)            
        elif len(obs.shape) == 2:
            target_pose = obs[:,:-thetas_count*2]
            thetas_cos = obs[:,obs.shape[1] - (thetas_count*2):obs.shape[1] - thetas_count]
            thetas_sin = obs[:,obs.shape[1] - thetas_count:]
        #print(thetas_cos)
        #print(thetas_sin)
        #print("thetas.shape", thetas.shape)
        #print("target_pose.shape", target_pose.shape)
        thetas = torch.atan2(thetas_sin, thetas_cos)
        #print(thetas)
        pose, error_pose = get_pose_and_error_pose(thetas, target_pose, self.open_chain)
        # pose decomposition
        pose_linear =  pose[:, :3]
        pose_angular_cos =  pose[:, 3:].cos()
        pose_angular_sin =  pose[:, 3:].sin()
        # error pose decomposition 
        error_pose_linear =  error_pose[:, :3]
        error_pose_angular_cos =  error_pose[:, 3:].cos()
        error_pose_angular_sin =  error_pose[:, 3:].sin()
        
        #print("---------------------")
        #print("obs", obs)
        #print("error_pose", error_pose)
        #print("pose", pose)
        #print("target_pose", target_pose)
        #print("---------------------")
        
        #manifold_error = torch.cat([pose_linear, pose_angular_cos, pose_angular_sin, error_pose_linear, error_pose_angular_cos, error_pose_angular_sin, thetas_sin, thetas_cos], 1)
        manifold_error = torch.cat([pose_linear, pose_angular_cos , pose_angular_sin, error_pose_linear, error_pose_angular_cos, error_pose_angular_sin, thetas_cos, thetas_sin], 1)        
        #print(manifold_error.shape, on_manifold_count)
        #print("manifold_error", manifold_error)
        #print("obs", obs.shape, manifold_error.shape)
        if len(obs.shape) == 1:
            return manifold_error.squeeze(0)
        else:
            #print("---------------------")
            #print("error_pose", error_pose.abs()[:, :])
            #print("error_pose max", error_pose[:,-1].max().item())
            #print("error_pose min", error_pose[:,-1].min().item())
            #print("---------------------")
            return manifold_error

    # _apply_to_composite will execute the observation spec transform across all
    # in_keys/out_keys pairs and write the result in the observation_spec which
    # is of type ``Composite``
        
    #TODO minimum and maximum are incorrect!!
    @_apply_to_composite
    def transform_observation_spec(self, observation_spec):
        return BoundedTensorSpec(
            minimum=-10000,
            maximum=10000,
            shape=(on_manifold_count),
            dtype=observation_spec.dtype,
            device=observation_spec.device,
        )
        
cat_transform = CatTensors(
    in_keys=["thetas_sin", "thetas_cos", "target_pose"], dim=-1, out_key="observation", del_keys=False
)
transformed_env.append_transform(cat_transform)
on_manifold_error = OnManifodErrorTransform(in_keys=["observation"], 
                                            out_keys=["on_manifold_error"],
                                            open_chain=used_open_chain)
transformed_env.append_transform(on_manifold_error)
'''

'\nclass OnManifodErrorTransform(Transform):\n    def __init__(self, in_keys, out_keys, open_chain):\n        super().__init__(in_keys, out_keys)\n        self.open_chain = open_chain\n        \n    def _apply_transform(self, obs: torch.Tensor) -> None:\n        #print("-----------------")\n        #print("obs.shape", obs.shape)\n        #print("obs", obs)\n        thetas_sin, thetas_cos, target_pose = None, None, None\n        if len(obs.shape) == 1:\n            target_pose = obs[:-thetas_count*2].unsqueeze(0)\n            thetas_cos = obs[obs.shape[0] - (thetas_count*2):obs.shape[0] - thetas_count].unsqueeze(0)\n            thetas_sin = obs[obs.shape[0] - thetas_count:].unsqueeze(0)            \n        elif len(obs.shape) == 2:\n            target_pose = obs[:,:-thetas_count*2]\n            thetas_cos = obs[:,obs.shape[1] - (thetas_count*2):obs.shape[1] - thetas_count]\n            thetas_sin = obs[:,obs.shape[1] - thetas_count:]\n        #print(thetas_cos)\n        #print(thetas_s

In [213]:
#cat_transform = CatTensors(
#    in_keys=["sin", "cos", "thdot"], dim=-1, out_key="observation", del_keys=False
#)
#transformed_env.append_transform(cat_transform)

Once more, let us check that our env specs match what is received:



In [214]:
check_env_specs(transformed_env)

check_env_specs succeeded!


## Executing a rollout

Executing a rollout is a succession of simple steps:

* reset the environment
* while some condition is not met:

  * compute an action given a policy
  * execute a step given this action
  * collect the data
  * make a MDP step

* gather the data and return

These operations have been convinently wrapped in the :func:`EnvBase.rollout`
method, from which we provide a simplified version here below.



In [215]:
def simple_rollout(steps=100):
    # preallocate:
    data = TensorDict({}, [steps])
    # reset
    _data = transformed_env.reset()
    for i in range(steps):
        _data["action"] = transformed_env.action_spec.rand()
        _data = transformed_env.step(_data)
        data[i] = _data
        _data = step_mdp(_data, keep_other=True)
    return data


print("data from rollout:", simple_rollout(100))

data from rollout: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([100, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                params: TensorDict(
                    fields={
                        max_theta_deltas: Tensor(shape=torch.Size([100, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([100]),
                    device=None,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                target_pose: Tensor(shape=torch.Size([100, 6]), device=cpu, dtype=torch.float32, is_shared=False),
                thetas_cos: Tensor(shape=torch.Size([100, 2]), de

## Batching computations

The last unexplored end of our tutorial is the ability that we have to
batch computations in TorchRL. Because our environment does not
make any assumptions regarding the input data shape, we can seamlessly
execute it over batches of data. Even better: for non-batch-locked
environments such as our Pendulum, we can change the batch size on the fly
without recreating the env.
To do this, we just generate parameters with the desired shape.




In [216]:
batch_size = 10  # number of environments to be executed in batch
td = transformed_env.reset(transformed_env.gen_params(batch_size=[batch_size]))
print(f"reset (batch size of {batch_size})", td)
td = transformed_env.rand_step(td)
print(f"rand step (batch size of {batch_size})", td)

reset (batch size of 10) TensorDict(
    fields={
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        params: TensorDict(
            fields={
                max_theta_deltas: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False),
        target_pose: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
        thetas_cos: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        thetas_sin: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)
rand step (batch size of 10) TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 1])

executing a rollout with a batch of data requires us to reset the env
out of the rollout function, since we need to define the batch_size
dynamically and this is not supported by :func:`EnvBase.rollout`:




In [217]:
rollout = transformed_env.rollout(
    3,
    auto_reset=False,  # we're executing the reset out of the ``rollout`` call
    tensordict=transformed_env.reset(transformed_env.gen_params(batch_size=[batch_size])),
)
print("rollout of len 3 (batch size of 10):", rollout)

rollout of len 3 (batch size of 10): TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                params: TensorDict(
                    fields={
                        max_theta_deltas: Tensor(shape=torch.Size([10, 3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([10, 3]),
                    device=None,
                    is_shared=False),
                reward: Tensor(shape=torch.Size([10, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                target_pose: Tensor(shape=torch.Size([10, 3, 6]), device=cpu, dtype=torch.float32, is_shared=False),
                thetas_cos: Tenso

## Training a simple policy

In this example, we will train a simple policy using the reward as a
differentiable objective (i.e. a negative loss).
We will take advantage of the fact that our dynamic system is fully
differentiable to backpropagate through the trajectory return and adjust the
weights of our policy to maximise this value directly. Of course, in many
settings many of the assumptions we make do not hold, such as
differentiability of the system and full access to the underlying mechanics.

Still, this is a very simple example that showcases how a training loop can
be coded with a custom environment in TorchRL.

Let us first write the policy network:




In [262]:
torch.manual_seed(0)
transformed_env.set_seed(0)

net = nn.Sequential(
    nn.Linear(in_features=1, out_features=4, bias=True),
    nn.Tanh(),
    nn.Linear(in_features=4, out_features=thetas_count, bias=True)
).cuda()

import torch.nn.functional as F

class DummyActor(nn.Module):
    def __init__(
        self,
    ):
        super(DummyActor, self).__init__()
        self.fc1 = nn.Linear(in_features=on_manifold_count, out_features=1024, bias=True)
        self.fc2 = nn.Linear(in_features=1024+on_manifold_count, out_features=1024, bias=True)
        self.fc3 = nn.Linear(in_features=1024+on_manifold_count, out_features=512, bias=True)
        self.fc4 = nn.Linear(in_features=512+on_manifold_count, out_features=512, bias=True)
        self.fc5 = nn.Linear(in_features=512+on_manifold_count, out_features=128, bias=True)
        self.fc6 = nn.Linear(in_features=128+on_manifold_count, out_features=64, bias=True)
        self.fc_cos = nn.Linear(in_features=64+on_manifold_count, out_features=1, bias=True)
        self.fc_sin = nn.Linear(in_features=64+on_manifold_count, out_features=1, bias=True)

    def forward(self, thetas_sin, thetas_cos, target_pose):
        #print(thetas_sin)
        #print(thetas_cos)
        #print(target_pose)
        thetas = torch.atan2(thetas_sin, thetas_cos)
        pose, error_pose = get_pose_and_error_pose(thetas, target_pose, used_open_chain)
        #print(error_pose)
        #print("state", thetas_sin.shape, thetas_cos.shape, target_pose.shape)
        # pose decomposition
        pose_linear =  pose[:, :3]
        pose_angular_cos =  pose[:, 3:].cos()
        pose_angular_sin =  pose[:, 3:].sin()
        # error pose decomposition 
        error_pose_linear =  error_pose[:, :3]
        error_pose_angular_cos =  error_pose[:, 3:].cos()
        error_pose_angular_sin =  error_pose[:, 3:].sin()
        
        manifold_error = torch.cat([pose_linear, pose_angular_cos , pose_angular_sin, error_pose_linear, error_pose_angular_cos, error_pose_angular_sin, thetas_cos, thetas_sin], 1)      
        #print(manifold_error.shape, on_manifold_count)
        x = torch.cat([F.tanh(self.fc1(manifold_error)), manifold_error],1)
        x = torch.cat([F.tanh(self.fc2(x)), manifold_error],1)
        x = torch.cat([F.tanh(self.fc3(x)), manifold_error],1)
        x = torch.cat([F.tanh(self.fc4(x)), manifold_error],1)
        x = torch.cat([F.tanh(self.fc5(x)), manifold_error],1)
        x = torch.cat([F.tanh(self.fc6(x)), manifold_error],1)
        cos = self.fc_cos(x).cos()
        sin = self.fc_sin(x).sin()
        return torch.cat([sin, cos], 1)

net = DummyActor().cuda()

and our optimizer:




### Training loop

We will successively:

* generate a trajectory
* sum the rewards
* backpropagate through the graph defined by these operations
* clip the gradient norm and make an optimization step
* repeat

At the end of the training loop, we should have a final reward close to 0
which demonstrates that the pendulum is upward and still as desired.




In [None]:
policy = TensorDictModule(
    net,
    in_keys=["thetas_sin", "thetas_cos", "target_pose"],#["on_manifold_error"], #["target_pose"],
    out_keys=["action"],
).cuda()
optim = torch.optim.Adam(policy.parameters(), lr=2e-8)

batch_size = 1024
iterations = 20000_000
pbar = tqdm.tqdm(range(iterations // batch_size))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, iterations)
logs = defaultdict(list)

# current_env = env 
current_env = transformed_env.cuda()
for _ in pbar:
    init_td = current_env.reset(current_env.gen_params(batch_size=[batch_size])).cuda()
    rollout = current_env.rollout(10, policy, tensordict=init_td, auto_reset=False).cuda()
    #print(rollout)
    traj_return = rollout["next", "reward"].mean()
    (-traj_return).backward()
    gn = torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean().item())
    scheduler.step()


def plot():
    import matplotlib
    from matplotlib import pyplot as plt

    is_ipython = "inline" in matplotlib.get_backend()
    if is_ipython:
        from IPython import display

    with plt.ion():
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.plot(logs["return"])
        plt.title("returns")
        plt.xlabel("iteration")
        plt.subplot(1, 2, 2)
        plt.plot(logs["last_reward"])
        plt.title("last reward")
        plt.xlabel("iteration")
        if is_ipython:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        plt.show()


plot()

  0%|                                                                                                                                                                                                 | 0/19531 [00:00<?, ?it/s]

tensor([[ 1.5293e-08,  2.1085e-08,  2.3731e-08, -1.7189e+00, -2.3063e-01,
          9.5181e-01],
        [-2.6863e-08, -1.7829e-08, -5.2087e-09, -3.7674e-02,  2.4164e-01,
         -2.1069e-01],
        [ 3.7101e-08,  1.5985e-08, -1.1906e-08, -6.9569e-01, -1.0387e+00,
         -2.7239e+00],
        ...,
        [-1.1692e-08,  1.1671e-08,  8.6188e-09, -2.2443e-01,  1.5274e-01,
         -5.1895e-01],
        [ 1.0299e-08,  6.0083e-09, -1.1187e-08, -4.0940e+00,  1.6028e+00,
          1.2373e+00],
        [ 5.3793e-09, -1.9685e-08, -7.5560e-09, -7.0278e-01,  2.2770e-01,
         -1.7690e+00]], device='cuda:0')
tensor([[ 1.5517e-08,  1.8745e-08,  2.7232e-08, -1.9169e+00, -1.0499e-01,
          9.9216e-01],
        [-1.5659e-08, -6.5553e-10, -1.3448e-08, -3.5228e-02,  3.5276e-01,
         -1.0042e-01],
        [ 1.9840e-08, -4.9988e-09, -4.3673e-09, -5.3543e-01, -6.0340e-01,
         -2.6082e+00],
        ...,
        [-7.7682e-09,  1.5036e-08,  1.3388e-08,  8.9535e-02, -9.1402e-02,
         

reward: -3.0201, last reward: -3.4927, gradient norm:  3.132:   0%|                                                                                                                         | 1/19531 [00:00<2:28:54,  2.19it/s]

tensor([[ 1.4913e-09, -1.8469e-10, -5.7627e-09, -1.4494e-02,  5.0418e-01,
          9.0235e-01],
        [-1.0935e-08, -1.0439e-08, -9.4621e-09,  9.9725e-01, -7.1647e-01,
         -7.8964e-01],
        [-2.2584e-08, -4.0277e-08,  9.4199e-09, -1.6289e+00,  7.1680e-01,
         -1.3365e+00],
        ...,
        [-9.2242e-09, -3.8547e-08,  9.8892e-09,  4.3221e-01, -4.9574e-01,
         -2.6841e-01],
        [-9.2309e-09,  3.7170e-08,  2.1526e-08,  4.1278e-01, -7.9855e-01,
          8.8736e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  6.7406e-01,  2.9792e-01,
         -7.4248e-01]], device='cuda:0')
tensor([[-1.5209e-08, -8.9685e-09, -2.2888e-09,  8.0877e-02,  7.8847e-01,
          1.1246e+00],
        [ 1.7359e-08, -4.3100e-09,  1.3527e-08,  1.0579e+00, -9.1941e-01,
         -5.9806e-01],
        [-2.1151e-08, -1.7326e-08, -1.2246e-08, -1.5859e+00,  7.2497e-01,
         -1.3314e+00],
        ...,
        [ 7.7965e-09,  6.3502e-10,  8.2734e-10,  2.8094e-01, -3.8681e-01,
         

reward: -3.1637, last reward: -3.7115, gradient norm:  2.291:   0%|                                                                                                                         | 2/19531 [00:00<2:32:51,  2.13it/s]

tensor([[-3.9225e-08,  1.8277e-08, -7.6577e-09,  6.2803e-02,  1.2001e-01,
          1.7576e-01],
        [ 1.2457e-09, -3.0749e-08, -8.7557e-09, -3.8586e-02,  3.6724e-01,
         -4.7658e-01],
        [-6.0970e-09, -2.7162e-08,  3.1314e-08,  1.3550e-01, -2.1603e+00,
          1.4725e-01],
        ...,
        [-1.1914e-08,  1.5652e-08,  9.4801e-09, -2.1059e-01, -4.5798e-01,
          2.0183e-01],
        [ 2.9451e-08, -4.0263e-08,  9.6351e-09,  1.8976e-01, -1.7179e-01,
         -2.9406e-01],
        [-6.0605e-08,  2.0874e-08, -2.0366e-08,  1.1653e+00, -1.0602e-01,
         -2.5224e-02]], device='cuda:0')
tensor([[-2.5153e-08,  1.5771e-08, -5.3704e-09,  9.7508e-02,  2.1049e-01,
          7.9027e-02],
        [-1.3949e-08, -6.2841e-08, -6.8711e-09, -1.1805e-02,  5.3286e-01,
         -3.0943e-01],
        [ 3.5534e-09,  1.3107e-09,  4.2567e-08, -8.9193e-02, -2.3706e+00,
          2.2052e-01],
        ...,
        [ 4.5792e-08,  2.1774e-08,  2.1239e-08, -1.9350e-01, -3.8147e-01,
         

reward: -3.1059, last reward: -3.6517, gradient norm:  2.6:   0%|                                                                                                                           | 3/19531 [00:01<2:29:56,  2.17it/s]

tensor([[ 5.3635e-09, -2.0262e-09,  7.5036e-09, -2.9988e-01, -4.2564e-01,
          3.6363e-01],
        [ 3.9462e-08, -5.2029e-08,  1.6224e-08,  6.3193e-01, -1.0810e+00,
         -1.6184e-01],
        [-1.2890e-08, -1.1194e-08,  1.7281e-09, -1.4141e+00, -9.3791e-02,
         -5.4504e-01],
        ...,
        [ 2.7181e-08, -3.8187e-08,  1.5772e-08,  9.5032e-01,  2.0907e-01,
         -1.9484e+00],
        [-3.4044e-10,  6.8513e-09, -1.0458e-08,  2.1299e-01,  2.2305e-02,
         -1.1916e-01],
        [-7.5185e-09, -2.5744e-09,  9.4726e-09, -4.8415e-03,  7.3450e-02,
         -6.7390e-01]], device='cuda:0')
tensor([[ 3.7727e-08,  2.6416e-09,  5.2831e-09, -1.5756e-01, -1.6563e-01,
          6.6248e-01],
        [ 2.5814e-08, -5.9823e-08,  2.2278e-09,  5.7298e-01, -1.4445e+00,
          9.1590e-02],
        [-9.6981e-09, -1.5731e-08,  9.5353e-10, -1.0890e+00, -2.4969e-01,
         -8.7588e-01],
        ...,
        [ 2.6578e-08, -3.8261e-08,  1.6194e-08,  9.7174e-01,  2.0367e-01,
         

reward: -3.0898, last reward: -3.5722, gradient norm:  2.585:   0%|                                                                                                                         | 4/19531 [00:01<2:33:27,  2.12it/s]

tensor([[-4.0690e-08,  5.5767e-08, -1.2478e-07, -3.8629e+00,  1.3975e+00,
          1.8848e+00],
        [ 3.0040e-08,  1.9407e-08, -4.4756e-08,  3.4446e-01,  6.5639e-01,
          2.7195e-01],
        [-5.1161e-08,  1.0164e-08,  5.7596e-09,  3.8405e-01, -5.1237e-01,
         -8.6593e-01],
        ...,
        [-1.0536e-08,  4.5677e-08, -1.3465e-08, -2.9015e-01,  7.5532e-01,
         -2.2540e-01],
        [-2.5540e-08, -7.5019e-08, -1.1041e-09,  5.5037e-02, -2.4855e-01,
         -2.9751e-02],
        [ 1.1440e-08, -1.7766e-08, -1.3069e-08,  4.6712e-01,  7.6710e-01,
         -5.9193e-01]], device='cuda:0')
tensor([[-4.3759e-08,  5.4085e-08, -1.4574e-07, -3.8108e+00,  1.0806e+00,
          1.7933e+00],
        [ 3.9086e-08,  1.6632e-09, -3.6253e-08,  4.0035e-01,  7.0427e-01,
          3.3295e-01],
        [-3.2862e-08,  2.3218e-08,  1.9855e-08,  4.3370e-01, -7.2870e-01,
         -6.4236e-01],
        ...,
        [ 1.3780e-08, -5.9009e-08, -3.6584e-09, -1.7329e-01,  6.0634e-01,
         

reward: -3.1352, last reward: -3.6875, gradient norm:  2.588:   0%|                                                                                                                         | 5/19531 [00:02<2:35:32,  2.09it/s]

tensor([[-5.2590e-08,  1.1709e-08, -5.0255e-10,  5.7224e-01, -7.8094e-01,
         -6.7483e-01],
        [ 7.7959e-08, -1.4488e-08, -8.5790e-09,  5.7904e-02, -5.8368e-03,
         -3.1229e-01],
        [ 6.0038e-08, -1.4558e-07, -1.1405e-08, -6.5825e-02,  2.2570e-01,
         -3.9142e+00],
        ...,
        [ 4.0223e-09, -1.0592e-08,  8.9528e-09, -4.9103e-01, -1.9266e-01,
         -6.4239e-01],
        [-1.7627e-09, -6.0688e-10,  3.0314e-08,  1.2720e+00, -1.1220e+00,
         -4.9463e-02],
        [-3.7091e-09, -2.9133e-09, -1.5605e-08,  2.1303e-01,  4.7566e-01,
         -4.2310e-01]], device='cuda:0')
tensor([[-3.0158e-08,  2.6516e-08,  1.6064e-08,  5.8728e-01, -1.0547e+00,
         -4.2330e-01],
        [ 5.8838e-08,  1.7100e-08, -5.4523e-09,  2.7639e-01, -5.9725e-02,
         -8.7070e-02],
        [ 3.6919e-08, -1.2094e-07, -9.0927e-09, -6.6503e-02,  2.2867e-01,
         -3.9124e+00],
        ...,
        [ 1.5040e-09, -9.5036e-09,  5.2927e-09, -9.0561e-01, -1.5958e-01,
         

reward: -3.0645, last reward: -3.4999, gradient norm:  2.648:   0%|                                                                                                                         | 6/19531 [00:02<2:31:44,  2.14it/s]

tensor([[ 4.0179e-07,  3.3637e-07,  1.2746e-07, -2.9999e-02, -4.9281e-02,
         -4.6338e+00],
        [ 2.6083e-10,  1.1135e-09,  6.1541e-09,  3.5154e-01, -1.0479e-01,
         -3.5654e-01],
        [-3.5994e-09, -1.4156e-08, -1.9983e-08,  9.1785e-01,  9.5812e-01,
         -8.9995e-01],
        ...,
        [-5.6733e-08,  2.9616e-08,  5.9482e-09, -2.7684e-01,  4.9049e-02,
         -5.9496e-01],
        [-2.2373e-08,  2.4671e-08, -1.8051e-08, -3.0991e-02,  1.3028e-01,
          6.8568e-01],
        [ 2.6771e-08,  6.6900e-09, -1.2897e-08,  5.6920e-01,  3.2392e-01,
          3.6419e-01]], device='cuda:0')
tensor([[ 3.2416e-07,  1.9483e-07,  6.1978e-08, -4.3184e-01, -5.4767e-01,
         -4.3245e+00],
        [-2.5967e-10,  7.9763e-09,  8.4163e-09,  1.2168e-01, -5.2303e-02,
         -1.2788e-01],
        [ 2.2135e-08, -2.9903e-08, -4.3500e-08,  9.5609e-01,  1.1861e+00,
         -9.9187e-01],
        ...,
        [ 2.5603e-08,  2.7379e-08, -3.1456e-09, -2.6622e-02,  8.2347e-03,
         

reward: -3.1711, last reward: -3.6946, gradient norm:  2.404:   0%|                                                                                                                         | 7/19531 [00:03<2:29:31,  2.18it/s]

tensor([[-8.1700e-08, -3.6568e-08, -2.8209e-08,  1.4334e+00,  9.4588e-01,
         -1.0085e-01],
        [ 1.8858e-08,  3.5505e-08, -7.0054e-09, -2.2751e-01,  2.8907e-01,
          5.2881e-01],
        [-8.1500e-09,  5.3133e-09,  2.0547e-08,  1.2515e+00, -9.5232e-01,
          4.2135e-01],
        ...,
        [-3.3600e-09,  6.8013e-08, -1.1038e-08,  2.4091e-02, -1.9397e-01,
          1.1765e+00],
        [-3.0219e-08,  8.8396e-09,  1.3767e-08,  1.2087e+00, -1.2720e+00,
         -5.8101e-01],
        [ 9.8492e-13, -1.9646e-10, -2.1013e-09,  1.8388e-01,  2.2232e-02,
          6.8901e-01]], device='cuda:0')
tensor([[-6.6519e-08, -7.9765e-08,  1.0331e-07, -3.8387e+00, -2.3225e+00,
          1.0308e-01],
        [ 3.1566e-08, -2.8632e-09, -1.9954e-09, -1.1973e-01,  1.7864e-01,
          6.8485e-01],
        [-3.4368e-09,  1.6880e-08,  1.8389e-08,  1.0123e+00, -9.4394e-01,
          6.1667e-01],
        ...,
        [ 9.6988e-10, -2.3622e-08, -6.5021e-09,  1.4017e-02,  2.1410e-01,
         

reward: -3.1078, last reward: -3.6050, gradient norm:  2.647:   0%|                                                                                                                         | 8/19531 [00:03<2:35:22,  2.09it/s]

tensor([[-3.4141e-09, -2.5847e-08,  1.0595e-07, -3.5060e+00,  4.5662e-01,
          7.8854e-01],
        [ 3.0608e-08,  3.7026e-08, -5.0552e-09,  3.0895e-01,  1.3230e-01,
          7.2693e-01],
        [ 4.5172e-09, -2.7235e-08,  3.9268e-08, -2.1951e+00, -3.4847e+00,
         -8.1678e-01],
        ...,
        [-2.3090e-09, -5.7119e-08, -3.1228e-08, -8.9059e-01,  6.3048e-01,
         -1.2064e+00],
        [ 1.4909e-08,  8.1265e-09,  5.0548e-09, -1.5079e-01, -4.2100e-01,
         -1.1182e+00],
        [-7.2016e-09, -3.8657e-09, -1.3976e-09, -9.8668e-02,  1.8187e-01,
         -9.8789e-01]], device='cuda:0')
tensor([[-3.4659e-09, -2.6591e-08,  4.9352e-08, -3.4526e+00,  6.3803e-01,
          6.0757e-01],
        [-4.1534e-08,  2.2682e-08,  1.2639e-09,  5.5571e-01,  1.6573e-01,
          9.3729e-01],
        [-2.2207e-08, -2.7407e-08,  5.8307e-08, -2.3948e+00, -3.3739e+00,
         -1.0763e+00],
        ...,
        [-3.4681e-08, -1.9161e-08, -1.8414e-08, -6.8227e-01,  5.7677e-01,
         

reward: -3.0903, last reward: -3.6121, gradient norm:  2.983:   0%|                                                                                                                         | 9/19531 [00:04<2:40:29,  2.03it/s]

tensor([[ 2.1457e-08, -6.7616e-08,  3.3919e-09,  6.2757e-01, -5.0281e-01,
          7.6145e-01],
        [ 1.8654e-08, -1.8502e-08,  7.8037e-09,  4.2892e-01, -6.8132e-02,
         -3.3656e-01],
        [-1.0300e-08,  2.5028e-08, -7.4287e-10, -8.5751e-02,  5.9757e-02,
          3.9120e-01],
        ...,
        [ 4.9081e-09,  5.2216e-08,  1.7381e-08, -5.1612e-01,  4.0185e-02,
         -3.7247e-01],
        [ 4.0674e-09, -2.0494e-07,  3.4598e-08,  1.5496e-01, -2.4831e-01,
         -4.1997e+00],
        [-1.3739e-08, -6.2583e-09, -3.9283e-09,  3.3408e-01,  1.6504e-01,
         -5.1838e-01]], device='cuda:0')
tensor([[-1.0646e-08, -7.3358e-08,  2.0991e-09,  7.8743e-01, -8.9912e-01,
          1.0236e+00],
        [ 1.7358e-08, -1.0623e-08,  1.0030e-08,  4.0893e-01, -6.9004e-02,
         -3.1803e-01],
        [-1.2494e-08,  1.1333e-08,  5.6151e-11,  3.9750e-02, -3.2535e-02,
          5.4602e-01],
        ...,
        [-5.1047e-08,  6.1916e-08,  1.0360e-08, -3.2081e-01,  5.5874e-02,
         

reward: -3.0883, last reward: -3.6209, gradient norm:  2.531:   0%|                                                                                                                        | 10/19531 [00:04<2:39:09,  2.04it/s]

tensor([[-3.4267e-08,  4.6477e-08, -6.2334e-09, -1.2978e-03, -5.5961e-02,
          9.9235e-02],
        [ 8.1649e-08, -1.7284e-07, -3.4554e-08,  4.5173e-01,  9.9404e-01,
         -3.8083e+00],
        [-8.5789e-09, -1.7314e-08,  1.8906e-10, -1.6697e-01, -1.7651e-01,
         -9.3088e-01],
        ...,
        [ 1.0639e-08,  2.3325e-08, -5.4118e-09, -4.4900e+00, -8.1767e-01,
         -7.0959e-01],
        [-1.0067e-07,  3.5099e-08, -6.3430e-09,  8.5161e-01, -2.9422e-01,
          5.3301e-01],
        [ 8.1079e-09, -1.9485e-08,  6.2049e-09, -6.6847e-01, -3.5790e-01,
         -6.7667e-01]], device='cuda:0')
tensor([[-1.8555e-08,  1.5256e-08, -7.4241e-09, -4.1253e-05,  5.4911e-03,
          3.7817e-02],
        [ 6.0673e-08, -8.4449e-08, -1.0582e-09,  1.0571e-01,  1.4399e-01,
         -3.5725e+00],
        [-7.4350e-11,  7.5992e-11,  3.5770e-09,  4.6420e-02,  3.6551e-02,
         -6.4280e-01],
        ...,
        [ 7.8236e-10,  3.2739e-08, -7.5241e-09, -3.9759e+00, -1.5416e+00,
         

reward: -3.1644, last reward: -3.6461, gradient norm:  2.35:   0%|                                                                                                                         | 11/19531 [00:05<2:35:28,  2.09it/s]

tensor([[ 3.2160e-09, -2.5001e-08,  2.5984e-08,  4.8531e-01, -8.8873e-01,
         -1.0274e-01],
        [-2.3070e-08, -2.9750e-08,  2.6525e-09, -2.8609e-01, -3.3558e-01,
         -5.0993e-01],
        [-4.6019e-08, -5.9974e-08,  1.5986e-08,  2.0356e-02, -1.3559e-02,
         -2.3402e+00],
        ...,
        [-1.2216e-08,  1.0936e-08, -1.6108e-10, -5.6585e-01,  9.0804e-02,
          6.9935e-01],
        [-2.1259e-08, -2.8208e-08, -4.4114e-09, -2.5945e-01,  1.9143e-01,
         -1.8497e+00],
        [ 2.4767e-09, -4.3283e-09, -8.2772e-09, -9.2028e-01,  5.4447e-01,
         -2.0487e-01]], device='cuda:0')
tensor([[ 1.0418e-08, -2.1254e-08,  3.9259e-08,  4.8365e-01, -9.9986e-01,
         -1.2479e-02],
        [-3.3439e-08, -5.4797e-08,  5.0643e-09, -4.5230e-02, -3.6258e-02,
         -1.4000e-01],
        [-5.8949e-08,  8.1197e-09,  1.0361e-08,  3.5998e-01, -3.6973e-01,
         -1.8824e+00],
        ...,
        [ 5.9081e-09,  1.8842e-08, -4.1909e-10, -4.7022e-01,  5.2597e-02,
         

reward: -3.0831, last reward: -3.5985, gradient norm:  2.515:   0%|                                                                                                                        | 12/19531 [00:05<2:33:55,  2.11it/s]

tensor([[-1.1973e-08,  2.4638e-08, -3.3798e-09,  4.4071e-01,  9.7277e-01,
          9.7391e-01],
        [-1.9185e-08,  2.5701e-08,  1.4748e-08, -4.6089e-01, -5.5696e-01,
          9.0660e-01],
        [-3.2853e-08, -1.4194e-07,  4.5960e-08, -6.0402e-01, -9.5221e-01,
         -4.2019e+00],
        ...,
        [-4.7656e-10,  1.4559e-09, -1.3602e-09, -1.5630e+00, -6.3252e-01,
          3.8882e-01],
        [ 8.2556e-09,  8.6462e-09, -2.7185e-08,  8.0907e-01,  1.2118e+00,
          9.5214e-01],
        [ 5.0805e-08, -4.1216e-09,  3.6779e-09, -7.1651e-02,  1.1731e-01,
         -4.3838e-01]], device='cuda:0')
tensor([[-1.6531e-08,  5.5284e-09, -2.0597e-08,  3.6393e-01,  1.3083e+00,
          6.4610e-01],
        [-2.7883e-08,  9.2448e-09,  8.6597e-09, -2.2526e-01, -3.6482e-01,
          6.6007e-01],
        [ 8.8463e-09,  5.9399e-08, -1.8459e-08,  3.2719e-01,  7.9167e-01,
          1.4676e+00],
        ...,
        [ 2.5686e-09, -4.5038e-09,  6.4757e-09, -1.1534e+00, -7.0118e-01,
         

reward: -3.0770, last reward: -3.6246, gradient norm:  2.169:   0%|                                                                                                                        | 13/19531 [00:06<2:33:02,  2.13it/s]

tensor([[-7.4649e-09, -4.2368e-09, -7.0399e-10, -1.5708e+00,  5.6652e-01,
         -6.7940e-01],
        [-3.6793e-08, -3.3940e-09,  1.3365e-08, -4.1999e-01, -2.7639e-01,
          3.3836e-01],
        [ 1.4050e-11,  4.4624e-11, -1.3955e-09, -5.4457e-02,  3.4908e-02,
         -1.5610e+00],
        ...,
        [-5.7855e-09, -2.2335e-08,  1.5254e-08,  4.5879e-02,  1.9479e-02,
         -9.7006e-01],
        [ 1.5169e-08,  4.5653e-08, -3.3070e-08, -3.6766e+00,  2.2317e+00,
          8.9659e-01],
        [ 2.6700e-08,  6.6021e-09,  1.7485e-09, -3.1685e-01, -2.4930e-01,
          6.5348e-01]], device='cuda:0')
tensor([[-6.6392e-09, -6.9143e-09,  2.1025e-09, -1.3299e+00,  2.2866e-01,
         -1.0598e+00],
        [-2.6879e-08,  1.4749e-08,  5.8354e-09, -2.7508e-01, -1.4515e-01,
          5.3435e-01],
        [ 6.2895e-09,  5.3644e-09, -1.0860e-09, -1.9360e-01,  1.4814e-01,
         -1.3890e+00],
        ...,
        [-2.3306e-09, -2.8039e-08,  1.9748e-08,  8.6545e-02,  3.8961e-02,
         

reward: -3.1352, last reward: -3.5757, gradient norm:  3.021:   0%|                                                                                                                        | 14/19531 [00:06<2:32:32,  2.13it/s]

tensor([[-2.0787e-09, -5.9923e-09,  1.9603e-08, -5.8260e-01, -5.4332e-01,
         -1.2851e-01],
        [ 9.5261e-09,  3.1483e-08,  8.9328e-09,  4.1151e-02, -2.8944e-01,
         -2.6114e-01],
        [-1.9138e-08,  1.0381e-08, -8.2728e-09, -2.8765e-02,  3.6259e-01,
         -5.1957e-01],
        ...,
        [-7.7948e-09,  4.4892e-08,  7.6824e-09,  8.6660e-02, -5.9343e-01,
         -2.8652e-02],
        [ 6.7835e-09,  1.0590e-08,  1.1246e-08, -1.0606e+00, -1.4312e+00,
          1.6596e-01],
        [-1.3314e-07, -1.2748e-07, -1.1278e-07,  6.7174e-01,  1.9989e+00,
         -4.0513e+00]], device='cuda:0')
tensor([[ 1.0103e-08,  1.2388e-08,  9.4993e-09, -3.5229e-01, -2.2080e-01,
          2.4202e-01],
        [ 1.2700e-08,  1.5177e-08,  8.9272e-09,  2.1565e-03, -7.0329e-02,
         -4.1836e-02],
        [-3.1683e-08,  3.9607e-08,  4.3207e-09, -3.7527e-03,  2.3475e-01,
         -3.9712e-01],
        ...,
        [ 1.2558e-08,  1.4439e-08,  1.8915e-08,  1.2776e-02, -8.6179e-01,
         

reward: -3.1057, last reward: -3.6452, gradient norm:  2.638:   0%|                                                                                                                        | 15/19531 [00:07<2:38:11,  2.06it/s]

tensor([[ 6.4593e-09, -4.6162e-08,  1.0206e-09, -2.2232e-01,  3.1702e-01,
         -6.3933e-01],
        [ 8.1047e-10, -4.0728e-09, -3.0259e-09,  2.3621e-01,  6.9731e-02,
         -3.5922e-01],
        [-1.0624e-09, -5.8616e-09,  7.8876e-10, -1.3933e+00, -3.7896e-02,
          4.7713e-01],
        ...,
        [-1.6930e-09,  6.9263e-09, -3.1292e-09,  3.5109e-01, -1.4957e-01,
         -4.1402e-01],
        [-2.8832e-08, -1.1504e-07, -2.8333e-08, -9.0152e-01,  7.5798e-01,
         -1.3586e+00],
        [ 7.0322e-10,  1.5788e-08,  2.6468e-09, -8.8983e-01, -3.4843e-01,
          1.3549e-01]], device='cuda:0')
tensor([[-2.2230e-08, -3.6072e-09,  2.5949e-09, -2.2918e-02,  4.7726e-02,
         -3.1937e-01],
        [ 1.0748e-09, -4.1815e-09, -2.8338e-09,  3.2573e-01,  1.1360e-01,
         -4.5357e-01],
        [-9.2646e-10, -3.8520e-09,  3.3892e-09, -1.7044e+00,  2.0794e-01,
          6.5102e-01],
        ...,
        [ 6.3983e-09,  8.0969e-09, -3.1205e-09,  1.8033e-01, -9.6237e-02,
         

reward: -3.0674, last reward: -3.6228, gradient norm:  2.244:   0%|                                                                                                                        | 16/19531 [00:07<2:33:58,  2.11it/s]

tensor([[-2.0624e-08,  3.4062e-08,  2.8940e-08, -1.5895e-01, -7.6646e-01,
          3.8334e-01],
        [ 2.2077e-08,  9.3520e-09,  1.2433e-08, -5.2347e-01, -7.8886e-01,
          1.7845e-01],
        [-3.3086e-08,  4.6090e-09,  6.3158e-09, -5.7016e-01, -4.7743e-01,
          1.2928e+00],
        ...,
        [ 2.0852e-08, -7.7214e-09,  1.3355e-08, -1.0182e-01, -1.1239e-02,
         -9.1492e-01],
        [-1.1839e-08,  1.2085e-08,  6.3395e-09,  4.8389e-01, -3.3307e-01,
          4.3989e-01],
        [ 1.1597e-08, -1.0802e-08, -1.0871e-08, -7.9276e-01,  1.2608e+00,
         -1.3247e+00]], device='cuda:0')
tensor([[-1.9815e-08,  3.3306e-08,  2.4785e-08, -1.2497e-01, -7.0978e-01,
          3.2871e-01],
        [ 2.8226e-08,  3.0907e-08, -5.8610e-09, -3.1808e-01, -6.3141e-01,
         -4.3451e-02],
        [-8.3697e-08, -3.4227e-08, -1.7297e-09, -4.4263e-01, -2.9836e-01,
          1.5253e+00],
        ...,
        [ 3.9031e-08, -9.1811e-10,  1.2803e-08,  1.8360e-01,  4.6606e-02,
         

reward: -3.0756, last reward: -3.5762, gradient norm:  2.331:   0%|                                                                                                                        | 17/19531 [00:08<2:32:18,  2.14it/s]

tensor([[ 1.1767e-08, -2.2219e-08,  8.9107e-09, -9.9729e-01,  3.6149e-01,
          7.6025e-02],
        [-1.1811e-08,  3.5546e-08,  6.7526e-09, -3.8164e-03, -4.4646e-01,
          6.0031e-01],
        [ 2.1634e-08, -4.0679e-08, -3.0440e-09, -2.3275e-01,  2.3864e-01,
         -6.4125e-01],
        ...,
        [ 2.8250e-08,  3.0565e-08,  2.9104e-08,  3.2957e-01, -6.2067e-01,
          1.0778e+00],
        [ 1.4502e-08, -9.6842e-09,  1.6932e-08, -4.4859e-01,  1.1215e+00,
         -1.0516e+00],
        [ 1.1797e-08, -4.9838e-09, -1.0342e-08, -7.6244e-01,  9.7970e-01,
          4.7061e-01]], device='cuda:0')
tensor([[ 4.3496e-08, -1.8604e-08, -2.4104e-09, -8.1924e-01,  3.7442e-01,
          2.2914e-01],
        [ 8.1433e-09,  1.3679e-08,  4.7850e-09,  1.6841e-02, -3.2018e-01,
          4.8396e-01],
        [ 1.5438e-08, -2.6961e-08, -2.2202e-09, -2.8569e-01,  3.2336e-01,
         -5.4016e-01],
        ...,
        [ 1.1187e-08, -1.6583e-08,  2.2392e-08,  1.1299e-01, -3.4209e-01,
         

reward: -3.1187, last reward: -3.6035, gradient norm:  2.455:   0%|                                                                                                                        | 18/19531 [00:08<2:30:22,  2.16it/s]

tensor([[ 3.5376e-08, -6.3319e-09,  2.6583e-08, -9.3666e-01, -6.9862e-01,
          4.5187e-01],
        [-6.3273e-09, -1.5813e-08, -1.3573e-08,  1.0919e+00, -3.2829e-01,
         -1.4751e+00],
        [-2.7813e-08,  4.1778e-08,  5.4411e-09,  5.6508e-01, -9.7153e-01,
         -3.7055e-01],
        ...,
        [ 5.1341e-08, -4.8544e-08, -4.9038e-08, -1.7006e+00,  1.3152e+00,
         -7.1345e-01],
        [ 7.3382e-09,  1.4260e-08,  1.9476e-08, -6.9482e-02, -3.4726e-01,
          1.0065e+00],
        [-9.3661e-09, -1.9490e-08,  6.1423e-09,  1.1187e+00,  9.1303e-03,
         -6.4327e-01]], device='cuda:0')
tensor([[ 3.3281e-08, -1.3044e-08,  1.9560e-08, -1.4280e+00, -6.8734e-01,
          7.0622e-01],
        [-4.2186e-09, -1.9844e-08, -1.3019e-08,  1.4180e+00, -2.4967e-01,
         -1.5853e+00],
        [-1.2242e-08,  2.9820e-08,  4.2748e-09,  5.6069e-01, -1.0814e+00,
         -2.7769e-01],
        ...,
        [ 4.7619e-08, -5.1497e-08, -4.5162e-08, -1.3712e+00,  1.3220e+00,
         

reward: -3.1539, last reward: -3.6480, gradient norm:  2.592:   0%|                                                                                                                        | 19/19531 [00:08<2:31:52,  2.14it/s]

tensor([[-1.7570e-08,  1.0674e-08,  8.1142e-09, -3.0575e+00,  2.0596e+00,
          6.5502e-02],
        [ 4.0258e-09, -2.7274e-08, -1.3512e-08, -6.4446e-01,  2.1988e-01,
         -2.1874e-01],
        [ 1.9021e-08, -3.7010e-08,  6.9916e-09, -3.9345e-01, -3.3803e-01,
         -8.7750e-01],
        ...,
        [-1.3751e-08,  2.7975e-08, -5.7713e-09,  4.1351e-02, -1.6016e-01,
         -3.3395e-01],
        [ 1.9539e-08,  2.9704e-08, -2.7429e-08, -3.8903e+00, -6.6941e-01,
         -1.5685e+00],
        [-1.1091e-08, -1.8844e-08,  5.0866e-08,  1.2919e-01, -1.7555e-01,
          4.8156e-02]], device='cuda:0')
tensor([[ 7.4761e-09, -1.3867e-08, -1.9259e-08, -3.0958e+00,  2.0385e+00,
          7.9857e-02],
        [ 2.7769e-09, -1.7603e-08, -1.7417e-08, -7.6424e-01,  3.2730e-01,
         -7.1492e-02],
        [-5.5777e-09, -3.3456e-08,  2.5614e-10, -1.6891e-01, -1.0518e-01,
         -5.8930e-01],
        ...,
        [ 1.6342e-08,  3.3991e-10,  1.4543e-08,  4.2131e-02, -5.0011e-01,
         

reward: -3.1512, last reward: -3.6437, gradient norm:  2.275:   0%|                                                                                                                        | 20/19531 [00:09<2:33:31,  2.12it/s]

tensor([[-2.6467e-08, -3.4143e-08,  7.8867e-09,  5.6826e-02, -4.7859e-01,
         -2.9207e-01],
        [-2.0316e-08, -1.1325e-08,  7.0165e-09, -1.0877e+00, -1.3977e+00,
          4.8354e-01],
        [ 9.9297e-09,  2.2574e-08,  5.8762e-09,  1.7443e-01, -2.9481e-01,
          1.1509e-01],
        ...,
        [-3.8149e-09, -1.8249e-08,  1.0773e-09, -9.8575e-01,  5.5520e-01,
          2.2241e-01],
        [-2.0539e-08,  7.7069e-09, -1.5713e-10,  4.6559e-01, -3.8879e-02,
         -8.8001e-01],
        [ 2.0223e-09, -1.2142e-08, -1.1711e-09,  1.4976e+00, -7.3162e-02,
         -4.6882e-01]], device='cuda:0')
tensor([[-1.4165e-08, -4.5788e-08,  8.2213e-09,  1.1457e-02, -3.1627e-01,
         -1.3276e-01],
        [-9.5958e-09, -3.3382e-09,  1.9794e-08, -1.2134e+00, -1.4218e+00,
          5.2570e-01],
        [ 1.6987e-08,  1.3404e-08,  2.9344e-09,  1.0857e-01, -2.0837e-01,
          2.2309e-01],
        ...,
        [ 3.4510e-08, -8.1668e-10, -1.0008e-08, -5.9618e-01,  4.9707e-01,
         

reward: -3.1489, last reward: -3.6351, gradient norm:  2.684:   0%|▏                                                                                                                       | 21/19531 [00:09<2:32:33,  2.13it/s]

tensor([[ 1.9091e-08, -3.3013e-09,  3.5099e-09, -5.5977e-01,  4.6977e-01,
          1.4491e+00],
        [ 3.4552e-09, -3.0786e-08,  6.9293e-09, -1.3166e-01,  4.2021e-01,
          1.6043e-01],
        [-1.2381e-08,  1.2684e-08, -1.0375e-08,  2.9676e-01,  3.1514e-02,
          4.0413e-01],
        ...,
        [ 1.7509e-08, -5.8190e-08,  9.4789e-09,  1.3619e-01, -5.2549e-01,
         -6.7006e-01],
        [ 3.4763e-08, -1.9926e-08,  7.5230e-08,  1.2494e+00, -4.1524e+00,
         -3.5648e-01],
        [-3.9094e-09,  8.5717e-09,  5.9684e-09, -4.2118e+00, -2.3895e-01,
         -1.4232e+00]], device='cuda:0')
tensor([[-2.8508e-08, -8.8923e-08,  4.5677e-08,  1.5590e+00, -1.8393e+00,
         -3.5048e+00],
        [-1.1689e-09, -6.0175e-08,  4.7063e-09, -1.2690e-01,  3.7674e-01,
          1.1891e-01],
        [-1.0144e-08,  1.6227e-08, -9.5683e-09, -2.5492e-02,  1.4002e-03,
          7.2845e-01],
        ...,
        [-2.9608e-09, -1.5860e-08,  5.9492e-09,  8.8572e-02, -8.2644e-01,
         

reward: -3.1103, last reward: -3.5928, gradient norm:  2.582:   0%|▏                                                                                                                       | 22/19531 [00:10<2:34:57,  2.10it/s]

tensor([[-4.1211e-09,  3.6358e-09,  7.2922e-08, -7.1352e-01, -4.4299e+00,
          1.0791e+00],
        [ 1.5666e-08, -2.7014e-09,  1.8031e-09, -1.3156e+00, -1.8436e-01,
         -2.9552e-01],
        [ 1.9981e-08, -1.7394e-08,  8.1098e-09, -1.3272e-01,  2.5483e-01,
          1.8755e-01],
        ...,
        [-2.4711e-08, -2.0040e-08,  6.3953e-09,  3.5276e-02, -1.7214e-01,
         -1.3553e+00],
        [ 8.4286e-09, -4.7494e-08, -2.3143e-08,  3.0746e-01,  7.7473e-01,
         -9.8397e-01],
        [ 1.3636e-08, -4.2776e-08,  1.1116e-09,  2.1161e+00,  1.2238e+00,
         -3.0345e+00]], device='cuda:0')
tensor([[-3.8458e-08,  1.4085e-08,  9.7144e-08, -8.5531e-01, -4.3987e+00,
          8.8279e-01],
        [ 1.7489e-08, -8.4230e-09, -3.2314e-09, -1.1553e+00, -2.5205e-01,
         -4.4025e-01],
        [ 1.5366e-08, -1.7099e-09, -1.6602e-09, -1.5512e-01,  3.2667e-01,
          2.5970e-01],
        ...,
        [-2.2075e-08,  4.0069e-09, -2.4293e-09,  2.0173e-03, -1.5641e-02,
         

reward: -3.0646, last reward: -3.5649, gradient norm:  2.373:   0%|▏                                                                                                                       | 23/19531 [00:10<2:33:22,  2.12it/s]

tensor([[ 1.7085e-09,  8.1571e-08,  1.9436e-07, -3.4681e-01, -3.8583e+00,
          1.9824e+00],
        [ 5.6078e-08, -6.9936e-08, -2.6130e-08, -2.3858e+00, -1.7335e+00,
         -2.8167e+00],
        [ 7.6342e-08, -1.1304e-07,  5.9615e-08,  2.2675e-01, -2.0782e-01,
         -9.6756e-01],
        ...,
        [-5.8918e-09, -3.9765e-08,  1.4674e-08, -7.9637e-01,  5.8239e-01,
         -7.8582e-02],
        [ 1.4660e-08, -3.3725e-09,  7.3293e-09, -2.7914e-01, -9.6258e-01,
          3.0994e-01],
        [-2.7823e-08, -1.0712e-08,  1.0946e-08, -9.2669e-01, -6.3495e-01,
         -7.8911e-01]], device='cuda:0')
tensor([[-4.6299e-09,  2.9706e-08,  9.2759e-08, -3.9242e-01, -3.8740e+00,
          2.0031e+00],
        [ 6.0776e-08, -7.3608e-08, -2.4870e-08, -2.4662e+00, -2.0370e+00,
         -2.5688e+00],
        [ 8.7900e-08, -1.0367e-07,  6.7990e-08,  3.7691e-01, -2.9431e-01,
         -1.1124e+00],
        ...,
        [-1.8102e-08, -2.6047e-08, -1.9783e-09, -8.4286e-01,  6.8219e-01,
         

reward: -3.2010, last reward: -3.7523, gradient norm:  2.636:   0%|▏                                                                                                                       | 24/19531 [00:11<2:38:33,  2.05it/s]

tensor([[-4.7576e-08,  6.8505e-09,  1.5084e-08,  6.4827e-01, -9.6533e-03,
          7.3968e-01],
        [ 8.5197e-09, -1.8165e-08,  1.2939e-09, -4.0395e-01, -6.4309e-02,
         -8.8423e-01],
        [-1.7210e-09,  3.1338e-08,  6.4434e-09, -9.1534e-01, -5.5119e-01,
         -5.0292e-02],
        ...,
        [ 2.4645e-08,  4.0161e-08,  3.3548e-09, -2.9513e-02, -3.5398e-01,
          1.0682e+00],
        [-6.5814e-08,  1.8760e-08,  1.1895e-08, -8.0228e-01,  5.2751e-01,
         -1.4235e-01],
        [-5.6097e-10, -1.4765e-08, -2.7855e-09, -1.1311e-01,  2.7096e-01,
          1.1556e-01]], device='cuda:0')
tensor([[-6.4836e-08,  1.7149e-08,  2.3429e-08,  1.1066e+00, -2.6637e-01,
          1.0841e+00],
        [-2.3835e-09, -2.5025e-08, -2.9504e-09, -1.2660e-01, -3.8228e-02,
         -1.1654e+00],
        [ 6.6497e-08,  7.8709e-08,  5.5373e-09, -7.3737e-01, -3.1956e-01,
          1.9902e-01],
        ...,
        [ 2.1994e-08,  3.9370e-09, -2.2228e-09, -3.1895e-02, -2.0281e-01,
         

reward: -3.1930, last reward: -3.6889, gradient norm:  3.275:   0%|▏                                                                                                                       | 25/19531 [00:11<2:36:48,  2.07it/s]

tensor([[-1.2971e-08,  1.1716e-08,  3.0641e-09, -5.9701e-01,  1.8117e-02,
          3.7857e-01],
        [-2.2460e-08,  4.5851e-08, -7.6652e-08, -2.6742e+00,  1.3468e+00,
          1.6145e+00],
        [-6.9188e-09, -4.7998e-08,  3.0745e-08,  1.2048e-01, -2.6020e+00,
         -7.5073e-01],
        ...,
        [ 7.0135e-09,  2.2796e-09, -2.1755e-09,  5.5628e-01, -5.4654e-01,
         -6.5555e-01],
        [-1.7039e-08, -1.2514e-08, -7.8907e-09,  5.3996e-01,  4.9349e-01,
          5.7063e-01],
        [-4.6804e-08,  3.8441e-08,  3.9977e-09, -1.4688e+00, -3.2427e+00,
         -2.3848e+00]], device='cuda:0')
tensor([[-1.2259e-08,  1.3727e-08,  4.2174e-09, -6.6207e-01,  4.1405e-02,
          4.3778e-01],
        [-1.6797e-09,  4.4122e-08, -5.1449e-08, -2.7093e+00,  1.2835e+00,
          1.5360e+00],
        [-8.0939e-09, -5.1443e-08,  4.0313e-08,  1.8853e-01, -2.5603e+00,
         -8.1977e-01],
        ...,
        [ 2.0495e-08,  1.8757e-08,  5.8240e-10,  2.2253e-01, -3.2099e-01,
         

reward: -3.1929, last reward: -3.7262, gradient norm:  2.25:   0%|▏                                                                                                                        | 26/19531 [00:12<2:34:13,  2.11it/s]

tensor([[ 4.5804e-08,  1.9175e-08,  1.5605e-09, -3.3279e-01, -1.1882e-01,
         -1.3047e+00],
        [-8.7307e-10,  8.2281e-09, -8.7212e-08,  7.0887e-01, -3.4810e+00,
          1.6967e+00],
        [-2.2114e-08, -1.6003e-08, -1.8522e-08, -3.1020e-01,  8.4854e-01,
         -1.2061e+00],
        ...,
        [-1.7191e-08,  5.9462e-09,  4.9626e-08, -5.0011e-01, -1.8475e+00,
         -1.2295e-01],
        [-1.7709e-08,  1.9176e-08,  6.9351e-09,  4.5298e-01, -1.0746e-01,
         -9.8102e-01],
        [-5.9747e-08, -3.1483e-08,  5.8113e-08, -1.1500e+00, -1.0954e+00,
         -1.5661e+00]], device='cuda:0')
tensor([[-1.6838e-09,  3.4656e-09, -1.6376e-10,  8.6110e-02,  1.1843e-02,
         -9.0418e-01],
        [ 1.1200e-08,  5.3374e-09, -6.3288e-08,  7.3938e-01, -3.4639e+00,
          1.6726e+00],
        [-1.6566e-08, -1.3957e-08, -1.2277e-08, -3.1266e-01,  6.9474e-01,
         -1.3696e+00],
        ...,
        [ 3.3628e-09,  2.0231e-08,  4.3504e-08, -6.0442e-01, -1.6153e+00,
         

reward: -3.0867, last reward: -3.5937, gradient norm:  2.694:   0%|▏                                                                                                                       | 27/19531 [00:12<2:31:56,  2.14it/s]

tensor([[-5.4843e-09,  1.2044e-08, -1.6118e-08,  2.3180e-01,  6.6598e-01,
         -3.6619e-01],
        [-1.9531e-08,  1.1709e-08,  1.4620e-08,  3.6871e-01, -4.8142e-01,
         -5.8877e-01],
        [ 4.4112e-09, -4.2779e-08,  5.9543e-09, -3.5603e-01,  7.4905e-02,
         -1.4929e+00],
        ...,
        [ 1.3290e-08, -1.0734e-07, -3.1024e-08, -1.0285e+00,  1.4170e-01,
         -1.3418e+00],
        [ 7.9744e-09,  3.3405e-08,  4.3416e-08, -3.2009e-01, -1.2816e+00,
          9.5228e-01],
        [ 6.6975e-09,  2.0734e-08, -6.0476e-09, -2.3937e-01, -3.8193e-01,
         -8.1894e-01]], device='cuda:0')
tensor([[ 4.1651e-09, -1.7905e-08, -1.8105e-08,  2.2343e-01,  8.1376e-01,
         -4.8541e-01],
        [-1.8373e-08,  2.8278e-08,  1.6699e-08,  2.5500e-01, -3.8579e-01,
         -4.5982e-01],
        [ 4.3650e-09, -4.2125e-08,  2.4960e-09, -3.5673e-01,  7.5182e-02,
         -1.4921e+00],
        ...,
        [-8.7761e-08, -5.1506e-08, -8.5869e-09, -5.6342e-01,  1.9295e-01,
         

reward: -3.0847, last reward: -3.6494, gradient norm:  2.303:   0%|▏                                                                                                                       | 28/19531 [00:13<2:32:30,  2.13it/s]

tensor([[ 4.3299e-08, -4.9183e-08, -6.0853e-09,  9.6757e-02,  7.4978e-03,
         -1.4482e+00],
        [ 1.2105e-08,  1.7611e-08, -1.9175e-10, -1.2630e-01,  1.7492e-01,
         -3.7073e-01],
        [-2.3410e-08,  6.8551e-09, -1.6049e-09, -1.2811e+00, -1.1458e+00,
         -5.0657e-01],
        ...,
        [-1.2142e-08, -1.4893e-08, -4.8406e-09, -1.6505e+00, -3.5755e-01,
         -3.8316e-01],
        [-8.3734e-09,  8.3734e-09,  3.0624e-08, -1.1136e+00, -1.0853e+00,
          4.6731e-01],
        [-3.2238e-08,  5.9037e-09, -1.1126e-08,  1.0684e-01, -3.7916e-01,
          4.3428e-01]], device='cuda:0')
tensor([[-2.3265e-08, -4.9626e-08, -8.3652e-09,  2.4569e-01,  1.8268e-03,
         -1.3034e+00],
        [ 1.9622e-08,  1.9039e-08, -9.7064e-10, -1.2851e-01,  1.7883e-01,
         -3.6622e-01],
        [-5.3118e-08, -6.3138e-09, -1.8659e-08, -1.1614e+00, -7.4403e-01,
         -3.0075e-01],
        ...,
        [ 1.8358e-08, -2.0018e-08, -1.2462e-10, -1.5036e+00, -1.9862e-01,
         

reward: -3.1371, last reward: -3.7002, gradient norm:  2.592:   0%|▏                                                                                                                       | 29/19531 [00:13<2:40:07,  2.03it/s]

tensor([[ 2.6289e-08,  1.7300e-08,  1.1479e-09,  7.0952e-02,  2.8083e-01,
         -6.8066e-01],
        [ 3.1600e-08,  1.0850e-08,  1.3586e-09, -1.7972e-01, -6.6330e-02,
          2.5856e-01],
        [-6.4259e-08, -1.6894e-07,  9.0478e-08, -7.5343e-01, -1.7303e+00,
         -3.5228e+00],
        ...,
        [-1.5949e-08,  3.1715e-08, -1.0025e-08, -1.0694e-01,  1.4123e-01,
          1.2327e+00],
        [ 1.3146e-08,  1.0851e-08, -3.0005e-09,  6.6758e-01, -3.6060e-01,
         -4.3103e-01],
        [ 6.0053e-09, -1.0888e-08, -5.2013e-09,  8.9935e-01,  7.8848e-01,
         -5.5304e-01]], device='cuda:0')
tensor([[ 3.5994e-08,  4.6909e-08,  7.3198e-11,  1.6819e-01,  4.6217e-01,
         -4.7260e-01],
        [ 3.0295e-08,  2.1676e-08, -2.3655e-09,  2.2303e-02,  1.1093e-02,
          4.3407e-02],
        [-4.3171e-08, -1.1863e-07,  5.9781e-08, -8.5272e-01, -1.7661e+00,
         -3.3840e+00],
        ...,
        [ 3.1555e-08,  5.8309e-08, -2.0034e-08, -2.0115e-01,  3.2362e-01,
         

reward: -3.1143, last reward: -3.5912, gradient norm:  2.482:   0%|▏                                                                                                                       | 30/19531 [00:14<2:42:24,  2.00it/s]

tensor([[-3.9336e-08, -5.8130e-08, -4.0606e-08,  2.4560e+00, -3.3835e+00,
         -8.3174e-01],
        [ 2.0794e-07, -3.1414e-07, -4.5603e-08, -2.7571e-02,  2.3845e-02,
         -4.0818e+00],
        [ 5.0254e-09,  2.1858e-08,  7.3552e-09, -3.4531e-02, -3.2696e-02,
          4.4004e-01],
        ...,
        [-1.9782e-08,  3.4082e-08,  2.6574e-09,  3.3477e-04,  5.9105e-01,
         -1.9341e+00],
        [ 7.6422e-09,  2.3623e-08, -1.5427e-09,  5.4972e-01, -4.3289e-01,
         -5.5677e-02],
        [-3.8759e-09, -2.2237e-08, -7.4530e-09, -1.0120e+00, -5.8043e-01,
          3.3842e-01]], device='cuda:0')
tensor([[-3.8932e-08, -5.6923e-08,  1.3554e-08,  2.4490e+00, -3.3902e+00,
         -8.4225e-01],
        [ 7.4066e-08, -2.1496e-07, -1.9887e-08,  2.2547e-01, -2.3107e-01,
         -3.8971e+00],
        [-5.1427e-09,  9.0154e-09,  4.5762e-09, -1.0397e-01, -9.0037e-02,
          5.2806e-01],
        ...,
        [-1.7395e-08,  3.4530e-08,  1.9400e-09,  3.5537e-02,  6.9638e-01,
         

reward: -3.1345, last reward: -3.7301, gradient norm:  2.546:   0%|▏                                                                                                                       | 31/19531 [00:14<2:37:41,  2.06it/s]

tensor([[-3.4121e-08, -6.1314e-09, -9.5042e-09, -2.2363e-01, -4.8052e-01,
          4.7608e-01],
        [ 3.9739e-08, -9.5849e-08,  4.0758e-08,  1.7454e+00, -2.0160e-01,
         -3.1830e+00],
        [-6.9688e-10, -1.4868e-08,  2.8890e-09,  1.6431e-02,  4.7290e-02,
          8.4650e-02],
        ...,
        [ 1.1336e-08,  1.1798e-08,  5.7627e-09, -3.1100e-02, -5.7959e-02,
          7.9006e-01],
        [-6.9998e-08, -3.0425e-08, -4.5059e-08, -2.7712e+00, -3.4299e+00,
         -3.8404e-01],
        [ 1.2402e-08,  8.9219e-09, -5.7451e-09, -5.9678e-01, -4.3107e-01,
         -5.1905e-01]], device='cuda:0')
tensor([[ 1.0300e-08,  1.2535e-08, -1.3953e-09, -1.6708e-01, -2.7983e-01,
          6.8680e-01],
        [-5.5885e-08, -1.2116e-07, -1.4788e-08,  1.5816e+00, -2.4120e-01,
         -3.2028e+00],
        [ 5.2078e-09, -1.7856e-08,  9.4482e-09, -2.0828e-01, -3.3603e-01,
          5.1880e-01],
        ...,
        [ 1.3302e-08,  1.1903e-08, -1.4007e-09,  6.5046e-02,  1.6665e-01,
         

reward: -3.1284, last reward: -3.6228, gradient norm:  2.814:   0%|▏                                                                                                                       | 32/19531 [00:15<2:42:24,  2.00it/s]

tensor([[ 9.1220e-09, -3.3085e-08,  2.1891e-09, -5.2225e-01,  1.4834e-01,
         -1.0684e+00],
        [-4.0779e-08, -2.1004e-08, -1.1907e-08,  8.8605e-01,  3.2767e-01,
         -8.3005e-01],
        [ 4.3917e-08,  5.7208e-09,  7.9835e-09,  3.8010e-01,  2.6914e-01,
         -1.6692e+00],
        ...,
        [ 3.5268e-09,  4.4060e-08, -1.5695e-08,  5.7806e-01, -2.9753e-02,
          1.2474e+00],
        [-2.4615e-08,  4.9076e-09, -2.1456e-09, -2.0798e-01, -4.4315e-01,
         -2.7275e-01],
        [-4.1383e-08, -2.8982e-08, -1.6007e-08,  2.4662e-01,  1.0200e+00,
         -1.5116e+00]], device='cuda:0')
tensor([[ 1.3664e-08, -2.1038e-08,  4.5324e-09, -3.9392e-01,  8.2836e-02,
         -1.2173e+00],
        [-1.4040e-08, -6.1586e-09, -9.2398e-09,  1.3391e+00,  1.9007e-01,
         -3.9566e-01],
        [ 3.8901e-08,  1.2755e-09,  2.4634e-09,  7.6225e-01,  3.5109e-01,
         -1.2545e+00],
        ...,
        [-1.0958e-07, -4.9426e-08,  5.0274e-08, -2.3893e+00,  5.4264e-01,
         

reward: -3.1929, last reward: -3.6996, gradient norm:  2.698:   0%|▏                                                                                                                       | 33/19531 [00:15<2:48:26,  1.93it/s]

tensor([[ 6.7564e-09, -5.1135e-08,  3.1365e-08,  3.4329e-01, -1.1593e+00,
         -1.0089e+00],
        [ 1.5350e-08, -3.1321e-08, -5.0975e-09,  5.0980e-02,  3.4089e-02,
         -9.0838e-01],
        [ 1.3434e-08, -2.6127e-08,  3.5760e-08, -1.2354e+00, -9.7961e-01,
         -3.9233e-02],
        ...,
        [ 1.0265e-08, -1.3241e-08,  1.4445e-08,  2.2812e-02, -5.1613e-01,
         -8.5345e-01],
        [-3.3380e-08,  5.2422e-08,  1.0119e-08, -3.2332e+00, -1.5777e+00,
          1.9664e+00],
        [-3.1121e-08,  4.0926e-08,  3.6740e-09, -4.7496e-01,  1.1292e+00,
         -1.6231e+00]], device='cuda:0')
tensor([[ 1.0683e-08, -4.7923e-08,  2.7666e-08,  2.3541e-01, -1.0365e+00,
         -9.2133e-01],
        [ 3.7999e-08, -1.4461e-08, -9.7389e-09,  2.4791e-01,  1.2966e-01,
         -6.8925e-01],
        [-1.5044e-08, -2.7993e-08,  2.2046e-08, -1.1355e+00, -6.8867e-01,
          1.7011e-01],
        ...,
        [ 6.1075e-09,  1.4917e-09,  1.4404e-08, -1.3803e-02, -3.3704e-01,
         

reward: -3.1658, last reward: -3.7014, gradient norm:  2.613:   0%|▏                                                                                                                       | 34/19531 [00:16<2:51:47,  1.89it/s]

tensor([[ 7.5384e-09,  2.0928e-09,  1.7689e-08, -7.1037e-03, -3.2147e-02,
         -5.7138e-01],
        [-1.6095e-08, -6.6347e-09,  2.0536e-09, -1.0911e+00, -1.3114e-01,
          3.9446e-01],
        [-3.9740e-13,  3.4600e-12,  4.6160e-09,  1.4943e-03,  2.0693e-04,
          1.3946e-01],
        ...,
        [-3.7596e-09, -1.3165e-08,  1.2679e-08,  3.0069e-01,  6.8525e-01,
         -3.3686e-02],
        [-2.7273e-08, -2.0902e-09, -1.1416e-08,  4.2945e-01,  3.4403e-01,
          1.1987e+00],
        [ 1.9022e-08, -2.2020e-08, -6.2960e-09,  1.1584e+00,  5.3652e-02,
         -3.8626e-01]], device='cuda:0')
tensor([[ 6.0300e-09,  4.4096e-08,  1.8333e-08, -1.5556e-01, -3.6749e-01,
         -2.0281e-01],
        [-1.6037e-08, -7.7895e-09,  3.9371e-09, -1.0916e+00, -1.3094e-01,
          3.9484e-01],
        [ 3.2161e-10,  1.2970e-09,  4.8673e-09, -2.2921e-01, -5.2257e-03,
          3.6734e-01],
        ...,
        [-6.7978e-09, -1.3752e-08,  1.0139e-08,  2.9995e-01,  8.3106e-01,
         

reward: -3.0953, last reward: -3.6701, gradient norm:  2.445:   0%|▏                                                                                                                       | 35/19531 [00:16<2:54:47,  1.86it/s]

tensor([[ 5.2286e-08, -1.6015e-08,  1.1153e-08,  6.2210e-01,  8.3772e-01,
         -1.1830e+00],
        [-1.1179e-08, -1.3901e-08,  3.7129e-08,  2.5898e+00, -2.6053e+00,
         -5.2147e-01],
        [ 1.7760e-08, -1.1006e-08,  1.5463e-08, -2.6545e-02, -1.4744e-01,
          1.1325e+00],
        ...,
        [ 3.3516e-08, -3.4240e-08, -1.0976e-08, -7.4495e-01,  4.9982e-01,
         -8.1783e-01],
        [-7.4273e-08,  7.3776e-09, -1.0108e-08, -8.3702e-01,  2.4628e-01,
          3.0139e-01],
        [-1.7739e-08,  2.3792e-08,  9.1767e-09, -3.0750e-01, -6.4011e-01,
          7.7760e-01]], device='cuda:0')
tensor([[ 6.1089e-08, -5.0740e-08, -7.9161e-09,  7.1676e-01,  8.7622e-01,
         -1.0773e+00],
        [-2.8226e-09, -2.5847e-08,  2.5015e-08,  2.5893e+00, -2.6062e+00,
         -5.2225e-01],
        [-5.4478e-09, -1.6805e-08,  9.6858e-09, -2.1701e-01, -5.5041e-01,
          1.4893e+00],
        ...,
        [ 2.5523e-08, -2.1157e-08, -1.1272e-08, -8.7288e-01,  7.9499e-01,
         

reward: -3.0816, last reward: -3.6214, gradient norm:  2.958:   0%|▏                                                                                                                       | 36/19531 [00:17<2:53:43,  1.87it/s]

tensor([[ 1.6308e-08, -4.5821e-09,  5.9877e-09,  1.8834e-01,  4.5101e-02,
         -4.2646e-01],
        [-1.5601e-08,  1.4158e-08, -6.5936e-09,  1.3955e-02, -7.3984e-03,
         -9.2137e-02],
        [ 3.1184e-09, -1.8189e-08, -1.9460e-08, -5.3312e-01,  1.2907e+00,
          6.9855e-01],
        ...,
        [ 1.4570e-08, -3.6605e-08,  1.2212e-08, -2.2754e-01,  2.1258e-01,
         -4.6344e-01],
        [-3.5547e-08, -5.3224e-08,  4.9369e-08, -1.8554e+00, -1.1008e+00,
         -1.0973e+00],
        [ 3.3741e-08,  1.5926e-08, -1.2718e-09,  2.7799e-01, -4.9169e-02,
          3.9049e-01]], device='cuda:0')
tensor([[ 1.4907e-08,  1.2644e-09,  7.4972e-09, -2.1072e-01, -7.0684e-03,
         -2.4664e-02],
        [ 1.0025e-09,  2.2870e-08, -8.3222e-09, -1.1484e-01,  7.2521e-02,
          5.9357e-02],
        [ 1.4193e-08, -1.3317e-08, -2.8917e-08, -7.3075e-01,  1.3890e+00,
          5.1265e-01],
        ...,
        [-1.5503e-08, -3.0209e-09,  1.3006e-08, -1.0125e-01,  1.1091e-01,
         

reward: -3.1888, last reward: -3.6491, gradient norm:  2.64:   0%|▏                                                                                                                        | 37/19531 [00:18<2:54:54,  1.86it/s]

tensor([[ 2.2328e-08,  1.4471e-08, -1.7394e-08, -3.4609e-02, -5.8809e-02,
          6.5002e-02],
        [ 1.8145e-08,  5.7262e-09, -2.5379e-09, -1.8123e+00,  5.4164e-01,
          4.7466e-01],
        [ 2.4632e-08,  1.7806e-08, -9.0406e-10,  8.0495e-02,  2.8949e-01,
         -2.3539e-01],
        ...,
        [-6.2552e-08, -1.3827e-07,  6.4624e-08, -3.9125e+00, -1.0034e+00,
         -1.7159e+00],
        [-6.5609e-08,  7.2387e-08,  3.2012e-09, -5.0137e-01,  1.0070e-01,
          1.2478e+00],
        [-1.2517e-08,  1.0590e-08, -4.3474e-09,  5.2589e-01, -2.3706e-01,
          7.1515e-01]], device='cuda:0')
tensor([[ 6.6332e-09, -3.1795e-10, -8.2660e-09, -1.4882e-01, -2.0640e-01,
          2.4925e-01],
        [ 1.8226e-08,  1.0485e-08, -3.9148e-09, -1.6751e+00,  3.5090e-01,
          3.8913e-01],
        [ 3.6971e-08, -2.6504e-09, -6.5996e-09, -3.8464e-02, -7.6854e-02,
          1.4837e-01],
        ...,
        [-9.5136e-08, -1.0937e-07,  1.2605e-08, -3.8124e+00, -6.1334e-01,
         

reward: -3.1489, last reward: -3.6289, gradient norm:  2.762:   0%|▏                                                                                                                       | 38/19531 [00:18<2:53:00,  1.88it/s]

tensor([[ 1.0599e-08, -2.8139e-09,  4.5138e-09, -1.7473e-01, -4.8164e-01,
          4.5974e-01],
        [-5.9577e-08, -2.4339e-09, -7.6436e-10,  1.7590e-01,  5.1531e-02,
         -8.2219e-02],
        [ 1.1936e-08,  7.6900e-09,  9.2096e-09, -1.0540e-01,  4.0334e-01,
         -1.1145e+00],
        ...,
        [ 3.6430e-07, -7.4397e-08, -1.6701e-07, -2.0738e+00, -9.2101e-01,
         -3.8578e+00],
        [ 2.2641e-08,  3.2428e-08, -6.3022e-09,  3.9533e-02,  7.5368e-03,
          1.2156e+00],
        [ 1.7267e-08, -4.1211e-08,  1.3770e-08,  3.0251e-01, -7.4368e-01,
         -4.1398e-01]], device='cuda:0')
tensor([[ 9.7808e-09,  1.2375e-08,  1.6436e-09, -4.4657e-01, -7.6823e-01,
          7.7365e-01],
        [-1.5196e-08, -2.3402e-09,  6.4632e-10,  4.3209e-01,  6.8886e-02,
          1.6859e-01],
        [ 6.4633e-08, -5.9496e-09,  5.5818e-09, -5.8082e-03,  6.4671e-02,
         -7.9759e-01],
        ...,
        [ 3.4933e-07, -3.8212e-08, -2.4203e-07, -2.2540e+00, -8.0114e-01,
         

reward: -3.1416, last reward: -3.6614, gradient norm:  2.815:   0%|▏                                                                                                                       | 39/19531 [00:19<2:58:19,  1.82it/s]

tensor([[ 1.6222e-08, -3.5688e-08,  8.1915e-09, -6.1561e-01, -4.0651e-01,
         -1.7247e+00],
        [ 3.8040e-08,  2.0501e-07,  1.8898e-08, -3.3261e+00,  2.8347e+00,
         -5.6369e-01],
        [-1.3991e-08,  1.2273e-08, -7.5867e-09, -1.2036e+00,  3.3562e-01,
         -1.6736e-01],
        ...,
        [ 2.0478e-08,  3.7773e-08,  5.9785e-09,  6.0810e-01, -3.8935e-01,
          9.1508e-01],
        [-1.5438e-08, -1.6094e-08,  8.8715e-09, -1.7040e-01,  1.5813e-01,
         -3.8910e-02],
        [ 2.0524e-08,  1.4097e-08, -3.3903e-09, -3.4400e-01,  9.1265e-03,
          6.4405e-01]], device='cuda:0')
tensor([[-4.2457e-09, -1.9837e-08, -1.4674e-09, -2.9782e-01, -1.2928e-01,
         -1.4352e+00],
        [ 3.3044e-08,  1.9260e-07,  4.9212e-08, -3.4589e+00,  2.4543e+00,
         -2.2089e-01],
        [-1.4826e-08,  1.3864e-08, -2.9416e-09, -1.2059e+00,  3.3825e-01,
         -1.6461e-01],
        ...,
        [ 2.9494e-08,  3.5557e-08,  1.0329e-08,  8.1069e-01, -7.3011e-01,
         

reward: -3.1357, last reward: -3.6244, gradient norm:  3.118:   0%|▏                                                                                                                       | 40/19531 [00:19<2:57:35,  1.83it/s]

tensor([[ 1.4584e-08, -1.1949e-08,  2.9656e-09, -2.6127e-01,  5.1371e-02,
          3.8748e-02],
        [ 5.2433e-08, -3.2100e-08,  4.1706e-09,  6.2668e-01,  1.5835e+00,
         -3.7967e+00],
        [ 2.7323e-10, -1.0295e-09, -1.5797e-09,  1.1394e+00,  3.3027e-01,
          1.3316e-01],
        ...,
        [-2.5313e-08,  1.8294e-08,  4.6684e-09,  5.3558e-01, -5.0440e-01,
          7.0978e-01],
        [ 7.7372e-08,  3.4717e-08, -1.6911e-08, -8.6422e-01, -5.1948e-02,
         -4.3723e+00],
        [-1.3713e-08,  8.5466e-09, -4.2790e-09, -7.5674e-01,  1.0464e-01,
          3.1697e-01]], device='cuda:0')
tensor([[ 4.7289e-08,  1.4144e-08, -2.6564e-09, -2.4575e-02,  7.9914e-03,
          2.7947e-01],
        [ 6.0635e-08, -2.9185e-08,  2.4714e-09,  5.0142e-01,  1.4850e+00,
         -3.9796e+00],
        [ 8.6734e-10,  6.2234e-09,  4.7570e-09,  8.2806e-01,  8.4206e-02,
          4.8333e-01],
        ...,
        [ 7.2647e-09,  3.9113e-08,  9.0345e-09,  2.0465e-01, -2.8540e-01,
         

reward: -3.1155, last reward: -3.6212, gradient norm:  2.311:   0%|▎                                                                                                                       | 41/19531 [00:20<2:54:39,  1.86it/s]

tensor([[ 1.6437e-08, -2.5453e-08,  2.5790e-08, -1.6921e+00, -9.2063e-01,
         -1.1936e-01],
        [-3.3131e-09, -6.3563e-08,  8.1222e-09,  4.3227e-01,  6.2002e-02,
          1.1366e+00],
        [ 2.0706e-09, -3.2869e-09, -1.5539e-08, -1.2864e+00,  5.0508e-01,
          2.3262e-01],
        ...,
        [-3.5963e-08, -5.6243e-08,  1.5256e-08, -3.1750e-01, -1.3348e+00,
         -2.4641e+00],
        [ 2.7424e-08,  6.9553e-09, -3.8803e-08, -4.6672e-01,  1.1594e+00,
          7.8100e-01],
        [ 9.2278e-09,  1.5251e-09,  2.1924e-08, -5.3349e-01, -8.3191e-01,
         -3.1197e-02]], device='cuda:0')
tensor([[-3.4959e-08, -4.6451e-08,  3.6934e-08, -1.4940e+00, -5.0005e-01,
          1.3531e-01],
        [ 2.8724e-08, -5.2707e-08,  1.3886e-08,  7.3152e-01,  8.8259e-03,
          1.3569e+00],
        [ 5.1739e-09,  4.6170e-10, -2.6543e-08, -1.3230e+00,  5.6033e-01,
          2.7246e-01],
        ...,
        [ 2.3185e-08, -2.1435e-08,  1.9990e-08, -4.9472e-01, -1.4612e+00,
         

reward: -3.1991, last reward: -3.7397, gradient norm:  2.506:   0%|▎                                                                                                                       | 42/19531 [00:20<2:46:25,  1.95it/s]

tensor([[-1.4193e-08,  3.4898e-09, -1.3173e-08, -3.4992e-01,  6.3560e-01,
          1.1871e+00],
        [-9.1400e-09,  8.4912e-09,  3.9178e-08,  3.8011e-01, -1.5284e+00,
         -1.0890e+00],
        [ 3.9371e-08,  4.3656e-08,  4.0769e-09, -4.4507e-01,  3.3734e-01,
         -1.6454e+00],
        ...,
        [ 1.7400e-08,  2.4369e-08, -1.2179e-08,  1.1346e+00,  7.5485e-01,
          5.6823e-01],
        [ 5.8919e-09,  4.6160e-08, -2.7182e-08,  1.4438e-03,  9.9544e-01,
          3.5452e-01],
        [ 1.7363e-08, -1.5230e-09, -1.0895e-08,  1.2987e-02, -1.5648e-02,
          1.0951e+00]], device='cuda:0')
tensor([[-2.5015e-08,  4.3452e-08, -1.3508e-08, -2.4233e-01,  5.2755e-01,
          1.3477e+00],
        [-3.7164e-08,  4.9497e-10,  1.2891e-08,  2.1984e-01, -1.7742e+00,
         -8.1934e-01],
        [ 6.5175e-08, -2.9267e-08,  1.7191e-09, -2.4933e-01,  2.2882e-01,
         -1.4829e+00],
        ...,
        [ 1.8216e-08,  2.0826e-08, -1.8668e-08,  1.2228e+00,  1.1284e+00,
         

reward: -3.0330, last reward: -3.6014, gradient norm:  2.557:   0%|▎                                                                                                                       | 43/19531 [00:21<2:42:21,  2.00it/s]

tensor([[-5.0349e-09,  2.0522e-08,  1.2323e-08,  3.9446e-03, -3.1884e-01,
          9.3507e-01],
        [ 3.5372e-08, -2.5304e-08, -2.2639e-08, -5.7688e-01,  1.1743e+00,
          2.8428e-01],
        [ 1.1442e-07, -1.1265e-08,  9.7994e-08,  8.4469e-01, -2.5613e-01,
         -4.7196e-01],
        ...,
        [ 8.5491e-09, -7.0751e-09,  5.3844e-09, -6.0663e-01, -1.9696e-01,
         -1.8763e+00],
        [ 1.6368e-08, -2.4666e-08, -1.6470e-08, -3.1181e-01,  1.5392e-01,
         -1.0495e+00],
        [ 1.0605e-08,  1.2031e-08, -3.8480e-09, -8.5316e-02, -1.0426e-01,
          9.8877e-01]], device='cuda:0')
tensor([[-2.6618e-08,  1.4212e-08,  8.3170e-09, -2.8248e-04, -3.4695e-01,
          9.5979e-01],
        [ 1.5918e-08,  3.4871e-09, -8.1375e-09, -6.7163e-01,  1.2242e+00,
          2.0195e-01],
        [ 1.1001e-07, -3.2578e-08,  9.5542e-08,  6.0918e-01, -2.5866e-01,
         -2.7987e-01],
        ...,
        [ 3.2492e-08, -1.4076e-08, -6.1522e-09, -5.2271e-01, -1.4819e-01,
         

reward: -3.0293, last reward: -3.5808, gradient norm:  2.726:   0%|▎                                                                                                                       | 44/19531 [00:21<2:37:25,  2.06it/s]

tensor([[-1.8459e-08, -4.8227e-08, -5.7666e-09,  6.7121e-01,  1.0510e+00,
          1.2472e+00],
        [-2.8543e-08, -1.0238e-08,  7.3362e-09, -2.5821e+00, -7.8466e-01,
         -4.9423e-01],
        [-4.0174e-08,  1.3638e-08, -3.3926e-08,  1.4816e+00,  8.4566e-03,
          4.2286e-01],
        ...,
        [ 2.4184e-08, -1.5608e-08,  3.7942e-09,  8.5861e-01,  3.2062e-01,
         -1.9965e+00],
        [ 2.5168e-08, -5.6081e-08,  5.0063e-09, -1.1927e-02, -5.2979e-01,
         -5.2765e-01],
        [ 1.2815e-08, -8.4898e-09, -7.8989e-09, -3.3520e-01,  4.4343e-02,
          8.3242e-01]], device='cuda:0')
tensor([[ 3.5135e-09, -5.6938e-08,  6.5714e-08, -1.9294e+00, -2.5302e+00,
         -2.9170e+00],
        [-2.3985e-08, -2.0861e-08,  5.9685e-09, -2.2264e+00, -1.0719e+00,
         -8.5694e-01],
        [-7.0336e-08, -5.7236e-09,  4.1538e-08, -4.1907e+00,  5.0929e-01,
         -1.3702e+00],
        ...,
        [ 7.0096e-10, -1.2399e-08, -2.5138e-09,  3.3475e-01,  4.2729e-02,
         

reward: -3.1446, last reward: -3.5565, gradient norm:  2.733:   0%|▎                                                                                                                       | 45/19531 [00:22<2:34:36,  2.10it/s]

tensor([[ 1.0866e-08, -2.4728e-08,  1.1823e-08, -4.3359e-01, -4.4788e-01,
         -6.1329e-01],
        [ 2.8987e-08,  7.1944e-09, -2.2699e-08,  7.3242e-01,  6.3858e-02,
         -3.1887e-01],
        [ 5.9854e-08,  4.0254e-08,  7.1336e-09, -3.2724e-01, -1.4951e-01,
          8.2072e-01],
        ...,
        [-1.5594e-08,  2.2938e-09, -3.4417e-09,  5.9426e-01,  2.0894e-01,
         -3.7152e-01],
        [-2.5345e-08, -3.1623e-08, -9.0578e-09,  3.3043e-02,  5.1238e-01,
         -1.3593e+00],
        [ 1.0655e-08, -2.7074e-08, -1.7884e-08, -1.0320e-01,  6.6767e-01,
         -1.5410e+00]], device='cuda:0')
tensor([[-3.8307e-08, -2.2530e-08,  1.8184e-08, -4.0349e-01, -3.9363e-01,
         -5.6127e-01],
        [ 1.0519e-07,  4.5880e-08, -7.2835e-09,  1.0774e+00, -9.4886e-02,
          1.4104e-02],
        [ 4.7821e-08,  3.5514e-08, -1.1480e-09, -1.1863e-01, -3.8422e-02,
          1.0588e+00],
        ...,
        [-2.8645e-08, -1.1533e-08, -5.7097e-09,  4.2340e-01,  1.0486e-01,
         

reward: -3.1600, last reward: -3.6510, gradient norm:  2.94:   0%|▎                                                                                                                        | 46/19531 [00:22<2:35:11,  2.09it/s]

tensor([[ 2.1595e-08, -1.0335e-09,  2.2123e-08, -6.2585e-02, -4.5627e+00,
         -1.2669e-01],
        [ 1.9892e-08, -1.2189e-08,  1.6757e-08,  7.8243e-01, -1.0699e+00,
          1.0585e+00],
        [-4.1614e-08,  5.9743e-09,  1.5782e-11, -3.1631e-01, -1.2607e-01,
          1.0368e+00],
        ...,
        [ 6.1066e-09, -1.9256e-08,  6.8249e-10, -1.3169e-02, -2.9004e-02,
          8.0649e-01],
        [ 1.2513e-07, -4.4131e-09,  7.7669e-09, -5.4314e-01,  6.9433e-02,
         -5.1476e-01],
        [ 1.1970e-08,  3.1395e-09,  6.4830e-08,  1.4225e+00, -1.5171e+00,
         -3.1256e-01]], device='cuda:0')
tensor([[ 2.4072e-08,  2.1722e-08,  4.1664e-08, -4.9078e-01, -4.3078e+00,
         -4.0052e-01],
        [ 3.9250e-08, -1.6291e-08,  2.2996e-08,  7.6987e-01, -9.8148e-01,
          1.0188e+00],
        [ 2.6464e-08,  6.0653e-08, -1.5929e-09, -5.9946e-02, -1.5009e-02,
          1.3138e+00],
        ...,
        [-6.2215e-09, -3.5427e-08, -7.9936e-09,  2.6835e-01,  3.5337e-01,
         

reward: -3.1442, last reward: -3.6866, gradient norm:  2.696:   0%|▎                                                                                                                       | 47/19531 [00:22<2:34:05,  2.11it/s]

tensor([[-1.9094e-08, -4.5654e-09, -1.8943e-08, -9.4603e-03,  1.1406e-01,
         -4.6168e-01],
        [-2.6494e-08,  2.1706e-08, -2.8131e-09,  1.4145e-01,  1.0587e-01,
         -8.4035e-01],
        [ 1.9765e-09, -2.2275e-08,  3.0931e-09,  1.2363e-02,  5.6202e-03,
         -1.7766e-01],
        ...,
        [-1.5495e-08, -2.8180e-08, -5.4638e-09, -4.1379e-01,  5.3392e-01,
         -2.1130e+00],
        [ 1.4585e-08,  3.0547e-08,  6.8094e-09, -1.8813e-01, -1.4522e-01,
          1.4691e+00],
        [-2.1413e-11, -1.1730e-10, -1.1754e-09,  1.9441e-01, -5.2851e-02,
         -5.0791e-01]], device='cuda:0')
tensor([[-1.4045e-08,  1.4625e-08, -1.5857e-08, -2.7851e-02, -2.6226e-01,
         -8.4587e-02],
        [ 7.4250e-08,  1.8494e-08, -4.4263e-09,  3.3005e-01,  1.9748e-01,
         -6.2805e-01],
        [-6.1301e-08, -2.6489e-08, -2.5282e-09,  3.0982e-01,  8.6644e-02,
          1.2914e-01],
        ...,
        [-1.7777e-08, -3.3662e-08, -2.8421e-09, -4.0071e-01,  5.0051e-01,
         

reward: -3.1436, last reward: -3.6149, gradient norm:  2.507:   0%|▎                                                                                                                       | 48/19531 [00:23<2:31:58,  2.14it/s]

tensor([[ 1.4206e-08,  3.2351e-08,  1.9502e-08,  2.4736e-01, -4.6493e-01,
          7.9780e-02],
        [ 7.5677e-09, -1.5547e-08,  2.1324e-08,  1.6190e-02, -1.6362e-01,
          2.1890e-01],
        [-2.3694e-08, -1.7199e-08,  2.1106e-08,  2.5786e-02, -8.4023e-01,
         -1.6157e+00],
        ...,
        [-5.5556e-09, -2.0293e-09, -1.3876e-08,  1.1656e-01,  2.6485e-01,
         -3.6427e-01],
        [ 1.1853e-08,  1.9865e-08,  1.0401e-08,  7.4750e-01,  4.4267e-01,
         -2.5605e-02],
        [ 7.0909e-08,  4.6465e-08,  1.3133e-08, -2.9544e-02, -2.4566e-02,
          1.3021e+00]], device='cuda:0')
tensor([[ 5.1311e-09,  1.6328e-08,  1.8248e-08,  1.6515e-01, -3.6449e-01,
          2.0594e-01],
        [ 1.3013e-08, -2.9621e-09,  2.7024e-08, -4.3990e-02, -5.3191e-01,
          5.6803e-01],
        [-2.4310e-08, -1.6450e-08,  2.0750e-08, -6.5301e-03, -7.4421e-01,
         -1.5584e+00],
        ...,
        [-7.2823e-09, -4.9451e-10, -1.3096e-08,  1.4796e-02,  2.4453e-02,
         

reward: -3.1910, last reward: -3.6785, gradient norm:  2.765:   0%|▎                                                                                                                       | 49/19531 [00:23<2:30:13,  2.16it/s]

tensor([[-9.6891e-09,  9.1893e-09,  4.6431e-09,  3.4232e-01,  2.4444e-01,
          1.4220e+00],
        [-1.7845e-08, -4.3844e-08,  4.5590e-08,  1.1341e-01, -1.6565e+00,
         -1.2749e+00],
        [-1.1370e-08, -1.1588e-08,  5.5738e-09, -7.3193e-02, -5.9188e-02,
          1.0944e+00],
        ...,
        [ 1.5419e-08, -3.5177e-08, -1.3127e-09, -1.8481e-01,  9.0368e-02,
         -8.2826e-01],
        [ 3.2176e-08, -6.8609e-08,  4.3524e-08,  1.2089e+00, -2.4363e+00,
         -3.0703e+00],
        [ 3.9906e-09,  7.7023e-08, -2.5862e-08,  6.4441e-03,  4.1209e-01,
          4.6189e-01]], device='cuda:0')
tensor([[ 3.1109e-08, -3.4386e-08,  1.6191e-09, -3.6059e-01, -1.9437e-01,
         -4.5683e+00],
        [-1.9526e-08, -4.2690e-08,  4.3712e-08, -8.6386e-02, -1.3341e+00,
         -1.1656e+00],
        [-1.4087e-08, -1.0057e-08,  5.8767e-09,  3.6817e-02,  3.4226e-02,
          9.5712e-01],
        ...,
        [ 1.2942e-08, -3.5687e-08, -3.8383e-09, -2.9487e-01,  1.6941e-01,
         

reward: -3.2041, last reward: -3.6584, gradient norm:  2.853:   0%|▎                                                                                                                       | 50/19531 [00:24<2:29:36,  2.17it/s]

tensor([[-3.1711e-08, -1.3958e-08,  7.5253e-09, -2.1053e-02,  3.1985e-02,
         -8.2658e-01],
        [ 6.1914e-09,  8.0640e-10,  6.1559e-08,  2.0732e+00, -3.7845e+00,
         -7.8821e-01],
        [ 1.2213e-08, -2.4768e-10, -1.9903e-08, -2.8958e-01,  4.4063e-01,
          7.3585e-01],
        ...,
        [-1.1614e-07,  2.6434e-09,  3.9089e-08,  1.4146e+00,  2.3502e-01,
          7.4113e-01],
        [-5.0680e-09, -2.9716e-08,  2.0912e-08,  1.4305e-01,  3.6945e-01,
         -4.1056e-01],
        [ 1.8684e-08,  1.2392e-08,  1.7917e-08,  2.6426e-01, -3.8751e-01,
         -1.0579e+00]], device='cuda:0')
tensor([[-4.5280e-08,  1.8990e-08,  1.1343e-08,  1.0528e-01, -2.2249e-01,
         -5.4364e-01],
        [-1.1868e-08, -4.4944e-09, -2.9524e-08, -9.3769e-01,  1.3561e+00,
          2.1602e-01],
        [ 2.4614e-08,  1.1702e-08, -1.9535e-08, -2.7480e-01,  4.0078e-01,
          6.9988e-01],
        ...,
        [-1.1615e-08,  2.2605e-08, -1.1385e-07, -3.8614e+00, -8.1125e-02,
         

reward: -3.0667, last reward: -3.5995, gradient norm:  2.304:   0%|▎                                                                                                                       | 51/19531 [00:24<2:32:01,  2.14it/s]

tensor([[ 6.3455e-09, -6.6531e-09, -2.2088e-08,  5.7258e-01,  5.4260e-01,
         -1.9064e-02],
        [ 7.4185e-10, -7.3924e-08,  1.5412e-08,  3.4934e-01, -7.1207e-01,
         -1.1080e+00],
        [ 7.1925e-08,  2.3355e-08,  2.4770e-08, -7.9542e-01,  1.5427e+00,
         -2.2087e+00],
        ...,
        [ 1.4954e-09, -7.2143e-09,  2.7619e-09,  1.1031e-01, -3.4218e-02,
         -3.9222e-01],
        [-2.4977e-08,  4.4727e-08, -3.2575e-09, -2.7556e+00, -1.7985e+00,
          2.7327e+00],
        [ 5.0123e-08, -9.6486e-09,  1.8247e-08, -1.4819e+00, -3.7753e-01,
          3.1842e-01]], device='cuda:0')
tensor([[-6.7538e-09,  6.6995e-09, -5.2854e-09,  5.6332e-01,  5.2325e-01,
         -1.2640e-04],
        [ 4.4520e-09, -6.9076e-08,  1.2704e-08,  2.6198e-01, -6.1754e-01,
         -1.0183e+00],
        [ 6.8073e-08, -2.6822e-08,  1.1367e-08, -4.0746e-01,  1.2095e+00,
         -2.1639e+00],
        ...,
        [ 1.3195e-09, -7.3079e-09,  2.0859e-09,  6.8216e-02, -2.2787e-02,
         

reward: -3.0144, last reward: -3.4727, gradient norm:  2.354:   0%|▎                                                                                                                       | 52/19531 [00:25<2:30:07,  2.16it/s]

tensor([[ 2.7900e-08,  3.4870e-08,  1.8318e-08,  6.2312e-01, -6.3533e-02,
         -1.9222e-01],
        [ 6.2296e-08,  2.5496e-09,  6.3632e-09, -7.5141e-01, -4.0950e-01,
         -1.3554e-01],
        [ 2.6659e-08, -4.0440e-08,  6.4472e-09, -1.2015e-01,  6.4604e-01,
          6.1485e-01],
        ...,
        [ 3.7035e-08,  4.0430e-08,  1.8653e-08, -3.6041e-01, -1.1889e+00,
         -2.4275e-02],
        [-1.9407e-08,  1.9078e-08,  9.1021e-09,  1.2351e+00, -8.9313e-01,
         -1.8104e-01],
        [ 5.6011e-08,  2.1988e-08,  1.7290e-08, -7.2395e-01,  4.1848e-01,
         -2.9491e-01]], device='cuda:0')
tensor([[ 2.9600e-08,  2.9754e-08,  1.8209e-08,  5.1399e-01, -7.9940e-02,
         -9.1261e-02],
        [-1.8330e-10, -8.1631e-09,  1.1527e-09, -5.5044e-01, -2.1099e-01,
          1.1791e-01],
        [ 4.3659e-08, -6.2018e-08,  9.4308e-09, -1.1581e-01,  6.3814e-01,
          6.2386e-01],
        ...,
        [ 2.4345e-08, -7.4164e-09,  1.7773e-08, -4.0595e-01, -8.2232e-01,
         

reward: -3.1125, last reward: -3.5933, gradient norm:  2.489:   0%|▎                                                                                                                       | 53/19531 [00:25<2:34:38,  2.10it/s]

tensor([[-1.9687e-08,  3.9412e-08,  2.2008e-08,  1.2281e-01, -2.4882e+00,
         -5.7656e-01],
        [ 1.2819e-08, -1.0137e-08,  1.7255e-08,  1.0013e+00,  1.1244e-01,
          2.0114e-01],
        [-4.8263e-08, -2.4654e-08,  1.2877e-08, -7.5640e-02,  5.8136e-02,
         -6.4726e-01],
        ...,
        [-3.2177e-08, -2.7643e-09, -6.1955e-09,  2.3688e-01, -4.4346e-01,
         -9.5538e-02],
        [ 4.3177e-08,  8.1605e-09,  7.6478e-09,  1.9763e-01,  6.3661e-01,
          3.1590e-01],
        [-3.6743e-09, -7.9120e-09,  1.6334e-08, -8.8541e-02,  2.2900e-01,
          4.5759e-01]], device='cuda:0')
tensor([[-2.4227e-08,  4.0183e-08,  9.3061e-09,  1.8400e-01, -2.5344e+00,
         -5.6471e-01],
        [-2.4406e-09, -4.3638e-09,  2.0771e-08,  1.2829e+00,  3.6228e-01,
         -9.0474e-02],
        [-3.0188e-08,  2.3329e-09,  1.3598e-08,  1.1656e-01, -1.1623e-01,
         -3.8956e-01],
        ...,
        [-1.2687e-08,  4.5618e-08,  4.1308e-09,  2.6088e-01, -5.7373e-01,
         

reward: -3.1047, last reward: -3.6467, gradient norm:  2.652:   0%|▎                                                                                                                       | 54/19531 [00:26<2:32:48,  2.12it/s]

tensor([[-3.1080e-09,  1.9951e-08,  2.0763e-08, -2.2423e-01, -2.0752e-01,
         -4.6940e-01],
        [-7.0503e-08,  3.6033e-08, -9.4290e-09, -1.3420e+00,  5.4068e-01,
          8.4669e-02],
        [ 5.2238e-09, -1.4330e-08, -1.3490e-09, -3.1401e-01, -7.1002e-02,
         -6.9720e-01],
        ...,
        [ 7.1551e-09, -8.3687e-09,  5.7880e-08,  5.3640e-01, -2.1334e+00,
         -3.0727e-01],
        [-1.6278e-08,  1.9369e-08,  1.3586e-08, -1.6458e+00, -2.3699e-01,
         -1.2689e-01],
        [-4.2153e-09, -8.5884e-09,  4.6261e-10, -1.1381e+00, -9.0178e-02,
         -2.0026e+00]], device='cuda:0')
tensor([[-4.1076e-09,  2.7617e-08,  2.3943e-08, -1.7155e-01, -1.4671e-01,
         -3.9296e-01],
        [-2.4114e-08,  2.2287e-08, -1.2507e-08, -1.1450e+00,  5.7699e-01,
          2.3286e-01],
        [ 6.1410e-09, -1.6631e-08,  2.0625e-09, -3.0518e-01, -7.0406e-02,
         -7.0617e-01],
        ...,
        [ 3.5353e-09, -1.1582e-08,  4.3942e-08,  4.2288e-01, -2.0637e+00,
         

reward: -3.1036, last reward: -3.6403, gradient norm:  2.752:   0%|▎                                                                                                                       | 55/19531 [00:26<2:31:27,  2.14it/s]

tensor([[ 1.1764e-08, -1.6383e-08, -1.3743e-09,  1.7640e-01,  6.7490e-02,
         -4.3697e+00],
        [ 2.5623e-07, -4.7632e-08, -1.7309e-07, -2.0280e+00,  4.0161e-01,
         -4.1081e+00],
        [-5.9849e-09,  5.5106e-09,  1.4722e-09, -9.5873e-02, -4.7572e-01,
          1.4755e+00],
        ...,
        [ 5.6059e-08, -2.1475e-08, -1.1563e-08,  8.2212e-02, -9.9526e-02,
         -3.2669e-01],
        [ 4.4311e-09, -2.7903e-08, -1.0735e-08,  1.3336e+00,  6.9189e-01,
         -1.6194e+00],
        [ 3.9809e-08,  2.6595e-08, -4.4513e-10, -3.8123e-01, -4.5088e-01,
          1.5007e-01]], device='cuda:0')
tensor([[ 5.9662e-09, -8.1838e-09,  3.4357e-09,  6.0382e-02,  2.4737e-02,
         -4.4209e+00],
        [ 1.9713e-07, -3.9793e-08, -1.6572e-07, -2.1977e+00,  7.3616e-01,
         -3.6147e+00],
        [-4.1708e-09, -3.0351e-09,  2.2621e-09,  4.3200e-01,  1.6370e+00,
         -4.2665e+00],
        ...,
        [ 1.4328e-08, -1.4954e-08, -1.4583e-08,  1.6965e-01, -2.4489e-01,
         

reward: -3.1417, last reward: -3.6783, gradient norm:  2.704:   0%|▎                                                                                                                       | 56/19531 [00:27<2:29:31,  2.17it/s]

tensor([[-5.4038e-08,  1.3760e-08, -1.4938e-09,  3.0268e-02, -3.2659e-01,
         -3.7267e+00],
        [ 4.5970e-08,  3.2221e-08, -1.9459e-08,  4.5644e-01,  9.2576e-01,
          4.0998e-01],
        [-8.0932e-10, -2.2573e-08,  1.7772e-08, -4.0993e+00, -1.4650e-01,
         -6.3701e-01],
        ...,
        [ 1.5075e-08, -2.7119e-08, -1.6763e-08, -9.9026e-02,  4.4305e-01,
         -4.3982e-01],
        [ 6.8268e-09, -3.5469e-08, -5.3485e-08,  1.2881e-01,  4.7176e-01,
         -5.5221e-01],
        [ 2.8700e-08, -9.2829e-08,  5.8071e-08, -3.2059e-01, -2.1249e+00,
         -3.5902e+00]], device='cuda:0')
tensor([[ 1.5375e-08,  3.4554e-09,  4.4898e-09, -3.4434e-02, -7.6792e-01,
         -3.3743e+00],
        [ 2.1545e-08,  2.0289e-08, -1.7380e-08,  8.6481e-01,  1.1460e+00,
          6.8436e-01],
        [-2.2716e-09, -1.8152e-08,  9.4899e-09, -3.8637e+00,  2.6204e-01,
         -6.3639e-01],
        ...,
        [-1.8969e-08, -3.2596e-09, -3.2234e-09, -5.8207e-02,  3.4323e-01,
         

reward: -3.1653, last reward: -3.6791, gradient norm:  2.73:   0%|▎                                                                                                                        | 57/19531 [00:27<2:39:30,  2.03it/s]

tensor([[ 5.4439e-08,  2.5689e-08, -5.7944e-09,  1.8671e-02,  3.3946e-02,
          1.5318e-01],
        [ 6.3760e-09, -8.4001e-09, -1.2701e-09, -3.6884e-01,  3.4169e-01,
          3.2334e-01],
        [ 7.6363e-10, -5.6367e-09,  6.9975e-09,  6.4833e-01, -2.5925e-01,
         -1.1200e-02],
        ...,
        [-1.0034e-09,  7.2963e-11, -3.2383e-08,  7.7375e-01, -1.9294e-01,
          4.8173e-01],
        [ 5.4158e-09,  6.1469e-09, -3.1340e-09,  3.0094e-02, -1.5549e-01,
         -1.6766e+00],
        [-2.7303e-08,  5.1928e-08,  3.6449e-09,  1.3718e-02,  4.7977e-01,
          4.3281e-01]], device='cuda:0')
tensor([[ 6.0150e-08,  1.5341e-08, -1.1915e-08,  2.1792e-01,  2.8096e-01,
          4.6200e-01],
        [-2.0412e-08,  2.6121e-08, -1.2761e-08, -5.1755e-01,  6.6961e-01,
          6.2232e-01],
        [ 1.2901e-09,  2.5020e-09,  9.9455e-09,  4.9567e-01, -2.4186e-01,
          1.3200e-01],
        ...,
        [-1.3644e-08,  5.2511e-08, -1.3510e-08,  9.6482e-01, -3.5603e-01,
         

reward: -3.1462, last reward: -3.6664, gradient norm:  2.982:   0%|▎                                                                                                                       | 58/19531 [00:28<2:48:14,  1.93it/s]

tensor([[ 1.8321e-08, -3.9220e-08,  4.4095e-09,  5.3211e-01, -5.5181e-01,
         -2.0226e+00],
        [-3.9360e-08,  5.1176e-08, -1.0858e-08,  3.9876e-01, -7.8810e-01,
         -4.3480e-01],
        [-1.1685e-08, -1.0919e-08, -1.9949e-09,  4.9803e-01,  2.4708e-01,
         -1.1861e+00],
        ...,
        [ 7.7270e-09, -9.8957e-10, -3.5340e-09,  1.7677e-01,  8.0404e-01,
          1.5992e-01],
        [ 7.9855e-10,  3.9259e-08, -3.7496e-08, -6.1946e-01,  8.4714e-01,
          1.1459e+00],
        [-5.8769e-09, -1.9456e-09,  8.1902e-09, -6.2071e-02, -3.8268e-01,
         -4.9007e-01]], device='cuda:0')
tensor([[ 1.9649e-08, -4.8862e-08,  4.6133e-09,  9.1480e-01, -7.1338e-01,
         -2.1807e+00],
        [-4.4001e-08,  5.2541e-08, -1.8755e-08,  2.4256e-01, -6.3036e-01,
         -2.5290e-01],
        [-1.0319e-08, -1.2665e-08, -9.2225e-09,  5.0589e-01,  2.5381e-01,
         -1.1938e+00],
        ...,
        [-9.3004e-09, -1.4131e-08,  5.5661e-09,  1.5302e-01,  9.0895e-01,
         

reward: -3.1146, last reward: -3.6136, gradient norm:  2.569:   0%|▎                                                                                                                       | 59/19531 [00:28<2:52:44,  1.88it/s]

tensor([[ 2.1438e-09, -2.4425e-10,  8.4672e-09, -1.0308e-01, -4.8598e-01,
         -5.6963e-01],
        [-4.5495e-08, -2.3453e-08,  5.4888e-08,  1.4168e+00, -4.0414e+00,
         -4.9276e-01],
        [-7.4424e-09, -2.9801e-08,  2.3002e-08, -2.4509e-01,  8.5283e-02,
         -6.7953e-01],
        ...,
        [ 1.0964e-08, -8.0253e-10,  1.4463e-08,  1.2908e+00, -1.8275e+00,
          9.5069e-01],
        [ 6.3431e-09,  1.7301e-08, -2.7206e-08,  4.2965e-01,  8.2244e-01,
         -3.8737e-01],
        [ 2.7577e-08,  2.0775e-08, -9.3877e-10, -1.1936e-02, -1.3169e-01,
          9.2314e-01]], device='cuda:0')
tensor([[ 2.4060e-09, -2.9815e-10,  9.2347e-09, -1.0975e-01, -4.9938e-01,
         -5.5457e-01],
        [ 5.3151e-10,  2.2727e-10, -3.0651e-10,  1.0669e+00, -3.9669e+00,
         -5.7159e-01],
        [ 3.2668e-09, -2.7006e-08,  1.9361e-08, -4.2646e-01,  2.0144e-01,
         -4.6176e-01],
        ...,
        [-2.1453e-08,  4.8850e-08,  5.4883e-08,  1.1513e+00, -1.8239e+00,
         

reward: -3.1522, last reward: -3.6367, gradient norm:  2.805:   0%|▎                                                                                                                       | 60/19531 [00:29<2:53:22,  1.87it/s]

tensor([[-5.6130e-08,  3.2159e-08,  5.1123e-09,  6.8719e-01, -2.2284e-02,
          2.7149e-01],
        [ 2.4356e-08,  4.0248e-08,  9.5191e-10,  1.1317e-01,  3.2534e-01,
          1.0980e+00],
        [-2.9810e-08,  1.4785e-08,  3.4057e-09, -1.2642e+00, -4.2268e-01,
          2.3416e-01],
        ...,
        [ 1.1224e-09, -5.5041e-09,  9.0130e-10,  4.5753e-01,  4.4620e-02,
         -3.9987e-01],
        [-2.5105e-09,  1.4137e-08, -5.3588e-08, -3.3121e+00,  1.2218e+00,
          1.1931e+00],
        [ 3.8071e-08, -1.5475e-08,  2.4278e-09, -3.8255e-01,  3.0369e-01,
         -9.5670e-01]], device='cuda:0')
tensor([[ 2.6722e-08,  3.1418e-08,  3.9775e-09,  1.0863e+00, -2.6411e-01,
          6.2258e-01],
        [ 2.4351e-08,  4.2615e-08,  1.0340e-09,  3.1403e-02,  6.2423e-02,
          1.3705e+00],
        [-4.5954e-08, -1.5261e-08,  1.7511e-08, -1.0840e+00, -2.2691e-01,
          4.5635e-01],
        ...,
        [ 2.7197e-09, -7.5986e-09, -1.7872e-09,  7.4570e-01,  1.8703e-01,
         

reward: -3.1969, last reward: -3.6539, gradient norm:  2.943:   0%|▎                                                                                                                       | 61/19531 [00:29<2:47:46,  1.93it/s]

tensor([[-9.4139e-08, -9.0926e-09, -3.2622e-08,  8.2351e-01,  9.2593e-01,
         -6.6215e-01],
        [-2.1362e-08, -5.0622e-08,  1.1446e-08, -7.2104e-01, -4.2039e-01,
         -1.2538e+00],
        [-3.1742e-08, -1.3379e-09,  1.5923e-08, -7.2527e-01,  1.8232e-01,
         -1.7027e-01],
        ...,
        [ 1.2066e-08, -2.8137e-08, -2.7753e-09,  6.2095e-03, -1.5257e-01,
         -8.2259e-01],
        [-2.4013e-08, -2.8743e-08,  1.6302e-08, -2.1004e+00,  6.9631e-01,
         -1.6146e+00],
        [-1.1075e-08,  6.8505e-09, -1.5156e-08, -1.7975e-01,  4.2436e-01,
          8.0507e-01]], device='cuda:0')
tensor([[-7.5790e-09,  5.5734e-09, -7.5231e-08,  1.2427e+00,  9.7440e-01,
         -3.1491e-01],
        [-3.3175e-08, -6.8030e-08,  4.0341e-09, -4.7291e-01, -1.9055e-01,
         -1.0142e+00],
        [ 4.6310e-08,  2.8923e-08, -5.9617e-09, -5.0889e-01,  1.8516e-01,
          2.6051e-02],
        ...,
        [ 1.8350e-09, -1.5593e-08,  4.3913e-11, -4.5914e-04, -4.9651e-02,
         

reward: -3.1093, last reward: -3.6452, gradient norm:  2.641:   0%|▍                                                                                                                       | 62/19531 [00:30<2:48:45,  1.92it/s]

tensor([[-3.0552e-08,  5.8875e-09, -1.1192e-10,  6.2230e-01, -4.1213e-01,
         -4.3323e-02],
        [-2.2104e-08,  5.0055e-08, -2.5157e-08, -5.7430e-01,  2.8444e-01,
          1.2664e+00],
        [ 1.5091e-08, -9.0326e-09, -7.8091e-09, -1.0352e+00, -3.0591e-01,
         -1.0687e-01],
        ...,
        [-2.3648e-08, -3.4873e-08, -7.8005e-09,  2.4449e-01,  6.3768e-01,
         -1.1404e+00],
        [ 3.4077e-08, -1.0002e-10, -1.9918e-08, -7.5010e-01, -4.2223e-01,
          3.0928e-01],
        [-3.8102e-08, -3.2697e-08,  1.1439e-08,  3.6252e-01,  2.5298e-02,
          4.2578e-01]], device='cuda:0')
tensor([[-2.2481e-08,  1.4820e-08,  3.2386e-09,  7.0882e-01, -4.2890e-01,
         -1.1972e-01],
        [ 2.2624e-08,  2.4787e-09, -9.4411e-09, -5.2239e-01,  2.7528e-01,
          1.3218e+00],
        [ 1.4696e-08,  8.0527e-10, -8.7291e-09, -6.7124e-01, -5.6205e-02,
          2.7799e-01],
        ...,
        [-1.1612e-09, -6.2315e-08, -1.4337e-08,  2.0376e-01,  5.8599e-01,
         

reward: -3.1280, last reward: -3.6213, gradient norm:  2.604:   0%|▍                                                                                                                       | 63/19531 [00:30<2:43:03,  1.99it/s]

tensor([[-1.7887e-08, -3.2988e-08, -1.4796e-09, -4.4041e-02,  1.9852e-02,
         -2.1484e+00],
        [ 1.2188e-08, -1.1665e-08,  2.2980e-08, -1.4548e+00, -9.3520e-01,
          1.5135e-01],
        [ 8.6996e-08,  5.5179e-09, -4.8114e-08, -9.2909e-01, -5.3128e-01,
         -4.9625e-01],
        ...,
        [-6.6961e-08,  1.4281e-08, -1.0615e-08,  7.0621e-01,  2.8909e-01,
          1.4676e+00],
        [ 7.3577e-09, -1.3718e-08, -6.6216e-09,  1.5724e-02, -1.2618e-02,
         -9.9200e-01],
        [ 5.8074e-08, -1.5950e-08, -2.2086e-08,  5.9441e-01, -4.4981e+00,
         -7.3025e-01]], device='cuda:0')
tensor([[-4.1021e-08, -4.8769e-08, -5.9116e-09,  1.7905e-01, -6.1165e-02,
         -2.3320e+00],
        [ 8.0994e-09, -1.0824e-08,  1.6937e-08, -1.8879e+00, -8.2596e-01,
          3.2536e-01],
        [ 2.2461e-08, -4.5855e-08, -2.3779e-08, -8.0793e-01, -3.7354e-01,
         -3.5114e-01],
        ...,
        [ 3.2044e-08, -1.8754e-07,  7.2182e-11, -2.3192e+00, -5.8816e-01,
         

reward: -3.1520, last reward: -3.6021, gradient norm:  2.584:   0%|▍                                                                                                                       | 64/19531 [00:31<2:38:50,  2.04it/s]

tensor([[-1.0162e-07, -3.0089e-08, -6.7901e-08, -3.9094e-01,  4.0189e-02,
         -1.8348e+00],
        [-6.0289e-08, -3.3223e-08, -1.2106e-08,  2.6188e-01,  3.8799e-01,
         -5.3543e-01],
        [ 6.8672e-09,  1.4997e-08, -3.6536e-09, -5.6628e-02, -3.6902e-01,
         -2.7060e-03],
        ...,
        [ 1.4661e-08,  1.8216e-07, -2.2623e-08, -3.7193e+00,  2.1507e+00,
         -1.2722e+00],
        [-1.3503e-08,  2.2996e-08, -4.9508e-09,  6.3340e-03,  4.7249e-02,
         -1.7138e+00],
        [-3.8729e-08,  1.6100e-08, -1.5941e-08,  1.0800e+00,  9.4115e-01,
         -1.6883e-02]], device='cuda:0')
tensor([[-9.7688e-08,  8.5694e-09, -6.8552e-08, -7.6585e-01,  2.3709e-01,
         -1.3882e+00],
        [-4.5000e-08, -2.9227e-08, -1.5025e-08,  2.8006e-01,  4.0436e-01,
         -5.1086e-01],
        [ 1.7583e-08,  1.2206e-08, -1.3639e-09, -2.2390e-01, -6.7575e-01,
          3.1855e-01],
        ...,
        [ 9.0483e-09,  1.4749e-07, -4.8927e-08, -3.7362e+00,  2.1270e+00,
         

reward: -3.1903, last reward: -3.6826, gradient norm:  2.73:   0%|▍                                                                                                                        | 65/19531 [00:31<2:35:58,  2.08it/s]

tensor([[-7.8589e-09,  2.0881e-08,  9.5075e-09, -5.0638e-01, -1.0025e+00,
         -1.7355e-01],
        [-1.0905e-08, -3.0235e-08, -1.0734e-08, -4.2681e-02,  3.7453e-01,
         -9.5228e-02],
        [-3.8150e-08, -2.0209e-08,  7.1981e-09, -1.7286e+00, -1.3750e+00,
          1.6984e-01],
        ...,
        [ 1.1585e-08,  2.3755e-08,  7.0612e-09, -3.5631e+00,  2.0135e-01,
          1.7616e+00],
        [ 2.4076e-08, -7.8564e-09, -1.3879e-09, -3.0955e-02,  8.1136e-02,
          1.2774e+00],
        [ 1.5941e-08,  5.1107e-08, -1.9105e-08, -5.2692e-01,  3.9408e-01,
          4.3172e-01]], device='cuda:0')
tensor([[ 1.9935e-08,  2.2022e-08,  2.4140e-08, -6.8159e-01, -1.1009e+00,
         -1.8085e-02],
        [-1.1622e-08, -1.4821e-08, -8.0676e-09, -3.1212e-02,  4.7013e-01,
         -1.9845e-03],
        [-3.0401e-08, -2.8696e-08,  1.5240e-08, -1.1168e+00, -1.3748e+00,
         -9.3115e-02],
        ...,
        [ 6.1000e-09,  1.0580e-08,  3.6520e-09, -3.6001e+00,  3.4991e-01,
         

reward: -3.1750, last reward: -3.7104, gradient norm:  2.931:   0%|▍                                                                                                                       | 66/19531 [00:32<2:36:17,  2.08it/s]

tensor([[ 8.4143e-10, -3.4208e-09,  2.4673e-09,  2.1745e-01, -2.1508e-02,
         -4.4806e-01],
        [-1.9804e-09, -3.0865e-08,  2.8517e-10,  1.2197e-01,  1.0256e+00,
         -3.8527e-01],
        [-8.0214e-08, -2.5943e-08,  6.2241e-08, -1.7276e+00, -4.7159e-01,
         -1.0014e+00],
        ...,
        [-1.4265e-08,  1.1978e-08, -3.2851e-09,  2.3023e-02, -1.1030e-02,
          1.1429e-01],
        [ 9.3828e-09, -2.1488e-08, -3.0181e-09, -2.8590e-01,  2.2225e-01,
         -1.5920e-01],
        [-4.2035e-10,  3.6772e-09, -1.2286e-08, -7.7029e-01,  6.0839e-01,
         -8.5963e-01]], device='cuda:0')
tensor([[-7.6477e-08,  9.0210e-09, -8.3191e-10,  4.0623e-01, -8.0998e-02,
         -2.4977e-01],
        [ 2.8188e-09,  5.6614e-10,  3.0814e-09,  1.7897e-01,  8.0924e-01,
         -2.1169e-01],
        [-1.1714e-07, -1.6602e-08,  5.1604e-08, -1.3952e+00, -1.3819e-01,
         -8.5548e-01],
        ...,
        [-2.5883e-08,  2.1364e-08, -6.8529e-09, -2.1884e-01,  1.4664e-01,
         

reward: -3.2514, last reward: -3.7206, gradient norm:  2.946:   0%|▍                                                                                                                       | 67/19531 [00:32<2:32:57,  2.12it/s]

tensor([[-2.9608e-07,  3.5030e-08,  6.2541e-08, -2.4964e-01,  4.1141e-02,
         -4.6520e+00],
        [ 2.1171e-09, -5.3336e-09,  1.6197e-08,  8.5215e-01, -7.3723e-01,
          6.4570e-01],
        [ 9.6585e-09, -1.5138e-08,  1.6774e-08, -5.1250e-01, -4.9586e-01,
         -7.2705e-01],
        ...,
        [ 2.7305e-08,  5.1971e-08,  5.9224e-08,  1.4054e-01, -1.7317e+00,
          4.1567e-01],
        [ 2.5170e-08,  3.7605e-08,  3.0322e-08, -2.8063e-01, -6.5798e-01,
         -4.1495e-01],
        [ 2.5738e-08, -5.3764e-08,  3.9289e-09,  5.2519e-01,  1.4243e-01,
         -8.8776e-01]], device='cuda:0')
tensor([[-2.8534e-07,  1.1352e-07,  5.0496e-10,  5.8464e-01, -1.9404e-01,
         -4.2926e+00],
        [ 2.3081e-08, -3.7672e-08,  1.7005e-08,  9.2687e-01, -9.8694e-01,
          7.8822e-01],
        [ 4.9258e-09, -2.4262e-09,  1.9031e-08, -6.7744e-01, -5.5420e-01,
         -5.5446e-01],
        ...,
        [ 1.4045e-08,  8.0278e-08,  3.5244e-08, -1.1362e-01, -1.4349e+00,
         

reward: -3.0942, last reward: -3.6241, gradient norm:  2.64:   0%|▍                                                                                                                        | 68/19531 [00:33<2:33:44,  2.11it/s]

tensor([[ 4.6584e-08,  2.4204e-08, -3.2439e-08,  6.7258e-01,  7.5559e-01,
         -3.3962e-01],
        [-8.5589e-08, -9.6919e-08, -7.8414e-08, -1.6766e-01, -4.0978e-01,
         -1.1690e+00],
        [ 1.4970e-08,  4.5241e-08, -1.8765e-08, -1.4662e-01, -1.0369e+00,
         -9.2655e-01],
        ...,
        [-1.3346e-09, -8.9645e-09, -4.6720e-09,  9.6959e-01, -2.1422e-02,
          3.6433e-01],
        [ 1.0861e-08,  4.1034e-09,  9.0485e-09,  4.4743e-01, -1.1932e+00,
         -4.1351e-01],
        [-1.4506e-10, -1.4672e-09,  6.0639e-09,  8.5157e-01,  1.1577e-01,
         -1.0667e-01]], device='cuda:0')
tensor([[ 3.5960e-08,  3.8008e-08, -3.5349e-08,  6.9113e-01,  7.6251e-01,
         -3.2208e-01],
        [-8.6124e-08, -6.4905e-08, -6.7792e-08, -3.8766e-01, -6.4212e-01,
         -8.3753e-01],
        [ 2.0390e-08, -2.4764e-10, -1.2619e-09, -1.7897e-01, -1.0781e+00,
         -8.7449e-01],
        ...,
        [ 4.0126e-09,  1.4301e-08,  1.0117e-08,  1.2387e+00, -1.9595e-01,
         

reward: -3.1974, last reward: -3.6823, gradient norm:  2.301:   0%|▍                                                                                                                       | 69/19531 [00:33<2:31:56,  2.13it/s]

tensor([[ 8.4428e-09, -4.6385e-09, -2.7729e-08,  8.5349e-01, -8.6679e-01,
         -7.3537e-01],
        [ 7.9830e-08, -5.1020e-08,  5.8809e-09, -3.8896e-02, -5.1323e-02,
         -4.3046e+00],
        [-1.5287e-08,  6.4639e-09, -6.5085e-09, -2.2674e-01,  1.3311e-01,
         -2.2710e-01],
        ...,
        [-4.7133e-09,  4.0126e-08, -2.0113e-08,  1.4927e+00, -2.2051e-01,
          4.9739e-01],
        [ 1.1092e-08,  1.2037e-08,  1.1733e-09, -8.9229e-03, -5.3289e-03,
          1.4885e+00],
        [-4.2938e-09, -1.1356e-09, -3.9657e-08,  2.4309e-01,  6.4831e-01,
         -6.8321e-01]], device='cuda:0')
tensor([[ 2.2460e-08,  2.8069e-09, -2.4903e-08,  6.5428e-01, -7.9218e-01,
         -6.0511e-01],
        [ 6.5825e-08, -6.0833e-08,  1.0566e-08,  3.1282e-01,  3.2874e-01,
         -4.0507e+00],
        [-1.3060e-08,  2.4666e-08, -9.3858e-09, -4.0041e-01,  3.0530e-01,
          1.2485e-02],
        ...,
        [-4.2432e-08, -8.1880e-08,  4.9636e-08, -4.1024e+00,  9.9289e-01,
         

reward: -3.0716, last reward: -3.5635, gradient norm:  2.265:   0%|▍                                                                                                                       | 70/19531 [00:34<2:36:50,  2.07it/s]

tensor([[ 9.7974e-09,  2.2201e-08, -4.2108e-08,  8.1013e-01,  5.7215e-01,
          1.2578e+00],
        [-8.1604e-09,  3.1125e-08, -2.4710e-09, -5.5305e-01,  4.6355e-01,
         -6.4513e-01],
        [-1.8980e-08, -2.2484e-09, -5.1270e-09,  9.9047e-01, -6.5140e-01,
         -6.8220e-01],
        ...,
        [-2.2546e-08, -3.8591e-08,  2.2297e-08, -1.1965e+00, -1.0956e+00,
         -1.3863e-01],
        [ 1.4706e-08, -1.8375e-09,  5.7693e-09, -4.2634e-03,  4.2998e-02,
          2.4568e-01],
        [ 6.4777e-09, -1.4880e-08,  3.0834e-09,  2.5370e-01, -3.5244e-01,
         -1.2477e+00]], device='cuda:0')
tensor([[ 3.3641e-08, -1.6058e-07,  3.1432e-08, -2.6154e+00, -1.3861e+00,
         -3.2092e+00],
        [ 2.0636e-08,  2.1775e-08, -7.2704e-09, -3.3665e-01,  3.5349e-01,
         -4.4072e-01],
        [-3.1160e-08, -2.6951e-08, -3.1345e-08,  1.1141e+00, -1.1278e+00,
         -2.7597e-01],
        ...,
        [-1.8064e-08, -1.5712e-08,  2.7498e-08, -1.4797e+00, -1.0871e+00,
         

reward: -3.0318, last reward: -3.4822, gradient norm:  2.803:   0%|▍                                                                                                                       | 71/19531 [00:34<2:38:35,  2.05it/s]

tensor([[-1.8448e-08, -1.5087e-08, -6.6829e-09,  3.6454e-02,  1.1028e-01,
         -7.0452e-01],
        [ 4.6903e-08,  6.9329e-08, -2.1316e-08,  5.1497e-02,  2.5759e-01,
          2.3230e-01],
        [ 2.1320e-08,  1.1813e-08, -4.2555e-08, -2.2504e-01,  1.0654e+00,
          8.5366e-01],
        ...,
        [ 1.3204e-08, -6.1550e-08,  1.2035e-08,  1.0181e-01, -4.9738e-01,
         -8.3422e-01],
        [-2.1681e-09,  1.5346e-08,  1.1275e-09, -2.4562e-01, -4.8375e-01,
          2.2561e-01],
        [-6.5401e-09, -2.0786e-08,  1.9103e-08, -8.5502e-02, -1.3849e-01,
         -8.1700e-01]], device='cuda:0')
tensor([[ 7.3423e-09,  8.4787e-09, -1.1996e-08,  1.3218e-01,  2.9436e-01,
         -4.9540e-01],
        [ 3.7926e-08,  6.6964e-08, -2.2943e-08,  1.7820e-01,  5.1242e-01,
          4.9598e-01],
        [ 1.2817e-08,  6.3506e-09, -3.0640e-08, -3.7364e-01,  1.2066e+00,
          6.5955e-01],
        ...,
        [ 2.9885e-10, -6.3857e-08,  1.5311e-08,  6.6007e-02, -7.2299e-01,
         

reward: -3.1868, last reward: -3.6514, gradient norm:  2.509:   0%|▍                                                                                                                       | 72/19531 [00:35<2:36:41,  2.07it/s]

tensor([[ 4.3930e-08,  5.2220e-08,  2.3150e-09, -8.7666e-01, -5.0500e-01,
         -8.2778e-01],
        [ 6.8171e-09,  1.8094e-09, -4.7963e-09, -6.1459e-02,  1.7115e-01,
         -4.9437e-01],
        [-9.3986e-08, -9.9649e-08,  8.7382e-09, -7.4050e-01,  3.6344e-02,
         -3.7989e+00],
        ...,
        [ 3.0127e-08, -1.4250e-08,  1.3452e-09, -1.0761e-01,  4.5405e-02,
         -4.8642e-02],
        [-2.0415e-08, -1.0278e-07,  5.5558e-08, -1.2217e-02, -1.5098e-01,
          4.0067e-01],
        [-4.1303e-09, -2.3811e-08, -8.2229e-09,  3.8977e-01, -1.4551e+00,
         -1.5276e+00]], device='cuda:0')
tensor([[ 4.5271e-08,  8.7930e-09,  1.3896e-08, -5.7679e-01, -2.0050e-01,
         -5.1075e-01],
        [ 1.9697e-10,  1.4924e-08,  2.4824e-09, -9.1178e-02,  5.0377e-01,
         -1.6074e-01],
        [-1.2229e-08, -6.7193e-08, -2.0966e-09, -1.8200e-01,  3.2995e-02,
         -3.6172e+00],
        ...,
        [ 2.7230e-08, -1.9437e-08, -2.6421e-10,  2.0565e-01, -1.3476e-01,
         

reward: -3.1755, last reward: -3.6728, gradient norm:  2.412:   0%|▍                                                                                                                       | 73/19531 [00:35<2:36:22,  2.07it/s]

tensor([[ 4.6327e-09, -6.5627e-10, -1.6978e-08,  9.7814e-01,  4.9805e-01,
          2.4278e-01],
        [-1.5194e-08, -1.9115e-08,  6.4751e-09,  2.3500e-01, -9.8343e-02,
         -1.1759e+00],
        [-1.0153e-08, -1.3357e-08, -1.5859e-08,  1.4322e-01,  1.3437e+00,
          3.2803e-01],
        ...,
        [-3.1486e-08, -2.7132e-08, -3.5668e-08,  2.3580e+00, -3.1137e-01,
         -3.3000e+00],
        [-2.8884e-08,  8.2003e-08, -7.4957e-09, -6.6786e-02,  2.3951e-01,
          8.7797e-01],
        [ 4.5969e-08,  3.8569e-08, -3.2775e-08,  4.8916e-01,  7.3201e-01,
         -7.9329e-01]], device='cuda:0')
tensor([[ 2.4974e-09,  4.9765e-09, -7.2658e-09,  7.9791e-01,  2.8458e-01,
          4.9565e-01],
        [-9.0275e-09, -1.4576e-08,  5.4913e-09,  1.7425e-01, -7.9014e-02,
         -1.1199e+00],
        [-5.7049e-09, -1.4328e-08, -1.8618e-08,  2.3913e-01,  1.1293e+00,
          5.2732e-01],
        ...,
        [-4.1563e-08, -4.1303e-09, -3.0511e-08,  2.3905e+00, -4.1587e-01,
         

reward: -3.0314, last reward: -3.5351, gradient norm:  2.294:   0%|▍                                                                                                                       | 74/19531 [00:36<2:33:47,  2.11it/s]

tensor([[-7.4318e-08, -9.2531e-09,  3.6148e-08, -7.8171e-01, -3.8519e-01,
         -6.3242e-01],
        [ 3.2144e-09, -7.0891e-10,  5.5250e-09,  1.3564e-01,  1.6357e-01,
          5.9713e-01],
        [-1.2855e-08,  9.2535e-09,  1.7467e-09, -2.4271e-01, -1.0697e-01,
          1.2565e+00],
        ...,
        [ 1.7410e-08,  2.6675e-10,  1.1275e-08, -1.5543e-01, -8.1711e-01,
         -2.1522e-01],
        [ 1.0355e-08,  2.5280e-09,  1.8592e-08,  5.7962e-01,  5.6782e-01,
          3.2958e-01],
        [-3.7455e-09, -7.6079e-09, -2.2656e-08,  1.0281e+00, -2.0118e+00,
         -4.1411e-01]], device='cuda:0')
tensor([[-8.8223e-08, -1.2572e-08,  1.8419e-08, -4.5777e-01, -1.2488e-01,
         -2.8625e-01],
        [-6.8517e-09,  3.1062e-09,  5.6408e-09, -5.3866e-02, -4.9040e-02,
          8.7883e-01],
        [-1.2505e-08,  9.3904e-09,  4.3597e-09, -3.3847e-01, -1.3165e-01,
          1.3381e+00],
        ...,
        [ 4.5111e-08, -2.4101e-09,  8.1839e-09, -1.7739e-01, -5.2408e-01,
         

reward: -3.1352, last reward: -3.6447, gradient norm:  2.647:   0%|▍                                                                                                                       | 75/19531 [00:36<2:33:43,  2.11it/s]

tensor([[-4.4887e-08, -5.2130e-09,  9.7835e-09,  5.8175e-01,  2.9266e-01,
          1.1938e+00],
        [-8.2683e-08, -4.0556e-08,  1.4518e-08, -6.2930e-01, -1.0255e-01,
         -5.8842e-01],
        [ 3.4075e-08,  8.8492e-09,  3.4812e-08,  1.0019e-01, -5.6908e-02,
          4.9985e-01],
        ...,
        [ 4.8470e-08, -6.4667e-09,  1.0003e-08, -4.4675e-01, -4.4558e-01,
          1.1056e-01],
        [-1.9037e-08, -2.7527e-08, -3.0685e-09,  3.2127e-01,  2.7597e-01,
         -1.4677e+00],
        [-6.0172e-09,  1.6286e-08, -4.1358e-09, -3.3801e-02, -1.2031e-01,
          3.4286e-01]], device='cuda:0')
tensor([[-4.6363e-08,  1.7328e-09,  1.4004e-08,  8.5601e-01,  3.1567e-01,
          1.3676e+00],
        [-8.8639e-08, -1.4350e-08,  3.6458e-09, -2.0564e-01,  1.0226e-02,
         -1.8569e-01],
        [-1.9566e-08,  6.2750e-08,  3.7578e-08,  3.4009e-01, -2.6990e-01,
          7.9661e-01],
        ...,
        [-1.2096e-08, -2.0521e-08, -2.8800e-10, -2.5301e-01, -1.8191e-01,
         

reward: -3.1348, last reward: -3.6689, gradient norm:  2.471:   0%|▍                                                                                                                       | 76/19531 [00:37<2:32:55,  2.12it/s]

tensor([[-7.5198e-09, -9.8935e-10, -4.5051e-09,  6.4191e-02, -4.5866e-02,
         -2.2460e-01],
        [ 5.3236e-08,  8.6844e-08, -7.6869e-08,  4.8769e-01,  6.7156e-01,
          1.0102e+00],
        [-1.6978e-08, -1.5588e-09, -1.5715e-09, -1.3889e+00, -2.4261e-01,
          6.9101e-01],
        ...,
        [ 2.0201e-09,  8.8033e-09,  3.3495e-09,  6.2048e-01,  2.7540e-01,
          6.7196e-01],
        [-3.0575e-08, -8.4872e-09, -1.0843e-08,  5.0395e-02,  7.2485e-02,
         -1.0130e+00],
        [ 1.6359e-08, -3.2288e-08,  2.5516e-08, -1.4322e+00,  7.5317e-01,
         -5.9200e-01]], device='cuda:0')
tensor([[-1.9433e-10,  7.5930e-09, -5.9360e-09, -4.7918e-02,  3.9617e-02,
         -8.3810e-02],
        [ 4.4385e-08,  9.6138e-08, -6.8434e-08,  7.6436e-01,  8.1454e-01,
          1.1909e+00],
        [-2.4605e-09,  2.2626e-08, -1.3206e-08, -1.2257e+00, -9.3404e-02,
          8.9665e-01],
        ...,
        [ 8.8772e-10,  8.9009e-09,  1.7066e-09,  7.7797e-01,  4.5852e-01,
         

reward: -3.0941, last reward: -3.6002, gradient norm:  2.394:   0%|▍                                                                                                                       | 77/19531 [00:37<2:35:06,  2.09it/s]

tensor([[-3.4605e-09,  3.7187e-08,  5.0568e-08,  8.9527e-01, -1.7623e+00,
          6.6563e-01],
        [-9.0915e-09,  1.1043e-08, -3.6768e-08,  7.6342e-01,  1.1560e+00,
         -1.5891e-01],
        [-1.6584e-10, -9.6094e-10,  1.9313e-09, -9.1100e-01,  2.0587e-01,
         -3.0337e-01],
        ...,
        [ 1.5595e-09, -7.4875e-08, -8.6778e-08,  2.1896e-01,  1.0796e+00,
         -1.5697e+00],
        [-1.4778e-08,  5.8417e-09,  4.6177e-09,  4.6194e-01, -5.0487e-01,
          1.1640e+00],
        [ 3.1491e-08, -6.1541e-09,  1.3648e-09,  1.0485e+00, -1.0648e-01,
         -3.5339e-01]], device='cuda:0')
tensor([[-7.5719e-09,  3.6247e-08,  3.2106e-08,  5.3400e-01, -1.6239e+00,
          9.9532e-01],
        [ 2.0303e-08,  2.8687e-08, -4.3181e-08,  9.1836e-01,  1.2073e+00,
         -4.5687e-02],
        [ 2.4288e-11, -7.8451e-10,  2.3615e-09, -6.3688e-01,  4.5531e-02,
         -6.0804e-01],
        ...,
        [-3.5443e-09, -9.2261e-08, -8.6646e-08,  2.5077e-01,  6.5366e-01,
         

reward: -3.1436, last reward: -3.6440, gradient norm:  2.66:   0%|▍                                                                                                                        | 78/19531 [00:37<2:35:47,  2.08it/s]

tensor([[-4.5102e-08, -8.0306e-08,  1.9688e-08, -8.7916e-01, -5.3859e-01,
          5.0577e-02],
        [-2.2440e-08, -3.3030e-08,  4.0441e-08, -9.5766e-03, -7.7361e-01,
         -9.2221e-01],
        [-1.7511e-08,  9.5046e-10, -1.6804e-08, -8.8403e-01,  9.1718e-01,
         -3.4528e-01],
        ...,
        [-5.0307e-08, -7.6312e-09, -3.2844e-08, -9.0708e-01, -4.4639e-01,
         -4.8754e-01],
        [-3.8070e-09,  2.0281e-08,  2.1291e-08,  9.9681e-01, -5.4968e-01,
          1.0999e+00],
        [ 4.8739e-08,  5.0531e-08, -1.0480e-07,  1.8107e-01,  3.6487e-01,
         -6.7217e-01]], device='cuda:0')
tensor([[-1.9783e-08, -3.8905e-08,  6.5976e-09, -6.6628e-01, -2.7609e-01,
          3.5317e-01],
        [-1.9621e-08, -2.9902e-08,  2.9401e-08, -1.7151e-02, -7.5057e-01,
         -9.0489e-01],
        [-3.3376e-08,  1.6232e-08, -1.8257e-08, -9.0865e-01,  1.2177e+00,
         -1.1893e-01],
        ...,
        [-5.2426e-08, -2.5018e-09, -4.1533e-08, -1.0236e+00, -4.3811e-01,
         

reward: -3.1413, last reward: -3.6896, gradient norm:  3.309:   0%|▍                                                                                                                       | 79/19531 [00:38<2:35:40,  2.08it/s]

tensor([[ 4.8286e-08, -3.1762e-08, -4.7101e-08,  1.2913e+00, -4.7132e-01,
          2.4581e-01],
        [ 1.5729e-08,  1.0195e-07, -6.0896e-08, -3.4251e-01,  9.7948e-01,
         -2.3154e-01],
        [-5.4590e-08,  1.9303e-08, -4.6171e-11,  7.5231e-01,  5.0587e-02,
          1.2189e+00],
        ...,
        [ 4.0727e-08, -7.0970e-09,  2.1984e-08,  1.1788e-01, -9.3804e-01,
         -3.9647e-01],
        [ 2.9414e-08, -2.2769e-08, -9.6681e-09, -4.1948e-01,  1.0265e+00,
         -4.0367e+00],
        [-4.2625e-08,  3.7106e-08,  1.2305e-08,  1.2340e-01, -1.1403e+00,
         -1.4922e+00]], device='cuda:0')
tensor([[ 5.3060e-08, -1.3243e-08, -3.8340e-08,  1.4323e+00, -7.0126e-01,
          3.9359e-01],
        [ 4.9829e-08,  1.1554e-07, -7.2625e-08, -2.5078e-01,  1.3040e+00,
          3.2299e-02],
        [-3.6473e-08,  1.9215e-08, -2.2217e-09,  1.0304e+00, -5.2748e-02,
          1.3895e+00],
        ...,
        [ 3.9310e-08, -1.1063e-08,  2.0540e-08,  2.0937e-02, -7.4608e-01,
         

reward: -3.1108, last reward: -3.5739, gradient norm:  2.172:   0%|▍                                                                                                                       | 80/19531 [00:38<2:33:04,  2.12it/s]

tensor([[ 2.0162e-08,  1.0186e-08, -3.3619e-08, -4.5799e+00,  1.1006e-01,
         -5.3573e-01],
        [-3.2547e-08,  6.7229e-09,  1.0360e-08, -1.7059e-01,  6.0727e-02,
         -1.1077e+00],
        [ 1.5349e-08, -6.9255e-09,  1.0713e-08, -3.2769e-02, -1.9763e-02,
         -9.3466e-02],
        ...,
        [-1.2511e-08, -3.8146e-08,  7.4593e-09,  3.2006e-01,  1.4094e-01,
         -1.1774e+00],
        [-7.0881e-09,  2.6547e-09,  2.4261e-09,  2.8657e-02, -4.0497e-02,
          7.0509e-01],
        [ 3.1367e-08,  1.4192e-08,  2.0555e-08, -8.2221e-01, -8.6644e-01,
         -1.1437e-01]], device='cuda:0')
tensor([[ 5.4195e-09, -3.0811e-08, -1.0129e-07, -4.3415e+00,  4.7177e-01,
         -7.1084e-01],
        [-1.3758e-08,  3.5860e-08,  1.9293e-09,  1.6894e-01, -9.7453e-02,
         -7.4583e-01],
        [ 3.1920e-09, -2.1278e-08,  1.1420e-08,  1.7487e-01,  1.3942e-01,
         -3.5275e-01],
        ...,
        [-1.7541e-08, -3.9767e-08, -3.6958e-09,  7.5897e-01,  1.5484e-01,
         

reward: -3.0218, last reward: -3.5467, gradient norm:  2.74:   0%|▌                                                                                                                        | 81/19531 [00:39<2:35:29,  2.08it/s]

tensor([[ 2.1830e-08, -6.5349e-09, -5.0496e-09,  7.3919e-02,  1.0234e+00,
         -1.3347e+00],
        [ 1.7733e-09,  4.4188e-09, -1.7337e-08,  3.5771e-01,  1.4816e-01,
          1.6619e-01],
        [-4.1502e-08,  5.5624e-08, -7.4684e-09, -3.2067e-01,  1.9251e-01,
          8.2119e-02],
        ...,
        [ 1.5849e-08,  6.7570e-09,  1.3573e-08, -7.7127e-01, -1.1054e+00,
          6.6073e-01],
        [-7.4527e-08,  2.5216e-08,  2.2707e-08, -1.3398e+00, -6.9493e-01,
         -3.5770e+00],
        [ 4.5261e-09, -1.8627e-08,  6.3969e-11,  3.2638e-01,  2.2325e-01,
         -4.5588e-01]], device='cuda:0')
tensor([[ 2.9145e-08,  2.2066e-08, -1.5546e-08,  1.3001e-01,  1.1039e+00,
         -1.2306e+00],
        [ 2.0232e-09,  3.6540e-09, -1.9162e-08,  4.0729e-01,  1.8311e-01,
          1.0687e-01],
        [ 1.7758e-08,  1.1777e-08, -3.0138e-09, -4.2653e-02,  3.5921e-02,
          4.0151e-01],
        ...,
        [-3.9733e-09,  1.7832e-09,  9.8979e-10, -8.4808e-01, -1.1364e+00,
         

reward: -3.1463, last reward: -3.6880, gradient norm:  2.5:   0%|▌                                                                                                                         | 82/19531 [00:39<2:35:21,  2.09it/s]

tensor([[-4.1930e-08, -3.6826e-08,  5.9779e-09,  2.4283e-01, -7.2786e-02,
         -8.9190e-01],
        [ 9.3884e-09, -6.7815e-08,  8.4991e-09, -2.4727e-01, -1.1491e+00,
         -2.0609e+00],
        [ 2.7053e-08,  1.0643e-08,  1.2592e-08,  1.9882e-01, -8.0039e-01,
          6.3412e-01],
        ...,
        [ 1.1727e-08,  2.8125e-08, -2.3312e-09, -2.6469e-01, -1.3077e-01,
         -2.1637e-01],
        [ 5.1279e-09,  2.9462e-09,  6.9369e-09,  1.5336e-01, -4.8447e-01,
         -1.0705e+00],
        [-4.3306e-08,  2.5567e-08,  4.3599e-08,  1.3667e+00, -1.4204e+00,
         -2.4195e-01]], device='cuda:0')
tensor([[-1.0813e-08, -3.5069e-08,  1.1304e-08,  4.1835e-01, -1.7212e-01,
         -6.8663e-01],
        [ 7.2016e-09, -6.7202e-08,  6.5849e-09, -2.6696e-01, -1.0134e+00,
         -2.0170e+00],
        [ 1.3491e-08,  5.3500e-09,  6.3250e-09,  1.9950e-01, -7.9505e-01,
          6.3010e-01],
        ...,
        [ 2.9367e-08,  1.5773e-08, -1.5312e-09, -1.1983e-01, -4.7303e-02,
         

reward: -3.0945, last reward: -3.6127, gradient norm:  2.484:   0%|▌                                                                                                                       | 83/19531 [00:40<2:35:03,  2.09it/s]

tensor([[ 3.9004e-08, -2.7043e-08, -6.9655e-09, -7.4088e-01, -1.4777e-02,
         -3.6891e+00],
        [ 7.5216e-09,  1.3247e-08, -1.9450e-08,  1.0849e+00, -2.3482e-02,
         -9.4529e-01],
        [ 6.7748e-08, -8.1405e-08,  5.4358e-09, -2.4002e-01, -3.7934e-03,
          1.7754e-01],
        ...,
        [-3.6296e-09,  2.7221e-08, -5.0974e-09,  3.0237e-01,  1.2977e-01,
          9.0517e-01],
        [-1.8177e-08, -2.5787e-08, -7.5224e-09,  1.0644e+00, -5.2519e-01,
         -1.5263e-01],
        [-1.2091e-08, -2.7226e-08,  2.0517e-08, -3.5321e+00, -2.5390e-01,
         -1.7075e+00]], device='cuda:0')
tensor([[ 7.5074e-08, -5.4811e-08, -1.5762e-08, -8.0468e-01,  1.5586e-04,
         -3.6328e+00],
        [-1.1513e-08,  3.4580e-08, -1.9947e-08,  1.3886e+00, -2.7787e-01,
         -5.7558e-01],
        [ 6.0971e-09, -3.1719e-08,  1.0017e-08,  1.7667e-01, -3.4481e-02,
          5.9392e-01],
        ...,
        [-2.8894e-09,  2.7492e-08, -4.6028e-09,  2.5406e-01,  1.0087e-01,
         

reward: -3.1246, last reward: -3.6880, gradient norm:  2.857:   0%|▌                                                                                                                       | 84/19531 [00:40<2:33:00,  2.12it/s]