In [1]:

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import wandb
from sae_lens import SAE
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f56b0f306b0>

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device

In [4]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",
    sae_id="layer_20/width_16k/canonical",
)
sae = sae.to(device)

In [5]:
def select_action(policy_net, observation):
    # Get activated logits from policy_net (B, dict_size)
    logits = policy_net(observation)
    # Add small random noise for diversity
    epsilon = torch.randn_like(logits) * 0.01
    logits_noisy = logits + epsilon
    # Select top-1 value and index from noisy logits
    topk_vals, topk_indices = torch.topk(logits_noisy, k=1, dim=-1)
    # Clamp top values to a minimum of 1 (values below 1 become 1)
    topk_vals = torch.clamp(topk_vals, min=50.0)
    # Create one-hot action vector with the clamped top value at the selected index
    action = torch.zeros_like(logits)
    action.scatter_(1, topk_indices, topk_vals)
    # Compute softmax probabilities from the original logits and get log probability of the chosen index
    probs = torch.softmax(logits, dim=-1)
    chosen_probs = torch.gather(probs, 1, topk_indices)
    log_prob = torch.log(chosen_probs).squeeze(1)
    return action, log_prob

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, latent_dim, dict_size):
        super(PolicyNetwork, self).__init__()
        self.linear = nn.Linear(latent_dim, dict_size)
        self.activation = nn.Tanh()  # using Tanh instead of ReLU
    
    def forward(self, obs):
        logits = self.linear(obs)
        activated = self.activation(logits)
        return activated

In [7]:
class CriticNetwork(nn.Module):
    def __init__(self, latent_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(latent_dim, latent_dim)
        self.act = nn.Tanh()  # using Tanh here
        self.fc2 = nn.Linear(latent_dim, 1)
        
    def forward(self, obs):
        x = self.act(self.fc1(obs))
        value = self.fc2(x)
        return value

In [8]:
def batch_steering_hook(policy_net, sae):
    class SteeringHook:
        def __init__(self, policy_net, sae):
            self.policy_net = policy_net
            self.sae = sae
            self.observation = None   # will be tensor of shape (B, latent_dim)
            self.action = None        # (B,)
            self.log_prob = None      # (B,)

        def __call__(self, module, inputs):
            residual = inputs[0]  # shape: (B, seq_len, hidden_dim)
            observation = residual[:, -1, :]  # (B, latent_dim)
            self.observation = observation.detach()
            action, log_prob = select_action(self.policy_net, observation)
            steering = sae.decode(action)  # shape: (B, hidden_dim)
            self.action = action.detach()
            self.log_prob = log_prob.detach()
            # Add the corresponding steering vector to the last token.
            residual[:, -1, :] = residual[:, -1, :] + steering
            return (residual)
    return SteeringHook(policy_net, sae)

In [9]:
class PPOTrainer:
    def __init__(self, policy, critic, batch_size=8, ppo_clip=0.2, lr=1e-4):
        self.policy = policy
        self.critic = critic
        self.batch_size = batch_size
        self.ppo_clip = ppo_clip
        self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=lr)
        self.optimizer_critic = optim.Adam(self.critic.parameters(), lr=lr)
    
    def compute_advantages(self, rewards, values):
        return rewards - values

    def train_step(self, observations, actions, rewards, old_log_probs):
        # Compute policy outputs.
        mean = self.policy(observations)  # shape: (B, action_dim)
        sigma = torch.ones_like(mean) * 0.1  # fixed sigma
        dist = torch.distributions.Normal(mean, sigma)
        new_log_probs = dist.log_prob(actions).sum(dim=-1)
        
        # Compute critic outputs.
        values = self.critic(observations).squeeze(-1)
        
        advantages = rewards - values  # simple advantage
        ratio = torch.exp(new_log_probs - old_log_probs)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) * advantages
        policy_loss = -torch.min(surr1, surr2).mean()
        
        critic_loss = nn.MSELoss()(values, rewards)
        
        # Combine losses to perform a single backward pass.
        total_loss = policy_loss + critic_loss
        
        self.optimizer_policy.zero_grad()
        self.optimizer_critic.zero_grad()
        total_loss.backward()
        self.optimizer_policy.step()
        self.optimizer_critic.step()
        
        return policy_loss.item(), critic_loss.item()

In [10]:
class MMLUDataLoader:
    def __init__(self, dataset, split, limit=None):
        self.data = dataset[split]
        if limit is not None:
            self.data = self.data.select(range(limit))
        self.index = 0
        self.n_samples = len(self.data)
        
    def get_batch(self, batch_size):
        batch = []
        for _ in range(batch_size):
            sample = self.data[self.index % self.n_samples]
            self.index += 1
            batch.append({
                "question": sample["question"],
                "choices": sample["choices"],
                "answer": sample["answer"]
            })
        return batch

In [11]:
mmlu_dataset = load_dataset("cais/mmlu", "all")
train_loader = MMLUDataLoader(mmlu_dataset, split="auxiliary_train")
val_loader = MMLUDataLoader(mmlu_dataset, split="validation", limit=100)

In [12]:
LATENT_DIM = 2304   # Gemma's latent dimension (from layer 20)
DICT_SIZE = 16384   # SAE dictionary size

In [13]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
llm = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", output_hidden_states=True).to(device)
llm.eval()

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemm

In [14]:
policy_net = PolicyNetwork(LATENT_DIM, DICT_SIZE).to(device)
critic_net = CriticNetwork(LATENT_DIM).to(device)
ppo_trainer = PPOTrainer(policy_net, critic_net, batch_size=8, ppo_clip=0.2, lr=1e-4)

In [15]:
from tqdm import tqdm
import os

def train(num_steps=20000, validate_every=100, checkpoint_every=200, checkpoint_dir="./checkpoints"):
    wandb.init(project="gemma_mmlu_ppo")
    os.makedirs(checkpoint_dir, exist_ok=True)
    best_val_accuracy = 0.0
    train_acc_sum = 0.0
    train_acc_count = 0
    for step in tqdm(range(1, num_steps + 1), desc="Training Steps"):
        batch = train_loader.get_batch(ppo_trainer.batch_size)
        prompts, correct_answers = [], []
        for sample in batch:
            question = sample["question"]
            choices = sample["choices"]
            if isinstance(sample["answer"], int):
                correct_answer = chr(65 + sample["answer"])
            else:
                correct_answer = sample["answer"].strip().upper()
            prompt = (question + "\n" +
                      "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                      "\nChoose one of the following options only: A, B, C, or D" +
                      "\nAnswer:")
            prompts.append(prompt)
            correct_answers.append(correct_answer)
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
        steering_hook = batch_steering_hook(policy_net, sae)
        hook_handle = llm.model.layers[20].register_forward_pre_hook(steering_hook)
        generated_ids = llm.generate(inputs, max_new_tokens=1)
        hook_handle.remove()
        batch_rewards = []
        for i in range(generated_ids.shape[0]):
            gen_tok = tokenizer.decode(generated_ids[i, -1]).strip()
            predicted_label = gen_tok[0].upper() if gen_tok else ""
            reward = 1 if predicted_label == correct_answers[i] else 0
            batch_rewards.append(reward)
        obs_batch = steering_hook.observation
        action_batch = steering_hook.action
        log_prob_batch = steering_hook.log_prob
        rewards_tensor = torch.tensor(batch_rewards, dtype=torch.float32, device=device)
        nonzero_vals = action_batch[action_batch != 0]
        avg_activation = nonzero_vals.mean().item() if nonzero_vals.numel() > 0 else 0.0
        policy_loss, critic_loss = ppo_trainer.train_step(obs_batch, action_batch, rewards_tensor, log_prob_batch)
        train_accuracy = sum(batch_rewards) / len(batch_rewards)
        train_acc_sum += train_accuracy
        train_acc_count += 1
        wandb.log({"step": step, "policy_loss": policy_loss, "critic_loss": critic_loss, 
                   "train_accuracy": train_accuracy, "avg_activation": avg_activation})
        if step % validate_every == 0:
            used_indices = set()  # clear used_indices each validation step
            val_batch = val_loader.get_batch(val_loader.n_samples)
            val_prompts, correct_answers_val = [], []
            for sample in val_batch:
                question = sample["question"]
                choices = sample["choices"]
                if isinstance(sample["answer"], int):
                    correct_answer = chr(65 + sample["answer"])
                else:
                    correct_answer = sample["answer"].strip().upper()
                prompt = (question + "\n" +
                          "\n".join(f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)) +
                          "\nChoose one of the following options only: A, B, C, or D" +
                          "\nAnswer:")
                val_prompts.append(prompt)
                correct_answers_val.append(correct_answer)
            inputs_val = tokenizer(val_prompts, return_tensors="pt", padding=True, truncation=True).input_ids.to(device)
            steering_hook_val = batch_steering_hook(policy_net, sae)
            hook_handle_val = llm.model.layers[20].register_forward_pre_hook(steering_hook_val)
            generated_ids_val = llm.generate(inputs_val, max_new_tokens=1)
            hook_handle_val.remove()
            total_correct = 0
            val_action_batch = steering_hook_val.action
            topk_vals_val, topk_indices_val = torch.topk(val_action_batch, k=1, dim=-1)
            for idx in topk_indices_val.view(-1).tolist():
                used_indices.add(idx)
            for i in range(generated_ids_val.shape[0]):
                gen_tok = tokenizer.decode(generated_ids_val[i, -1]).strip()
                predicted_label = gen_tok[0].upper() if gen_tok else ""
                if predicted_label == correct_answers_val[i]:
                    total_correct += 1
            val_accuracy = total_correct / len(val_batch)
            avg_train_accuracy = train_acc_sum / train_acc_count if train_acc_count > 0 else 0
            wandb.log({"step": step, "val_accuracy": val_accuracy, "avg_train_accuracy": avg_train_accuracy, 
                       "unique_indices": len(used_indices)})
            print(f"Step {step}: Policy Loss {policy_loss:.4f}, Critic Loss {critic_loss:.4f}, "
                  f"Avg Train Acc {avg_train_accuracy:.4f}, Val Acc {val_accuracy:.4f}, "
                  f"Unique Indices: {len(used_indices)}")
            train_acc_sum, train_acc_count = 0.0, 0
        if step % checkpoint_every == 0:
            checkpoint_path = os.path.join(
                checkpoint_dir,
                f"gemma-2-2b_layer20_ppo_lr{ppo_trainer.optimizer_policy.defaults['lr']}_batch{ppo_trainer.batch_size}_step{step}.pt"
            )
            torch.save({
                'step': step,
                'policy_state_dict': policy_net.state_dict(),
                'critic_state_dict': critic_net.state_dict(),
                'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
            }, checkpoint_path)
            if step % validate_every == 0 and val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                best_checkpoint_path = os.path.join(checkpoint_dir, "best_policy.pt")
                torch.save({
                    'step': step,
                    'policy_state_dict': policy_net.state_dict(),
                    'critic_state_dict': critic_net.state_dict(),
                    'optimizer_policy_state_dict': ppo_trainer.optimizer_policy.state_dict(),
                    'optimizer_critic_state_dict': ppo_trainer.optimizer_critic.state_dict()
                }, best_checkpoint_path)


In [16]:
train()

[34m[1mwandb[0m: Currently logged in as: [33mseonglae[0m ([33mtexonom[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Training Steps:   0%|          | 0/20000 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
Training Steps:   0%|          | 100/20000 [01:17<9:04:38,  1.64s/it]

Step 100: Policy Loss 0.0409, Critic Loss 0.3624, Avg Train Acc 0.3825, Val Acc 0.4800, Unique Indices: 100


Training Steps:   1%|          | 199/20000 [02:00<58:08,  5.68it/s]  

Step 200: Policy Loss -0.0000, Critic Loss 0.0144, Avg Train Acc 0.6550, Val Acc 0.4900, Unique Indices: 98


Training Steps:   2%|▏         | 300/20000 [02:38<6:23:43,  1.17s/it] 

Step 300: Policy Loss 0.1686, Critic Loss 0.1412, Avg Train Acc 0.7300, Val Acc 0.4800, Unique Indices: 99


Training Steps:   2%|▏         | 399/20000 [03:18<3:31:29,  1.54it/s]

Step 400: Policy Loss 0.0563, Critic Loss 0.0481, Avg Train Acc 0.7937, Val Acc 0.4800, Unique Indices: 98


Training Steps:   2%|▎         | 500/20000 [04:33<8:29:47,  1.57s/it] 

Step 500: Policy Loss -0.0000, Critic Loss 0.0733, Avg Train Acc 0.8938, Val Acc 0.4800, Unique Indices: 99


Training Steps:   3%|▎         | 599/20000 [05:12<1:08:47,  4.70it/s]

Step 600: Policy Loss -0.0000, Critic Loss 0.0716, Avg Train Acc 0.8612, Val Acc 0.5000, Unique Indices: 99


Training Steps:   4%|▎         | 700/20000 [06:58<11:13:19,  2.09s/it]

Step 700: Policy Loss 0.0097, Critic Loss 0.3348, Avg Train Acc 0.7013, Val Acc 0.4800, Unique Indices: 100


Training Steps:   4%|▍         | 799/20000 [08:34<4:56:35,  1.08it/s] 

Step 800: Policy Loss 0.0429, Critic Loss 0.0936, Avg Train Acc 0.7100, Val Acc 0.4900, Unique Indices: 97


Training Steps:   4%|▍         | 900/20000 [10:17<9:48:30,  1.85s/it] 

Step 900: Policy Loss 0.1717, Critic Loss 0.1383, Avg Train Acc 0.7025, Val Acc 0.4900, Unique Indices: 99


Training Steps:   5%|▍         | 999/20000 [11:50<4:45:29,  1.11it/s]

Step 1000: Policy Loss 0.1444, Critic Loss 0.1196, Avg Train Acc 0.7238, Val Acc 0.4900, Unique Indices: 99


Training Steps:   6%|▌         | 1100/20000 [14:17<9:49:11,  1.87s/it] 

Step 1100: Policy Loss 0.1838, Critic Loss 0.3074, Avg Train Acc 0.6787, Val Acc 0.4900, Unique Indices: 99


Training Steps:   6%|▌         | 1199/20000 [15:47<5:06:54,  1.02it/s]

Step 1200: Policy Loss 0.0298, Critic Loss 0.1865, Avg Train Acc 0.7200, Val Acc 0.5100, Unique Indices: 97


Training Steps:   6%|▋         | 1300/20000 [17:38<9:51:52,  1.90s/it] 

Step 1300: Policy Loss 0.0414, Critic Loss 0.3279, Avg Train Acc 0.6863, Val Acc 0.5000, Unique Indices: 99


Training Steps:   7%|▋         | 1399/20000 [19:15<5:05:49,  1.01it/s]

Step 1400: Policy Loss 0.0871, Critic Loss 0.2009, Avg Train Acc 0.6925, Val Acc 0.4900, Unique Indices: 99


Training Steps:   8%|▊         | 1500/20000 [21:04<10:37:24,  2.07s/it]

Step 1500: Policy Loss 0.0243, Critic Loss 0.2891, Avg Train Acc 0.6937, Val Acc 0.5000, Unique Indices: 100


Training Steps:   8%|▊         | 1599/20000 [22:41<5:16:15,  1.03s/it] 

Step 1600: Policy Loss 0.2449, Critic Loss 0.2176, Avg Train Acc 0.6875, Val Acc 0.5000, Unique Indices: 97


Training Steps:   8%|▊         | 1700/20000 [24:38<9:55:42,  1.95s/it] 

Step 1700: Policy Loss 0.1278, Critic Loss 0.2253, Avg Train Acc 0.6987, Val Acc 0.4900, Unique Indices: 99


Training Steps:   9%|▉         | 1799/20000 [26:18<5:10:42,  1.02s/it]

Step 1800: Policy Loss 0.0486, Critic Loss 0.2593, Avg Train Acc 0.7050, Val Acc 0.5000, Unique Indices: 100


Training Steps:  10%|▉         | 1900/20000 [28:06<9:33:34,  1.90s/it] 

Step 1900: Policy Loss 0.0580, Critic Loss 0.0870, Avg Train Acc 0.7113, Val Acc 0.4900, Unique Indices: 100


Training Steps:  10%|▉         | 1999/20000 [29:42<5:04:17,  1.01s/it]

Step 2000: Policy Loss 0.1487, Critic Loss 0.1989, Avg Train Acc 0.6975, Val Acc 0.4800, Unique Indices: 97


Training Steps:  10%|█         | 2100/20000 [31:42<9:20:54,  1.88s/it] 

Step 2100: Policy Loss 0.1816, Critic Loss 0.1425, Avg Train Acc 0.7150, Val Acc 0.5000, Unique Indices: 99


Training Steps:  11%|█         | 2199/20000 [33:20<4:41:21,  1.05it/s]

Step 2200: Policy Loss 0.0984, Critic Loss 0.1764, Avg Train Acc 0.7075, Val Acc 0.5100, Unique Indices: 98


Training Steps:  12%|█▏        | 2300/20000 [35:13<9:10:36,  1.87s/it] 

Step 2300: Policy Loss 0.1493, Critic Loss 0.2851, Avg Train Acc 0.7050, Val Acc 0.4800, Unique Indices: 99


Training Steps:  12%|█▏        | 2399/20000 [36:50<7:32:34,  1.54s/it]

Step 2400: Policy Loss 0.1623, Critic Loss 0.1573, Avg Train Acc 0.6700, Val Acc 0.4900, Unique Indices: 100


Training Steps:  12%|█▎        | 2500/20000 [38:36<10:41:23,  2.20s/it]

Step 2500: Policy Loss 0.0960, Critic Loss 0.1678, Avg Train Acc 0.6925, Val Acc 0.5200, Unique Indices: 100


Training Steps:  13%|█▎        | 2599/20000 [40:15<5:08:48,  1.06s/it] 

Step 2600: Policy Loss -0.0000, Critic Loss 0.0685, Avg Train Acc 0.6737, Val Acc 0.5000, Unique Indices: 97


Training Steps:  14%|█▎        | 2700/20000 [42:05<9:43:53,  2.03s/it] 

Step 2700: Policy Loss 0.0116, Critic Loss 0.0228, Avg Train Acc 0.7075, Val Acc 0.5000, Unique Indices: 100


Training Steps:  14%|█▍        | 2799/20000 [43:42<4:45:46,  1.00it/s]

Step 2800: Policy Loss 0.0662, Critic Loss 0.1831, Avg Train Acc 0.6750, Val Acc 0.4900, Unique Indices: 99


Training Steps:  14%|█▍        | 2900/20000 [45:28<10:21:57,  2.18s/it]

Step 2900: Policy Loss 0.0715, Critic Loss 0.1299, Avg Train Acc 0.7013, Val Acc 0.4800, Unique Indices: 100


Training Steps:  15%|█▍        | 2999/20000 [46:59<4:56:17,  1.05s/it] 

Step 3000: Policy Loss 0.1003, Critic Loss 0.2673, Avg Train Acc 0.6825, Val Acc 0.4700, Unique Indices: 99


Training Steps:  16%|█▌        | 3100/20000 [50:01<9:53:30,  2.11s/it]  

Step 3100: Policy Loss 0.2471, Critic Loss 0.2874, Avg Train Acc 0.6963, Val Acc 0.4900, Unique Indices: 99


Training Steps:  16%|█▌        | 3199/20000 [51:36<5:07:56,  1.10s/it]

Step 3200: Policy Loss 0.1095, Critic Loss 0.1565, Avg Train Acc 0.7200, Val Acc 0.5000, Unique Indices: 99


Training Steps:  16%|█▋        | 3300/20000 [53:27<9:15:07,  1.99s/it] 

Step 3300: Policy Loss 0.0634, Critic Loss 0.1897, Avg Train Acc 0.6663, Val Acc 0.5000, Unique Indices: 99


Training Steps:  17%|█▋        | 3399/20000 [55:01<4:13:40,  1.09it/s]

Step 3400: Policy Loss 0.1591, Critic Loss 0.1481, Avg Train Acc 0.6600, Val Acc 0.4900, Unique Indices: 99


Training Steps:  18%|█▊        | 3500/20000 [56:50<8:52:32,  1.94s/it] 

Step 3500: Policy Loss 0.1417, Critic Loss 0.3130, Avg Train Acc 0.7050, Val Acc 0.4900, Unique Indices: 100


Training Steps:  18%|█▊        | 3599/20000 [58:25<4:55:25,  1.08s/it]

Step 3600: Policy Loss 0.1367, Critic Loss 0.3177, Avg Train Acc 0.6637, Val Acc 0.4800, Unique Indices: 99


Training Steps:  18%|█▊        | 3700/20000 [1:00:13<8:45:22,  1.93s/it]

Step 3700: Policy Loss 0.2872, Critic Loss 0.3500, Avg Train Acc 0.6863, Val Acc 0.4900, Unique Indices: 98


Training Steps:  19%|█▉        | 3799/20000 [1:01:48<4:45:25,  1.06s/it]

Step 3800: Policy Loss 0.0854, Critic Loss 0.1482, Avg Train Acc 0.6763, Val Acc 0.5100, Unique Indices: 99


Training Steps:  20%|█▉        | 3900/20000 [1:03:36<9:44:17,  2.18s/it] 

Step 3900: Policy Loss 0.0892, Critic Loss 0.1526, Avg Train Acc 0.6963, Val Acc 0.5100, Unique Indices: 100


Training Steps:  20%|█▉        | 3999/20000 [1:05:14<4:27:36,  1.00s/it]

Step 4000: Policy Loss 0.0153, Critic Loss 0.1563, Avg Train Acc 0.6837, Val Acc 0.4800, Unique Indices: 98


Training Steps:  20%|██        | 4100/20000 [1:07:59<8:36:20,  1.95s/it] 

Step 4100: Policy Loss 0.0428, Critic Loss 0.3077, Avg Train Acc 0.7200, Val Acc 0.4900, Unique Indices: 100


Training Steps:  21%|██        | 4199/20000 [1:09:41<4:58:17,  1.13s/it]

Step 4200: Policy Loss -0.0000, Critic Loss 0.0732, Avg Train Acc 0.7212, Val Acc 0.4800, Unique Indices: 98


Training Steps:  22%|██▏       | 4300/20000 [1:11:31<8:35:45,  1.97s/it] 

Step 4300: Policy Loss 0.2404, Critic Loss 0.2043, Avg Train Acc 0.6837, Val Acc 0.4800, Unique Indices: 98


Training Steps:  22%|██▏       | 4399/20000 [1:13:09<4:05:24,  1.06it/s]

Step 4400: Policy Loss 0.0314, Critic Loss 0.1385, Avg Train Acc 0.6787, Val Acc 0.4900, Unique Indices: 98


Training Steps:  22%|██▎       | 4500/20000 [1:14:59<8:18:59,  1.93s/it] 

Step 4500: Policy Loss 0.0431, Critic Loss 0.2318, Avg Train Acc 0.6950, Val Acc 0.4800, Unique Indices: 99


Training Steps:  23%|██▎       | 4599/20000 [1:16:33<3:54:22,  1.10it/s]

Step 4600: Policy Loss -0.0000, Critic Loss 0.2268, Avg Train Acc 0.6875, Val Acc 0.4800, Unique Indices: 98


Training Steps:  24%|██▎       | 4700/20000 [1:18:23<8:04:11,  1.90s/it] 

Step 4700: Policy Loss 0.0623, Critic Loss 0.1435, Avg Train Acc 0.6813, Val Acc 0.4900, Unique Indices: 100


Training Steps:  24%|██▍       | 4799/20000 [1:19:54<3:51:41,  1.09it/s]

Step 4800: Policy Loss 0.1264, Critic Loss 0.2884, Avg Train Acc 0.6775, Val Acc 0.4900, Unique Indices: 99


Training Steps:  24%|██▍       | 4900/20000 [1:21:49<7:41:27,  1.83s/it] 

Step 4900: Policy Loss 0.0505, Critic Loss 0.2704, Avg Train Acc 0.6950, Val Acc 0.5000, Unique Indices: 100


Training Steps:  25%|██▍       | 4999/20000 [1:23:27<4:41:35,  1.13s/it]

Step 5000: Policy Loss 0.0911, Critic Loss 0.1359, Avg Train Acc 0.6675, Val Acc 0.4900, Unique Indices: 100


Training Steps:  26%|██▌       | 5100/20000 [1:25:16<8:03:30,  1.95s/it] 

Step 5100: Policy Loss 0.0832, Critic Loss 0.2428, Avg Train Acc 0.6913, Val Acc 0.4900, Unique Indices: 100


Training Steps:  26%|██▌       | 5199/20000 [1:26:53<4:46:39,  1.16s/it]

Step 5200: Policy Loss 0.0131, Critic Loss 0.1437, Avg Train Acc 0.7150, Val Acc 0.5000, Unique Indices: 99


Training Steps:  26%|██▋       | 5300/20000 [1:28:45<7:54:04,  1.93s/it] 

Step 5300: Policy Loss 0.0834, Critic Loss 0.1849, Avg Train Acc 0.6775, Val Acc 0.4800, Unique Indices: 96


Training Steps:  27%|██▋       | 5399/20000 [1:30:22<3:44:00,  1.09it/s]

Step 5400: Policy Loss 0.1039, Critic Loss 0.1672, Avg Train Acc 0.6900, Val Acc 0.4900, Unique Indices: 99


Training Steps:  28%|██▊       | 5500/20000 [1:32:10<7:11:29,  1.79s/it] 

Step 5500: Policy Loss 0.0338, Critic Loss 0.1131, Avg Train Acc 0.6737, Val Acc 0.4900, Unique Indices: 99


Training Steps:  28%|██▊       | 5599/20000 [1:33:43<3:38:00,  1.10it/s]

Step 5600: Policy Loss 0.2138, Critic Loss 0.3193, Avg Train Acc 0.6837, Val Acc 0.4900, Unique Indices: 100


Training Steps:  28%|██▊       | 5700/20000 [1:35:31<7:49:25,  1.97s/it] 

Step 5700: Policy Loss 0.0786, Critic Loss 0.1774, Avg Train Acc 0.7125, Val Acc 0.5000, Unique Indices: 100


Training Steps:  29%|██▉       | 5799/20000 [1:37:04<3:18:28,  1.19it/s]

Step 5800: Policy Loss 0.2222, Critic Loss 0.2353, Avg Train Acc 0.6937, Val Acc 0.4800, Unique Indices: 97


Training Steps:  30%|██▉       | 5900/20000 [1:39:25<7:42:02,  1.97s/it] 

Step 5900: Policy Loss 0.0459, Critic Loss 0.1550, Avg Train Acc 0.7087, Val Acc 0.4900, Unique Indices: 99


Training Steps:  30%|██▉       | 5999/20000 [1:40:58<3:28:17,  1.12it/s]

Step 6000: Policy Loss 0.1404, Critic Loss 0.0764, Avg Train Acc 0.6737, Val Acc 0.4900, Unique Indices: 100


Training Steps:  30%|███       | 6100/20000 [1:42:49<8:24:37,  2.18s/it] 

Step 6100: Policy Loss 0.0440, Critic Loss 0.1775, Avg Train Acc 0.6600, Val Acc 0.4900, Unique Indices: 100


Training Steps:  31%|███       | 6199/20000 [1:44:27<3:25:57,  1.12it/s]

Step 6200: Policy Loss 0.1194, Critic Loss 0.1678, Avg Train Acc 0.6775, Val Acc 0.4800, Unique Indices: 99


Training Steps:  32%|███▏      | 6300/20000 [1:46:21<7:34:38,  1.99s/it] 

Step 6300: Policy Loss 0.2318, Critic Loss 0.1854, Avg Train Acc 0.6887, Val Acc 0.4900, Unique Indices: 98


Training Steps:  32%|███▏      | 6399/20000 [1:47:54<3:49:56,  1.01s/it]

Step 6400: Policy Loss 0.1072, Critic Loss 0.1416, Avg Train Acc 0.6725, Val Acc 0.4800, Unique Indices: 97


Training Steps:  32%|███▎      | 6500/20000 [1:49:38<6:50:23,  1.82s/it] 

Step 6500: Policy Loss 0.1667, Critic Loss 0.1496, Avg Train Acc 0.6837, Val Acc 0.5000, Unique Indices: 97


Training Steps:  33%|███▎      | 6599/20000 [1:51:18<3:39:51,  1.02it/s]

Step 6600: Policy Loss -0.0000, Critic Loss 0.1517, Avg Train Acc 0.6763, Val Acc 0.5000, Unique Indices: 100


Training Steps:  34%|███▎      | 6700/20000 [1:53:02<6:54:32,  1.87s/it] 

Step 6700: Policy Loss 0.0350, Critic Loss 0.2111, Avg Train Acc 0.6825, Val Acc 0.4700, Unique Indices: 100


Training Steps:  34%|███▍      | 6799/20000 [1:54:40<3:56:30,  1.07s/it]

Step 6800: Policy Loss 0.1935, Critic Loss 0.2566, Avg Train Acc 0.6787, Val Acc 0.5000, Unique Indices: 97


Training Steps:  34%|███▍      | 6900/20000 [1:56:27<7:04:07,  1.94s/it] 

Step 6900: Policy Loss 0.1398, Critic Loss 0.1663, Avg Train Acc 0.7125, Val Acc 0.4700, Unique Indices: 99


Training Steps:  35%|███▍      | 6999/20000 [1:57:57<3:51:10,  1.07s/it]

Step 7000: Policy Loss 0.1150, Critic Loss 0.1539, Avg Train Acc 0.6987, Val Acc 0.4900, Unique Indices: 99


Training Steps:  36%|███▌      | 7100/20000 [1:59:50<7:02:29,  1.97s/it] 

Step 7100: Policy Loss 0.0673, Critic Loss 0.1793, Avg Train Acc 0.6887, Val Acc 0.5000, Unique Indices: 100


Training Steps:  36%|███▌      | 7199/20000 [2:01:19<3:07:45,  1.14it/s]

Step 7200: Policy Loss 0.1192, Critic Loss 0.1698, Avg Train Acc 0.6937, Val Acc 0.4800, Unique Indices: 98


Training Steps:  36%|███▋      | 7300/20000 [2:03:13<8:10:22,  2.32s/it] 

Step 7300: Policy Loss 0.1126, Critic Loss 0.1687, Avg Train Acc 0.6937, Val Acc 0.4900, Unique Indices: 98


Training Steps:  37%|███▋      | 7399/20000 [2:04:54<3:27:08,  1.01it/s]

Step 7400: Policy Loss 0.0656, Critic Loss 0.1251, Avg Train Acc 0.6950, Val Acc 0.5000, Unique Indices: 99


Training Steps:  38%|███▊      | 7500/20000 [2:06:37<6:30:11,  1.87s/it] 

Step 7500: Policy Loss 0.0661, Critic Loss 0.0942, Avg Train Acc 0.7137, Val Acc 0.5100, Unique Indices: 99


Training Steps:  38%|███▊      | 7599/20000 [2:08:13<3:23:39,  1.01it/s]

Step 7600: Policy Loss 0.0494, Critic Loss 0.1257, Avg Train Acc 0.7075, Val Acc 0.5100, Unique Indices: 98


Training Steps:  38%|███▊      | 7700/20000 [2:10:04<6:43:42,  1.97s/it] 

Step 7700: Policy Loss 0.0127, Critic Loss 0.0724, Avg Train Acc 0.7025, Val Acc 0.4900, Unique Indices: 99


Training Steps:  39%|███▉      | 7799/20000 [2:11:40<3:24:46,  1.01s/it]

Step 7800: Policy Loss 0.1050, Critic Loss 0.2521, Avg Train Acc 0.7137, Val Acc 0.4800, Unique Indices: 98


Training Steps:  40%|███▉      | 7900/20000 [2:13:26<6:14:24,  1.86s/it] 

Step 7900: Policy Loss 0.0939, Critic Loss 0.1716, Avg Train Acc 0.6663, Val Acc 0.4800, Unique Indices: 99


Training Steps:  40%|███▉      | 7999/20000 [2:15:02<3:09:56,  1.05it/s]

Step 8000: Policy Loss 0.0724, Critic Loss 0.0896, Avg Train Acc 0.6475, Val Acc 0.4900, Unique Indices: 98


Training Steps:  40%|████      | 8100/20000 [2:17:00<6:34:45,  1.99s/it] 

Step 8100: Policy Loss 0.0082, Critic Loss 0.2935, Avg Train Acc 0.6750, Val Acc 0.5000, Unique Indices: 100


Training Steps:  41%|████      | 8199/20000 [2:18:35<3:08:43,  1.04it/s]

Step 8200: Policy Loss 0.0575, Critic Loss 0.1676, Avg Train Acc 0.7075, Val Acc 0.5000, Unique Indices: 99


Training Steps:  42%|████▏     | 8300/20000 [2:21:14<6:13:35,  1.92s/it] 

Step 8300: Policy Loss 0.0574, Critic Loss 0.1497, Avg Train Acc 0.6813, Val Acc 0.4900, Unique Indices: 99


Training Steps:  42%|████▏     | 8399/20000 [2:22:56<3:01:45,  1.06it/s]

Step 8400: Policy Loss 0.2164, Critic Loss 0.3786, Avg Train Acc 0.6587, Val Acc 0.4800, Unique Indices: 98


Training Steps:  42%|████▎     | 8500/20000 [2:24:19<5:47:40,  1.81s/it] 

Step 8500: Policy Loss 0.0606, Critic Loss 0.1698, Avg Train Acc 0.7188, Val Acc 0.4900, Unique Indices: 96


Training Steps:  43%|████▎     | 8599/20000 [2:25:23<2:13:42,  1.42it/s]

Step 8600: Policy Loss 0.2454, Critic Loss 0.2244, Avg Train Acc 0.7538, Val Acc 0.5100, Unique Indices: 99


Training Steps:  44%|████▎     | 8700/20000 [2:26:47<5:20:54,  1.70s/it] 

Step 8700: Policy Loss 0.1238, Critic Loss 0.2144, Avg Train Acc 0.7412, Val Acc 0.4900, Unique Indices: 99


Training Steps:  44%|████▍     | 8799/20000 [2:27:58<1:46:36,  1.75it/s]

Step 8800: Policy Loss 0.1063, Critic Loss 0.0946, Avg Train Acc 0.7400, Val Acc 0.4900, Unique Indices: 100


Training Steps:  44%|████▍     | 8900/20000 [2:29:18<4:51:50,  1.58s/it]

Step 8900: Policy Loss 0.0950, Critic Loss 0.3161, Avg Train Acc 0.7300, Val Acc 0.5100, Unique Indices: 99


Training Steps:  45%|████▍     | 8999/20000 [2:30:23<1:54:41,  1.60it/s]

Step 9000: Policy Loss 0.0959, Critic Loss 0.1684, Avg Train Acc 0.7225, Val Acc 0.4900, Unique Indices: 100


Training Steps:  46%|████▌     | 9100/20000 [2:31:42<5:14:59,  1.73s/it]

Step 9100: Policy Loss 0.1490, Critic Loss 0.1205, Avg Train Acc 0.7338, Val Acc 0.4900, Unique Indices: 99


Training Steps:  46%|████▌     | 9199/20000 [2:32:49<2:28:16,  1.21it/s]

Step 9200: Policy Loss 0.1656, Critic Loss 0.2404, Avg Train Acc 0.7525, Val Acc 0.4900, Unique Indices: 100


Training Steps:  46%|████▋     | 9300/20000 [2:34:10<4:35:45,  1.55s/it]

Step 9300: Policy Loss 0.0072, Critic Loss 0.1284, Avg Train Acc 0.7250, Val Acc 0.4900, Unique Indices: 100


Training Steps:  47%|████▋     | 9399/20000 [2:35:17<2:05:17,  1.41it/s]

Step 9400: Policy Loss 0.0710, Critic Loss 0.1974, Avg Train Acc 0.7338, Val Acc 0.4800, Unique Indices: 100


Training Steps:  48%|████▊     | 9500/20000 [2:36:35<4:40:59,  1.61s/it]

Step 9500: Policy Loss 0.1649, Critic Loss 0.1525, Avg Train Acc 0.7150, Val Acc 0.5000, Unique Indices: 100


Training Steps:  48%|████▊     | 9599/20000 [2:37:40<1:48:02,  1.60it/s]

Step 9600: Policy Loss 0.0939, Critic Loss 0.2533, Avg Train Acc 0.7238, Val Acc 0.5000, Unique Indices: 97


Training Steps:  48%|████▊     | 9700/20000 [2:39:38<4:48:54,  1.68s/it] 

Step 9700: Policy Loss 0.0221, Critic Loss 0.2864, Avg Train Acc 0.7113, Val Acc 0.5000, Unique Indices: 97


Training Steps:  49%|████▉     | 9799/20000 [2:40:45<2:04:00,  1.37it/s]

Step 9800: Policy Loss 0.2416, Critic Loss 0.2772, Avg Train Acc 0.7350, Val Acc 0.4900, Unique Indices: 98


Training Steps:  50%|████▉     | 9900/20000 [2:42:03<4:27:39,  1.59s/it]

Step 9900: Policy Loss -0.0000, Critic Loss 0.0801, Avg Train Acc 0.7200, Val Acc 0.4800, Unique Indices: 100


Training Steps:  50%|████▉     | 9999/20000 [2:43:14<1:46:15,  1.57it/s]

Step 10000: Policy Loss 0.0452, Critic Loss 0.1398, Avg Train Acc 0.7175, Val Acc 0.5000, Unique Indices: 99


Training Steps:  50%|█████     | 10100/20000 [2:44:32<4:38:04,  1.69s/it]

Step 10100: Policy Loss 0.0681, Critic Loss 0.3298, Avg Train Acc 0.7512, Val Acc 0.4800, Unique Indices: 100


Training Steps:  51%|█████     | 10199/20000 [2:45:38<1:54:01,  1.43it/s]

Step 10200: Policy Loss -0.0000, Critic Loss 0.0465, Avg Train Acc 0.7350, Val Acc 0.5000, Unique Indices: 98


Training Steps:  52%|█████▏    | 10300/20000 [2:47:07<4:14:07,  1.57s/it] 

Step 10300: Policy Loss 0.0567, Critic Loss 0.0788, Avg Train Acc 0.7525, Val Acc 0.4900, Unique Indices: 100


Training Steps:  52%|█████▏    | 10399/20000 [2:48:13<1:46:27,  1.50it/s]

Step 10400: Policy Loss -0.0000, Critic Loss 0.0585, Avg Train Acc 0.7375, Val Acc 0.4700, Unique Indices: 99


Training Steps:  52%|█████▎    | 10500/20000 [2:49:32<4:17:24,  1.63s/it]

Step 10500: Policy Loss 0.0425, Critic Loss 0.0401, Avg Train Acc 0.7075, Val Acc 0.4800, Unique Indices: 98


Training Steps:  53%|█████▎    | 10599/20000 [2:50:38<1:29:32,  1.75it/s]

Step 10600: Policy Loss 0.1629, Critic Loss 0.2134, Avg Train Acc 0.7150, Val Acc 0.5100, Unique Indices: 99


Training Steps:  54%|█████▎    | 10700/20000 [2:52:40<4:19:13,  1.67s/it] 

Step 10700: Policy Loss 0.1158, Critic Loss 0.1821, Avg Train Acc 0.7575, Val Acc 0.4900, Unique Indices: 98


Training Steps:  54%|█████▍    | 10799/20000 [2:53:47<1:35:37,  1.60it/s]

Step 10800: Policy Loss 0.1188, Critic Loss 0.1750, Avg Train Acc 0.7412, Val Acc 0.5000, Unique Indices: 100


Training Steps:  55%|█████▍    | 10900/20000 [2:55:07<4:12:26,  1.66s/it]

Step 10900: Policy Loss 0.0512, Critic Loss 0.1027, Avg Train Acc 0.7525, Val Acc 0.5000, Unique Indices: 97


Training Steps:  55%|█████▍    | 10999/20000 [2:56:12<1:25:03,  1.76it/s]

Step 11000: Policy Loss 0.0564, Critic Loss 0.1449, Avg Train Acc 0.7300, Val Acc 0.5100, Unique Indices: 99


Training Steps:  56%|█████▌    | 11100/20000 [2:57:28<4:02:28,  1.63s/it]

Step 11100: Policy Loss 0.0181, Critic Loss 0.2069, Avg Train Acc 0.7288, Val Acc 0.5100, Unique Indices: 99


Training Steps:  56%|█████▌    | 11199/20000 [2:58:33<1:32:14,  1.59it/s]

Step 11200: Policy Loss -0.0000, Critic Loss 0.1970, Avg Train Acc 0.7425, Val Acc 0.4800, Unique Indices: 100


Training Steps:  56%|█████▋    | 11300/20000 [2:59:51<4:07:52,  1.71s/it]

Step 11300: Policy Loss 0.0395, Critic Loss 0.1177, Avg Train Acc 0.7275, Val Acc 0.4900, Unique Indices: 99


Training Steps:  57%|█████▋    | 11399/20000 [3:00:57<1:24:57,  1.69it/s]

Step 11400: Policy Loss 0.0496, Critic Loss 0.0716, Avg Train Acc 0.7125, Val Acc 0.4900, Unique Indices: 100


Training Steps:  57%|█████▊    | 11500/20000 [3:02:17<4:00:44,  1.70s/it]

Step 11500: Policy Loss 0.1120, Critic Loss 0.0630, Avg Train Acc 0.7288, Val Acc 0.5000, Unique Indices: 100


Training Steps:  58%|█████▊    | 11599/20000 [3:03:16<30:43,  4.56it/s]  

Step 11600: Policy Loss -0.0000, Critic Loss 0.0909, Avg Train Acc 0.7525, Val Acc 0.4900, Unique Indices: 97


Training Steps:  58%|█████▊    | 11700/20000 [3:03:46<2:42:17,  1.17s/it]

Step 11700: Policy Loss 0.0686, Critic Loss 0.1652, Avg Train Acc 0.8688, Val Acc 0.5000, Unique Indices: 99


Training Steps:  59%|█████▉    | 11799/20000 [3:04:05<24:47,  5.51it/s]  

Step 11800: Policy Loss 0.0381, Critic Loss 0.0749, Avg Train Acc 0.8363, Val Acc 0.5000, Unique Indices: 100


Training Steps:  60%|█████▉    | 11901/20000 [3:04:57<1:50:23,  1.22it/s] 

Step 11900: Policy Loss 0.1057, Critic Loss 0.2013, Avg Train Acc 0.7338, Val Acc 0.5100, Unique Indices: 99


Training Steps:  60%|█████▉    | 11999/20000 [3:05:11<18:55,  7.04it/s]  

Step 12000: Policy Loss -0.0000, Critic Loss 0.3127, Avg Train Acc 0.6250, Val Acc 0.4900, Unique Indices: 99


Training Steps:  61%|██████    | 12101/20000 [3:05:46<1:45:35,  1.25it/s] 

Step 12100: Policy Loss 0.0851, Critic Loss 0.2398, Avg Train Acc 0.6562, Val Acc 0.4800, Unique Indices: 99


Training Steps:  61%|██████    | 12199/20000 [3:06:00<16:15,  8.00it/s]  

Step 12200: Policy Loss 0.0683, Critic Loss 0.4488, Avg Train Acc 0.6713, Val Acc 0.4900, Unique Indices: 98


Training Steps:  62%|██████▏   | 12301/20000 [3:06:25<1:43:07,  1.24it/s]

Step 12300: Policy Loss 0.2795, Critic Loss 0.2671, Avg Train Acc 0.6950, Val Acc 0.5000, Unique Indices: 97


Training Steps:  62%|██████▏   | 12399/20000 [3:06:38<17:28,  7.25it/s]  

Step 12400: Policy Loss 0.1710, Critic Loss 0.1641, Avg Train Acc 0.6725, Val Acc 0.5100, Unique Indices: 98


Training Steps:  62%|██████▎   | 12500/20000 [3:07:15<3:34:33,  1.72s/it]

Step 12500: Policy Loss 0.0323, Critic Loss 0.5051, Avg Train Acc 0.6200, Val Acc 0.5100, Unique Indices: 98


Training Steps:  63%|██████▎   | 12599/20000 [3:08:26<1:40:12,  1.23it/s]

Step 12600: Policy Loss 0.1092, Critic Loss 0.1923, Avg Train Acc 0.3787, Val Acc 0.4900, Unique Indices: 98


Training Steps:  64%|██████▎   | 12701/20000 [3:09:12<1:50:52,  1.10it/s]

Step 12700: Policy Loss 0.2228, Critic Loss 0.2129, Avg Train Acc 0.7438, Val Acc 0.5000, Unique Indices: 97


Training Steps:  64%|██████▍   | 12799/20000 [3:09:33<26:35,  4.51it/s]  

Step 12800: Policy Loss 0.0691, Critic Loss 0.0874, Avg Train Acc 0.7075, Val Acc 0.4900, Unique Indices: 100


Training Steps:  64%|██████▍   | 12900/20000 [3:10:36<3:07:16,  1.58s/it]

Step 12900: Policy Loss 0.0828, Critic Loss 0.2269, Avg Train Acc 0.8413, Val Acc 0.5000, Unique Indices: 98


Training Steps:  65%|██████▍   | 12999/20000 [3:11:42<1:28:03,  1.33it/s]

Step 13000: Policy Loss 0.0282, Critic Loss 0.0210, Avg Train Acc 0.8975, Val Acc 0.5000, Unique Indices: 99


Training Steps:  66%|██████▌   | 13100/20000 [3:13:18<3:09:54,  1.65s/it] 

Step 13100: Policy Loss 0.0926, Critic Loss 0.1449, Avg Train Acc 0.8363, Val Acc 0.4900, Unique Indices: 99


Training Steps:  66%|██████▌   | 13199/20000 [3:14:59<1:49:09,  1.04it/s]

Step 13200: Policy Loss 0.0500, Critic Loss 0.1764, Avg Train Acc 0.6875, Val Acc 0.5000, Unique Indices: 98


Training Steps:  66%|██████▋   | 13300/20000 [3:16:46<3:26:40,  1.85s/it]

Step 13300: Policy Loss 0.1955, Critic Loss 0.3208, Avg Train Acc 0.7075, Val Acc 0.4900, Unique Indices: 100


Training Steps:  67%|██████▋   | 13399/20000 [3:18:15<1:44:03,  1.06it/s]

Step 13400: Policy Loss 0.0591, Critic Loss 0.2625, Avg Train Acc 0.7137, Val Acc 0.4900, Unique Indices: 99


Training Steps:  68%|██████▊   | 13500/20000 [3:20:59<3:25:46,  1.90s/it] 

Step 13500: Policy Loss 0.2044, Critic Loss 0.2245, Avg Train Acc 0.7150, Val Acc 0.5100, Unique Indices: 99


Training Steps:  68%|██████▊   | 13599/20000 [3:22:32<1:34:18,  1.13it/s]

Step 13600: Policy Loss 0.1143, Critic Loss 0.2015, Avg Train Acc 0.6863, Val Acc 0.4900, Unique Indices: 98


Training Steps:  68%|██████▊   | 13700/20000 [3:24:16<3:20:25,  1.91s/it]

Step 13700: Policy Loss -0.0000, Critic Loss 0.3027, Avg Train Acc 0.7150, Val Acc 0.4900, Unique Indices: 97


Training Steps:  69%|██████▉   | 13799/20000 [3:25:48<1:45:21,  1.02s/it]

Step 13800: Policy Loss 0.2983, Critic Loss 0.3283, Avg Train Acc 0.6875, Val Acc 0.5000, Unique Indices: 100


Training Steps:  70%|██████▉   | 13900/20000 [3:28:03<3:19:29,  1.96s/it] 

Step 13900: Policy Loss 0.1376, Critic Loss 0.2458, Avg Train Acc 0.6825, Val Acc 0.5000, Unique Indices: 99


Training Steps:  70%|██████▉   | 13999/20000 [3:29:43<2:18:29,  1.38s/it]

Step 14000: Policy Loss 0.1438, Critic Loss 0.1096, Avg Train Acc 0.7050, Val Acc 0.5000, Unique Indices: 100


Training Steps:  70%|███████   | 14100/20000 [3:31:34<3:17:03,  2.00s/it]

Step 14100: Policy Loss 0.3111, Critic Loss 0.2688, Avg Train Acc 0.6887, Val Acc 0.5100, Unique Indices: 100


Training Steps:  71%|███████   | 14199/20000 [3:33:16<1:40:23,  1.04s/it]

Step 14200: Policy Loss 0.0788, Critic Loss 0.2832, Avg Train Acc 0.6875, Val Acc 0.4800, Unique Indices: 99


Training Steps:  72%|███████▏  | 14300/20000 [3:35:09<3:00:17,  1.90s/it]

Step 14300: Policy Loss 0.0454, Critic Loss 0.1071, Avg Train Acc 0.7163, Val Acc 0.4900, Unique Indices: 100


Training Steps:  72%|███████▏  | 14399/20000 [3:36:44<1:34:58,  1.02s/it]

Step 14400: Policy Loss 0.0482, Critic Loss 0.2193, Avg Train Acc 0.6987, Val Acc 0.4800, Unique Indices: 99


Training Steps:  72%|███████▎  | 14500/20000 [3:38:34<2:51:48,  1.87s/it]

Step 14500: Policy Loss 0.1214, Critic Loss 0.2406, Avg Train Acc 0.7063, Val Acc 0.4900, Unique Indices: 99


Training Steps:  73%|███████▎  | 14599/20000 [3:40:13<1:27:25,  1.03it/s]

Step 14600: Policy Loss 0.0696, Critic Loss 0.3806, Avg Train Acc 0.7225, Val Acc 0.4800, Unique Indices: 99


Training Steps:  74%|███████▎  | 14700/20000 [3:42:07<2:47:05,  1.89s/it]

Step 14700: Policy Loss 0.1094, Critic Loss 0.1150, Avg Train Acc 0.7013, Val Acc 0.4900, Unique Indices: 100


Training Steps:  74%|███████▍  | 14799/20000 [3:43:49<1:29:36,  1.03s/it]

Step 14800: Policy Loss 0.0476, Critic Loss 0.1421, Avg Train Acc 0.7037, Val Acc 0.4900, Unique Indices: 97


Training Steps:  74%|███████▍  | 14900/20000 [3:45:40<2:45:12,  1.94s/it]

Step 14900: Policy Loss 0.0472, Critic Loss 0.2408, Avg Train Acc 0.6763, Val Acc 0.4900, Unique Indices: 100


Training Steps:  75%|███████▍  | 14999/20000 [3:47:13<1:37:28,  1.17s/it]

Step 15000: Policy Loss 0.0723, Critic Loss 0.1873, Avg Train Acc 0.6913, Val Acc 0.4900, Unique Indices: 100


Training Steps:  76%|███████▌  | 15100/20000 [3:49:05<3:04:42,  2.26s/it]

Step 15100: Policy Loss -0.0000, Critic Loss 0.2075, Avg Train Acc 0.6787, Val Acc 0.4900, Unique Indices: 100


Training Steps:  76%|███████▌  | 15199/20000 [3:50:39<1:52:12,  1.40s/it]

Step 15200: Policy Loss 0.0593, Critic Loss 0.1587, Avg Train Acc 0.7013, Val Acc 0.5000, Unique Indices: 98


Training Steps:  76%|███████▋  | 15300/20000 [3:52:25<2:30:57,  1.93s/it]

Step 15300: Policy Loss 0.0769, Critic Loss 0.1217, Avg Train Acc 0.6887, Val Acc 0.4800, Unique Indices: 98


Training Steps:  77%|███████▋  | 15399/20000 [3:53:59<1:08:59,  1.11it/s]

Step 15400: Policy Loss 0.0478, Critic Loss 0.3772, Avg Train Acc 0.7063, Val Acc 0.4800, Unique Indices: 98


Training Steps:  78%|███████▊  | 15500/20000 [3:55:44<2:20:32,  1.87s/it]

Step 15500: Policy Loss 0.0736, Critic Loss 0.0888, Avg Train Acc 0.6775, Val Acc 0.4900, Unique Indices: 100


Training Steps:  78%|███████▊  | 15599/20000 [3:57:26<1:17:52,  1.06s/it]

Step 15600: Policy Loss 0.0238, Critic Loss 0.0740, Avg Train Acc 0.6913, Val Acc 0.5000, Unique Indices: 100


Training Steps:  78%|███████▊  | 15700/20000 [3:59:13<2:14:00,  1.87s/it]

Step 15700: Policy Loss 0.0181, Critic Loss 0.0545, Avg Train Acc 0.7300, Val Acc 0.4900, Unique Indices: 98


Training Steps:  79%|███████▉  | 15799/20000 [4:00:52<1:19:18,  1.13s/it]

Step 15800: Policy Loss 0.1557, Critic Loss 0.2199, Avg Train Acc 0.6613, Val Acc 0.5200, Unique Indices: 99


Training Steps:  80%|███████▉  | 15900/20000 [4:02:44<2:23:52,  2.11s/it]

Step 15900: Policy Loss 0.1873, Critic Loss 0.2610, Avg Train Acc 0.6675, Val Acc 0.5000, Unique Indices: 98


Training Steps:  80%|███████▉  | 15999/20000 [4:04:25<1:08:19,  1.02s/it]

Step 16000: Policy Loss 0.0696, Critic Loss 0.2511, Avg Train Acc 0.6813, Val Acc 0.5000, Unique Indices: 100


Training Steps:  80%|████████  | 16100/20000 [4:06:47<2:05:02,  1.92s/it] 

Step 16100: Policy Loss 0.2370, Critic Loss 0.1962, Avg Train Acc 0.6787, Val Acc 0.4900, Unique Indices: 98


Training Steps:  81%|████████  | 16199/20000 [4:08:25<1:03:25,  1.00s/it]

Step 16200: Policy Loss 0.1236, Critic Loss 0.1616, Avg Train Acc 0.6913, Val Acc 0.5000, Unique Indices: 99


Training Steps:  82%|████████▏ | 16300/20000 [4:10:14<1:56:45,  1.89s/it]

Step 16300: Policy Loss 0.1073, Critic Loss 0.1597, Avg Train Acc 0.6713, Val Acc 0.4900, Unique Indices: 99


Training Steps:  82%|████████▏ | 16399/20000 [4:11:54<55:42,  1.08it/s]  

Step 16400: Policy Loss 0.0019, Critic Loss 0.1536, Avg Train Acc 0.7000, Val Acc 0.4800, Unique Indices: 100


Training Steps:  82%|████████▎ | 16500/20000 [4:13:55<1:48:40,  1.86s/it]

Step 16500: Policy Loss 0.0033, Critic Loss 0.2381, Avg Train Acc 0.6850, Val Acc 0.4800, Unique Indices: 99


Training Steps:  83%|████████▎ | 16599/20000 [4:15:35<58:52,  1.04s/it]  

Step 16600: Policy Loss 0.0432, Critic Loss 0.0673, Avg Train Acc 0.7150, Val Acc 0.4900, Unique Indices: 100


Training Steps:  84%|████████▎ | 16700/20000 [4:17:26<1:47:27,  1.95s/it]

Step 16700: Policy Loss 0.0199, Critic Loss 0.2719, Avg Train Acc 0.7275, Val Acc 0.4800, Unique Indices: 99


Training Steps:  84%|████████▍ | 16799/20000 [4:19:08<1:03:31,  1.19s/it]

Step 16800: Policy Loss 0.1316, Critic Loss 0.2254, Avg Train Acc 0.6737, Val Acc 0.4800, Unique Indices: 97


Training Steps:  84%|████████▍ | 16900/20000 [4:20:57<1:32:31,  1.79s/it]

Step 16900: Policy Loss 0.0431, Critic Loss 0.2980, Avg Train Acc 0.6913, Val Acc 0.5000, Unique Indices: 98


Training Steps:  85%|████████▍ | 16999/20000 [4:22:37<53:37,  1.07s/it]  

Step 17000: Policy Loss -0.0000, Critic Loss 0.4121, Avg Train Acc 0.6713, Val Acc 0.4800, Unique Indices: 100


Training Steps:  86%|████████▌ | 17100/20000 [4:24:23<1:39:23,  2.06s/it]

Step 17100: Policy Loss 0.0769, Critic Loss 0.0810, Avg Train Acc 0.6900, Val Acc 0.4900, Unique Indices: 99


Training Steps:  86%|████████▌ | 17199/20000 [4:26:00<48:26,  1.04s/it]  

Step 17200: Policy Loss 0.0277, Critic Loss 0.1549, Avg Train Acc 0.6875, Val Acc 0.4900, Unique Indices: 99


Training Steps:  86%|████████▋ | 17300/20000 [4:27:47<1:26:56,  1.93s/it]

Step 17300: Policy Loss 0.0492, Critic Loss 0.1800, Avg Train Acc 0.6750, Val Acc 0.5100, Unique Indices: 100


Training Steps:  87%|████████▋ | 17399/20000 [4:29:23<44:28,  1.03s/it]  

Step 17400: Policy Loss 0.0449, Critic Loss 0.0933, Avg Train Acc 0.6913, Val Acc 0.4900, Unique Indices: 98


Training Steps:  88%|████████▊ | 17500/20000 [4:32:12<1:18:58,  1.90s/it] 

Step 17500: Policy Loss 0.0692, Critic Loss 0.0983, Avg Train Acc 0.6800, Val Acc 0.4900, Unique Indices: 99


Training Steps:  88%|████████▊ | 17599/20000 [4:33:49<38:48,  1.03it/s]  

Step 17600: Policy Loss 0.1909, Critic Loss 0.1808, Avg Train Acc 0.6937, Val Acc 0.4900, Unique Indices: 98


Training Steps:  88%|████████▊ | 17700/20000 [4:35:39<1:14:28,  1.94s/it]

Step 17700: Policy Loss 0.1456, Critic Loss 0.1083, Avg Train Acc 0.7087, Val Acc 0.4900, Unique Indices: 99


Training Steps:  89%|████████▉ | 17799/20000 [4:37:17<35:49,  1.02it/s]  

Step 17800: Policy Loss 0.0371, Critic Loss 0.2705, Avg Train Acc 0.6725, Val Acc 0.4800, Unique Indices: 100


Training Steps:  90%|████████▉ | 17900/20000 [4:39:10<1:08:12,  1.95s/it]

Step 17900: Policy Loss 0.1868, Critic Loss 0.2645, Avg Train Acc 0.6850, Val Acc 0.4900, Unique Indices: 98


Training Steps:  90%|████████▉ | 17999/20000 [4:40:41<30:53,  1.08it/s]  

Step 18000: Policy Loss 0.0569, Critic Loss 0.1712, Avg Train Acc 0.6900, Val Acc 0.4900, Unique Indices: 99


Training Steps:  90%|█████████ | 18100/20000 [4:42:27<59:34,  1.88s/it]  

Step 18100: Policy Loss 0.0267, Critic Loss 0.0968, Avg Train Acc 0.6725, Val Acc 0.4800, Unique Indices: 100


Training Steps:  91%|█████████ | 18199/20000 [4:44:01<27:42,  1.08it/s]

Step 18200: Policy Loss 0.1941, Critic Loss 0.2952, Avg Train Acc 0.7188, Val Acc 0.5000, Unique Indices: 98


Training Steps:  92%|█████████▏| 18300/20000 [4:45:47<53:25,  1.89s/it]  

Step 18300: Policy Loss 0.0848, Critic Loss 0.2624, Avg Train Acc 0.7075, Val Acc 0.4900, Unique Indices: 99


Training Steps:  92%|█████████▏| 18399/20000 [4:47:21<27:55,  1.05s/it]

Step 18400: Policy Loss 0.1085, Critic Loss 0.2285, Avg Train Acc 0.6975, Val Acc 0.4900, Unique Indices: 99


Training Steps:  92%|█████████▎| 18500/20000 [4:49:04<46:18,  1.85s/it]  

Step 18500: Policy Loss 0.1157, Critic Loss 0.0793, Avg Train Acc 0.6900, Val Acc 0.5000, Unique Indices: 99


Training Steps:  93%|█████████▎| 18599/20000 [4:50:44<23:49,  1.02s/it]

Step 18600: Policy Loss 0.0387, Critic Loss 0.1451, Avg Train Acc 0.6600, Val Acc 0.4900, Unique Indices: 98


Training Steps:  94%|█████████▎| 18700/20000 [4:52:34<41:58,  1.94s/it]  

Step 18700: Policy Loss 0.0127, Critic Loss 0.2754, Avg Train Acc 0.6813, Val Acc 0.4900, Unique Indices: 98


Training Steps:  94%|█████████▍| 18799/20000 [4:54:16<22:33,  1.13s/it]

Step 18800: Policy Loss 0.0658, Critic Loss 0.2369, Avg Train Acc 0.6837, Val Acc 0.4800, Unique Indices: 98


Training Steps:  94%|█████████▍| 18900/20000 [4:56:31<33:49,  1.84s/it]  

Step 18900: Policy Loss 0.2403, Critic Loss 0.3420, Avg Train Acc 0.6713, Val Acc 0.4900, Unique Indices: 99


Training Steps:  95%|█████████▍| 18999/20000 [4:58:03<20:29,  1.23s/it]

Step 19000: Policy Loss 0.0640, Critic Loss 0.4077, Avg Train Acc 0.6800, Val Acc 0.4900, Unique Indices: 99


Training Steps:  96%|█████████▌| 19100/20000 [4:59:51<27:44,  1.85s/it]

Step 19100: Policy Loss 0.0054, Critic Loss 0.1984, Avg Train Acc 0.6775, Val Acc 0.5100, Unique Indices: 100


Training Steps:  96%|█████████▌| 19199/20000 [5:01:24<13:39,  1.02s/it]

Step 19200: Policy Loss 0.0818, Critic Loss 0.2795, Avg Train Acc 0.6800, Val Acc 0.5100, Unique Indices: 98


Training Steps:  96%|█████████▋| 19300/20000 [5:03:16<23:00,  1.97s/it]

Step 19300: Policy Loss -0.0000, Critic Loss 0.0945, Avg Train Acc 0.6863, Val Acc 0.4800, Unique Indices: 99


Training Steps:  97%|█████████▋| 19399/20000 [5:04:49<08:50,  1.13it/s]

Step 19400: Policy Loss 0.0751, Critic Loss 0.3206, Avg Train Acc 0.7200, Val Acc 0.5000, Unique Indices: 98


Training Steps:  98%|█████████▊| 19500/20000 [5:06:36<15:46,  1.89s/it]

Step 19500: Policy Loss -0.0000, Critic Loss 0.0488, Avg Train Acc 0.6763, Val Acc 0.4800, Unique Indices: 99


Training Steps:  98%|█████████▊| 19599/20000 [5:08:12<06:35,  1.02it/s]

Step 19600: Policy Loss 0.1070, Critic Loss 0.2638, Avg Train Acc 0.7000, Val Acc 0.4900, Unique Indices: 100


Training Steps:  98%|█████████▊| 19700/20000 [5:09:53<09:09,  1.83s/it]

Step 19700: Policy Loss 0.1853, Critic Loss 0.3818, Avg Train Acc 0.6950, Val Acc 0.4900, Unique Indices: 100


Training Steps:  99%|█████████▉| 19799/20000 [5:11:30<03:14,  1.03it/s]

Step 19800: Policy Loss 0.0351, Critic Loss 0.1969, Avg Train Acc 0.6925, Val Acc 0.5000, Unique Indices: 99


Training Steps: 100%|█████████▉| 19900/20000 [5:13:25<03:05,  1.86s/it]

Step 19900: Policy Loss 0.1335, Critic Loss 0.1881, Avg Train Acc 0.7087, Val Acc 0.5100, Unique Indices: 100


Training Steps: 100%|█████████▉| 19999/20000 [5:14:56<00:00,  1.09it/s]

Step 20000: Policy Loss 0.2102, Critic Loss 0.3176, Avg Train Acc 0.7188, Val Acc 0.4800, Unique Indices: 100


Training Steps: 100%|██████████| 20000/20000 [5:15:05<00:00,  1.06it/s]
