In [1]:

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset
import wandb

In [3]:

import torch

In [4]:
from sae_lens import SAE

In [5]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fe196b34690>

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

In [7]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_20/width_16k/canonical",
)
sae = sae.to(device)

In [8]:
def select_action(policy_net, observation):
    # Get activated logits from policy_net (B, dict_size)
    logits = policy_net(observation)
    # Add small random noise for diversity
    epsilon = torch.randn_like(logits) * 0.01
    logits_noisy = logits + epsilon
    # Select top-1 value and index from noisy logits
    topk_vals, topk_indices = torch.topk(logits_noisy, k=1, dim=-1)
    # Clamp top values to a minimum of 1 (values below 1 become 1)
    topk_vals = torch.clamp(topk_vals, min=50.0)
    # Create one-hot action vector with the clamped top value at the selected index
    action = torch.zeros_like(logits)
    action.scatter_(1, topk_indices, topk_vals)
    # Compute softmax probabilities from the original logits and get log probability of the chosen index
    probs = torch.softmax(logits, dim=-1)
    chosen_probs = torch.gather(probs, 1, topk_indices)
    log_prob = torch.log(chosen_probs).squeeze(1)
    return action, log_prob

In [9]:
class PolicyNetwork(nn.Module):
    def __init__(self, latent_dim, dict_size):
        super(PolicyNetwork, self).__init__()
        self.linear = nn.Linear(latent_dim, dict_size)
        self.activation = nn.Tanh()  # using Tanh instead of ReLU
    
    def forward(self, obs):
        logits = self.linear(obs.to(torch.float32))
        activated = self.activation(logits)
        return activated

In [10]:
class CriticNetwork(nn.Module):
    def __init__(self, latent_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(latent_dim, latent_dim)
        self.act = nn.Tanh()  # using Tanh here
        self.fc2 = nn.Linear(latent_dim, 1)
        
    def forward(self, obs):
        x = self.act(self.fc1(obs))
        value = self.fc2(x)
        return value

In [11]:
def batch_steering_hook(policy_net, sae):
    class SteeringHook:
        def __init__(self, policy_net, sae):
            self.policy_net = policy_net
            self.sae = sae
            self.observation = None   # will be tensor of shape (B, latent_dim)
            self.action = None        # (B,)
            self.log_prob = None      # (B,)

        def __call__(self, module, inputs):
            residual = inputs[0]  # shape: (B, seq_len, hidden_dim)
            observation = residual[:, -1, :]  # (B, latent_dim)
            self.observation = observation.detach()
            action, log_prob = select_action(self.policy_net, observation)
            steering = sae.decode(action)  # shape: (B, hidden_dim)
            self.action = action.detach()
            self.log_prob = log_prob.detach()
            # Add the corresponding steering vector to the last token.
            residual[:, -1, :] = residual[:, -1, :] + steering
            return (residual)
    return SteeringHook(policy_net, sae)

In [12]:
class PPOTrainer:
    def __init__(self, policy, critic, batch_size=8, ppo_clip=0.2, lr=1e-4):
        self.policy = policy
        self.critic = critic
        self.batch_size = batch_size
        self.ppo_clip = ppo_clip
        self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr)
    
    def compute_advantages(self, rewards, values):
        return rewards - values

    def train_step(self, observations, actions, rewards, old_log_probs):
        # Compute policy outputs.
        mean = self.policy(observations)  # shape: (B, action_dim)
        sigma = torch.ones_like(mean) * 0.1  # fixed sigma
        dist = torch.distributions.Normal(mean, sigma)
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        
        # Compute critic outputs.
        values = self.critic(observations).squeeze(-1)
        
        advantages = rewards - values  # simple advantage
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        critic_loss = nn.MSELoss()(values, rewards)
        
        # Combine losses to perform a single backward pass.
        total_loss = policy_loss + critic_loss
        
        self.optimizer_policy.zero_grad()
        self.optimizer_critic.zero_grad()
        total_loss.backward()
        self.optimizer_policy.step()
        self.optimizer_critic.step()
        
        return policy_loss.item(), critic_loss.item()

In [13]:
class MMLUDataLoader:
    def __init__(self, dataset, split, limit=None):
        self.data = dataset[split]
        if limit is not None:
            self.data = self.data.select(range(limit))
        self.index = 0
        self.n_samples = len(self.data)
        
    def get_batch(self, batch_size):
        batch = []
        for _ in range(batch_size):
            sample = self.data[self.index % self.n_samples]
            self.index += 1
            batch.append({
                "question": sample["question"],
                "choices": sample["choices"],
                "answer": sample["answer"]
            })
        return batch

In [14]:
mmlu_dataset = load_dataset("cais/mmlu", "all")
mmlu_dataset = mmlu_dataset.shuffle(seed=42)
train_loader = MMLUDataLoader(mmlu_dataset, split="auxiliary_train")
val_loader = MMLUDataLoader(mmlu_dataset, split="validation", limit=100)

In [15]:
LATENT_DIM = 2304   # Gemma's latent dimension (from layer 20)
DICT_SIZE = 16384   # SAE dictionary size

In [16]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
llm = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", output_hidden_states=True, device_map="auto", torch_dtype=torch.bfloat16)
llm.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.80it/s]


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [17]:
policy_net = PolicyNetwork(LATENT_DIM, DICT_SIZE).to(device)
critic_net = CriticNetwork(LATENT_DIM).to(device)
ppo_trainer = PPOTrainer(policy_net, critic_net, batch_size=8, ppo_clip=0.2, lr=1e-4)

In [None]:
from tqdm import tqdm
import os

def train(num_steps=20000, validate_every=100, checkpoint_every=10000, checkpoint_dir="./checkpoints"):
    wandb.init(project="gemma_mmlu_ppo")
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_val_accuracy = 0.0
    train_acc_sum = 0.0
    train_acc_count = 0
    for step in tqdm(range(1, num_steps + 1), desc="Training Steps"):
        batch = train_loader.get_batch(ppo_trainer.batch_size)
        prompts, correct_answers = [], []
        for sample in batch:
            question = sample["question"]
            choices = sample["choices"]
            if isinstance(sample["answer"], int):
                correct_answer = chr(65 + sample["answer"])
            else:
                correct_answer = sample["answer"].strip().upper()
            prompt = (question + "\n" +
                      "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                      "\nChoose one of the following options only: A, B, C, or D" +
                      "\nAnswer:")
            prompts.append(prompt)
            correct_answers.append(correct_answer)
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        steering_hook = batch_steering_hook(policy_net, sae)
        hook_handle = llm.model.layers[20].register_forward_pre_hook(steering_hook)
        generated_ids = llm.generate(inputs, max_new_tokens=1)
        hook_handle.remove()
        batch_rewards = []
        for i in range(generated_ids.shape[0]):
            gen_tok = tokenizer.decode(generated_ids[i, -1]).strip()
            predicted_label = gen_tok[0].upper() if gen_tok else ""
            reward = 1 if predicted_label == correct_answers[i] else 0
            batch_rewards.append(reward)
        obs_batch = steering_hook.observation.to(torch.float32)
        action_batch = steering_hook.action.to(torch.float32)
        log_prob_batch = steering_hook.log_prob.to(torch.float32)
        rewards_tensor = torch.tensor(batch_rewards, dtype=torch.float32, device=device)
        nonzero_vals = action_batch[action_batch != 0]
        avg_activation = nonzero_vals.mean().item() if nonzero_vals.numel() > 0 else 0.0
        policy_loss, critic_loss = ppo_trainer.train_step(obs_batch, action_batch, rewards_tensor, log_prob_batch)
        train_accuracy = sum(batch_rewards) / len(batch_rewards)
        train_acc_sum += train_accuracy
        train_acc_count += 1
        wandb.log({"step": step, "policy_loss": policy_loss, "critic_loss": critic_loss, 
                   "train_accuracy": train_accuracy, "avg_activation": avg_activation})
        if step % validate_every == 0:
            used_indices = set()  # clear used_indices each validation step
            val_batch = val_loader.get_batch(val_loader.n_samples)
            val_prompts, correct_answers_val = [], []
            for sample in val_batch:
                question = sample["question"]
                choices = sample["choices"]
                if isinstance(sample["answer"], int):
                    correct_answer = chr(65 + sample["answer"])
                else:
                    correct_answer = sample["answer"].strip().upper()
                prompt = (question + "\n" +
                          "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                          "\nChoose one of the following options only: A, B, C, or D" +
                          "\nAnswer:")
                val_prompts.append(prompt)
                correct_answers_val.append(correct_answer)
            inputs_val = tokenizer(val_prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            steering_hook_val = batch_steering_hook(policy_net, sae)
            hook_handle_val = llm.model.layers[20].register_forward_pre_hook(steering_hook_val)
            generated_ids_val = llm.generate(inputs_val, max_new_tokens=1)
            hook_handle_val.remove()
            total_correct = 0
            val_action_batch = steering_hook_val.action
            topk_vals_val, topk_indices_val = torch.topk(val_action_batch, k=1, dim=-1)
            for idx in topk_indices_val.view(-1).tolist():
                used_indices.add(idx)
            for i in range(generated_ids_val.shape[0]):
                gen_tok = tokenizer.decode(generated_ids_val[i, -1]).strip()
                predicted_label = gen_tok[0].upper() if gen_tok else ""
                if predicted_label == correct_answers_val[i]:
                    total_correct += 1
            val_accuracy = total_correct / len(val_batch)
            avg_train_accuracy = train_acc_sum / train_acc_count if train_acc_count > 0 else 0
            wandb.log({"step": step, "val_accuracy": val_accuracy, "avg_train_accuracy": avg_train_accuracy, 
                       "unique_indices": len(used_indices)})
            print(f"Step {step}: Policy Loss {policy_loss:.4f}, Critic Loss {critic_loss:.4f}, "
                  f"Avg Train Acc {avg_train_accuracy:.4f}, Val Acc {val_accuracy:.4f}, "
                  f"Unique Indices: {len(used_indices)}")
            train_acc_sum, train_acc_count = 0.0, 0
        if step % checkpoint_every == 0:
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f"gemma-2-2b_layer20_ppo_lr{ppo_trainer.optimizer_policy.defaults['lr']}_batch{ppo_trainer.batch_size}_step{step}.pt"
            )
            torch.save({
                'step': step,
                'policy_state_dict': policy_net.state_dict(),
                'critic_state_dict': critic_net.state_dict(),
                'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
            }, checkpoint_path)
        if step % validate_every == 0 and val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_checkpoint_path = os.path.join(checkpoint_dir, "best_policy.pt")
            torch.save({
                'step': step,
                'policy_state_dict': policy_net.state_dict(),
                'critic_state_dict': critic_net.state_dict(),
                'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
            }, best_checkpoint_path)


In [19]:
train()

[34m[1mwandb[0m: Currently logged in as: [33mseonglae[0m ([33mtexonom[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Training Steps:   0%|          | 0/20000 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
Training Steps:   0%|          | 99/20000 [00:36<1:57:12,  2.83it/s]

Step 100: Policy Loss 0.0700, Critic Loss 0.2164, Avg Train Acc 0.6800, Val Acc 0.5500, Unique Indices: 98


Training Steps:   1%|          | 199/20000 [01:21<2:23:19,  2.30it/s] 

Step 200: Policy Loss 0.0615, Critic Loss 0.1671, Avg Train Acc 0.7225, Val Acc 0.5600, Unique Indices: 98


Training Steps:   2%|▏         | 300/20000 [02:13<7:23:09,  1.35s/it] 

Step 300: Policy Loss 0.1097, Critic Loss 0.3177, Avg Train Acc 0.7037, Val Acc 0.5500, Unique Indices: 100


Training Steps:   2%|▏         | 400/20000 [02:53<7:24:54,  1.36s/it]

Step 400: Policy Loss 0.1678, Critic Loss 0.1663, Avg Train Acc 0.7013, Val Acc 0.5300, Unique Indices: 99


Training Steps:   2%|▎         | 500/20000 [03:33<7:10:15,  1.32s/it]

Step 500: Policy Loss 0.0632, Critic Loss 0.1955, Avg Train Acc 0.7125, Val Acc 0.5400, Unique Indices: 99


Training Steps:   3%|▎         | 600/20000 [04:15<7:18:41,  1.36s/it]

Step 600: Policy Loss 0.1359, Critic Loss 0.1128, Avg Train Acc 0.7100, Val Acc 0.5400, Unique Indices: 99


Training Steps:   4%|▎         | 700/20000 [04:55<7:46:26,  1.45s/it]

Step 700: Policy Loss 0.1918, Critic Loss 0.1532, Avg Train Acc 0.7137, Val Acc 0.5400, Unique Indices: 100


Training Steps:   4%|▍         | 800/20000 [05:36<7:07:21,  1.34s/it]

Step 800: Policy Loss 0.0721, Critic Loss 0.2030, Avg Train Acc 0.7063, Val Acc 0.5400, Unique Indices: 98


Training Steps:   4%|▍         | 900/20000 [06:17<7:19:55,  1.38s/it]

Step 900: Policy Loss 0.0898, Critic Loss 0.2524, Avg Train Acc 0.6800, Val Acc 0.5400, Unique Indices: 98


Training Steps:   5%|▍         | 999/20000 [06:54<2:01:16,  2.61it/s]

Step 1000: Policy Loss 0.1694, Critic Loss 0.2952, Avg Train Acc 0.7087, Val Acc 0.5500, Unique Indices: 99


Training Steps:   6%|▌         | 1100/20000 [07:43<6:55:44,  1.32s/it] 

Step 1100: Policy Loss 0.1674, Critic Loss 0.3697, Avg Train Acc 0.7000, Val Acc 0.5400, Unique Indices: 100


Training Steps:   6%|▌         | 1200/20000 [08:26<6:59:11,  1.34s/it]

Step 1200: Policy Loss 0.0649, Critic Loss 0.0550, Avg Train Acc 0.7000, Val Acc 0.5500, Unique Indices: 99


Training Steps:   6%|▋         | 1300/20000 [09:07<6:51:15,  1.32s/it]

Step 1300: Policy Loss 0.0511, Critic Loss 0.1357, Avg Train Acc 0.7262, Val Acc 0.5400, Unique Indices: 99


Training Steps:   7%|▋         | 1400/20000 [09:48<7:08:49,  1.38s/it]

Step 1400: Policy Loss 0.1495, Critic Loss 0.2742, Avg Train Acc 0.7200, Val Acc 0.5500, Unique Indices: 98


Training Steps:   8%|▊         | 1500/20000 [10:30<6:58:38,  1.36s/it]

Step 1500: Policy Loss 0.0561, Critic Loss 0.3372, Avg Train Acc 0.6987, Val Acc 0.5300, Unique Indices: 97


Training Steps:   8%|▊         | 1600/20000 [11:10<6:43:49,  1.32s/it]

Step 1600: Policy Loss 0.1422, Critic Loss 0.1601, Avg Train Acc 0.7113, Val Acc 0.5400, Unique Indices: 98


Training Steps:   8%|▊         | 1700/20000 [11:51<6:48:30,  1.34s/it]

Step 1700: Policy Loss 0.2365, Critic Loss 0.2893, Avg Train Acc 0.6725, Val Acc 0.5400, Unique Indices: 98


Training Steps:   9%|▉         | 1800/20000 [12:34<7:07:19,  1.41s/it]

Step 1800: Policy Loss 0.1722, Critic Loss 0.1809, Avg Train Acc 0.7225, Val Acc 0.5500, Unique Indices: 100


Training Steps:  10%|▉         | 1900/20000 [13:16<6:35:31,  1.31s/it]

Step 1900: Policy Loss -0.0000, Critic Loss 0.2344, Avg Train Acc 0.7025, Val Acc 0.5500, Unique Indices: 99


Training Steps:  10%|▉         | 1999/20000 [13:57<2:30:56,  1.99it/s]

Step 2000: Policy Loss 0.0764, Critic Loss 0.1354, Avg Train Acc 0.6900, Val Acc 0.5500, Unique Indices: 99


Training Steps:  10%|█         | 2100/20000 [14:46<6:31:27,  1.31s/it] 

Step 2100: Policy Loss 0.0196, Critic Loss 0.3916, Avg Train Acc 0.7025, Val Acc 0.5400, Unique Indices: 97


Training Steps:  11%|█         | 2200/20000 [15:27<6:31:55,  1.32s/it]

Step 2200: Policy Loss 0.0507, Critic Loss 0.1604, Avg Train Acc 0.6775, Val Acc 0.5400, Unique Indices: 98


Training Steps:  12%|█▏        | 2300/20000 [16:09<7:01:14,  1.43s/it]

Step 2300: Policy Loss 0.0471, Critic Loss 0.4179, Avg Train Acc 0.7238, Val Acc 0.5400, Unique Indices: 99


Training Steps:  12%|█▏        | 2400/20000 [16:51<6:26:48,  1.32s/it]

Step 2400: Policy Loss 0.1199, Critic Loss 0.1351, Avg Train Acc 0.7087, Val Acc 0.5300, Unique Indices: 100


Training Steps:  12%|█▎        | 2500/20000 [17:32<6:42:43,  1.38s/it]

Step 2500: Policy Loss 0.1036, Critic Loss 0.3240, Avg Train Acc 0.7087, Val Acc 0.5500, Unique Indices: 97


Training Steps:  13%|█▎        | 2600/20000 [18:15<6:32:29,  1.35s/it]

Step 2600: Policy Loss 0.0743, Critic Loss 0.1009, Avg Train Acc 0.7050, Val Acc 0.5400, Unique Indices: 97


Training Steps:  14%|█▎        | 2700/20000 [18:56<6:50:03,  1.42s/it]

Step 2700: Policy Loss 0.1505, Critic Loss 0.2602, Avg Train Acc 0.6987, Val Acc 0.5400, Unique Indices: 100


Training Steps:  14%|█▍        | 2800/20000 [19:38<6:25:56,  1.35s/it]

Step 2800: Policy Loss 0.0697, Critic Loss 0.2949, Avg Train Acc 0.7087, Val Acc 0.5400, Unique Indices: 99


Training Steps:  14%|█▍        | 2900/20000 [20:18<6:16:16,  1.32s/it]

Step 2900: Policy Loss 0.1101, Critic Loss 0.2973, Avg Train Acc 0.6975, Val Acc 0.5500, Unique Indices: 100


Training Steps:  15%|█▍        | 2999/20000 [20:56<1:42:36,  2.76it/s]

Step 3000: Policy Loss 0.0856, Critic Loss 0.2315, Avg Train Acc 0.6925, Val Acc 0.5400, Unique Indices: 99


Training Steps:  16%|█▌        | 3100/20000 [21:48<6:26:35,  1.37s/it] 

Step 3100: Policy Loss 0.1376, Critic Loss 0.2883, Avg Train Acc 0.7075, Val Acc 0.5500, Unique Indices: 98


Training Steps:  16%|█▌        | 3200/20000 [22:29<6:16:17,  1.34s/it]

Step 3200: Policy Loss 0.1706, Critic Loss 0.1183, Avg Train Acc 0.7100, Val Acc 0.5500, Unique Indices: 100


Training Steps:  16%|█▋        | 3300/20000 [23:13<6:32:26,  1.41s/it]

Step 3300: Policy Loss 0.0396, Critic Loss 0.1492, Avg Train Acc 0.7037, Val Acc 0.5500, Unique Indices: 98


Training Steps:  17%|█▋        | 3400/20000 [23:55<6:16:00,  1.36s/it]

Step 3400: Policy Loss 0.2126, Critic Loss 0.2540, Avg Train Acc 0.7150, Val Acc 0.5500, Unique Indices: 98


Training Steps:  18%|█▊        | 3500/20000 [24:37<6:19:42,  1.38s/it]

Step 3500: Policy Loss -0.0000, Critic Loss 0.5584, Avg Train Acc 0.6837, Val Acc 0.5500, Unique Indices: 100


Training Steps:  18%|█▊        | 3600/20000 [25:19<6:22:56,  1.40s/it]

Step 3600: Policy Loss 0.0147, Critic Loss 0.2224, Avg Train Acc 0.7400, Val Acc 0.5300, Unique Indices: 100


Training Steps:  18%|█▊        | 3700/20000 [26:01<6:13:24,  1.37s/it]

Step 3700: Policy Loss 0.0895, Critic Loss 0.1331, Avg Train Acc 0.7113, Val Acc 0.5600, Unique Indices: 98


Training Steps:  19%|█▉        | 3800/20000 [26:43<6:11:53,  1.38s/it]

Step 3800: Policy Loss 0.0042, Critic Loss 0.5440, Avg Train Acc 0.7037, Val Acc 0.5300, Unique Indices: 100


Training Steps:  20%|█▉        | 3900/20000 [27:26<6:01:09,  1.35s/it]

Step 3900: Policy Loss 0.0840, Critic Loss 0.0654, Avg Train Acc 0.7212, Val Acc 0.5400, Unique Indices: 100


Training Steps:  20%|█▉        | 3999/20000 [28:05<1:53:03,  2.36it/s]

Step 4000: Policy Loss 0.1891, Critic Loss 0.2555, Avg Train Acc 0.7087, Val Acc 0.5500, Unique Indices: 99


Training Steps:  20%|██        | 4100/20000 [28:56<5:51:16,  1.33s/it] 

Step 4100: Policy Loss 0.1419, Critic Loss 0.2184, Avg Train Acc 0.6925, Val Acc 0.5400, Unique Indices: 100


Training Steps:  21%|██        | 4200/20000 [29:38<6:01:59,  1.37s/it]

Step 4200: Policy Loss 0.2589, Critic Loss 0.2626, Avg Train Acc 0.7075, Val Acc 0.5400, Unique Indices: 99


Training Steps:  22%|██▏       | 4300/20000 [30:20<6:17:33,  1.44s/it]

Step 4300: Policy Loss 0.0817, Critic Loss 0.1889, Avg Train Acc 0.7125, Val Acc 0.5400, Unique Indices: 98


Training Steps:  22%|██▏       | 4400/20000 [31:02<5:52:03,  1.35s/it]

Step 4400: Policy Loss 0.0726, Critic Loss 0.1629, Avg Train Acc 0.6837, Val Acc 0.5300, Unique Indices: 98


Training Steps:  22%|██▎       | 4500/20000 [31:44<5:44:01,  1.33s/it]

Step 4500: Policy Loss -0.0000, Critic Loss 0.0957, Avg Train Acc 0.7075, Val Acc 0.5300, Unique Indices: 99


Training Steps:  23%|██▎       | 4600/20000 [32:26<5:45:25,  1.35s/it]

Step 4600: Policy Loss 0.0543, Critic Loss 0.1613, Avg Train Acc 0.6937, Val Acc 0.5400, Unique Indices: 98


Training Steps:  24%|██▎       | 4700/20000 [33:07<6:09:08,  1.45s/it]

Step 4700: Policy Loss -0.0000, Critic Loss 0.2214, Avg Train Acc 0.7075, Val Acc 0.5500, Unique Indices: 100


Training Steps:  24%|██▍       | 4800/20000 [33:50<6:11:33,  1.47s/it]

Step 4800: Policy Loss 0.0874, Critic Loss 0.2731, Avg Train Acc 0.7188, Val Acc 0.5600, Unique Indices: 100


Training Steps:  24%|██▍       | 4900/20000 [34:30<5:42:56,  1.36s/it]

Step 4900: Policy Loss 0.2129, Critic Loss 0.2452, Avg Train Acc 0.7087, Val Acc 0.5400, Unique Indices: 100


Training Steps:  25%|██▍       | 4999/20000 [35:08<1:21:12,  3.08it/s]

Step 5000: Policy Loss 0.1744, Critic Loss 0.1861, Avg Train Acc 0.7063, Val Acc 0.5400, Unique Indices: 100


Training Steps:  26%|██▌       | 5100/20000 [35:58<5:45:50,  1.39s/it] 

Step 5100: Policy Loss 0.0494, Critic Loss 0.0505, Avg Train Acc 0.7238, Val Acc 0.5400, Unique Indices: 100


Training Steps:  26%|██▌       | 5200/20000 [36:40<5:37:27,  1.37s/it]

Step 5200: Policy Loss 0.2084, Critic Loss 0.1203, Avg Train Acc 0.6925, Val Acc 0.5500, Unique Indices: 99


Training Steps:  26%|██▋       | 5300/20000 [37:23<5:39:50,  1.39s/it]

Step 5300: Policy Loss 0.0312, Critic Loss 0.2146, Avg Train Acc 0.7075, Val Acc 0.5500, Unique Indices: 98


Training Steps:  27%|██▋       | 5400/20000 [38:04<5:32:33,  1.37s/it]

Step 5400: Policy Loss 0.0104, Critic Loss 0.0481, Avg Train Acc 0.7113, Val Acc 0.5500, Unique Indices: 99


Training Steps:  28%|██▊       | 5500/20000 [38:47<5:42:47,  1.42s/it]

Step 5500: Policy Loss 0.1531, Critic Loss 0.2636, Avg Train Acc 0.7087, Val Acc 0.5400, Unique Indices: 98


Training Steps:  28%|██▊       | 5600/20000 [39:30<5:32:49,  1.39s/it]

Step 5600: Policy Loss 0.0729, Critic Loss 0.0576, Avg Train Acc 0.7125, Val Acc 0.5500, Unique Indices: 99


Training Steps:  28%|██▊       | 5700/20000 [40:12<5:30:33,  1.39s/it]

Step 5700: Policy Loss 0.1203, Critic Loss 0.2846, Avg Train Acc 0.6850, Val Acc 0.5500, Unique Indices: 100


Training Steps:  29%|██▉       | 5800/20000 [40:53<5:10:39,  1.31s/it]

Step 5800: Policy Loss 0.1287, Critic Loss 0.3268, Avg Train Acc 0.7100, Val Acc 0.5500, Unique Indices: 96


Training Steps:  30%|██▉       | 5900/20000 [41:36<5:21:25,  1.37s/it]

Step 5900: Policy Loss 0.1272, Critic Loss 0.2097, Avg Train Acc 0.7238, Val Acc 0.5400, Unique Indices: 100


Training Steps:  30%|██▉       | 5999/20000 [42:15<1:28:09,  2.65it/s]

Step 6000: Policy Loss 0.2462, Critic Loss 0.2086, Avg Train Acc 0.7013, Val Acc 0.5500, Unique Indices: 99


Training Steps:  30%|███       | 6100/20000 [43:06<5:18:22,  1.37s/it] 

Step 6100: Policy Loss 0.3567, Critic Loss 0.4224, Avg Train Acc 0.7013, Val Acc 0.5500, Unique Indices: 100


Training Steps:  31%|███       | 6200/20000 [43:48<5:19:56,  1.39s/it]

Step 6200: Policy Loss 0.0505, Critic Loss 0.1718, Avg Train Acc 0.7163, Val Acc 0.5300, Unique Indices: 99


Training Steps:  32%|███▏      | 6300/20000 [44:31<5:06:38,  1.34s/it]

Step 6300: Policy Loss 0.0972, Critic Loss 0.1619, Avg Train Acc 0.7338, Val Acc 0.5400, Unique Indices: 99


Training Steps:  32%|███▏      | 6400/20000 [45:12<5:10:15,  1.37s/it]

Step 6400: Policy Loss 0.0133, Critic Loss 0.1926, Avg Train Acc 0.7188, Val Acc 0.5400, Unique Indices: 99


Training Steps:  32%|███▎      | 6500/20000 [45:55<5:20:47,  1.43s/it]

Step 6500: Policy Loss 0.0613, Critic Loss 0.0747, Avg Train Acc 0.7175, Val Acc 0.5400, Unique Indices: 96


Training Steps:  33%|███▎      | 6600/20000 [46:39<5:18:09,  1.42s/it]

Step 6600: Policy Loss 0.1214, Critic Loss 0.3144, Avg Train Acc 0.6875, Val Acc 0.5500, Unique Indices: 100


Training Steps:  34%|███▎      | 6700/20000 [47:22<4:53:45,  1.33s/it]

Step 6700: Policy Loss 0.0929, Critic Loss 0.2339, Avg Train Acc 0.6950, Val Acc 0.5600, Unique Indices: 99


Training Steps:  34%|███▍      | 6801/20000 [48:05<3:45:08,  1.02s/it]

Step 6800: Policy Loss 0.0721, Critic Loss 0.2469, Avg Train Acc 0.6950, Val Acc 0.5200, Unique Indices: 97


Training Steps:  34%|███▍      | 6900/20000 [48:45<5:00:22,  1.38s/it]

Step 6900: Policy Loss 0.0163, Critic Loss 0.1607, Avg Train Acc 0.6863, Val Acc 0.5400, Unique Indices: 99


Training Steps:  35%|███▍      | 6999/20000 [49:23<1:25:57,  2.52it/s]

Step 7000: Policy Loss 0.0499, Critic Loss 0.2656, Avg Train Acc 0.6875, Val Acc 0.5400, Unique Indices: 100


Training Steps:  36%|███▌      | 7100/20000 [50:14<5:01:45,  1.40s/it]

Step 7100: Policy Loss 0.0487, Critic Loss 0.0680, Avg Train Acc 0.6713, Val Acc 0.5500, Unique Indices: 100


Training Steps:  36%|███▌      | 7200/20000 [50:55<4:49:13,  1.36s/it]

Step 7200: Policy Loss -0.0000, Critic Loss 0.2217, Avg Train Acc 0.6975, Val Acc 0.5500, Unique Indices: 96


Training Steps:  36%|███▋      | 7300/20000 [51:37<4:47:19,  1.36s/it]

Step 7300: Policy Loss 0.0447, Critic Loss 0.0808, Avg Train Acc 0.7275, Val Acc 0.5400, Unique Indices: 98


Training Steps:  37%|███▋      | 7400/20000 [52:18<4:36:44,  1.32s/it]

Step 7400: Policy Loss 0.0665, Critic Loss 0.3210, Avg Train Acc 0.7025, Val Acc 0.5400, Unique Indices: 100


Training Steps:  38%|███▊      | 7500/20000 [52:59<4:35:10,  1.32s/it]

Step 7500: Policy Loss 0.0804, Critic Loss 0.0797, Avg Train Acc 0.7388, Val Acc 0.5500, Unique Indices: 99


Training Steps:  38%|███▊      | 7600/20000 [53:40<4:43:45,  1.37s/it]

Step 7600: Policy Loss 0.0123, Critic Loss 0.2126, Avg Train Acc 0.6813, Val Acc 0.5400, Unique Indices: 97


Training Steps:  38%|███▊      | 7700/20000 [54:22<4:36:22,  1.35s/it]

Step 7700: Policy Loss -0.0000, Critic Loss 0.2247, Avg Train Acc 0.7300, Val Acc 0.5500, Unique Indices: 99


Training Steps:  39%|███▉      | 7800/20000 [55:04<4:37:50,  1.37s/it]

Step 7800: Policy Loss 0.0661, Critic Loss 0.1162, Avg Train Acc 0.7312, Val Acc 0.5400, Unique Indices: 100


Training Steps:  40%|███▉      | 7900/20000 [55:46<4:44:36,  1.41s/it]

Step 7900: Policy Loss 0.1831, Critic Loss 0.2076, Avg Train Acc 0.7113, Val Acc 0.5400, Unique Indices: 99


Training Steps:  40%|███▉      | 7999/20000 [56:26<1:29:19,  2.24it/s]

Step 8000: Policy Loss 0.1610, Critic Loss 0.1022, Avg Train Acc 0.7125, Val Acc 0.5400, Unique Indices: 99


Training Steps:  40%|████      | 8100/20000 [57:16<4:23:49,  1.33s/it]

Step 8100: Policy Loss 0.1382, Critic Loss 0.2318, Avg Train Acc 0.7087, Val Acc 0.5400, Unique Indices: 99


Training Steps:  41%|████      | 8200/20000 [57:59<4:33:49,  1.39s/it]

Step 8200: Policy Loss 0.0265, Critic Loss 0.1473, Avg Train Acc 0.7013, Val Acc 0.5500, Unique Indices: 99


Training Steps:  42%|████▏     | 8300/20000 [58:41<4:32:16,  1.40s/it]

Step 8300: Policy Loss 0.2250, Critic Loss 0.2451, Avg Train Acc 0.7063, Val Acc 0.5300, Unique Indices: 98


Training Steps:  42%|████▏     | 8400/20000 [59:23<4:20:07,  1.35s/it]

Step 8400: Policy Loss 0.0205, Critic Loss 0.0963, Avg Train Acc 0.7063, Val Acc 0.5500, Unique Indices: 99


Training Steps:  42%|████▎     | 8500/20000 [1:00:04<4:19:18,  1.35s/it]

Step 8500: Policy Loss 0.3003, Critic Loss 0.2251, Avg Train Acc 0.7300, Val Acc 0.5500, Unique Indices: 96


Training Steps:  43%|████▎     | 8600/20000 [1:00:46<4:13:47,  1.34s/it]

Step 8600: Policy Loss 0.0447, Critic Loss 0.2238, Avg Train Acc 0.6925, Val Acc 0.5500, Unique Indices: 100


Training Steps:  44%|████▎     | 8700/20000 [1:01:27<4:12:30,  1.34s/it]

Step 8700: Policy Loss 0.0643, Critic Loss 0.1568, Avg Train Acc 0.6800, Val Acc 0.5500, Unique Indices: 99


Training Steps:  44%|████▍     | 8800/20000 [1:02:10<4:12:04,  1.35s/it]

Step 8800: Policy Loss 0.0532, Critic Loss 0.0737, Avg Train Acc 0.7050, Val Acc 0.5300, Unique Indices: 98


Training Steps:  44%|████▍     | 8900/20000 [1:02:51<4:12:00,  1.36s/it]

Step 8900: Policy Loss 0.0375, Critic Loss 0.1722, Avg Train Acc 0.6887, Val Acc 0.5400, Unique Indices: 100


Training Steps:  45%|████▍     | 8999/20000 [1:03:30<1:09:15,  2.65it/s]

Step 9000: Policy Loss 0.0883, Critic Loss 0.2223, Avg Train Acc 0.7288, Val Acc 0.5300, Unique Indices: 100


Training Steps:  46%|████▌     | 9100/20000 [1:04:22<4:08:17,  1.37s/it]

Step 9100: Policy Loss 0.0635, Critic Loss 0.1557, Avg Train Acc 0.7200, Val Acc 0.5500, Unique Indices: 100


Training Steps:  46%|████▌     | 9200/20000 [1:05:03<3:59:07,  1.33s/it]

Step 9200: Policy Loss 0.0506, Critic Loss 0.1391, Avg Train Acc 0.7150, Val Acc 0.5400, Unique Indices: 99


Training Steps:  46%|████▋     | 9300/20000 [1:05:43<3:58:03,  1.33s/it]

Step 9300: Policy Loss 0.1216, Critic Loss 0.1154, Avg Train Acc 0.7312, Val Acc 0.5500, Unique Indices: 100


Training Steps:  47%|████▋     | 9400/20000 [1:06:25<4:01:18,  1.37s/it]

Step 9400: Policy Loss -0.0000, Critic Loss 0.3529, Avg Train Acc 0.7137, Val Acc 0.5300, Unique Indices: 100


Training Steps:  48%|████▊     | 9500/20000 [1:07:08<3:56:26,  1.35s/it]

Step 9500: Policy Loss 0.0791, Critic Loss 0.1557, Avg Train Acc 0.6850, Val Acc 0.5500, Unique Indices: 99


Training Steps:  48%|████▊     | 9600/20000 [1:07:50<3:49:25,  1.32s/it]

Step 9600: Policy Loss 0.1081, Critic Loss 0.0842, Avg Train Acc 0.7063, Val Acc 0.5500, Unique Indices: 98


Training Steps:  48%|████▊     | 9700/20000 [1:08:31<3:47:19,  1.32s/it]

Step 9700: Policy Loss 0.1095, Critic Loss 0.2435, Avg Train Acc 0.7075, Val Acc 0.5400, Unique Indices: 97


Training Steps:  49%|████▉     | 9800/20000 [1:09:12<3:46:04,  1.33s/it]

Step 9800: Policy Loss 0.0702, Critic Loss 0.4119, Avg Train Acc 0.7100, Val Acc 0.5500, Unique Indices: 98


Training Steps:  50%|████▉     | 9900/20000 [1:09:54<3:47:47,  1.35s/it]

Step 9900: Policy Loss 0.2213, Critic Loss 0.2473, Avg Train Acc 0.7137, Val Acc 0.5600, Unique Indices: 100


Training Steps:  50%|████▉     | 9999/20000 [1:10:32<1:08:16,  2.44it/s]

Step 10000: Policy Loss 0.0516, Critic Loss 0.2655, Avg Train Acc 0.6887, Val Acc 0.5300, Unique Indices: 99


Training Steps:  50%|█████     | 10100/20000 [1:11:24<3:42:03,  1.35s/it]

Step 10100: Policy Loss 0.0788, Critic Loss 0.2234, Avg Train Acc 0.7013, Val Acc 0.5400, Unique Indices: 99


Training Steps:  51%|█████     | 10200/20000 [1:12:07<3:38:43,  1.34s/it]

Step 10200: Policy Loss 0.0779, Critic Loss 0.0870, Avg Train Acc 0.7188, Val Acc 0.5400, Unique Indices: 98


Training Steps:  52%|█████▏    | 10300/20000 [1:12:49<3:36:40,  1.34s/it]

Step 10300: Policy Loss 0.0233, Critic Loss 0.1400, Avg Train Acc 0.7238, Val Acc 0.5400, Unique Indices: 100


Training Steps:  52%|█████▏    | 10400/20000 [1:13:31<3:36:25,  1.35s/it]

Step 10400: Policy Loss 0.1789, Critic Loss 0.1138, Avg Train Acc 0.7188, Val Acc 0.5400, Unique Indices: 99


Training Steps:  52%|█████▎    | 10500/20000 [1:14:12<3:30:31,  1.33s/it]

Step 10500: Policy Loss 0.0615, Critic Loss 0.1466, Avg Train Acc 0.6950, Val Acc 0.5400, Unique Indices: 99


Training Steps:  53%|█████▎    | 10600/20000 [1:14:54<3:28:28,  1.33s/it]

Step 10600: Policy Loss 0.2021, Critic Loss 0.2113, Avg Train Acc 0.6937, Val Acc 0.5500, Unique Indices: 100


Training Steps:  54%|█████▎    | 10700/20000 [1:15:35<3:24:01,  1.32s/it]

Step 10700: Policy Loss 0.0398, Critic Loss 0.1457, Avg Train Acc 0.6963, Val Acc 0.5500, Unique Indices: 97


Training Steps:  54%|█████▍    | 10800/20000 [1:16:16<3:25:51,  1.34s/it]

Step 10800: Policy Loss 0.0788, Critic Loss 0.0607, Avg Train Acc 0.7188, Val Acc 0.5200, Unique Indices: 98


Training Steps:  55%|█████▍    | 10900/20000 [1:16:57<3:26:51,  1.36s/it]

Step 10900: Policy Loss 0.0222, Critic Loss 0.2309, Avg Train Acc 0.7288, Val Acc 0.5500, Unique Indices: 98


Training Steps:  55%|█████▍    | 10999/20000 [1:17:34<54:16,  2.76it/s]  

Step 11000: Policy Loss 0.2414, Critic Loss 0.3150, Avg Train Acc 0.7100, Val Acc 0.5500, Unique Indices: 99


Training Steps:  56%|█████▌    | 11100/20000 [1:18:22<3:17:43,  1.33s/it]

Step 11100: Policy Loss -0.0000, Critic Loss 0.0717, Avg Train Acc 0.7000, Val Acc 0.5400, Unique Indices: 99


Training Steps:  56%|█████▌    | 11200/20000 [1:19:04<3:14:57,  1.33s/it]

Step 11200: Policy Loss 0.0958, Critic Loss 0.1176, Avg Train Acc 0.7238, Val Acc 0.5600, Unique Indices: 100


Training Steps:  56%|█████▋    | 11300/20000 [1:19:44<3:15:16,  1.35s/it]

Step 11300: Policy Loss 0.1208, Critic Loss 0.1740, Avg Train Acc 0.7262, Val Acc 0.5400, Unique Indices: 100


Training Steps:  57%|█████▋    | 11400/20000 [1:20:24<3:09:47,  1.32s/it]

Step 11400: Policy Loss 0.0944, Critic Loss 0.1505, Avg Train Acc 0.7163, Val Acc 0.5300, Unique Indices: 100


Training Steps:  57%|█████▊    | 11500/20000 [1:21:06<3:08:35,  1.33s/it]

Step 11500: Policy Loss -0.0000, Critic Loss 0.2596, Avg Train Acc 0.7000, Val Acc 0.5400, Unique Indices: 99


Training Steps:  58%|█████▊    | 11600/20000 [1:21:48<3:03:39,  1.31s/it]

Step 11600: Policy Loss 0.1142, Critic Loss 0.1899, Avg Train Acc 0.7262, Val Acc 0.5300, Unique Indices: 98


Training Steps:  58%|█████▊    | 11700/20000 [1:22:30<3:06:44,  1.35s/it]

Step 11700: Policy Loss -0.0000, Critic Loss 0.1791, Avg Train Acc 0.6875, Val Acc 0.5400, Unique Indices: 100


Training Steps:  59%|█████▉    | 11800/20000 [1:23:11<3:00:36,  1.32s/it]

Step 11800: Policy Loss 0.0122, Critic Loss 0.1133, Avg Train Acc 0.6800, Val Acc 0.5600, Unique Indices: 100


Training Steps:  60%|█████▉    | 11900/20000 [1:23:53<3:01:23,  1.34s/it]

Step 11900: Policy Loss 0.0568, Critic Loss 0.1635, Avg Train Acc 0.6950, Val Acc 0.5400, Unique Indices: 99


Training Steps:  60%|█████▉    | 11999/20000 [1:24:33<51:47,  2.57it/s]  

Step 12000: Policy Loss 0.1019, Critic Loss 0.2613, Avg Train Acc 0.6813, Val Acc 0.5500, Unique Indices: 99


Training Steps:  60%|██████    | 12100/20000 [1:25:23<3:01:08,  1.38s/it]

Step 12100: Policy Loss 0.2464, Critic Loss 0.2849, Avg Train Acc 0.7013, Val Acc 0.5300, Unique Indices: 99


Training Steps:  61%|██████    | 12200/20000 [1:26:07<2:53:49,  1.34s/it]

Step 12200: Policy Loss 0.1015, Critic Loss 0.0912, Avg Train Acc 0.7200, Val Acc 0.5200, Unique Indices: 97


Training Steps:  62%|██████▏   | 12300/20000 [1:26:48<2:49:54,  1.32s/it]

Step 12300: Policy Loss -0.0000, Critic Loss 0.4314, Avg Train Acc 0.6913, Val Acc 0.5400, Unique Indices: 99


Training Steps:  62%|██████▏   | 12400/20000 [1:27:30<3:02:17,  1.44s/it]

Step 12400: Policy Loss 0.0407, Critic Loss 0.0896, Avg Train Acc 0.7075, Val Acc 0.5400, Unique Indices: 97


Training Steps:  62%|██████▎   | 12500/20000 [1:28:11<2:46:09,  1.33s/it]

Step 12500: Policy Loss 0.2051, Critic Loss 0.2061, Avg Train Acc 0.6787, Val Acc 0.5400, Unique Indices: 99


Training Steps:  63%|██████▎   | 12600/20000 [1:28:52<2:45:12,  1.34s/it]

Step 12600: Policy Loss 0.0819, Critic Loss 0.1724, Avg Train Acc 0.6975, Val Acc 0.5500, Unique Indices: 99


Training Steps:  64%|██████▎   | 12700/20000 [1:29:33<2:42:18,  1.33s/it]

Step 12700: Policy Loss 0.0450, Critic Loss 0.2125, Avg Train Acc 0.7087, Val Acc 0.5500, Unique Indices: 96


Training Steps:  64%|██████▍   | 12800/20000 [1:30:15<2:40:18,  1.34s/it]

Step 12800: Policy Loss 0.0062, Critic Loss 0.0662, Avg Train Acc 0.6837, Val Acc 0.5400, Unique Indices: 99


Training Steps:  64%|██████▍   | 12900/20000 [1:30:56<2:43:44,  1.38s/it]

Step 12900: Policy Loss 0.1674, Critic Loss 0.1975, Avg Train Acc 0.7175, Val Acc 0.5300, Unique Indices: 98


Training Steps:  65%|██████▍   | 12999/20000 [1:31:34<51:22,  2.27it/s]  

Step 13000: Policy Loss 0.0953, Critic Loss 0.0995, Avg Train Acc 0.7137, Val Acc 0.5400, Unique Indices: 99


Training Steps:  66%|██████▌   | 13100/20000 [1:32:23<2:38:42,  1.38s/it]

Step 13100: Policy Loss 0.1702, Critic Loss 0.1306, Avg Train Acc 0.7050, Val Acc 0.5400, Unique Indices: 100


Training Steps:  66%|██████▌   | 13200/20000 [1:33:03<2:29:52,  1.32s/it]

Step 13200: Policy Loss 0.2086, Critic Loss 0.2262, Avg Train Acc 0.7075, Val Acc 0.5500, Unique Indices: 99


Training Steps:  66%|██████▋   | 13300/20000 [1:33:44<2:31:32,  1.36s/it]

Step 13300: Policy Loss 0.0158, Critic Loss 0.1438, Avg Train Acc 0.7163, Val Acc 0.5500, Unique Indices: 99


Training Steps:  67%|██████▋   | 13400/20000 [1:34:25<2:28:42,  1.35s/it]

Step 13400: Policy Loss 0.0685, Critic Loss 0.1167, Avg Train Acc 0.6737, Val Acc 0.5500, Unique Indices: 100


Training Steps:  68%|██████▊   | 13500/20000 [1:35:05<2:24:12,  1.33s/it]

Step 13500: Policy Loss 0.1062, Critic Loss 0.1566, Avg Train Acc 0.7150, Val Acc 0.5500, Unique Indices: 99


Training Steps:  68%|██████▊   | 13600/20000 [1:35:46<2:19:48,  1.31s/it]

Step 13600: Policy Loss 0.1043, Critic Loss 0.2380, Avg Train Acc 0.6937, Val Acc 0.5400, Unique Indices: 98


Training Steps:  68%|██████▊   | 13700/20000 [1:36:29<2:28:03,  1.41s/it]

Step 13700: Policy Loss 0.1105, Critic Loss 0.1216, Avg Train Acc 0.7150, Val Acc 0.5500, Unique Indices: 97


Training Steps:  69%|██████▉   | 13800/20000 [1:37:10<2:19:58,  1.35s/it]

Step 13800: Policy Loss 0.0677, Critic Loss 0.2352, Avg Train Acc 0.7100, Val Acc 0.5400, Unique Indices: 100


Training Steps:  70%|██████▉   | 13900/20000 [1:37:51<2:16:26,  1.34s/it]

Step 13900: Policy Loss 0.0922, Critic Loss 0.1440, Avg Train Acc 0.7375, Val Acc 0.5500, Unique Indices: 99


Training Steps:  70%|██████▉   | 13999/20000 [1:38:28<41:04,  2.43it/s]  

Step 14000: Policy Loss 0.2745, Critic Loss 0.2272, Avg Train Acc 0.6825, Val Acc 0.5400, Unique Indices: 100


Training Steps:  70%|███████   | 14100/20000 [1:39:17<2:10:40,  1.33s/it]

Step 14100: Policy Loss 0.1312, Critic Loss 0.2518, Avg Train Acc 0.7150, Val Acc 0.5500, Unique Indices: 100


Training Steps:  71%|███████   | 14200/20000 [1:39:58<2:16:15,  1.41s/it]

Step 14200: Policy Loss 0.2206, Critic Loss 0.2058, Avg Train Acc 0.6787, Val Acc 0.5400, Unique Indices: 98


Training Steps:  72%|███████▏  | 14300/20000 [1:40:40<2:06:25,  1.33s/it]

Step 14300: Policy Loss 0.1144, Critic Loss 0.1780, Avg Train Acc 0.7137, Val Acc 0.5400, Unique Indices: 99


Training Steps:  72%|███████▏  | 14400/20000 [1:41:23<2:11:31,  1.41s/it]

Step 14400: Policy Loss 0.1657, Critic Loss 0.2117, Avg Train Acc 0.7025, Val Acc 0.5400, Unique Indices: 99


Training Steps:  72%|███████▎  | 14500/20000 [1:42:07<2:08:45,  1.40s/it]

Step 14500: Policy Loss 0.0525, Critic Loss 0.2358, Avg Train Acc 0.6963, Val Acc 0.5400, Unique Indices: 99


Training Steps:  73%|███████▎  | 14600/20000 [1:42:48<2:00:57,  1.34s/it]

Step 14600: Policy Loss 0.1473, Critic Loss 0.1559, Avg Train Acc 0.7025, Val Acc 0.5400, Unique Indices: 98


Training Steps:  74%|███████▎  | 14700/20000 [1:43:28<1:56:49,  1.32s/it]

Step 14700: Policy Loss 0.0570, Critic Loss 0.2265, Avg Train Acc 0.6900, Val Acc 0.5400, Unique Indices: 98


Training Steps:  74%|███████▍  | 14800/20000 [1:44:10<1:58:19,  1.37s/it]

Step 14800: Policy Loss 0.0208, Critic Loss 0.3155, Avg Train Acc 0.7050, Val Acc 0.5400, Unique Indices: 100


Training Steps:  74%|███████▍  | 14900/20000 [1:44:52<1:59:43,  1.41s/it]

Step 14900: Policy Loss -0.0000, Critic Loss 0.2311, Avg Train Acc 0.7063, Val Acc 0.5500, Unique Indices: 99


Training Steps:  75%|███████▍  | 14999/20000 [1:45:30<35:42,  2.33it/s]  

Step 15000: Policy Loss 0.0555, Critic Loss 0.2107, Avg Train Acc 0.7137, Val Acc 0.5500, Unique Indices: 99


Training Steps:  76%|███████▌  | 15100/20000 [1:46:20<1:49:19,  1.34s/it]

Step 15100: Policy Loss 0.0644, Critic Loss 0.0703, Avg Train Acc 0.7075, Val Acc 0.5400, Unique Indices: 100


Training Steps:  76%|███████▌  | 15200/20000 [1:47:00<1:48:32,  1.36s/it]

Step 15200: Policy Loss 0.0899, Critic Loss 0.0836, Avg Train Acc 0.6975, Val Acc 0.5400, Unique Indices: 98


Training Steps:  76%|███████▋  | 15300/20000 [1:47:43<1:44:37,  1.34s/it]

Step 15300: Policy Loss 0.2010, Critic Loss 0.1684, Avg Train Acc 0.7113, Val Acc 0.5500, Unique Indices: 98


Training Steps:  77%|███████▋  | 15400/20000 [1:48:24<1:42:41,  1.34s/it]

Step 15400: Policy Loss 0.0648, Critic Loss 0.1520, Avg Train Acc 0.6963, Val Acc 0.5400, Unique Indices: 99


Training Steps:  78%|███████▊  | 15500/20000 [1:49:05<1:41:13,  1.35s/it]

Step 15500: Policy Loss 0.1203, Critic Loss 0.1911, Avg Train Acc 0.7037, Val Acc 0.5400, Unique Indices: 100


Training Steps:  78%|███████▊  | 15600/20000 [1:49:48<1:39:01,  1.35s/it]

Step 15600: Policy Loss -0.0000, Critic Loss 0.3708, Avg Train Acc 0.6975, Val Acc 0.5400, Unique Indices: 100


Training Steps:  78%|███████▊  | 15700/20000 [1:50:30<1:40:34,  1.40s/it]

Step 15700: Policy Loss 0.0918, Critic Loss 0.1941, Avg Train Acc 0.7050, Val Acc 0.5500, Unique Indices: 98


Training Steps:  79%|███████▉  | 15800/20000 [1:51:13<1:33:18,  1.33s/it]

Step 15800: Policy Loss 0.2337, Critic Loss 0.1541, Avg Train Acc 0.7100, Val Acc 0.5400, Unique Indices: 100


Training Steps:  80%|███████▉  | 15900/20000 [1:51:56<1:34:16,  1.38s/it]

Step 15900: Policy Loss 0.1366, Critic Loss 0.0759, Avg Train Acc 0.7125, Val Acc 0.5400, Unique Indices: 97


Training Steps:  80%|███████▉  | 15999/20000 [1:52:33<27:02,  2.47it/s]  

Step 16000: Policy Loss 0.1000, Critic Loss 0.3201, Avg Train Acc 0.6975, Val Acc 0.5600, Unique Indices: 100


Training Steps:  80%|████████  | 16100/20000 [1:53:23<1:29:12,  1.37s/it]

Step 16100: Policy Loss -0.0000, Critic Loss 0.1139, Avg Train Acc 0.7438, Val Acc 0.5500, Unique Indices: 100


Training Steps:  81%|████████  | 16200/20000 [1:54:05<1:28:50,  1.40s/it]

Step 16200: Policy Loss 0.0646, Critic Loss 0.2747, Avg Train Acc 0.7037, Val Acc 0.5400, Unique Indices: 100


Training Steps:  82%|████████▏ | 16300/20000 [1:54:47<1:23:47,  1.36s/it]

Step 16300: Policy Loss 0.0489, Critic Loss 0.1802, Avg Train Acc 0.7013, Val Acc 0.5500, Unique Indices: 100


Training Steps:  82%|████████▏ | 16400/20000 [1:55:30<1:20:25,  1.34s/it]

Step 16400: Policy Loss 0.1661, Critic Loss 0.2745, Avg Train Acc 0.7063, Val Acc 0.5300, Unique Indices: 100


Training Steps:  82%|████████▎ | 16500/20000 [1:56:12<1:22:43,  1.42s/it]

Step 16500: Policy Loss 0.0109, Critic Loss 0.2575, Avg Train Acc 0.7212, Val Acc 0.5400, Unique Indices: 99


Training Steps:  83%|████████▎ | 16600/20000 [1:56:53<1:14:32,  1.32s/it]

Step 16600: Policy Loss 0.0906, Critic Loss 0.1855, Avg Train Acc 0.6913, Val Acc 0.5500, Unique Indices: 100


Training Steps:  84%|████████▎ | 16700/20000 [1:57:35<1:15:17,  1.37s/it]

Step 16700: Policy Loss -0.0000, Critic Loss 0.1013, Avg Train Acc 0.7037, Val Acc 0.5500, Unique Indices: 99


Training Steps:  84%|████████▍ | 16800/20000 [1:58:17<1:11:39,  1.34s/it]

Step 16800: Policy Loss 0.0017, Critic Loss 0.2332, Avg Train Acc 0.7087, Val Acc 0.5500, Unique Indices: 100


Training Steps:  84%|████████▍ | 16900/20000 [1:59:01<1:14:46,  1.45s/it]

Step 16900: Policy Loss 0.0788, Critic Loss 0.1785, Avg Train Acc 0.6837, Val Acc 0.5400, Unique Indices: 99


Training Steps:  85%|████████▍ | 16999/20000 [1:59:38<18:19,  2.73it/s]  

Step 17000: Policy Loss 0.0377, Critic Loss 0.0628, Avg Train Acc 0.7037, Val Acc 0.5600, Unique Indices: 100


Training Steps:  86%|████████▌ | 17100/20000 [2:00:27<1:06:01,  1.37s/it]

Step 17100: Policy Loss 0.1633, Critic Loss 0.1676, Avg Train Acc 0.7037, Val Acc 0.5500, Unique Indices: 98


Training Steps:  86%|████████▌ | 17200/20000 [2:01:08<1:02:30,  1.34s/it]

Step 17200: Policy Loss 0.0182, Critic Loss 0.0807, Avg Train Acc 0.7163, Val Acc 0.5500, Unique Indices: 98


Training Steps:  86%|████████▋ | 17300/20000 [2:01:51<1:00:10,  1.34s/it]

Step 17300: Policy Loss 0.0818, Critic Loss 0.1478, Avg Train Acc 0.7150, Val Acc 0.5400, Unique Indices: 98


Training Steps:  87%|████████▋ | 17400/20000 [2:02:32<1:00:07,  1.39s/it]

Step 17400: Policy Loss 0.0485, Critic Loss 0.1692, Avg Train Acc 0.7087, Val Acc 0.5400, Unique Indices: 98


Training Steps:  88%|████████▊ | 17500/20000 [2:03:13<56:17,  1.35s/it]  

Step 17500: Policy Loss 0.0466, Critic Loss 0.1384, Avg Train Acc 0.7113, Val Acc 0.5400, Unique Indices: 99


Training Steps:  88%|████████▊ | 17600/20000 [2:03:55<53:48,  1.35s/it]

Step 17600: Policy Loss 0.0594, Critic Loss 0.2116, Avg Train Acc 0.7137, Val Acc 0.5400, Unique Indices: 98


Training Steps:  88%|████████▊ | 17700/20000 [2:04:37<50:25,  1.32s/it]

Step 17700: Policy Loss 0.1598, Critic Loss 0.1618, Avg Train Acc 0.7050, Val Acc 0.5500, Unique Indices: 98


Training Steps:  89%|████████▉ | 17800/20000 [2:05:20<50:06,  1.37s/it]

Step 17800: Policy Loss 0.1075, Critic Loss 0.2271, Avg Train Acc 0.6913, Val Acc 0.5300, Unique Indices: 100


Training Steps:  90%|████████▉ | 17900/20000 [2:06:01<46:20,  1.32s/it]

Step 17900: Policy Loss 0.0033, Critic Loss 0.0812, Avg Train Acc 0.7137, Val Acc 0.5500, Unique Indices: 99


Training Steps:  90%|████████▉ | 17999/20000 [2:06:40<12:44,  2.62it/s]

Step 18000: Policy Loss 0.0534, Critic Loss 0.3428, Avg Train Acc 0.7100, Val Acc 0.5400, Unique Indices: 99


Training Steps:  90%|█████████ | 18100/20000 [2:07:34<44:54,  1.42s/it]  

Step 18100: Policy Loss 0.1238, Critic Loss 0.4046, Avg Train Acc 0.7025, Val Acc 0.5500, Unique Indices: 100


Training Steps:  91%|█████████ | 18200/20000 [2:08:16<41:04,  1.37s/it]

Step 18200: Policy Loss 0.0892, Critic Loss 0.1968, Avg Train Acc 0.6813, Val Acc 0.5300, Unique Indices: 99


Training Steps:  92%|█████████▏| 18300/20000 [2:08:57<39:42,  1.40s/it]

Step 18300: Policy Loss 0.0251, Critic Loss 0.1577, Avg Train Acc 0.7312, Val Acc 0.5400, Unique Indices: 100


Training Steps:  92%|█████████▏| 18400/20000 [2:09:40<38:47,  1.45s/it]

Step 18400: Policy Loss 0.0986, Critic Loss 0.2200, Avg Train Acc 0.7175, Val Acc 0.5200, Unique Indices: 99


Training Steps:  92%|█████████▎| 18500/20000 [2:10:23<33:34,  1.34s/it]

Step 18500: Policy Loss 0.0295, Critic Loss 0.0661, Avg Train Acc 0.7025, Val Acc 0.5400, Unique Indices: 99


Training Steps:  93%|█████████▎| 18600/20000 [2:11:05<31:29,  1.35s/it]

Step 18600: Policy Loss 0.1460, Critic Loss 0.1631, Avg Train Acc 0.7150, Val Acc 0.5500, Unique Indices: 98


Training Steps:  94%|█████████▎| 18700/20000 [2:11:47<29:08,  1.35s/it]

Step 18700: Policy Loss 0.0959, Critic Loss 0.2273, Avg Train Acc 0.7125, Val Acc 0.5600, Unique Indices: 100


Training Steps:  94%|█████████▍| 18800/20000 [2:12:31<28:04,  1.40s/it]

Step 18800: Policy Loss 0.0950, Critic Loss 0.1787, Avg Train Acc 0.7188, Val Acc 0.5400, Unique Indices: 99


Training Steps:  94%|█████████▍| 18900/20000 [2:13:12<25:54,  1.41s/it]

Step 18900: Policy Loss 0.0441, Critic Loss 0.0877, Avg Train Acc 0.7250, Val Acc 0.5400, Unique Indices: 99


Training Steps:  95%|█████████▍| 18999/20000 [2:13:51<07:59,  2.09it/s]

Step 19000: Policy Loss 0.1494, Critic Loss 0.2325, Avg Train Acc 0.7175, Val Acc 0.5500, Unique Indices: 99


Training Steps:  96%|█████████▌| 19100/20000 [2:14:43<20:32,  1.37s/it]

Step 19100: Policy Loss 0.0614, Critic Loss 0.0878, Avg Train Acc 0.6713, Val Acc 0.5400, Unique Indices: 100


Training Steps:  96%|█████████▌| 19200/20000 [2:15:26<17:59,  1.35s/it]

Step 19200: Policy Loss 0.0497, Critic Loss 0.2137, Avg Train Acc 0.7137, Val Acc 0.5400, Unique Indices: 97


Training Steps:  96%|█████████▋| 19300/20000 [2:16:07<15:45,  1.35s/it]

Step 19300: Policy Loss 0.1375, Critic Loss 0.2077, Avg Train Acc 0.6937, Val Acc 0.5400, Unique Indices: 99


Training Steps:  97%|█████████▋| 19400/20000 [2:16:49<13:23,  1.34s/it]

Step 19400: Policy Loss -0.0000, Critic Loss 0.2984, Avg Train Acc 0.6800, Val Acc 0.5400, Unique Indices: 98


Training Steps:  98%|█████████▊| 19500/20000 [2:17:30<12:09,  1.46s/it]

Step 19500: Policy Loss 0.0446, Critic Loss 0.3279, Avg Train Acc 0.6687, Val Acc 0.5400, Unique Indices: 99


Training Steps:  98%|█████████▊| 19600/20000 [2:18:13<09:03,  1.36s/it]

Step 19600: Policy Loss 0.1908, Critic Loss 0.2090, Avg Train Acc 0.6850, Val Acc 0.5400, Unique Indices: 100


Training Steps:  98%|█████████▊| 19700/20000 [2:18:53<07:00,  1.40s/it]

Step 19700: Policy Loss 0.0440, Critic Loss 0.0882, Avg Train Acc 0.7175, Val Acc 0.5400, Unique Indices: 100


Training Steps:  99%|█████████▉| 19800/20000 [2:19:35<04:23,  1.32s/it]

Step 19800: Policy Loss 0.1723, Critic Loss 0.1670, Avg Train Acc 0.7013, Val Acc 0.5500, Unique Indices: 99


Training Steps: 100%|█████████▉| 19900/20000 [2:20:15<02:10,  1.30s/it]

Step 19900: Policy Loss 0.0366, Critic Loss 0.3603, Avg Train Acc 0.7050, Val Acc 0.5500, Unique Indices: 100


Training Steps: 100%|█████████▉| 19999/20000 [2:20:53<00:00,  2.74it/s]

Step 20000: Policy Loss 0.0700, Critic Loss 0.0692, Avg Train Acc 0.7400, Val Acc 0.5400, Unique Indices: 100


Training Steps: 100%|██████████| 20000/20000 [2:21:02<00:00,  2.36it/s]
