In [2]:
import os
# For Colab/Google Drive integration:
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/FinRL/final')  # Change to your project folder in Drive

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import os
import json
import matplotlib.pyplot as plt
from collections import deque
import random
from typing import Tuple, List, Dict, Any
import warnings
warnings.filterwarnings('ignore')

# ==================== IMPROVED HYPERPARAMETERS ====================
PPO_VERSION = "5_7_3"
CLIP_RATIO = 0.2
LAMBDA_GAE = 0.95
ENTROPY_COEF = 0.05
VALUE_COEF = 0.5
MAX_GRAD_NORM = 0.5
TARGET_KL = 0.02
NUM_EPOCHS = 6
BATCH_SIZE = 64
HORIZON_LEN = 4096
GAMMA = 0.99
GAE_LAMBDA = 0.95
ADV_NORM = True
VALUE_NORM = True
RETURN_NORM = True

# Training parameters
LEARNING_RATE = 3e-4
TOTAL_STEPS = 2000000
SAVE_GAP = 50000
EVAL_TIMES = 8
REPEAT_TIMES = 4
REWARD_SCALE = 1e-1

# Environment parameters
NUM_SIMS = 64
STEP_GAP = 2
MAX_POSITION = 1
SLIPPAGE = 7e-7
NUM_IGNORE_STEP = 60

NET_DIMS = [256, 256, 128, 128, 64]
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4

# Environment parameters
STEP_GAP = 2
MAX_POSITION = 1
SLIPPAGE = 7e-7
NUM_IGNORE_STEP = 60
NORMALIZE_LLM_SIGNALS = True

def check_cuda():
    """Check CUDA availability and print status"""
    try:
        if torch.cuda.is_available():
            print(f"CUDA version: {torch.version.cuda}")
            print(f"GPU count: {torch.cuda.device_count()}")
            if torch.cuda.device_count() > 0:
                print(f"Current GPU: {torch.cuda.current_device()}")
                print(f"GPU name: {torch.cuda.get_device_name(0)}")
            return True
        else:
            print("CUDA is not available")
            return False
    except Exception as e:
        print(f"Error checking CUDA: {e}")
        return False

# ====================PPO NETWORK ARCHITECTURE ====================

class SimplePPOActor(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, net_dims: List[int]):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        #MLP layers
        layers = []
        input_dim = state_dim

        for i, dim in enumerate(net_dims):
            layers.extend([
                nn.Linear(input_dim, dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])

            if i < len(net_dims) - 1:
                layers.append(nn.LayerNorm(dim))
            input_dim = dim

        # Output layer
        layers.append(nn.Linear(input_dim, action_dim))

        self.net = nn.Sequential(*layers)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=0.01)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state)

    def get_action_probs(self, state: torch.Tensor) -> torch.Tensor:
        """Get action probabilities using softmax"""
        logits = self.forward(state)
        return F.softmax(logits, dim=-1)

    def get_action_log_probs(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """Get log probabilities for specific actions"""
        action_probs = self.get_action_probs(state)
        log_probs = torch.log(action_probs + 1e-8)
        action = action.unsqueeze(-1)
        gathered = log_probs.gather(-1, action).squeeze(-1)

        return gathered

class SimplePPOCritic(nn.Module):
    """Critic network for PPO"""

    def __init__(self, state_dim: int, net_dims: List[int]):
        super().__init__()
        self.state_dim = state_dim

        # Build MLP layers
        layers = []
        input_dim = state_dim

        for i, dim in enumerate(net_dims):
            layers.extend([
                nn.Linear(input_dim, dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])

            if i < len(net_dims) - 1:
                layers.append(nn.LayerNorm(dim))
            input_dim = dim

        # Output layer for single value
        layers.append(nn.Linear(input_dim, 1))

        self.net = nn.Sequential(*layers)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=1.0)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state).squeeze(-1)

# ==================== PPO AGENT ====================

class ImprovedPPOTrader:
    """PPO agent"""

    def __init__(self, state_dim: int, action_dim: int, net_dims: List[int],
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.state_dim = state_dim
        self.action_dim = action_dim

        # Networks
        self.actor = SimplePPOActor(state_dim, action_dim, net_dims).to(device)
        self.critic = SimplePPOCritic(state_dim, net_dims).to(device)

        self.optimizer = optim.Adam([
            {'params': self.actor.parameters(), 'lr': ACTOR_LR},
            {'params': self.critic.parameters(), 'lr': CRITIC_LR}
        ])
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=TOTAL_STEPS, eta_min=LEARNING_RATE*0.1)

        # PPO buffers
        self.states = []
        self.actions = []
        self.rewards = []
        self.values = []
        self.log_probs = []
        self.dones = []

        # Training stats
        self.total_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []

    def select_action(self, state: torch.Tensor, training: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Select action using current policy"""
        with torch.no_grad():
            action_probs = self.actor.get_action_probs(state)

            if training:
                temperature = 1.2  # Slightly higher temperature for more exploration
                logits = self.actor(state)
                scaled_logits = logits / temperature
                scaled_probs = F.softmax(scaled_logits, dim=-1)
                dist = torch.distributions.Categorical(scaled_probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)
            else:
                # Greedy action selection
                action = torch.argmax(action_probs, dim=-1)
                log_prob = torch.zeros_like(action, dtype=torch.float)

            value = self.critic(state)

        return action, log_prob, value

    def store_transition(self, state: torch.Tensor, action: torch.Tensor,
                        reward: float, value: torch.Tensor, log_prob: torch.Tensor, done: bool):
        """Store transition in PPO buffer"""
        self.states.append(state.cpu().numpy())
        self.actions.append(action.cpu().numpy())
        self.rewards.append(reward)
        self.values.append(value.cpu().numpy())
        self.log_probs.append(log_prob.cpu().numpy())
        self.dones.append(done)

    def compute_gae(self, rewards: List[float], values: List[float],
                   dones: List[bool], gamma: float = GAMMA, lambda_: float = GAE_LAMBDA) -> Tuple[np.ndarray, np.ndarray]:
        """Compute Generalized Advantage Estimation"""
        advantages = np.zeros_like(rewards)
        last_gae_lam = 0

        for t in reversed(range(len(rewards))):
            if t == len(rewards) - 1:
                next_value = 0
            else:
                next_value = values[t + 1]

            delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
            advantages[t] = last_gae_lam = delta + gamma * lambda_ * (1 - dones[t]) * last_gae_lam

        returns = advantages + values
        return advantages, returns

    def update_policy(self) -> Dict[str, float]:
        """Policy update"""
        if len(self.states) < BATCH_SIZE:
            return {}

        # Convert to tensors
        states = torch.FloatTensor(np.array(self.states)).to(self.device)
        actions = torch.LongTensor(np.array(self.actions)).to(self.device)
        old_log_probs = torch.FloatTensor(np.array(self.log_probs)).to(self.device)
        old_values = torch.FloatTensor(np.array(self.values)).to(self.device)

        # Compute GAE
        advantages, returns = self.compute_gae(self.rewards, self.values, self.dones)
        advantages = torch.FloatTensor(advantages).to(self.device)
        returns = torch.FloatTensor(returns).to(self.device)

        # Normalize advantages
        if ADV_NORM:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # Normalize returns
        if RETURN_NORM:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        actor_losses = []
        critic_losses = []
        kl_divs = []
        entropy_values = []

        for epoch in range(NUM_EPOCHS):
            # Shuffle data
            indices = torch.randperm(len(states))

            for start_idx in range(0, len(states), BATCH_SIZE):
                end_idx = min(start_idx + BATCH_SIZE, len(states))
                batch_indices = indices[start_idx:end_idx]

                batch_states = states[batch_indices]
                batch_actions = actions[batch_indices]
                batch_old_log_probs = old_log_probs[batch_indices]
                batch_advantages = advantages[batch_indices]
                batch_returns = returns[batch_indices]

                # Get current policy outputs
                current_log_probs = self.actor.get_action_log_probs(batch_states, batch_actions)
                current_values = self.critic(batch_states)

                # PPO clipped objective
                ratio = torch.exp(current_log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - CLIP_RATIO, 1 + CLIP_RATIO) * batch_advantages
                actor_loss = -torch.min(surr1, surr2).mean()

                value_loss = F.mse_loss(current_values, batch_returns)

                # Entropy bonus
                action_probs = self.actor.get_action_probs(batch_states)
                entropy = -(action_probs * torch.log(action_probs + 1e-8)).sum(dim=-1).mean()

                # Total loss
                total_loss = actor_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

                self.optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    list(self.actor.parameters()) + list(self.critic.parameters()),
                    MAX_GRAD_NORM
                )
                self.optimizer.step()

                # Compute KL divergence for monitoring
                with torch.no_grad():
                    kl_div = (batch_old_log_probs - current_log_probs).mean()

                actor_losses.append(actor_loss.item())
                critic_losses.append(value_loss.item())
                kl_divs.append(kl_div.item())
                entropy_values.append(entropy.item())

            avg_kl = np.mean(kl_divs[-len(kl_divs)//NUM_EPOCHS:])
            if avg_kl > TARGET_KL:
                print(f"Early stopping at epoch {epoch} due to high KL divergence: {avg_kl:.6f}")
                break

        # Clear buffer
        self.states.clear()
        self.actions.clear()
        self.rewards.clear()
        self.values.clear()
        self.log_probs.clear()
        self.dones.clear()

        return {
            'actor_loss': np.mean(actor_losses),
            'critic_loss': np.mean(critic_losses),
            'kl_div': np.mean(kl_divs),
            'entropy': np.mean(entropy_values),
            'epochs_completed': epoch + 1
        }

    def save_agent(self, save_dir: str):
        """Save agent models"""
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.actor.state_dict(), os.path.join(save_dir, 'actor.pth'))
        torch.save(self.critic.state_dict(), os.path.join(save_dir, 'critic.pth'))

        # Save training stats
        stats = {
            'total_steps': self.total_steps,
            'episode_rewards': self.episode_rewards,
            'episode_lengths': self.episode_lengths
        }
        with open(os.path.join(save_dir, 'training_stats.json'), 'w') as f:
            json.dump(stats, f)

    def save_or_load_agent(self, cwd: str, if_save: bool):
        """Save or load agent models - compatibility method for evaluation"""
        if if_save:
            self.save_agent(cwd)
        else:
            return self.load_agent(cwd)

    def load_agent(self, load_dir: str):
        """Load agent models"""
        actor_path = os.path.join(load_dir, 'actor.pth')
        critic_path = os.path.join(load_dir, 'critic.pth')

        if os.path.exists(actor_path) and os.path.exists(critic_path):
            self.actor.load_state_dict(torch.load(actor_path, map_location=self.device))
            self.critic.load_state_dict(torch.load(critic_path, map_location=self.device))
            print(f"Loaded PPO agent from {load_dir}")
            return True
        else:
            print(f"Model files not found in {load_dir}")
            return False

# ==================== IMPROVED TRAINING FUNCTION ====================

def train_improved_ppo_agent(timeframe: str = '1sec') -> ImprovedPPOTrader:
    """Train an improved PPO agent for the specified timeframe"""
    # Check CUDA status first
    cuda_available = check_cuda()

    try:
        if cuda_available and torch.cuda.is_available():
            device = torch.device("cuda:0")
            gpu_id = 0
            print(f"CUDA available, using GPU: {device}")
        else:
            device = torch.device("cpu")
            gpu_id = -1
            print(f"CUDA not available, using CPU: {device}")

        # Create environment
        env = TradeSimulator(
            num_sims=NUM_SIMS,
            slippage=SLIPPAGE,
            max_position=MAX_POSITION,
            step_gap=STEP_GAP,
            gpu_id=gpu_id,
            timeframe=timeframe
        )

    except Exception as e:
        print(f"Error creating environment: {e}")
        print("Trying with CPU fallback...")

        device = torch.device("cpu")
        gpu_id = -1

        env = TradeSimulator(
            num_sims=NUM_SIMS,
            slippage=SLIPPAGE,
            max_position=MAX_POSITION,
            step_gap=STEP_GAP,
            gpu_id=gpu_id,
            timeframe=timeframe
        )

    print(f"Environment ready: {env.full_seq_len} steps available")
    print(f"Environment info: state_dim={env.state_dim}, action_dim={env.action_dim}")

    # Create agent
    state_dim = env.state_dim
    action_dim = env.action_dim

    agent = ImprovedPPOTrader(state_dim, action_dim, NET_DIMS, device=str(device))
    print(f"Improved PPO agent created successfully on {device}")

    initial_state = env.reset()
    print(f"Environment initialized, state shape: {initial_state.shape}")

    # Training loop
    episode_reward = 0
    episode_length = 0
    update_count = 0

    print(f"Starting training for {TOTAL_STEPS} steps...")

    for step in range(TOTAL_STEPS):
        # Get current state (first simulation)
        state = env.get_state(env.step_is)
        state_single = state[0:1]  # Take first simulation for single agent training

        # Select action
        action, log_prob, value = agent.select_action(state_single, training=True)

        # Convert action from [0,1,2] to [-1,0,1] for environment (SHORT, HOLD, LONG)
        action_env = (action - 1).expand(env.num_sims)  # Convert to environment action space

        # Execute action
        next_state, reward, terminal, info = env.step(action_env)

        #Store transition with proper reward scaling and shaping
        base_reward = reward[0].item() * REWARD_SCALE

        #Reward shaping to encourage trading
        current_position = env.position[0].item()
        prev_position = getattr(env, '_prev_position', 0)
        position_change = abs(current_position - prev_position)
        trading_bonus = 0.001 * position_change  # Small bonus for taking positions

        #Profit incentive
        profit_incentive = 0.01 * base_reward if base_reward > 0 else 0

        reward_single = base_reward + trading_bonus + profit_incentive
        done_single = terminal[0].item()
        agent.store_transition(state_single, action, reward_single, value, log_prob, done_single)

        # Store previous position for next step
        env._prev_position = env.position[0].item()

        #Enhanced debugging for asset tracking
        if step % 1000 == 0:  # Debug every 1000 steps
            current_price = env.price_ary[env.step_is[0] + env.step_i, 2].item()
            print(f"🔍 Debug: Step {step} - Asset=${env.asset[0].item():.2f}, "
                  f"Cash=${env.cash[0].item():.2f}, "
                  f"Position={env.position[0].item()}, "
                  f"Price=${current_price:.2f}, "
                  f"Reward={reward_single:.6f}")

        # Update episode stats
        episode_reward += reward_single
        episode_length += 1

        # Update policy when buffer is full
        if len(agent.states) >= HORIZON_LEN:
            update_stats = agent.update_policy()
            update_count += 1

            # Learning rate scheduling for longer training
            agent.scheduler.step()

            # Better logging frequency
            if update_count % 50 == 0:
                print(f"Step {step}: Actor Loss: {update_stats.get('actor_loss', 0):.6f}, "
                      f"Critic Loss: {update_stats.get('critic_loss', 0):.6f}, "
                      f"KL Div: {update_stats.get('kl_div', 0):.6f}, "
                      f"Entropy: {update_stats.get('entropy', 0):.6f}")

        # Reset episode if done
        if done_single:
            agent.episode_rewards.append(episode_reward)
            agent.episode_lengths.append(episode_length)

            # Enhanced episode logging with asset tracking
            final_asset = env.asset[0].item()
            final_cash = env.cash[0].item()
            if len(agent.episode_rewards) % 20 == 0:
                avg_reward = np.mean(agent.episode_rewards[-20:])
                avg_length = np.mean(agent.episode_lengths[-20:])
                print(f"Episode {len(agent.episode_rewards)}: Avg Reward: {avg_reward:.6f}, Avg Length: {avg_length:.1f}, "
                      f"Final Asset=${final_asset:.2f}, Final Cash=${final_cash:.2f}")

            episode_reward = 0
            episode_length = 0
            env.reset()

        # Save agent periodically
        if step % SAVE_GAP == 0 and step > 0:
            save_dir = f"./trained_agents/{PPO_VERSION}/{timeframe}/PPO"
            agent.save_agent(save_dir)
            print(f"Saved agent at step {step}")

        # Update total steps
        agent.total_steps = step

    # Final save
    save_dir = f"./trained_agents/{PPO_VERSION}/{timeframe}/PPO"
    agent.save_agent(save_dir)
    print(f"Training completed! Agent saved to {save_dir}")

    return agent

# ==================== TRADING ENVIRONMENT ====================

class TradeSimulator:
    def __init__(self, num_sims=32, slippage=5e-5, max_position=1, step_gap=1,
                 delay_step=1, num_ignore_step=60, device=torch.device("cpu"), gpu_id=-1, timeframe='1sec'):
        self.device = torch.device(f"cuda:{gpu_id}") if gpu_id >= 0 else device
        self.num_sims = num_sims
        self.slippage = slippage
        self.delay_step = delay_step
        self.max_holding = 60 * 60 // step_gap
        self.max_position = max_position
        self.step_gap = step_gap
        self.sim_ids = torch.arange(self.num_sims, device=self.device)
        self.timeframe = timeframe

        self.load_data()

        self.env_name = "TradeSimulator-v0"
        self.state_dim = 2 + 8 + 2
        self.action_dim = 3
        self.if_discrete = True
        self.max_step = (self.seq_len - num_ignore_step) // step_gap
        self.target_return = +np.inf

        self.best_price = torch.zeros((num_sims,), dtype=torch.float32, device=device)
        self.stop_loss_thresh = 1e-3

    def load_data(self):
        # Load training predictions with multiple fallback options
        possible_paths = [
            f"output/{self.timeframe}/train_predictions.npy",
            f"output/{self.timeframe}/predictions.npy",
            f"../output/{self.timeframe}/train_predictions.npy",
            f"../output/{self.timeframe}/predictions.npy"
        ]

        factor_path = None
        for path in possible_paths:
            if os.path.exists(path):
                factor_path = path
                print(f"Found predictions at: {path}")
                break

        if factor_path is None:
            # Quit training if no predictions found
            print("Error: No prediction files found in any of the expected locations:")
            for path in possible_paths:
                print(f"   - {path}")
            print("Please ensure the prediction files exist before running training.")
            raise FileNotFoundError("Required prediction files not found. Cannot proceed with training.")

        self.factor_ary = np.load(factor_path)
        self.factor_ary = torch.tensor(self.factor_ary, dtype=torch.float32)

        if self.factor_ary.shape[0] == 0:
            raise ValueError(f"Alpha101 factors are empty for {self.timeframe}.")

        # Try multiple paths for price data
        possible_csv_paths = [
            f"data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train_{self.timeframe}_train_70.csv",
            f"../data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train_{self.timeframe}_train_70.csv",
            f"data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train.csv",
            f"../data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train.csv"
        ]

        csv_path = None
        for path in possible_csv_paths:
            if os.path.exists(path):
                csv_path = path
                print(f"Found price data at: {path}")
                break

        if csv_path is None:
            # Quit training if no price data found
            print("Error: No price data files found in any of the expected locations:")
            for path in possible_csv_paths:
                print(f"   - {path}")
            print("Please ensure the price data files exist before running training.")
            raise FileNotFoundError("Required price data files not found. Cannot proceed with training.")
        try:
            data_df = pd.read_csv(csv_path, engine='python')
        except Exception:
            data_df = pd.read_csv(csv_path)

        required_columns = ["bids_distance_3", "asks_distance_3", "midpoint", "sentiment_score", "risk_score"]
        missing_columns = [col for col in required_columns if col not in data_df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns: {missing_columns}.")

        self.price_ary = data_df[["bids_distance_3", "asks_distance_3", "midpoint"]].values
        self.price_ary[:, 0] = self.price_ary[:, 2] * (1 + self.price_ary[:, 0])
        self.price_ary[:, 1] = self.price_ary[:, 2] * (1 + self.price_ary[:, 0])
        self.llm_signals = data_df[["sentiment_score", "risk_score"]].values

        min_len = min(self.factor_ary.shape[0], self.price_ary.shape[0], self.llm_signals.shape[0])
        self.factor_ary = self.factor_ary[:min_len]
        self.price_ary = torch.tensor(self.price_ary[:min_len], dtype=torch.float32)
        self.llm_signals = torch.tensor(self.llm_signals[:min_len], dtype=torch.float32)

        if NORMALIZE_LLM_SIGNALS:
            llm_min = self.llm_signals.amin(dim=0, keepdim=True)
            llm_max = self.llm_signals.amax(dim=0, keepdim=True)
            llm_range = (llm_max - llm_min)
            self.llm_signals = (self.llm_signals - llm_min) / (llm_range + 1e-6)
            self.llm_signals = torch.clamp(self.llm_signals, 0.0, 1.0)

        if self.timeframe == '1sec':
            self.seq_len = 1800
        elif self.timeframe == '1min':
            self.seq_len = 64
        elif self.timeframe == '5min':
            self.seq_len = 16
        else:
            self.seq_len = 1800

        self.full_seq_len = self.price_ary.shape[0]

    def reset(self, slippage=None):
        self.slippage = slippage if isinstance(slippage, float) else self.slippage

        min_start = self.seq_len
        max_start = self.full_seq_len - self.seq_len * 2
        if min_start >= max_start:
            if self.full_seq_len > self.seq_len:
                max_start = self.full_seq_len - self.seq_len
                min_start = 0
                i0s = np.random.randint(min_start, max_start, size=self.num_sims)
            else:
                if self.full_seq_len > 0:
                    i0s = np.arange(min(self.num_sims, self.full_seq_len))
                    if len(i0s) < self.num_sims:
                        i0s = np.pad(i0s, (0, self.num_sims - len(i0s)), mode='constant', constant_values=0)
                else:
                    i0s = np.zeros(self.num_sims, dtype=np.int64)
        else:
            i0s = np.random.randint(min_start, max_start, size=self.num_sims)

        self.step_i = 0
        self.step_is = torch.tensor(i0s, dtype=torch.long, device=self.device)

        # Initialize with starting cash instead of zero
        starting_cash = 1000000.0  # $1M starting capital
        self.cash = torch.full((self.num_sims,), starting_cash, dtype=torch.float32, device=self.device)
        self.asset = torch.full((self.num_sims,), starting_cash, dtype=torch.float32, device=self.device)

        # Debug logging for asset initialization
        print(f"Debug: Environment reset - Cash: {self.cash[0].item():.2f}, Asset: {self.asset[0].item():.2f}")
        self.holding = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.position = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.empty_count = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.best_price = torch.zeros((self.num_sims,), dtype=torch.float32, device=self.device)

        step_is = self.step_is + self.step_i
        state = self.get_state(step_is)
        return state

    def step(self, action):
        try:
            if not hasattr(self, 'step_i'):
                self.step_i = 0
            self.step_i += self.step_gap
            step_is = self.step_is + self.step_i

            if action.dim() == 1:
                action = action.to(self.device)
            else:
                action = action.squeeze().to(self.device)
            action_int = action - 1

            old_cash = self.cash
            old_asset = self.asset
            old_position = self.position

            step_is_cpu = step_is.cpu()
            max_idx = self.price_ary.shape[0] - 1
            step_is_cpu = torch.clamp(step_is_cpu, 0, max_idx)
            mid_price = self.price_ary[step_is_cpu, 2].to(self.device)

            truncated = self.step_i >= (self.max_step * self.step_gap)
            if truncated:
                action_int = -old_position
            else:
                new_position = (old_position + action_int).clip(-self.max_position, self.max_position)
                action_int = new_position - old_position

            self.holding = self.holding + 1
            mask_max_holding = self.holding.gt(self.max_holding)
            if mask_max_holding.sum() > 0:
                action_int[mask_max_holding] = -old_position[mask_max_holding]
            self.holding[old_position == 0] = 0

            direction_mask1 = old_position.gt(0)
            if direction_mask1.sum() > 0:
                _best_price = torch.max(
                    torch.stack([self.best_price[direction_mask1], mid_price[direction_mask1]]),
                    dim=0,
                )[0]
                self.best_price[direction_mask1] = _best_price

            direction_mask2 = old_position.lt(0)
            if direction_mask2.sum() > 0:
                _best_price = torch.min(
                    torch.stack([self.best_price[direction_mask2], mid_price[direction_mask2]]),
                    dim=0,
                )[0]
                self.best_price[direction_mask2] = _best_price

            stop_loss_mask1 = torch.logical_and(direction_mask1, (self.best_price - mid_price).gt(self.stop_loss_thresh))
            stop_loss_mask2 = torch.logical_and(direction_mask2, (mid_price - self.best_price).gt(self.stop_loss_thresh))
            stop_loss_mask = torch.logical_or(stop_loss_mask1, stop_loss_mask2)
            if stop_loss_mask.sum() > 0:
                action_int[stop_loss_mask] = -old_position[stop_loss_mask]

            new_position = old_position + action_int
            direction = action_int.gt(0)
            cost = action_int * mid_price
            new_cash = old_cash - cost * torch.where(direction, 1 + self.slippage, 1 - self.slippage)
            new_asset = new_cash + new_position * mid_price
            reward = new_asset - old_asset

            # Debug logging for asset changes
            if self.step_i % 100 == 0:  # Debug every 100 steps
                print(f"Debug: Step {self.step_i} - Old Asset: {old_asset[0].item():.2f}, "
                      f"New Asset: {new_asset[0].item():.2f}, Reward: {reward[0].item():.6f}")

            self.cash = new_cash
            self.asset = new_asset
            self.position = new_position

            state = self.get_state(step_is)
            if truncated:
                terminal = torch.ones_like(self.position, dtype=torch.bool)
                state = self.reset()
            else:
                terminal = torch.zeros_like(self.position, dtype=torch.bool)

            return state, reward, terminal, {}

        except Exception as e:
            print(f"Error in environment step: {e}")
            state = torch.zeros((self.num_sims, self.state_dim), dtype=torch.float32).to(self.device)
            reward = torch.zeros((self.num_sims,), dtype=torch.float32).to(self.device)
            terminal = torch.zeros((self.num_sims,), dtype=torch.bool).to(self.device)
            return state, reward, terminal, {}

    def get_state(self, step_is):
        step_is_cpu = step_is.cpu()
        max_idx = min(self.factor_ary.shape[0], self.price_ary.shape[0], self.llm_signals.shape[0]) - 1
        step_is_cpu = torch.clamp(step_is_cpu, 0, max_idx)
        factor_ary = self.factor_ary[step_is_cpu, :].to(self.device)
        llm_signals = self.llm_signals[step_is_cpu, :].to(self.device)
        return torch.concat(
            (
                (self.position.float() / self.max_position)[:, None],
                (self.holding.float() / self.max_holding)[:, None],
                factor_ary,
                llm_signals,
            ),
            dim=1,
        )

In [4]:
agent = train_improved_ppo_agent('1sec')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
🔍 Debug: Step 100 - Old Asset: 999987.00, New Asset: 999986.94, Reward: -0.062500
🔍 Debug: Step 200 - Old Asset: 999979.94, New Asset: 999979.94, Reward: 0.000000
🔍 Debug: Step 300 - Old Asset: 999977.94, New Asset: 999977.94, Reward: 0.000000
🔍 Debug: Step 400 - Old Asset: 999945.50, New Asset: 999945.50, Reward: 0.000000
🔍 Debug: Step 1782000 - Asset=$999955.00, Cash=$1058426.00, Position=-1, Price=$58471.02, Reward=0.001000
🔍 Debug: Step 500 - Old Asset: 999953.69, New Asset: 999953.62, Reward: -0.062500
🔍 Debug: Step 600 - Old Asset: 999926.94, New Asset: 999926.88, Reward: -0.062500
🔍 Debug: Step 700 - Old Asset: 999951.00, New Asset: 999951.00, Reward: 0.000000
🔍 Debug: Step 800 - Old Asset: 999905.19, New Asset: 999905.19, Reward: 0.000000
🔍 Debug: Step 900 - Old Asset: 999908.44, New Asset: 999908.38, Reward: -0.062500
🔍 Debug: Step 1000 - Old Asset: 999878.56, New Asset: 999878.50, Reward: -0.062500
🔍 Debug: Step