In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import sys
import string
import random
import math
import os
import pickle
import copy
from collections import defaultdict, deque
import traceback # For better error printing

# --- Configuração Global ---
ALPHABET = list(string.ascii_lowercase) + list(string.digits) + [' ']
PAD_TOKEN = "<PAD>"; SOS_TOKEN = "<SOS>"; EOS_TOKEN = "<EOS>"; UNK_TOKEN = "<UNK>"
USER_AVATAR_TOKEN = "<USER_AVATAR>" # Apenas para tipo de célula

SPECIAL_TOKENS = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN]
VOCAB = SPECIAL_TOKENS + ALPHABET
char_to_idx = {char: i for i, char in enumerate(VOCAB)}
idx_to_char = {i: char for i, char in enumerate(VOCAB)}
VOCAB_SIZE = len(VOCAB)
PAD_IDX = char_to_idx[PAD_TOKEN]; SOS_IDX = char_to_idx[SOS_TOKEN]; EOS_IDX = char_to_idx[EOS_TOKEN]
print(f"Vocabulário Transformer ({VOCAB_SIZE}): {' '.join(VOCAB)}")

# --- Configuração ByteWorld ---
WORLD_WIDTH = 30; WORLD_HEIGHT = 20; STATE_DIM = 1
CELL_TYPES = {'empty': 0, 'wall': 1, 'food': 2, 'danger': 3, 'terminal': 5, 'user_avatar': 6}
CELL_CHARS = {0: '.', 1: '#', 2: 'F', 3: '!', 5: 'T', 6: 'U'}
TERMINAL_X, TERMINAL_Y = 5, 5
# *** Define USER_AVATAR_X/Y aqui no escopo global ***
USER_AVATAR_X, USER_AVATAR_Y = 6, 5 # Default, será sobrescrito se carregado ou inicializado

# --- Configuração Agente ESMA (AJUSTES DE RESPONSIVIDADE E COERÊNCIA) ---
PERCEPTION_WINDOW = 60; MAX_ENERGY = 1000.0
ENERGY_DECAY_RATE = 0.001      # Perde energia mais rápido
MOVE_ENERGY_COST = 0.01; FOOD_ENERGY_GAIN = 100.0; DANGER_ENERGY_LOSS = -25.0
CURIOSITY_INCREASE = 0.01; CURIOSITY_DECREASE_INTERACT = 0.7; SAFETY_THRESHOLD = 2.5
INACTION_PENALTY_STEPS = 5     # Penalidade por inação mais cedo
INACTION_CURIOSITY_BOOST = 0.05;
INACTION_ENERGY_PENALTY = 0.6   # Penalidade por inação maior
LOW_INTELLIGENCE_THRESHOLD = -5.0
LOW_INTELLIGENCE_PENALTY = -2.0
NO_RESPONSE_PENALTY = -20.0      # Penalidade de aprendizado por ignorar
RESPONSE_REWARD_BONUS = 50.0    # BÔNUS DE APRENDIZADO por responder
IGNORE_USER_PENALTY = -10.5      # PENALIDADE ENERGÉTICA DIRETA por ignorar
RESPONSE_ENERGY_BONUS = 20.5     # BÔNUS ENERGÉTICO DIRETO por responder
MIN_SPEECH_LENGTH_FOR_BONUS = 8 # Minimum length for coherence bonus
COHERENT_SPEECH_BONUS = 2.0     # Small reward bonus for longer speech

ACTIONS = ["MOVE_N", "MOVE_S", "MOVE_E", "MOVE_W", "EAT", "GOTO_FOOD", "GOTO_SAFE", "GOTO_TERM", "SPEAK"]
action_to_id = {action: i for i, action in enumerate(ACTIONS)}; id_to_action = {i: action for i, action in enumerate(ACTIONS)}
NUM_ACTIONS = len(ACTIONS)

# --- Parâmetros de Aprendizado Motivacional ---
MAX_MOTIVATED_RULES = 10000; RULE_CONFIDENCE_THRESHOLD = 0.05
REWARD_LEARNING_RATE = 1.99; CONFIDENCE_LEARNING_RATE = 0.05
EXPLORATION_RATE = 0.15 # Reduzido levemente
DISCOUNT_FACTOR = 0.9
NEED_BOOST_FACTOR = 10.0; SPEECH_REWARD_BONUS = 50.0; FEEDBACK_REWARD = 80.0
FEEDBACK_PUNISH = -40.0

# --- Parâmetros do Transformer LM ---
LM_D_MODEL = 96; LM_NHEAD = 4; LM_NUM_LAYERS = 4; LM_DIM_FEEDFORWARD = 256
LM_DROPOUT = 0.1; LM_CONTEXT_WINDOW = 48; LM_LEARNING_RATE = 1e-4; LM_BATCH_SIZE = 64
MIN_GENERATION_LENGTH = 30       # Minimum tokens to generate before allowing EOS
EOS_PENALTY_FACTOR = 1.5        # How much to penalize EOS before min length

# --- Configurações Gerais ---
additional_simulation_steps = 0
print_interval = 1
prune_interval = 10000
seq_limit = 64

# --- Dispositivo ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Dispositivo: {device}")

# --- Estruturas de Dados ESMA (Globais) ---
agent_state = {}
language_working_memory = deque(maxlen=LM_CONTEXT_WINDOW + 20)
perceptual_chunk_lexicon = {}
next_percept_chunk_id = 0
byte_world = [] # Será preenchido em __main__
terminal_buffer = "" # Usado no chat
world_step = 0
current_step = 0 # Para working memory
steps_stuck = 0
last_pos = (-1, -1)
chunk_cache = {}
last_state_action_before_good_reward = None
intelligence_score = 0.0
INTELLIGENCE_DECAY = 0.999
INTELLIGENCE_FEED_BONUS = 10.0
INTELLIGENCE_PUNISH_PENALTY = -10.0
INTELLIGENCE_EAT_BONUS = 50.5
INTELLIGENCE_DANGER_PENALTY = -1.0
INTELLIGENCE_LOW_ENERGY_PENALTY = -0.5

def create_motivated_rule_entry(): return {'expected_reward': 0.0, 'confidence': 0.0, 'count': 0}
def create_motivated_context_dict(): return defaultdict(create_motivated_rule_entry)
motivated_rule_base = defaultdict(create_motivated_context_dict) # Será preenchido em __main__

lm_training_buffer = deque(maxlen=LM_BATCH_SIZE * 20)
last_agent_speech_context = None

# --- Modelo Transformer LM (Globais) ---
lm_model = None # Será instanciado em __main__
lm_optimizer = None # Será instanciado em __main__
lm_criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


# --- Modelo Transformer LM Definition ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 500):
        super().__init__(); self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model); pe[:, 0, 0::2] = torch.sin(position * div_term); pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[:x.size(0)]; return self.dropout(x)

class MiniTransformerLM(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout, max_len):
        super().__init__(); self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=PAD_IDX)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_len + 1)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.init_weights()
    def init_weights(self) -> None:
        initrange = 0.1; self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output_layer.bias.data.zero_(); self.output_layer.weight.data.uniform_(-initrange, initrange)
        if self.embedding.padding_idx is not None:
             with torch.no_grad(): self.embedding.weight[self.embedding.padding_idx].fill_(0)
    def _generate_square_subsequent_mask(self, sz: int, dev: torch.device) -> torch.Tensor:
         return torch.triu(torch.ones(sz, sz, device=dev) * float('-inf'), diagonal=1)
    def forward(self, src: torch.Tensor, src_padding_mask: torch.Tensor = None) -> torch.Tensor:
        seq_len = src.size(1)
        src_mask = self._generate_square_subsequent_mask(seq_len, src.device)
        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        src_pos = self.pos_encoder(src_emb.permute(1, 0, 2)).permute(1, 0, 2)
        output = self.transformer_encoder(src_pos, mask=src_mask, src_key_padding_mask=src_padding_mask)
        logits = self.output_layer(output); return logits


# --- Funções ESMA ---
def update_working_memory(char):
    global current_step, language_working_memory # Modifies globals
    if char in char_to_idx: language_working_memory.append(char); current_step += 1

def add_to_lm_buffer(text_sequence, is_user_input=False, weight_multiplier=1, is_rewarded_speech=False):
    global lm_training_buffer # Modifies global
    if not text_sequence: return
    ids = [SOS_IDX] + [char_to_idx.get(c, PAD_IDX) for c in text_sequence if c in char_to_idx] + [EOS_IDX]
    if len(ids) < 3: return
    base_weight = 1
    if is_user_input: base_weight = 30
    elif is_rewarded_speech: base_weight = 50
    final_weight = base_weight * weight_multiplier
    for _ in range(final_weight):
        for i in range(1, len(ids)):
            context_end = i; context_start = max(0, i - LM_CONTEXT_WINDOW); context = ids[context_start:context_end]
            target = ids[i]
            if not context or target == PAD_IDX: continue
            padded_context = ([PAD_IDX] * (LM_CONTEXT_WINDOW - len(context))) + context
            if any(t != PAD_IDX for t in padded_context): lm_training_buffer.append((tuple(padded_context), target))

def train_language_model():
    global lm_training_buffer, lm_model, lm_optimizer, lm_criterion # Reads/Modifies globals
    if len(lm_training_buffer) < LM_BATCH_SIZE: return None
    if lm_model is None or lm_optimizer is None: return None # Ensure model exists

    lm_model.train()
    batch_indices = random.sample(range(len(lm_training_buffer)), min(LM_BATCH_SIZE, len(lm_training_buffer)))
    batch_data = [lm_training_buffer[i] for i in batch_indices]
    if len(lm_training_buffer) > lm_training_buffer.maxlen * 0.9:
         keep_items = list(lm_training_buffer)[-LM_BATCH_SIZE*5:]
         lm_training_buffer.clear(); lm_training_buffer.extend(keep_items)
    contexts, targets = zip(*batch_data)
    if not contexts or not targets: return None
    try:
        context_tensor = torch.tensor(contexts, dtype=torch.long, device=device)
        target_tensor = torch.tensor(targets, dtype=torch.long, device=device)
    except ValueError as e: print(f"Erro tensor batch: {e}\nC:{contexts}\nT:{targets}"); return None
    padding_mask = (context_tensor == PAD_IDX)
    lm_optimizer.zero_grad()
    try:
        logits = lm_model(context_tensor, src_padding_mask=padding_mask)
        last_logits = logits[:, -1, :]
        loss = lm_criterion(last_logits, target_tensor)
        if not torch.isnan(loss) and not torch.isinf(loss):
            loss.backward(); torch.nn.utils.clip_grad_norm_(lm_model.parameters(), 0.5); lm_optimizer.step(); return loss.item()
        else: print("Aviso: Perda NaN/Inf."); lm_optimizer.zero_grad(); return None
    except RuntimeError as e:
        if "CUDA out of memory" in str(e): print("Erro CUDA OOM."); torch.cuda.empty_cache()
        else: print(f"RuntimeError LM: {e}.");
        lm_optimizer.zero_grad(); return None

def generate_speech(max_len=25, temperature=0.6):
    global language_working_memory, lm_model, MIN_GENERATION_LENGTH, EOS_PENALTY_FACTOR, EOS_IDX # Reads globals
    if lm_model is None: return "?" # Cannot generate without model

    lm_model.eval()
    context_list = list(language_working_memory)[-LM_CONTEXT_WINDOW:]
    current_ids = [SOS_IDX] + [char_to_idx.get(c, PAD_IDX) for c in context_list if c in char_to_idx]
    generated_ids = []
    generated_token_count = 0
    fallback_count = 0; max_fallbacks = 3
    with torch.no_grad():
        for _ in range(max_len):
            input_ids = current_ids[-LM_CONTEXT_WINDOW:]
            padded_input_ids = ([PAD_IDX] * (LM_CONTEXT_WINDOW - len(input_ids))) + input_ids
            input_tensor = torch.tensor([padded_input_ids], dtype=torch.long, device=device)
            padding_mask = (input_tensor == PAD_IDX)
            try:
                 logits = lm_model(input_tensor, src_padding_mask=padding_mask)
                 next_token_logits = logits[:, -1, :]
                 if generated_token_count < MIN_GENERATION_LENGTH:
                     eos_logit_val = next_token_logits[0, EOS_IDX].item()
                     penalty = abs(eos_logit_val) * EOS_PENALTY_FACTOR + 1.0
                     next_token_logits[0, EOS_IDX] -= penalty
                 if torch.isnan(next_token_logits).any() or torch.isinf(next_token_logits).any(): raise ValueError("NaN/Inf logits")
                 k = 5
                 top_k_logits, top_k_indices = torch.topk(next_token_logits, k, dim=-1)
                 if temperature <= 0: next_token_id = top_k_indices[0, 0].item()
                 else:
                     scaled_logits_k = top_k_logits / (temperature + 1e-6)
                     probs_k = torch.softmax(scaled_logits_k, dim=-1)
                     if torch.isnan(probs_k).any() or torch.isinf(probs_k).any() or probs_k.sum() < 1e-9: raise ValueError("Invalid probs")
                     probs_k[0, (top_k_indices[0] == PAD_IDX).nonzero(as_tuple=True)[0]] = 0.0
                     probs_k[0, (top_k_indices[0] == SOS_IDX).nonzero(as_tuple=True)[0]] = 0.0
                     if probs_k.sum() < 1e-9: raise ValueError("No valid tokens")
                     probs_k = probs_k / probs_k.sum()
                     chosen_relative_idx = torch.multinomial(probs_k, num_samples=1).item()
                     next_token_id = top_k_indices[0, chosen_relative_idx].item()
                 fallback_count = 0
            except Exception as e:
                 valid_alphabet_indices = [i for i, char in enumerate(VOCAB) if char in ALPHABET];
                 if not valid_alphabet_indices: break
                 next_token_id = random.choice(valid_alphabet_indices); fallback_count += 1
                 if fallback_count >= max_fallbacks: print(f"Aviso: Fallback excessivo ({e})."); break
            if next_token_id == EOS_IDX and generated_token_count >= MIN_GENERATION_LENGTH: break
            if next_token_id != PAD_IDX and next_token_id != SOS_IDX:
                 generated_ids.append(next_token_id); current_ids.append(next_token_id)
                 if next_token_id != EOS_IDX: generated_token_count += 1
            if len(generated_ids) >= max_len: break
    final_generated_ids = [idx for idx in generated_ids if idx != EOS_IDX]
    if not final_generated_ids:
         valid_alphabet_indices = [i for i, char in enumerate(VOCAB) if char in ALPHABET];
         return idx_to_char.get(random.choice(valid_alphabet_indices), '?') if valid_alphabet_indices else ""
    return "".join([idx_to_char.get(idx, '?') for idx in final_generated_ids])

# --- initialize_world (Corrected) ---
# This function now only MODIFIES the global byte_world, terminal_buffer, USER_AVATAR_X/Y
def initialize_world():
    global terminal_buffer, byte_world, agent_state, USER_AVATAR_X, USER_AVATAR_Y
    print("Inicializando ByteWorld..."); terminal_buffer = ""
    # Create a local world first, then assign to global byte_world
    local_world = [[CELL_TYPES['empty'] for _ in range(WORLD_WIDTH)] for _ in range(WORLD_HEIGHT)]
    for y in range(WORLD_HEIGHT): local_world[y][0]=CELL_TYPES['wall']; local_world[y][WORLD_WIDTH-1]=CELL_TYPES['wall']
    for x in range(WORLD_WIDTH): local_world[0][x]=CELL_TYPES['wall']; local_world[WORLD_HEIGHT-1][x]=CELL_TYPES['wall']
    term_y_safe = max(1, min(TERMINAL_Y, WORLD_HEIGHT-2)); term_x_safe = max(1, min(TERMINAL_X, WORLD_WIDTH-2))
    if 0 <= term_y_safe < WORLD_HEIGHT and 0 <= term_x_safe < WORLD_WIDTH:
        local_world[term_y_safe][term_x_safe] = CELL_TYPES['terminal']

    # Avatar Placement
    avatar_y_safe = max(1, min(USER_AVATAR_Y, WORLD_HEIGHT-2)); avatar_x_safe = max(1, min(USER_AVATAR_X, WORLD_WIDTH-2))
    if 0 <= avatar_y_safe < WORLD_HEIGHT and 0 <= avatar_x_safe < WORLD_WIDTH and local_world[avatar_y_safe][avatar_x_safe] == CELL_TYPES['empty']:
        local_world[avatar_y_safe][avatar_x_safe] = CELL_TYPES['user_avatar']
    else:
         placed = False
         print(f"    Posição inicial do avatar ({avatar_x_safe},{avatar_y_safe}) ocupada. Procurando nova...")
         for _ in range(100):
              ny, nx = TERMINAL_Y + random.randint(-1,1), TERMINAL_X + random.randint(-1,1)
              if 0 <= ny < WORLD_HEIGHT and 0 <= nx < WORLD_WIDTH and local_world[ny][nx] == CELL_TYPES['empty']:
                  USER_AVATAR_X, USER_AVATAR_Y = nx, ny # Update globals
                  local_world[ny][nx] = CELL_TYPES['user_avatar']; placed=True
                  print(f"    Avatar colocado em ({nx},{ny}) via fallback.")
                  break
         if not placed:
             print("    Falha ao colocar avatar perto do terminal. Usando (1,1).")
             USER_AVATAR_X, USER_AVATAR_Y = 1,1
             if 1 < WORLD_HEIGHT and 1 < WORLD_WIDTH and local_world[1][1] == CELL_TYPES['empty']:
                 local_world[1][1] = CELL_TYPES['user_avatar']
             else: print("    AVISO CRÍTICO: Posição de fallback (1,1) também ocupada ou inválida!")

    # Food Placement
    num_food = 0
    for _ in range(300):
         x, y = random.randint(1, WORLD_WIDTH-2), random.randint(1, WORLD_HEIGHT-2)
         if 0 <= y < WORLD_HEIGHT and 0 <= x < WORLD_WIDTH and local_world[y][x] == CELL_TYPES['empty']:
              local_world[y][x] = CELL_TYPES['food']; num_food += 1
         if num_food >= 20: break

    # Danger Placement
    num_danger = 0
    start_x = agent_state.get('x', WORLD_WIDTH//2); start_y = agent_state.get('y', WORLD_HEIGHT//2) # Use agent_state if available
    for _ in range(150):
        x, y = random.randint(1, WORLD_WIDTH-2), random.randint(1, WORLD_HEIGHT-2)
        is_near_start=abs(x-start_x)<4 and abs(y-start_y)<4; is_near_term=abs(x-TERMINAL_X)<3 and abs(y-TERMINAL_Y)<3; is_near_avatar=abs(x-USER_AVATAR_X)<3 and abs(y-USER_AVATAR_Y)<3
        if 0<=y<WORLD_HEIGHT and 0<=x<WORLD_WIDTH and local_world[y][x]==CELL_TYPES['empty'] and not (is_near_start or is_near_term or is_near_avatar):
             local_world[y][x] = CELL_TYPES['danger']; num_danger += 1
        if num_danger >= 5: break

    byte_world = local_world # Assign the completed local world to the global variable
    print("ByteWorld inicializado.")


def print_world(world, agent_st, term_buf):
    global world_step, steps_stuck, intelligence_score, USER_AVATAR_X, USER_AVATAR_Y # Reads globals
    print("-" * (WORLD_WIDTH + 2))
    if not isinstance(world, list) or not world: print("## ERRO: Mundo inválido ##"); return
    for y in range(WORLD_HEIGHT):
        if y >= len(world): continue
        line = "#"; agent_x_print = agent_st.get('x', -1); agent_y_print = agent_st.get('y', -1)
        for x in range(WORLD_WIDTH):
             if x >= len(world[y]): line += '?'; continue
             cell_content = world[y][x]
             char_to_print = CELL_CHARS.get(cell_content, '?')
             if x == agent_x_print and y == agent_y_print: line += "@"
             elif x == USER_AVATAR_X and y == USER_AVATAR_Y: line += 'U'
             else: line += char_to_print
        line += "#";
        if y == TERMINAL_Y: line += f"  Terminal: [{term_buf}]"
        elif y == TERMINAL_Y + 1: line += f"  Inteligência: {intelligence_score:.2f}"
        print(line)
    print("-" * (WORLD_WIDTH + 2))
    print(f" Step: {world_step} | Energy: {agent_st.get('energy', 0.0):.1f} | Curiosity: {agent_st.get('curiosity', 0.0):.2f} | Safety: {agent_st.get('safety', 0.0):.2f} | Stuck: {steps_stuck} | Last Speech: '{agent_st.get('last_speech','')}'")

def get_perception(agent_x, agent_y, world):
    window_size = PERCEPTION_WINDOW; half = window_size // 2
    perception = [[CELL_TYPES['wall'] for _ in range(window_size)] for _ in range(window_size)]
    for dy in range(-half, half + 1):
        for dx in range(-half, half + 1):
            world_y, world_x = agent_y + dy, agent_x + dx
            if 0 <= world_y < WORLD_HEIGHT and 0 <= world_x < WORLD_WIDTH:
                 # Check world is valid list and indices are within bounds
                 if world and 0 <= world_y < len(world) and world[world_y] and 0 <= world_x < len(world[world_y]):
                     perception[dy + half][dx + half] = world[world_y][world_x]
                 # else: keep wall if index out of bounds
            # else: keep wall if index out of bounds
    return perception

def flatten_perception(perception_window): flat = []; [flat.extend(row) for row in perception_window]; return tuple(flat)

def chunk_perception(flat_perception_tuple):
    global next_percept_chunk_id, perceptual_chunk_lexicon, chunk_cache # Modifies globals
    if flat_perception_tuple in chunk_cache:
        if flat_perception_tuple in perceptual_chunk_lexicon:
            perceptual_chunk_lexicon[flat_perception_tuple]['count'] += 1
        return chunk_cache[flat_perception_tuple]
    if len(perceptual_chunk_lexicon) > 30000: return None # Limit lexicon size
    if flat_perception_tuple not in perceptual_chunk_lexicon:
        pid = next_percept_chunk_id; perceptual_chunk_lexicon[flat_perception_tuple] = {'id': pid, 'count': 1}; next_percept_chunk_id += 1;
        chunk_cache = {flat_perception_tuple: pid} # Update cache
        return pid
    else:
        perceptual_chunk_lexicon[flat_perception_tuple]['count'] += 1;
        pid = perceptual_chunk_lexicon[flat_perception_tuple]['id']
        chunk_cache = {flat_perception_tuple: pid} # Update cache
        return pid

def get_simplified_internal_state(state):
    global USER_AVATAR_X, USER_AVATAR_Y # Reads globals
    energy = state.get('energy', 0.0); curiosity = state.get('curiosity', 0.0); safety = state.get('safety', 0.0)
    hunger = "starving" if energy < MAX_ENERGY * 0.15 else ("hungry" if energy < MAX_ENERGY * 0.4 else ("ok" if energy < MAX_ENERGY * 0.8 else "full"))
    boredom = "very_bored" if curiosity > 0.9 else ("bored" if curiosity > 0.6 else ("ok" if curiosity > 0.2 else "engaged"))
    safety_level = "in_danger" if safety < 0.15 else ("wary" if safety < 0.5 else "safe")
    near_user = "far_user"
    agent_x = state.get('x', -1); agent_y = state.get('y', -1)
    if agent_x != -1 and agent_y != -1:
         dist_sq_user = (agent_x - USER_AVATAR_X)**2 + (agent_y - USER_AVATAR_Y)**2
         if dist_sq_user <= 4: near_user = "near_user"
    return (hunger, boredom, safety_level, near_user)

def select_action(percept_chunk_id, internal_state_tuple, agent_st):
    global motivated_rule_base, steps_stuck, terminal_buffer, byte_world, EXPLORATION_RATE # Reads/Uses globals
    if percept_chunk_id is None: return action_to_id[random.choice(ACTIONS)]
    context_key = (percept_chunk_id, internal_state_tuple); hunger, boredom, safety, near_user = internal_state_tuple
    ax, ay = agent_st.get('x', -1), agent_st.get('y', -1)
    is_at_terminal = (ax == TERMINAL_X and ay == TERMINAL_Y)
    has_terminal_text_pending = bool(terminal_buffer)
    if safety == "in_danger":
        nearest_danger = None; min_dist_sq = float('inf')
        for y_ in range(WORLD_HEIGHT):
            for x_ in range(WORLD_WIDTH):
                 if 0 <= y_ < len(byte_world) and 0 <= x_ < len(byte_world[y_]) and byte_world[y_][x_] == CELL_TYPES['danger']:
                      dist_sq = (ax - x_)**2 + (ay - y_)**2
                      if dist_sq < min_dist_sq: min_dist_sq = dist_sq; nearest_danger = (x_, y_)
        if nearest_danger:
             danger_x, danger_y = nearest_danger; best_flee_action = None; max_flee_dist = min_dist_sq
             possible_moves = {"MOVE_N": (ax, ay-1), "MOVE_S": (ax, ay+1), "MOVE_E": (ax+1, ay), "MOVE_W": (ax-1, ay)}
             valid_flee_moves = []
             for move_name, (nx, ny) in possible_moves.items():
                  if 0 <= ny < WORLD_HEIGHT and 0 <= nx < WORLD_WIDTH and byte_world[ny][nx] != CELL_TYPES['wall']:
                       valid_flee_moves.append((move_name, (nx - danger_x)**2 + (ny - danger_y)**2))
             if valid_flee_moves:
                 valid_flee_moves.sort(key=lambda item: item[1], reverse=True); best_flee_action = valid_flee_moves[0][0]
                 return action_to_id[best_flee_action]
             else: move_actions = [k for k,v in action_to_id.items() if "MOVE" in k]; return action_to_id[random.choice(move_actions)] if move_actions else action_to_id.get("SPEAK", 0)
    if has_terminal_text_pending:
        if is_at_terminal and 'SPEAK' in action_to_id and random.random() < 0.98: return action_to_id['SPEAK']
        elif not is_at_terminal and 'GOTO_TERM' in action_to_id and random.random() < 0.90: return action_to_id['GOTO_TERM']
        elif 'SPEAK' in action_to_id and random.random() < 0.7: return action_to_id['SPEAK']
        else: return action_to_id[random.choice(ACTIONS)]
    if hunger == "starving":
        if 0 <= ay < WORLD_HEIGHT and 0 <= ax < WORLD_WIDTH and byte_world[ay][ax] == CELL_TYPES['food'] and 'EAT' in action_to_id: return action_to_id['EAT']
        if 'GOTO_FOOD' in action_to_id: return action_to_id['GOTO_FOOD']
    if hunger == "hungry":
        if 0 <= ay < WORLD_HEIGHT and 0 <= ax < WORLD_WIDTH and byte_world[ay][ax] == CELL_TYPES['food'] and 'EAT' in action_to_id: return action_to_id['EAT']
        perception = get_perception(ax, ay, byte_world); food_pos_local = None
        for r, row in enumerate(perception):
            for c, cell_type in enumerate(row):
                if cell_type == CELL_TYPES['food']: food_pos_local = (c - PERCEPTION_WINDOW//2, r - PERCEPTION_WINDOW//2); break
            if food_pos_local: break
        if food_pos_local:
             food_dx, food_dy = food_pos_local; action_name = ""
             if abs(food_dx) > abs(food_dy): action_name = "MOVE_E" if food_dx > 0 else "MOVE_W"
             elif food_dy != 0: action_name = "MOVE_S" if food_dy > 0 else "MOVE_N"
             elif food_dx != 0: action_name = "MOVE_E" if food_dx > 0 else "MOVE_W"
             if action_name in action_to_id: return action_to_id[action_name]
        if 'GOTO_FOOD' in action_to_id and random.random() < 0.7: return action_to_id['GOTO_FOOD']
    user_interaction_urge = agent_st.get('curiosity', 0.0) + (0.3 if near_user == "near_user" else 0.0)
    if is_at_terminal and not has_terminal_text_pending and random.random() < user_interaction_urge * 0.6 and 'SPEAK' in action_to_id: return action_to_id['SPEAK']
    if near_user == "near_user" and not is_at_terminal and not has_terminal_text_pending and random.random() < 0.5 and 'GOTO_TERM' in action_to_id: return action_to_id['GOTO_TERM']
    if boredom == "very_bored" and near_user == "far_user":
         move_actions = [k for k,v in action_to_id.items() if "MOVE" in k]; return action_to_id[random.choice(move_actions)] if move_actions else action_to_id.get("SPEAK", 0)
    current_exploration_rate = EXPLORATION_RATE + min(0.7, steps_stuck * 0.04)
    if context_key in motivated_rule_base and motivated_rule_base[context_key] and random.random() > current_exploration_rate:
        possible_actions = motivated_rule_base[context_key]
        valid_actions = {aid: data for aid, data in possible_actions.items() if id_to_action.get(aid)}
        if valid_actions:
            action_scores = {aid: data.get('expected_reward', 0.0) * (data.get('confidence', 0.0)**2 + 1e-3) for aid, data in valid_actions.items()}
            action_scores = {aid: score + random.gauss(0, 0.05) for aid, score in action_scores.items()}
            best_action_id = max(action_scores, key=action_scores.get)
            if valid_actions[best_action_id].get('confidence', 0.0) > RULE_CONFIDENCE_THRESHOLD: return best_action_id
    return action_to_id[random.choice(ACTIONS)]


# --- execute_action (Corrected) ---
def execute_action(agent_st, action_id, world, user_input_was_present):
    global terminal_buffer, byte_world, steps_stuck, last_pos, last_agent_speech_context, intelligence_score # Reads/Modifies globals
    action_name = id_to_action.get(action_id, "UNKNOWN"); ax = agent_st.get('x',0); ay = agent_st.get('y',0)
    new_x, new_y = ax, ay; reward = -ENERGY_DECAY_RATE; interaction_occurred = False; moved_successfully = False
    agent_generated_speech = ""
    responded_to_user = (action_name == "SPEAK" and user_input_was_present)
    ignored_user = (action_name != "SPEAK" and user_input_was_present)
    if ignored_user: reward += IGNORE_USER_PENALTY
    elif responded_to_user: reward += RESPONSE_ENERGY_BONUS
    if action_name == "SPEAK":
        perception = get_perception(ax, ay, world); flat = flatten_perception(perception)
        percept_id = chunk_perception(flat); internal_tuple = get_simplified_internal_state(agent_st)
        if percept_id is not None: last_agent_speech_context = ((percept_id, internal_tuple), action_id)
        else: last_agent_speech_context = None
    target_x, target_y = -1, -1

    # --- Lógica GOTO ---
    if action_name == "GOTO_FOOD":
        nearest_food = None; min_dist_sq = float('inf')
        for y_ in range(WORLD_HEIGHT):
            for x_ in range(WORLD_WIDTH):
                if 0<=y_<len(world) and 0<=x_<len(world[y_]) and world[y_][x_] == CELL_TYPES['food']:
                    dist_sq = (ax - x_)**2 + (ay - y_)**2
                    if dist_sq < min_dist_sq: min_dist_sq = dist_sq; nearest_food = (x_, y_)
        if nearest_food: target_x, target_y = nearest_food
    elif action_name == "GOTO_SAFE":
        nearest_danger = None; min_dist_sq_danger = float('inf')
        for y_ in range(WORLD_HEIGHT):
            for x_ in range(WORLD_WIDTH):
                if 0<=y_<len(world) and 0<=x_<len(world[y_]) and world[y_][x_] == CELL_TYPES['danger']:
                    dist_sq = (ax - x_)**2 + (ay - y_)**2
                    if dist_sq < min_dist_sq_danger: min_dist_sq_danger = dist_sq; nearest_danger = (x_, y_)
        if nearest_danger:
            danger_x, danger_y = nearest_danger; best_flee_target = None; max_flee_dist = min_dist_sq_danger
            possible_flee_moves = []
            for move_name_flee in ["MOVE_N", "MOVE_S", "MOVE_E", "MOVE_W"]:
                temp_x_flee, temp_y_flee = ax, ay
                if move_name_flee == "MOVE_N": temp_y_flee -= 1
                elif move_name_flee == "MOVE_S": temp_y_flee += 1
                elif move_name_flee == "MOVE_W": temp_x_flee -= 1
                elif move_name_flee == "MOVE_E": temp_x_flee += 1
                if 0 <= temp_y_flee < WORLD_HEIGHT and 0 <= temp_x_flee < WORLD_WIDTH and world[temp_y_flee][temp_x_flee] != CELL_TYPES['wall']:
                    new_dist_sq = (temp_x_flee - danger_x)**2 + (temp_y_flee - danger_y)**2
                    possible_flee_moves.append(((temp_x_flee, temp_y_flee), new_dist_sq))
            if possible_flee_moves:
                possible_flee_moves.sort(key=lambda item: item[1], reverse=True); target_x, target_y = possible_flee_moves[0][0]
    elif action_name == "GOTO_TERM": target_x, target_y = TERMINAL_X, TERMINAL_Y

    # --- Execução de Movimento ---
    if ("GOTO" in action_name and target_x != -1) or "MOVE" in action_name:
         temp_x, temp_y = new_x, new_y; move_dx, move_dy = 0, 0
         if "MOVE" in action_name:
             if action_name == "MOVE_N": move_dy = -1
             elif action_name == "MOVE_S": move_dy = 1
             elif action_name == "MOVE_W": move_dx = -1
             elif action_name == "MOVE_E": move_dx = 1
         elif "GOTO" in action_name and target_x != -1:
             dx = target_x - ax; dy = target_y - ay
             if abs(dx) > abs(dy):
                 if dx != 0: move_dx = int(math.copysign(1, dx))
             elif dy != 0: move_dy = int(math.copysign(1, dy))
             elif dx != 0: move_dx = int(math.copysign(1, dx))
         if move_dx != 0 or move_dy != 0:
             temp_x += move_dx; temp_y += move_dy
             if 0 <= temp_y < WORLD_HEIGHT and 0 <= temp_x < WORLD_WIDTH and world[temp_y][temp_x] != CELL_TYPES['wall']:
                 new_x, new_y = temp_x, temp_y; moved_successfully = True
             else: reward -= 0.5
         elif "GOTO" in action_name and target_x == -1: reward -= 0.2
         agent_st['x'], agent_st['y'] = new_x, new_y
         if "GOTO" in action_name and new_x == target_x and new_y == target_y: interaction_occurred = True

    # --- Execução de Ações Locais ---
    elif action_name == "EAT":
        if 0 <= ay < WORLD_HEIGHT and 0 <= ax < WORLD_WIDTH:
             if world[ay][ax] == CELL_TYPES['food']:
                 reward += FOOD_ENERGY_GAIN; world[ay][ax] = CELL_TYPES['empty']; interaction_occurred = True
             else: reward -= 0.5
        else: reward -= 0.5
    elif action_name == "SPEAK":
        agent_generated_speech = generate_speech(max_len=20); agent_st['last_speech'] = agent_generated_speech
        interaction_occurred = True
        if ax == TERMINAL_X and ay == TERMINAL_Y:
             terminal_buffer = agent_generated_speech; print(f"      Agente falou no terminal: '{agent_generated_speech}'")
             add_to_lm_buffer(agent_generated_speech)
             if ("food" in agent_generated_speech.lower() or "comer" in agent_generated_speech.lower() or "fome" in agent_generated_speech.lower()) and random.random() < 0.6:
                  food_placed = False; fx, fy = -1, -1
                  for _ in range(20):
                     dist=random.randint(1,3);fx,fy=ax+random.randint(-dist,dist),ay+random.randint(-dist,dist)
                     if 0<=fy<len(byte_world) and 0<=fx<len(byte_world[fy]) and byte_world[fy][fx]==CELL_TYPES['empty']:
                         byte_world[fy][fx] = CELL_TYPES['food']
                         reward += 15.0
                         print(f"      *** Ambiente reagiu com comida em ({fx},{fy})! ***")
                         food_placed=True
                         break # Correctly indented break
        else: print(f"      Agente falou no vazio: '{agent_generated_speech}'"); add_to_lm_buffer(agent_generated_speech)

    # --- Atualiza Estado Interno ---
    current_agent_y=agent_st.get('y',0); current_agent_x=agent_st.get('x',0)
    if 0 <= current_agent_y < WORLD_HEIGHT and 0 <= current_agent_x < WORLD_WIDTH:
         current_cell = world[current_agent_y][current_agent_x]
         if current_cell == CELL_TYPES['danger']: reward += DANGER_ENERGY_LOSS
    else: reward -= 1.0
    agent_st['energy'] = max(0.0, min(MAX_ENERGY, agent_st.get('energy', 0.0) + reward))
    if interaction_occurred: agent_st['curiosity'] *= (1.0 - CURIOSITY_DECREASE_INTERACT)
    else: agent_st['curiosity'] = min(1.0, agent_st.get('curiosity', 0.0) + CURIOSITY_INCREASE)
    min_dist_sq = float('inf'); agent_x_safe=agent_st.get('x',0); agent_y_safe=agent_st.get('y',0)
    for y_ in range(WORLD_HEIGHT):
        for x_ in range(WORLD_WIDTH):
             if 0 <= y_ < len(world) and 0 <= x_ < len(world[y_]) and world[y_][x_] == CELL_TYPES['danger']: min_dist_sq = min(min_dist_sq, (agent_x_safe - x_)**2 + (agent_y_safe - y_)**2)
    agent_st['safety'] = 1.0 - math.exp(-min_dist_sq / (SAFETY_THRESHOLD**2 * 5))
    current_pos = (agent_st.get('x'), agent_st.get('y'))
    if current_pos == last_pos: steps_stuck += 1
    else: steps_stuck = 0; last_pos = current_pos
    if steps_stuck > INACTION_PENALTY_STEPS:
        agent_st['curiosity'] = min(1.0, agent_st.get('curiosity', 0.0) + INACTION_CURIOSITY_BOOST)
        agent_st['energy'] = max(0.0, agent_st.get('energy',0) - INACTION_ENERGY_PENALTY); reward -= INACTION_ENERGY_PENALTY
    intelligence_score *= INTELLIGENCE_DECAY
    if action_name == "EAT" and reward > 0: intelligence_score += INTELLIGENCE_EAT_BONUS * ((reward - FOOD_ENERGY_GAIN) / FOOD_ENERGY_GAIN + 1)
    if agent_state['energy'] < MAX_ENERGY * 0.1: intelligence_score += INTELLIGENCE_LOW_ENERGY_PENALTY
    if agent_st.get('safety') < 0.2: intelligence_score += INTELLIGENCE_DANGER_PENALTY
    if intelligence_score < LOW_INTELLIGENCE_THRESHOLD: reward += LOW_INTELLIGENCE_PENALTY
    if responded_to_user:
        intelligence_score += 0.5
    elif ignored_user:
        intelligence_score -= 1.0
    return reward, agent_generated_speech


def learn_motivated_rules(prev_percept_id, prev_internal_state, action_id, reward,
                          current_percept_id, current_internal_state,
                          user_input_was_present,
                          agent_generated_speech=None, user_feedback=None):
    global motivated_rule_base, last_state_action_before_good_reward, intelligence_score, \
           NO_RESPONSE_PENALTY, RESPONSE_REWARD_BONUS, \
           MIN_SPEECH_LENGTH_FOR_BONUS, COHERENT_SPEECH_BONUS
    if prev_percept_id is None or prev_internal_state is None or current_percept_id is None or current_internal_state is None: return
    context_key = (prev_percept_id, prev_internal_state)
    if action_id not in motivated_rule_base[context_key]: motivated_rule_base[context_key][action_id] = create_motivated_rule_entry()
    rule_entry = motivated_rule_base[context_key][action_id]
    action_name = id_to_action.get(action_id)
    final_reward = reward
    responded_to_user = (action_name == "SPEAK" and user_input_was_present)
    ignored_user = (action_name != "SPEAK" and user_input_was_present)
    if ignored_user:
        final_reward += NO_RESPONSE_PENALTY
        rule_entry['confidence'] = max(0.0, rule_entry.get('confidence', 0.0) * 0.5 - CONFIDENCE_LEARNING_RATE * 5)
    elif responded_to_user:
        final_reward += RESPONSE_REWARD_BONUS
        rule_entry['confidence'] = min(1.0, rule_entry.get('confidence', 0.0) + CONFIDENCE_LEARNING_RATE * 4)
    if action_name == "SPEAK" and agent_generated_speech and len(agent_generated_speech) >= MIN_SPEECH_LENGTH_FOR_BONUS:
        final_reward += COHERENT_SPEECH_BONUS
        rule_entry['confidence'] = min(1.0, rule_entry.get('confidence', 0.0) + CONFIDENCE_LEARNING_RATE * 0.5)
        print(f"    [Learn] Bônus Coerência (+{COHERENT_SPEECH_BONUS:.1f})")
    if user_feedback == 'feed':
        final_reward += FEEDBACK_REWARD; intelligence_score += INTELLIGENCE_FEED_BONUS
        print(f"    [Learn] Reforço Feed (+{FEEDBACK_REWARD}) p/ {context_key}->{action_name}")
        rule_entry['confidence'] = min(1.0, rule_entry.get('confidence', 0.0) + CONFIDENCE_LEARNING_RATE * 5)
        if agent_generated_speech: print(f"    [Learn] Add fala recompensada '{agent_generated_speech}' ao buffer"); add_to_lm_buffer(agent_generated_speech, is_rewarded_speech=True)
    elif user_feedback == 'punish':
        final_reward += FEEDBACK_PUNISH; intelligence_score += INTELLIGENCE_PUNISH_PENALTY
        print(f"    [Learn] Punição Punish ({FEEDBACK_PUNISH}) p/ {context_key}->{action_name}")
        rule_entry['confidence'] = max(0.0, rule_entry.get('confidence', 0.0) - CONFIDENCE_LEARNING_RATE * 4)
    prev_action_was_speak = False
    if last_state_action_before_good_reward is not None:
        prev_context_speak, prev_action_id_speak = last_state_action_before_good_reward
        if id_to_action.get(prev_action_id_speak) == "SPEAK":
            prev_action_was_speak = True
            if reward > 5.0 and action_name != "SPEAK":
                if prev_context_speak in motivated_rule_base and prev_action_id_speak in motivated_rule_base[prev_context_speak]:
                    prev_rule_entry=motivated_rule_base[prev_context_speak][prev_action_id_speak]; prev_old_q=prev_rule_entry.get('expected_reward',0.); bonus_target_q=(reward + SPEECH_REWARD_BONUS)
                    prev_rule_entry['expected_reward']=prev_old_q+REWARD_LEARNING_RATE*0.7*(bonus_target_q - prev_old_q); prev_rule_entry['confidence']=min(1., prev_rule_entry.get('confidence',0.)+CONFIDENCE_LEARNING_RATE*3)
                last_state_action_before_good_reward = None
    if action_name == "SPEAK": last_state_action_before_good_reward = (context_key, action_id)
    elif reward <= 5.0 : last_state_action_before_good_reward = None
    next_context_key = (current_percept_id, current_internal_state); max_next_q_value = 0.0
    if next_context_key in motivated_rule_base and motivated_rule_base[next_context_key]:
        possible_next_actions = motivated_rule_base[next_context_key]
        q_values = [d.get('expected_reward', 0.0) for d in possible_next_actions.values()]
        if q_values: max_next_q_value = max(q_values)
    target_q_value = final_reward + DISCOUNT_FACTOR * max_next_q_value
    old_q_value = rule_entry.get('expected_reward', 0.0)
    rule_entry['expected_reward'] = old_q_value + REWARD_LEARNING_RATE * (target_q_value - old_q_value)
    rule_entry['count'] = rule_entry.get('count', 0) + 1
    prev_conf = rule_entry.get('confidence', 0.0); q_diff = rule_entry['expected_reward'] - old_q_value
    confidence_change = CONFIDENCE_LEARNING_RATE * (1.0 + q_diff * 0.5); confidence_decay = 0.998
    if not (user_feedback or ignored_user or responded_to_user or (action_name == "SPEAK" and agent_generated_speech and len(agent_generated_speech) >= MIN_SPEECH_LENGTH_FOR_BONUS)):
         rule_entry['confidence'] = min(1.0, max(0.0, prev_conf * confidence_decay + confidence_change))
    else: rule_entry['confidence'] = min(1.0, max(0.0, rule_entry['confidence'] * confidence_decay))

def prune_motivated_rules(min_confidence=0.01, min_count=3):
    global motivated_rule_base
    pruned_count = 0; contexts_to_prune = []
    for context_key, actions in list(motivated_rule_base.items()):
        action_ids_to_prune = []
        for action_id, data in list(actions.items()):
            if (data['confidence'] < min_confidence and data['count'] >= min_count) or \
               (data['expected_reward'] < -15.0 and data['count'] > 5): action_ids_to_prune.append(action_id)
        for action_id in action_ids_to_prune:
            if context_key in motivated_rule_base and action_id in motivated_rule_base[context_key]: del motivated_rule_base[context_key][action_id]; pruned_count += 1
        if context_key in motivated_rule_base and not motivated_rule_base[context_key]: contexts_to_prune.append(context_key)
    for key in contexts_to_prune:
        if key in motivated_rule_base: del motivated_rule_base[key]
    if pruned_count > 0: print(f"    [Pruning] Removidas {pruned_count} regras motivadas.")

def get_input_stream(total_steps, text_corpus):
    print(f"Simulando fluxo de entrada com {len(text_corpus)} textos base..."); step_count = 0; text_index = 0; char_index = 0
    while step_count < total_steps:
        if not text_corpus: yield random.choice(ALPHABET); step_count += 1; continue
        if text_index >= len(text_corpus): text_index = 0; random.shuffle(text_corpus)
        current_text = text_corpus[text_index]
        if not current_text: text_index += 1; continue
        while char_index < len(current_text) and step_count < total_steps:
            char = current_text[char_index];
            if char in ALPHABET: yield char; step_count += 1
            char_index += 1
        if step_count < total_steps:
             if current_text and current_text.strip() and current_text[-1] != ' ': yield ' '; step_count += 1
             text_index += 1; char_index = 0
             if text_index >= len(text_corpus) and step_count < total_steps - len(text_corpus)*2 : pass
             elif step_count >= total_steps: break
    print("\nFim do fluxo de entrada simulado.")

def pretrain_lm(num_batches=5000):
    global lm_training_buffer
    print(f"\n--- Iniciando Pré-treino do LM ({num_batches} batches) ---");
    pretrain_data = [ "oi", "ola", "tudo bem", "sim", "nao", "tchau", "bom dia", "boa noite", "como vai", "vou bem", "e voce", "obrigado", "de nada", "por favor", "a b c d e f g", "1 2 3 4 5", "hello world", "teste 123", "preciso de ajuda", "o que e isso", "sim esta tudo bem", "nao obrigado", "ate logo", "falar com voce", "preciso comer", "energia baixa", "cuidado perigo", "fome", "seguro", "perto terminal", "longe perigo", "comer comida", "ir terminal", "falar oi", "responder sim", "responder nao", "eu quero", "voce pode", "ajuda por favor", "me da comida", "esta com fome", "voce esta bem", "onde esta a comida", "comida perto", "perigo perto", "estou seguro", "preciso energia", "vou comer", "vou fugir", "mover norte", "mover sul", "mover leste", "mover oeste", "ir para comida", "ir para seguro", "ir para terminal", "responde", "fala comigo", "o que voce acha" ] * 500
    random.shuffle(pretrain_data); lm_losses = []; print("  Preenchendo buffer...");
    target_buffer_size = min(lm_training_buffer.maxlen, len(pretrain_data) * LM_CONTEXT_WINDOW // 2); items_added = 0
    for text in pretrain_data:
        add_to_lm_buffer(text); items_added += 1;
        if len(lm_training_buffer) >= target_buffer_size: break
    print(f"  Buffer LM preenchido com {len(lm_training_buffer)} ex ({items_added} textos).");
    if len(lm_training_buffer) < LM_BATCH_SIZE: print("  Aviso: Buffer insuficiente."); return
    print("  Iniciando batches...");
    for i in range(num_batches):
        loss = train_language_model()
        if loss is not None: lm_losses.append(loss)
        if (i + 1) % 500 == 0: avg_loss = sum(lm_losses)/len(lm_losses) if lm_losses else 0; print(f"  Pré-treino LM - Batch {i+1}/{num_batches}, Avg Loss: {avg_loss:.4f}"); lm_losses = []
        if len(lm_training_buffer) < LM_BATCH_SIZE * 5:
            print("  Repopulando buffer..."); add_count=0;
            for text in random.sample(pretrain_data, k=min(len(pretrain_data)//2, 1000)):
                 add_to_lm_buffer(text); add_count+=1;
                 if len(lm_training_buffer) >= target_buffer_size: break
    lm_training_buffer.clear(); print("--- Fim do Pré-treino do LM ---")


# --- chat_with_esma_t (Corrected loss initialization) ---
def chat_with_esma_t():
    # This function now RELIES on globals being set correctly before it's called
    global agent_state, byte_world, world_step, terminal_buffer, steps_stuck, last_pos, \
           language_working_memory, motivated_rule_base, perceptual_chunk_lexicon, \
           next_percept_chunk_id, chunk_cache, lm_training_buffer, lm_model, lm_optimizer, \
           last_state_action_before_good_reward, intelligence_score, USER_AVATAR_X, USER_AVATAR_Y, \
           current_step, last_agent_speech_context

    print("\n--- ESMA-T Chat (Agente com Transformer LM) ---")
    print(f"Iniciando chat no Passo: {world_step}.")
    print(f"  Chunks:{len(perceptual_chunk_lexicon)} | Regras:{sum(len(v) for v in motivated_rule_base.values())}")
    print("Comandos: /sair, /save, /passar, /feed, /punish, /place food X Y, /move u [n|s|e|w], /info, /help, /sim N, [texto]")

    # Reset chat-specific states if needed
    if 'language_working_memory' not in globals() or not isinstance(language_working_memory, deque):
        language_working_memory = deque(maxlen=LM_CONTEXT_WINDOW + 10)
    language_working_memory.clear()
    language_working_memory.append(idx_to_char[SOS_IDX])
    terminal_buffer = "" # Clear any leftover buffer
    last_agent_speech_context = None # Reset speech context
    pending_feedback = None # Reset pending feedback

    chat_steps=0; max_chat_steps=2000;
    # steps_stuck and last_pos should carry over from the loaded state

    while chat_steps < max_chat_steps:
        chat_steps += 1; print("-" * 40)
        if not agent_state: agent_state = {'x': -1, 'y': -1}; print("ERRO: agent_state vazio!")
        if not isinstance(byte_world, list) or not byte_world: print("ERRO: byte_world inválido!"); break

        print_world(byte_world, agent_state, terminal_buffer)
        user_input=""; command_processed=False; user_provided_input_this_turn=False; skip_agent_turn=False
        try:
            user_input = input("Tu: ").lower().strip()
            if user_input: user_provided_input_this_turn = True
            if user_input.startswith('/'):
                command_parts = user_input.split(); command = command_parts[0]
                if command == '/sair': print("Saindo..."); break
                elif command == '/save': raise KeyboardInterrupt
                elif command == '/passar': print("... (passando turno) ..."); command_processed = True
                elif command == '/feed':
                     if last_agent_speech_context: pending_feedback = 'feed'; print("    [Feedback] OK. Reforço aplicado.")
                     else: print("    (Nada p/ /feed)")
                     command_processed = True
                elif command == '/punish':
                     if last_agent_speech_context: pending_feedback = 'punish'; print("    [Feedback] OK. Punição aplicada.")
                     else: print("    (Nada p/ /punish)")
                     command_processed = True
                elif command == '/place' and len(command_parts)>=4 and command_parts[1]=='food':
                     try:
                         px = int(command_parts[2])
                         py = int(command_parts[3])
                         if 0<=py<WORLD_HEIGHT and 0<=px<WORLD_WIDTH and byte_world[py][px]==CELL_TYPES['empty']:
                             byte_world[py][px] = CELL_TYPES['food']
                             print(f"    Comida em ({px},{py})")
                         else:
                             print("    Inválido/Ocupado.")
                     except (ValueError, IndexError):
                         print("    Uso: /place food X Y")
                     command_processed = True
                elif command == '/move' and len(command_parts)>=3 and command_parts[1]=='u':
                     direction=command_parts[2]; ux,uy=USER_AVATAR_X,USER_AVATAR_Y; nux,nuy=ux,uy
                     if direction=='n' and uy>0: nuy-=1
                     elif direction=='s' and uy<WORLD_HEIGHT-1: nuy+=1
                     elif direction=='w' and ux>0: nux-=1
                     elif direction=='e' and ux<WORLD_WIDTH-1: nux+=1
                     else: print("    Direção inválida (n,s,e,w)"); command_processed=True; continue
                     if 0<=nuy<WORLD_HEIGHT and 0<=nux<WORLD_WIDTH and byte_world[nuy][nux]!=CELL_TYPES['wall'] and not (nux==agent_state.get('x') and nuy==agent_state.get('y')):
                          if 0<=USER_AVATAR_Y<len(byte_world) and 0<=USER_AVATAR_X<len(byte_world[USER_AVATAR_Y]): byte_world[USER_AVATAR_Y][USER_AVATAR_X]=CELL_TYPES['empty']
                          USER_AVATAR_X,USER_AVATAR_Y=nux,nuy; byte_world[USER_AVATAR_Y][USER_AVATAR_X]=CELL_TYPES['user_avatar']; print(f"    Avatar movido p/ ({USER_AVATAR_X},{USER_AVATAR_Y})")
                     else: print("    Movimento inválido p/ avatar.")
                     command_processed = True
                elif command == '/info':
                     print(f"    Info Agente:"); print(f"      Pos: ({agent_state.get('x',-1)},{agent_state.get('y',-1)}) E:{agent_state.get('energy',0.):.1f}/{MAX_ENERGY} C:{agent_state.get('curiosity',0.):.2f} S:{agent_state.get('safety',0.):.2f}")
                     print(f"      Intel:{intelligence_score:.2f} Stuck:{steps_stuck} LastAct:{id_to_action.get(agent_state.get('last_action_id'),'N/A')} LastRew:{agent_state.get('last_reward',0.):.2f}")
                     print(f"      Regras:{sum(len(v) for v in motivated_rule_base.values())} Chunks:{len(perceptual_chunk_lexicon)} LM Buf:{len(lm_training_buffer)}/{lm_training_buffer.maxlen}")
                     print(f"      Avatar:({USER_AVATAR_X},{USER_AVATAR_Y}) LastSpeechCtx:{last_agent_speech_context}")
                     command_processed = True
                elif command == '/help': print("    Comandos: /sair, /save, /passar, /feed, /punish, /place food X Y, /move u [n|s|e|w], /info, /help, /sim N, [texto]"); command_processed = True
                elif command == '/sim' and len(command_parts)>=2 and command_parts[1].isdigit():
                    num_sim_steps = int(command_parts[1])
                    if num_sim_steps <= 0: print("    /sim N > 0"); command_processed=True; continue
                    print(f"\n--- Simulação {num_sim_steps} passos ---"); sim_start_time=time.time(); sim_agent_died=False
                    for sim_step in range(num_sim_steps):
                        if agent_state.get('energy', 0) <= 0: print(f"\nAgente morreu sim @ {sim_step+1}!"); sim_agent_died=True; break
                        sim_user_input_present=False; sim_prev_state=copy.deepcopy(agent_state)
                        sim_perc=get_perception(sim_prev_state.get('x',0), sim_prev_state.get('y',0), byte_world); sim_flat=flatten_perception(sim_perc)
                        sim_prev_perc_id=chunk_perception(sim_flat); sim_prev_internal=get_simplified_internal_state(sim_prev_state)
                        if sim_prev_perc_id is None or sim_prev_internal is None: continue
                        sim_action_id=select_action(sim_prev_perc_id, sim_prev_internal, agent_state)
                        sim_reward, sim_speech = execute_action(agent_state, sim_action_id, byte_world, sim_user_input_present)
                        agent_state['last_action_id']=sim_action_id; agent_state['last_reward']=sim_reward
                        sim_new_perc=get_perception(agent_state.get('x',0), agent_state.get('y',0), byte_world); sim_new_flat=flatten_perception(sim_new_perc)
                        sim_new_perc_id=chunk_perception(sim_new_flat); sim_new_internal=get_simplified_internal_state(agent_state)
                        if sim_new_perc_id is not None and sim_new_internal is not None:
                            learn_motivated_rules(sim_prev_perc_id, sim_prev_internal, sim_action_id, sim_reward, sim_new_perc_id, sim_new_internal, user_input_was_present=sim_user_input_present, agent_generated_speech=sim_speech, user_feedback=None)
                        if sim_speech: add_to_lm_buffer(sim_speech); [update_working_memory(c) for c in sim_speech]
                        world_step += 1
                        if world_step % 20 == 0: train_language_model()
                        if world_step % prune_interval == 0: prune_motivated_rules()
                        if (sim_step + 1) % 500 == 0: print(f"  ... Sim passo {sim_step+1}/{num_sim_steps}")
                    sim_end_time=time.time(); elapsed_time=sim_end_time-sim_start_time; steps_done=sim_step+1 if sim_agent_died else num_sim_steps
                    print(f"--- Simulação concluída. {steps_done} passos em {elapsed_time:.2f}s ({steps_done/elapsed_time:.1f} steps/s avg) ---")
                    command_processed = True; skip_agent_turn = True
                else: print(f"    Comando desconhecido: {user_input}"); command_processed = True
            elif user_input:
                terminal_buffer = user_input[:25]; print(f"Terminal p/ agente: [{terminal_buffer}]")
                add_to_lm_buffer(terminal_buffer, is_user_input=True); [update_working_memory(c) for c in terminal_buffer]; command_processed = False
        except EOFError: print("\nSaindo (EOF)..."); break
        except KeyboardInterrupt: print("\nSalvando e saindo (Interrupção)..."); break
        if user_input == '/sair' or 'KeyboardInterrupt' in str(sys.exc_info()[0]) or 'EOFError' in str(sys.exc_info()[0]): break

        if not skip_agent_turn:
            agent_died = False
            # **** Initialize loss variables for the turn ****
            loss = None
            lm_loss_turn = None
            # ***********************************************

            if agent_state.get('energy', 0) <= 0: print("\nAgente sem energia!"); agent_died = True;
            if not agent_died:
                user_input_was_present_for_agent = bool(terminal_buffer)
                prev_agent_state = copy.deepcopy(agent_state)
                perception_turn = get_perception(prev_agent_state.get('x',0), prev_agent_state.get('y',0), byte_world); flat_perception_turn = flatten_perception(perception_turn)
                prev_percept_id_turn = chunk_perception(flat_perception_turn); prev_internal_tuple_turn = get_simplified_internal_state(prev_agent_state)
                if prev_percept_id_turn is None or prev_internal_tuple_turn is None: print("Aviso: Falha percepção/estado."); action_id = action_to_id[random.choice(ACTIONS)]
                else: action_id = select_action(prev_percept_id_turn, prev_internal_tuple_turn, agent_state)
                world_step += 1; immediate_reward, agent_speech_output = execute_action(agent_state, action_id, byte_world, user_input_was_present_for_agent)
                if user_input_was_present_for_agent: terminal_buffer = ""
                agent_state['last_action_id'] = action_id; agent_state['last_reward'] = immediate_reward
                new_perception_turn = get_perception(agent_state.get('x',0), agent_state.get('y',0), byte_world); new_flat_perception_turn = flatten_perception(new_perception_turn)
                new_percept_chunk_id_turn = chunk_perception(new_flat_perception_turn); new_internal_state_tuple_turn = get_simplified_internal_state(agent_state)
                if new_percept_chunk_id_turn is not None and new_internal_state_tuple_turn is not None and prev_percept_id_turn is not None and prev_internal_tuple_turn is not None:
                    learn_motivated_rules(prev_percept_id_turn, prev_internal_tuple_turn, action_id, immediate_reward, new_percept_chunk_id_turn, new_internal_state_tuple_turn, user_input_was_present=user_input_was_present_for_agent, agent_generated_speech=agent_speech_output, user_feedback=pending_feedback)
                else: print("Aviso: Estado inválido pós-ação. Pulando aprendizado.")
                pending_feedback = None
                if agent_speech_output: [update_working_memory(c) for c in agent_speech_output];

                # Train LM periodically
                if world_step % 5 == 0:
                    loss = train_language_model() # Assign to loss

                # Check if loss was calculated and store it
                if loss is not None:
                     lm_loss_turn = loss

                # Prune rules
                if world_step % prune_interval == 0: prune_motivated_rules()
                # Random food
                if world_step % 150 == 0 and random.random() < 0.15:
                      if isinstance(byte_world, list) and byte_world: fx, fy = random.randint(1, WORLD_WIDTH-2), random.randint(1, WORLD_HEIGHT-2);
                      if 0 <= fy < len(byte_world) and 0 <= fx < len(byte_world[fy]) and byte_world[fy][fx] == CELL_TYPES['empty']: byte_world[fy][fx] = CELL_TYPES['food']

            # Print loss if it was calculated this turn
            if lm_loss_turn is not None:
                print(f"  (LM Loss: {lm_loss_turn:.4f})")
            if agent_died: print("\nAgente morreu. Fim."); break
    print("\nFim do Chat.")


# --- Ponto de Entrada Principal ---
if __name__ == "__main__":
    # --- Instantiate LM and Optimizer ---
    lm_model = MiniTransformerLM(VOCAB_SIZE, LM_D_MODEL, LM_NHEAD, LM_NUM_LAYERS, LM_DIM_FEEDFORWARD, LM_DROPOUT, LM_CONTEXT_WINDOW + 5).to(device)
    lm_optimizer = optim.Adam(lm_model.parameters(), lr=LM_LEARNING_RATE)
    print(f"Modelo Transformer LM instanciado com {sum(p.numel() for p in lm_model.parameters()):,} parâmetros.")

    # --- Load/Init State ---
    save_path_esma = 'esma_transformer_state.pkl'
    loaded_successfully = False
    if os.path.exists(save_path_esma):
        print(f"Carregando estado ESMA-T de {save_path_esma}...")
        try:
            loaded_data = torch.load(save_path_esma, map_location='cpu')
            required_keys = ['agent_state', 'world_step', 'byte_world', 'perceptual_chunk_lexicon', 'next_percept_chunk_id', 'motivated_rule_base', 'lm_model_state_dict', 'lm_optimizer_state_dict', 'current_step', 'last_state_action_before_good_reward', 'intelligence_score', 'USER_AVATAR_X', 'USER_AVATAR_Y']
            if isinstance(loaded_data, dict) and all(k in loaded_data for k in required_keys):
                agent_state = loaded_data['agent_state']; world_step = loaded_data['world_step']; byte_world = loaded_data['byte_world']
                perceptual_chunk_lexicon = loaded_data['perceptual_chunk_lexicon']; next_percept_chunk_id = loaded_data['next_percept_chunk_id']
                motivated_rule_base = defaultdict(create_motivated_context_dict)
                saved_rules = loaded_data.get('motivated_rule_base', {})
                for context_str, next_options in saved_rules.items():
                    try: context_tuple = eval(context_str)
                    except: continue
                    if isinstance(context_tuple, tuple) and len(context_tuple) == 2: motivated_rule_base[context_tuple].update({int(k): v for k, v in next_options.items() if isinstance(v, dict) and str(k).isdigit()})
                lm_model.load_state_dict(loaded_data['lm_model_state_dict']); lm_model.to(device)
                lm_optimizer.load_state_dict(loaded_data['lm_optimizer_state_dict'])
                for state in lm_optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor): state[k] = v.to(device)
                current_step = loaded_data.get('current_step', world_step); last_state_action_before_good_reward = loaded_data.get('last_state_action_before_good_reward')
                intelligence_score = loaded_data.get('intelligence_score', 0.0); USER_AVATAR_X = loaded_data.get('USER_AVATAR_X', TERMINAL_X + 1); USER_AVATAR_Y = loaded_data.get('USER_AVATAR_Y', TERMINAL_Y)

                if not (0 <= USER_AVATAR_Y < WORLD_HEIGHT and 0 <= USER_AVATAR_X < WORLD_WIDTH):
                    print("Aviso: Posição do Avatar carregada inválida. Resetando.")
                    USER_AVATAR_X, USER_AVATAR_Y = TERMINAL_X + 1, TERMINAL_Y
                try:
                    if not byte_world or USER_AVATAR_Y >= len(byte_world) or USER_AVATAR_X >= len(byte_world[0]): raise IndexError("Avatar coordinates out of bounds")
                    if byte_world[USER_AVATAR_Y][USER_AVATAR_X] != CELL_TYPES['user_avatar']:
                        print(f"Aviso: Célula ({USER_AVATAR_X}, {USER_AVATAR_Y}) não é avatar. Tentando realocar...")
                        if byte_world[USER_AVATAR_Y][USER_AVATAR_X] == CELL_TYPES['empty']: byte_world[USER_AVATAR_Y][USER_AVATAR_X] = CELL_TYPES['user_avatar']; print("    Célula marcada como avatar.")
                        else:
                             placed = False
                             for _ in range(100):
                                 ny,nx=TERMINAL_Y+random.randint(-1,1),TERMINAL_X+random.randint(-1,1)
                                 if 0<=ny<WORLD_HEIGHT and 0<=nx<WORLD_WIDTH:
                                     if byte_world[ny][nx]==CELL_TYPES['empty']:
                                         USER_AVATAR_X=nx; USER_AVATAR_Y=ny; byte_world[ny][nx]=CELL_TYPES['user_avatar']; placed=True
                                         print(f"    Avatar realocado para ({USER_AVATAR_X}, {USER_AVATAR_Y})."); break
                             if not placed:
                                 print("    Falha ao realocar. Colocando em (1,1)."); USER_AVATAR_X,USER_AVATAR_Y=1,1
                                 if 1<WORLD_HEIGHT and 1<WORLD_WIDTH and byte_world[1][1]==CELL_TYPES['empty']: byte_world[1][1]=CELL_TYPES['user_avatar']
                                 else: print("AVISO: Fallback (1,1) ocupado/inválido!")
                    else: byte_world[USER_AVATAR_Y][USER_AVATAR_X] = CELL_TYPES['user_avatar']
                except IndexError as ie:
                     print(f"ERRO GRAVE: Índice do avatar ({USER_AVATAR_X},{USER_AVATAR_Y}) fora dos limites! Resetando. Erro: {ie}"); USER_AVATAR_X, USER_AVATAR_Y = 1, 1
                     if WORLD_HEIGHT > 1 and WORLD_WIDTH > 1 and len(byte_world) > 1 and len(byte_world[0]) > 1:
                          if byte_world[1][1]==CELL_TYPES['empty']: byte_world[1][1]=CELL_TYPES['user_avatar']
                          else: print("AVISO: Fallback (1,1) ocupado/inválido!")
                     else: print("AVISO: Mundo inválido para fallback (1,1).")

                print(f"Estado ESMA-T carregado (Passo: {world_step}, Intel: {intelligence_score:.2f})."); loaded_successfully = True
            else: print("Arq estado inválido.")
        except EOFError: print("Erro: Arq .pkl vazio/corrompido.")
        except Exception as e: print(f"Erro fatal ao carregar: {e}."); traceback.print_exc()

    # --- Initialization from Zero ---
    if not loaded_successfully:
         print("Iniciando do zero.");
         agent_state = {'x':WORLD_WIDTH//2,'y':WORLD_HEIGHT//2,'energy':MAX_ENERGY,'curiosity':0.,'safety':1.,'last_action_id':None,'last_reward':0.,'last_speech':"",'last_speech_rewarded':None}
         language_working_memory.clear(); perceptual_chunk_lexicon={}; next_percept_chunk_id=0
         terminal_buffer=""; world_step=0; steps_stuck=0; last_pos=(-1,-1); chunk_cache={}
         motivated_rule_base=defaultdict(create_motivated_context_dict)
         USER_AVATAR_X,USER_AVATAR_Y=TERMINAL_X+1,TERMINAL_Y
         initialize_world(); # Initialize global byte_world
         current_step=0
         last_state_action_before_good_reward=None; intelligence_score=0.; last_agent_speech_context=None
         lm_optimizer = optim.Adam(lm_model.parameters(), lr=LM_LEARNING_RATE) # Re-init optimizer
         pretrain_lm(num_batches=6000)

    # --- Optional Silent Simulation ---
    run_simulation_before_chat = False
    if run_simulation_before_chat:
        print("\n--- Simulação Silenciosa ---"); sim_steps=10000; sim_start_time=time.time(); sim_agent_died=False
        for step in range(sim_steps):
            if agent_state.get('energy', 0) <= 0: print(f"Agente morreu sim @ {step+1}."); sim_agent_died=True; break
            sim_user_input_present=False; sim_prev_state=copy.deepcopy(agent_state)
            sim_perc=get_perception(sim_prev_state.get('x',0), sim_prev_state.get('y',0), byte_world); sim_flat=flatten_perception(sim_perc); sim_prev_perc_id=chunk_perception(sim_flat); sim_prev_internal=get_simplified_internal_state(sim_prev_state)
            if sim_prev_perc_id is None or sim_prev_internal is None: continue
            sim_action_id=select_action(sim_prev_perc_id, sim_prev_internal, agent_state)
            sim_reward, sim_speech = execute_action(agent_state, sim_action_id, byte_world, sim_user_input_present)
            agent_state['last_action_id']=sim_action_id; agent_state['last_reward']=sim_reward
            sim_new_perc=get_perception(agent_state.get('x',0), agent_state.get('y',0), byte_world); sim_new_flat=flatten_perception(sim_new_perc); sim_new_perc_id=chunk_perception(sim_new_flat); sim_new_internal=get_simplified_internal_state(agent_state)
            if sim_new_perc_id is not None and sim_new_internal is not None:
                learn_motivated_rules(sim_prev_perc_id, sim_prev_internal, sim_action_id, sim_reward, sim_new_perc_id, sim_new_internal, user_input_was_present=sim_user_input_present, agent_generated_speech=sim_speech, user_feedback=None)
            if sim_speech: add_to_lm_buffer(sim_speech); [update_working_memory(c) for c in sim_speech]
            world_step += 1
            if world_step % 20 == 0: train_language_model()
            if world_step % prune_interval == 0: prune_motivated_rules()
            if (step + 1) % 1000 == 0: print(f"  Sim passo {step+1}/{sim_steps}")
        sim_end_time=time.time(); steps_done_sim=step+1 if sim_agent_died else sim_steps; print(f"--- Fim Sim ({steps_done_sim} passos em {sim_end_time-sim_start_time:.2f}s) ---")

    # --- Start Interactive Chat ---
    if lm_model is None or lm_optimizer is None or not byte_world or not agent_state:
         print("ERRO CRÍTICO: Estado global não inicializado corretamente antes do chat.")
    else:
        print("\n--- Iniciando Chat Interativo ---")
        try:
            chat_with_esma_t()
        except KeyboardInterrupt: print("\nChat Interrompido.")
        except Exception as e: print(f"\nErro no chat: {e}"); traceback.print_exc()
        finally:
            save_path_final = 'esma_transformer_state_final.pkl'
            print(f"\nSalvando estado final em {save_path_final}...")
            try:
                rule_base_mot_to_save = {}
                for context_k, actions_v in motivated_rule_base.items():
                     context_str = str(context_k); rule_base_mot_to_save[context_str] = {str(k): v for k, v in actions_v.items()}
                state_to_save = {'agent_state':agent_state, 'world_step':world_step, 'byte_world':byte_world, 'perceptual_chunk_lexicon':perceptual_chunk_lexicon, 'next_percept_chunk_id':next_percept_chunk_id, 'motivated_rule_base':rule_base_mot_to_save, 'current_step':current_step, 'last_state_action_before_good_reward':last_state_action_before_good_reward, 'intelligence_score':intelligence_score, 'USER_AVATAR_X':USER_AVATAR_X, 'USER_AVATAR_Y':USER_AVATAR_Y, 'lm_model_state_dict':lm_model.state_dict(), 'lm_optimizer_state_dict':lm_optimizer.state_dict()}
                torch.save(state_to_save, save_path_final); print("Estado final salvo.")
            except Exception as e: print(f"Erro ao salvar estado final: {e}")

    print("\nExecução principal concluída.")

Vocabulário Transformer (40): <PAD> <SOS> <EOS> a b c d e f g h i j k l m n o p q r s t u v w x y z 0 1 2 3 4 5 6 7 8 9  
Dispositivo: cuda
Modelo Transformer LM instanciado com 540,200 parâmetros.
Iniciando do zero.
Inicializando ByteWorld...
ByteWorld inicializado.

--- Iniciando Pré-treino do LM (6000 batches) ---
  Preenchendo buffer...
  Buffer LM preenchido com 1280 ex (111 textos).
  Iniciando batches...
  Pré-treino LM - Batch 500/6000, Avg Loss: 1.5685
  Pré-treino LM - Batch 1000/6000, Avg Loss: 0.4748
  Pré-treino LM - Batch 1500/6000, Avg Loss: 0.3743
  Pré-treino LM - Batch 2000/6000, Avg Loss: 0.3462
  Pré-treino LM - Batch 2500/6000, Avg Loss: 0.3360
  Pré-treino LM - Batch 3000/6000, Avg Loss: 0.3350
  Pré-treino LM - Batch 3500/6000, Avg Loss: 0.3267
  Pré-treino LM - Batch 4000/6000, Avg Loss: 0.3296
  Pré-treino LM - Batch 4500/6000, Avg Loss: 0.3238
  Pré-treino LM - Batch 5000/6000, Avg Loss: 0.3240
  Pré-treino LM - Batch 5500/6000, Avg Loss: 0.3143
  Pré-treino L

Traceback (most recent call last):
  File "<ipython-input-7-9488f5b748d2>", line 967, in <cell line: 0>
    chat_with_esma_t() # Chat function now assumes globals are ready
    ^^^^^^^^^^^^^^^^^^
  File "<ipython-input-7-9488f5b748d2>", line 826, in chat_with_esma_t
    if loss is not None: lm_loss_turn = loss
       ^^^^
UnboundLocalError: cannot access local variable 'loss' where it is not associated with a value
