# Overview of the Code

This file contains two major sections that build and train reinforcement learning (RL) environments using a grammar-based approach. The goal in both sections is to have an RL agent learn to generate syntactically and semantically correct code. The first part targets generating GPU PTX code from a SASS-like grammar, while the second part simplifies the task to produce a single “MAC6” instruction.

This work is inspired by the paper:

> **CuAsmRL: Optimizing GPU SASS Schedules via Deep Reinforcement Learning**  
> Guoliang He, Eiko Yoneki  
> [arXiv:2501.08071](https://doi.org/10.48550/arXiv.2501.08071)  
>  
> *Large language models (LLMs) are remarked by their substantial computational requirements. To mitigate the cost, researchers develop specialized CUDA kernels, which often fuse several tensor operations to maximize the utilization of GPUs. However, those specialized kernels may still leave performance on the table as CUDA assembly experts show that manual optimization of GPU SASS schedules can lead to better performance. This work employs an automatic approach to optimize GPU SASS schedules using deep reinforcement learning by formulating an assembly game where RL agents iteratively mutate schedules to improve throughput.*

---

## Part 1: Grammar-Based RL Environment for PTX Code Generation

### 1. Dynamic Grammar Creation
- **Grammar Definition:**  
  A function constructs a dynamic grammar with rules to generate code for a GPU kernel. The grammar includes nonterminals such as `Program`, `InstructionList`, `Instruction`, and specific instructions (`Load`, `MAC`, `Store`).
- **Components:**  
  - **Registers:** Defined as tokens $R0, R1, \dots, R_{num\_registers-1}$.
  - **Memory References:** Tokens such as $[<A0\_addr>], [<B0\_addr>], [<C0\_addr>]$, etc.
- **Purpose:**  
  To produce a complete program that will be transformed into SASS code and then into PTX.

### 2. Parsing and Vocabulary Utilities
- **Parsing Functions:**  
  - `find_leftmost_nonterminal` locates the leftmost nonterminal in the current symbol list.
  - `expansions_of` retrieves possible production expansions for a nonterminal.
- **Vocabulary Building:**  
  A function iterates through the grammar to build a set of terminal tokens, which are then sorted and indexed for observation purposes.

### 3. SASS to PTX Conversion, Compilation, and GPU Execution
- **SASS to PTX Conversion:**  
  The function `sass_to_ptx` converts a SASS-like program to a PTX kernel by:
  - Replacing placeholders (e.g., converting $<A0\_addr>$ to $A0\_ptr$).
  - Mapping opcodes (e.g., replacing `LDG` with `ld.global.f32`).
  - Wrapping the instructions with a PTX kernel header and footer.
- **Compilation and Execution:**  
  - The PTX code is compiled using `nvcc` into a CUDA binary (cubin file).
  - A function loads the compiled module via PyCUDA, executes the kernel, and measures execution time.
  - It compares the output against ground truth computed from 4x4 matrices.
- **Matrix Loading:**  
  A helper function loads (or generates) 3000 matrices (divided into three groups) and computes ground truth for a matrix multiply-accumulate operation.

### 4. RL Environment for Grammar-Based Code Generation
- **Environment Setup:**  
  The `GrammarMultiLoadEnv` class inherits from Gym’s `Env` and allows the RL agent to select production rules from the grammar.
- **Observation and Action Spaces:**  
  - **Observations:**  
    Represented as a fixed-length vector of token indices where nonterminals are given a special value.
  - **Actions:**  
    Consist of all valid $(nonterminal, \ expansion\ index)$ pairs.
- **Reward Structure:**  
  - Partial rewards are given for semantically correct instructions (e.g., a successful load or a valid MAC operation).
  - Heavy penalties (e.g., $-500$) are applied if the final parse is incomplete or if the program fails to compile/execute correctly.
- **Finalization:**  
  Upon completion, the SASS program is converted to PTX, compiled, and executed on the GPU. The final reward is determined by execution time and correctness.

### 5. Training with PPO
- **RL Algorithm:**  
  The training loop employs the Proximal Policy Optimization (PPO) algorithm from Stable Baselines3, wrapped in a vectorized environment (`DummyVecEnv`).
- **Execution:**  
  After training, several test episodes are executed to print out the total reward and debug information.

---

## Part 2: Simplified Environment for a Single MAC6 Instruction

### 1. Simplified Grammar Definition
- **Grammar Rules:**  
  A reduced grammar is defined to generate only one valid instruction:
  
  $$\text{Program} \rightarrow \text{MAC6 END}$$  
  
  $$\text{MAC6} \rightarrow \text{"MAC6"}\ \text{Register},\ \text{Register},\ \text{Register},\ \text{Register},\ \text{Register},\ \text{Register};$$  
  
  $$\text{Register} \rightarrow \{ R0, R1, R2, R3, R4, R5 \}$$
  
- **Objective:**  
  To ensure the RL agent generates a single instruction with six register operands, simplifying the output format.

### 2. Environment Details
- **Observation and Action Spaces:**  
  Observations are fixed-length vectors of token indices and actions correspond to choices of production rule expansions.
- **Program Finalization and Reward:**  
  The environment finalizes the parse when the production is complete and checks:
  - The program must begin with `MAC6` and end with `END`.
  - The tokens must appear in the correct order, including commas and the semicolon.
  - A positive reward (e.g., $+10$) is granted if the generated program is valid; otherwise, a penalty is applied.

### 3. Training with PPO
- **Simplified Training:**  
  The PPO algorithm is employed in this simplified setting. The reduced complexity of the grammar and reward structure allows the RL agent to focus on learning the correct sequence to produce a valid `MAC6` instruction.

---

## Summary

- **Part 1:**  
  Demonstrates a comprehensive setup where an RL agent learns to generate a complete PTX kernel by expanding a dynamic grammar, converting SASS to PTX, compiling, and running it on the GPU. Rewards are based on semantic correctness and execution performance.

- **Part 2:**  
  Reduces the task to generating a single valid `MAC6` instruction with six registers, greatly simplifying the grammar and reward design. This serves as a stepping stone to understand the core RL task without additional complexities.

The design and approach of this code take inspiration from **CuAsmRL** by Guoliang He and Eiko Yoneki, which explores using deep reinforcement learning to optimize GPU SASS schedules. Just as CuAsmRL formulates an assembly game where the RL agent iteratively applies actions to improve kernel throughput, this code formulates a grammar-based game where the RL agent learns to generate correct code sequences. This work provides a foundational framework for integrating automated code generation and optimization techniques into compiler systems, aiming to reduce the manual effort typically required by CUDA assembly experts.


I will note that the file does not work exactly as expected. The outputs are saved to the .ptx file, but the loss is not what I expect. This implies that the program still needs more fixing, but the basic ideas is good!


In [5]:
import os
import subprocess
import numpy as np
import time
import pycuda.driver as cuda
import pycuda.autoinit  # automatically initialize CUDA

# ------------------------------------------------------------------------------
# 1. Read the RL agent's output PTX program from file.
# ------------------------------------------------------------------------------
ptx_filename = "agent_output.ptx"
if not os.path.exists(ptx_filename):
    # For demonstration, if the file doesn't exist, create a default PTX program.
    default_ptx = r"""
.version 7.0
.target sm_75
.address_size 64

.visible .entry matrix_mmac(
    .param .u64 A_ptr,
    .param .u64 B_ptr,
    .param .u64 C_ptr,
    .param .u64 D_ptr
)
{
    // Assume a 4x4 matrix multiply-accumulate: D = A*B + C.
    // Each thread computes one element. We assume blockDim.x=4, blockDim.y=4.
    .reg .s32 tid_x, tid_y, idx;
    .reg .s32 i, j;
    .reg .f32 a_val, b_val, c_val, prod, sum;
    .reg .u64 A, B, C, D;

    // Load parameters.
    ld.param.u64 A, [A_ptr];
    ld.param.u64 B, [B_ptr];
    ld.param.u64 C, [C_ptr];
    ld.param.u64 D, [D_ptr];

    // Get thread indices (assume 4x4 block).
    mov.u32 tid_x, %tid.x;
    mov.u32 tid_y, %tid.y;
    mov.s32 i, tid_y;
    mov.s32 j, tid_x;

    // Initialize sum = 0.
    mov.f32 sum, 0f00000000;

    // For k = 0 to 3, unrolled loop.
    // k = 0:
    mul.lo.s32 idx, i, 4;
    add.s32 idx, idx, 0;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 a_val, [A + idx];
    mul.lo.s32 idx, 0, 4;
    add.s32 idx, idx, j;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 b_val, [B + idx];
    mul.f32 prod, a_val, b_val;
    add.f32 sum, sum, prod;

    // k = 1:
    mul.lo.s32 idx, i, 4;
    add.s32 idx, idx, 1;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 a_val, [A + idx];
    mul.lo.s32 idx, 1, 4;
    add.s32 idx, idx, j;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 b_val, [B + idx];
    mul.f32 prod, a_val, b_val;
    add.f32 sum, sum, prod;

    // k = 2:
    mul.lo.s32 idx, i, 4;
    add.s32 idx, idx, 2;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 a_val, [A + idx];
    mul.lo.s32 idx, 2, 4;
    add.s32 idx, idx, j;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 b_val, [B + idx];
    mul.f32 prod, a_val, b_val;
    add.f32 sum, sum, prod;

    // k = 3:
    mul.lo.s32 idx, i, 4;
    add.s32 idx, idx, 3;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 a_val, [A + idx];
    mul.lo.s32 idx, 3, 4;
    add.s32 idx, idx, j;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 b_val, [B + idx];
    mul.f32 prod, a_val, b_val;
    add.f32 sum, sum, prod;

    // Add C[i,j]:
    mul.lo.s32 idx, i, 4;
    add.s32 idx, idx, j;
    mul.lo.s32 idx, idx, 4;
    ld.global.f32 c_val, [C + idx];
    add.f32 sum, sum, c_val;

    // Store result:
    st.global.f32 [D + idx], sum;

    ret;
}
"""
    with open(ptx_filename, "w") as f:
        f.write(default_ptx)
    print("Default PTX program written to", ptx_filename)
else:
    print("Loaded PTX program from", ptx_filename)

with open(ptx_filename, "r") as f:
    agent_ptx_program = f.read()

# ------------------------------------------------------------------------------
# 2. Compile the PTX program into a cubin file using nvcc.
# ------------------------------------------------------------------------------
cubin_filename = "agent_output.cubin"
compile_cmd = ["nvcc", "-arch=sm_75", "-cubin", ptx_filename, "-o", cubin_filename]
result = subprocess.run(compile_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if result.returncode != 0:
    print("Compilation failed with error:")
    print(result.stderr.decode())
    exit(1)
else:
    print("Compilation succeeded. Cubin file generated:", cubin_filename)

# ------------------------------------------------------------------------------
# 3. Generate or load 3000 random 4x4 matrices.
#    We'll assume they are stored in a text file "big_mats.txt" with 3000 lines,
#    each line containing 16 space-separated floats (representing a 4x4 matrix).
#    The first 1000 lines are A matrices, the next 1000 are B, and the last 1000 are C.
# ------------------------------------------------------------------------------
n = 1000
total_mats = 3 * n
big_mats_file = "big_mats.txt"
if not os.path.exists(big_mats_file):
    print(f"{big_mats_file} not found. Generating 3000 random 4x4 matrices...")
    mats = np.random.rand(total_mats, 4, 4).astype(np.float32)
    with open(big_mats_file, "w") as f:
        for mat in mats:
            f.write(" ".join(map(str, mat.flatten())) + "\n")
else:
    print(f"Loading matrices from {big_mats_file}...")
    mats = []
    with open(big_mats_file, "r") as f:
        for line in f:
            vals = list(map(float, line.strip().split()))
            if len(vals) == 16:
                mat = np.array(vals).reshape((4, 4)).astype(np.float32)
                mats.append(mat)
            if len(mats) >= total_mats:
                break
    mats = np.array(mats)
    print("Loaded", len(mats), "matrices.")

# Partition matrices: first 1000 are A, next 1000 are B, last 1000 are C.
A_mats = mats[0:n]
B_mats = mats[n:2*n]
C_mats = mats[2*n:3*n]

# For demonstration, compute ground truth for the first triple (A^1, B^1, C^1) using numpy.
A0 = A_mats[0]
B0 = B_mats[0]
C0 = C_mats[0]
numpy_result = np.matmul(A0, B0) + C0

# ------------------------------------------------------------------------------
# 4. Load the compiled cubin using PyCUDA.
# ------------------------------------------------------------------------------
import pycuda.driver as cuda
module = cuda.module_from_file(cubin_filename)
kernel_func = module.get_function("matrix_mmac")

# ------------------------------------------------------------------------------
# 5. Launch the kernel for each triple and measure total compute time.
#    We assume each kernel execution computes D = A*B + C on a full 4x4 matrix.
#    We use a block of (4,4,1) threads so that each thread computes one element.
# ------------------------------------------------------------------------------
block_dim = (4, 4, 1)
grid_dim = (1, 1)

start_total = time.time()
first_output = None  # to store output of the first triple

for i in range(n):
    A = A_mats[i]
    B = B_mats[i]
    C = C_mats[i]
    D = np.zeros((4, 4), dtype=np.float32)
    
    # Flatten matrices to 1D arrays.
    A_flat = A.flatten()
    B_flat = B.flatten()
    C_flat = C.flatten()
    D_flat = D.flatten()
    
    # Allocate device memory.
    A_gpu = cuda.mem_alloc(A_flat.nbytes)
    B_gpu = cuda.mem_alloc(B_flat.nbytes)
    C_gpu = cuda.mem_alloc(C_flat.nbytes)
    D_gpu = cuda.mem_alloc(D_flat.nbytes)
    
    # Copy host data to device.
    cuda.memcpy_htod(A_gpu, A_flat)
    cuda.memcpy_htod(B_gpu, B_flat)
    cuda.memcpy_htod(C_gpu, C_flat)
    
    # Launch the kernel.
    kernel_func(A_gpu, B_gpu, C_gpu, D_gpu, block=block_dim, grid=grid_dim)
    cuda.Context.synchronize()
    
    # Copy result back to host.
    cuda.memcpy_dtoh(D_flat, D_gpu)
    D_result = D_flat.reshape((4, 4))
    
    # Save the first triple's output.
    if i == 0:
        first_output = D_result

end_total = time.time()
total_compute_time = end_total - start_total
average_time = total_compute_time / n

print("\n--- GPU Compute Time ---")
print(f"Total GPU compute time for {n} triples: {total_compute_time:.6f} seconds")
print(f"Average time per triple: {average_time:.6f} seconds")

# ------------------------------------------------------------------------------
# 6. Display the first triple's matrices and compare results.
# ------------------------------------------------------------------------------
print("\n--- First Triple Matrices and Results ---")
print("Matrix A^1 (first A matrix):")
print(A0)
print("\nMatrix B^1 (first B matrix):")
print(B0)
print("\nMatrix C^1 (first C matrix):")
print(C0)
print("\nGPU Kernel Output D (first triple):")
print(first_output)
print("\nNumpy Computation (A^1 * B^1 + C^1):")
print(numpy_result)


Loaded PTX program from agent_output.ptx
Compilation failed with error:
ptxas agent_output.ptx, line 40; fatal   : Parsing error near 'idx': syntax error
ptxas fatal   : Ptx assembly aborted due to errors

Loading matrices from big_mats.txt...
Loaded 3000 matrices.


RuntimeError: cuModuleLoad failed: file not found

In [1]:
import gym
from gym import spaces
import numpy as np
import time
import os

# Initialize PyCUDA and query the GPU for its maximum registers per block.
import pycuda.driver as drv
drv.init()
device = drv.Device(0)
# Query the maximum registers per block for this device.
# (This value is typically large – e.g. 65536 for many NVIDIA GPUs.)
max_registers = device.get_attribute(drv.device_attribute.MAX_REGISTERS_PER_BLOCK)
print(f"Max registers per block on this GPU: {max_registers}")

class FlexibleSassEnvDynamicRegisters(gym.Env):
    """
    This environment lets an RL agent generate a complete SASS program (token-by-token)
    for performing a matrix MAC operation. The program can interleave load, compute, and store
    instructions arbitrarily, as long as the final output matches the expected result.
    
    At initialization, the environment queries the GPU for its maximum registers per block and
    builds a vocabulary that includes tokens for registers R0 up to R(max_registers-1). This way,
    the agent can choose any register available on the device.
    
    The intended computation is as follows (using 4×4 matrices):
         result = C;  for each of N iterations: result = A * B + result
    with A, B, and C taken from memory.
    """
    
    def __init__(self, required_mac_ops=1000, max_length=10000, matrix_file="matrices.txt"):
        super(FlexibleSassEnvDynamicRegisters, self).__init__()
        self.required_mac_ops = required_mac_ops  # e.g., 1000 MAC operations
        self.max_length = max_length             # Maximum program length (in tokens)
        self.matrix_file = matrix_file
        
        # Query the GPU for the maximum registers available per block.
        try:
            self.max_registers = drv.Device(0).get_attribute(drv.device_attribute.MAX_REGISTERS_PER_BLOCK)
        except Exception as e:
            print("Error querying GPU registers. Defaulting to 256 registers.")
            self.max_registers = 256
        
        # Build dynamic register tokens: "R0", "R1", ..., "R(max_registers-1)"
        register_tokens = [f"R{i}" for i in range(self.max_registers)]
        
        # Fixed tokens for opcodes, punctuation, etc.
        fixed_tokens = [
            "PAD",     # Padding token
            "LDG",     # Load instruction
            "MUL",     # Multiply instruction
            "FADD",    # Floating-point add (for MAC)
            "STG",     # Store instruction
            "D",       # Destination token (for store)
            "mem",     # Memory reference
            "[",       # Left bracket
            "]",       # Right bracket
            ",",       # Comma separator
            "->",      # Arrow (for data flow)
            "END",     # Program terminator
            "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
            ";"        # Instruction separator
        ]
        
        # The overall vocabulary is the fixed tokens plus all register tokens.
        self.vocab = fixed_tokens + register_tokens
        self.vocab_size = len(self.vocab)
        self.token_to_id = {token: idx for idx, token in enumerate(self.vocab)}
        self.id_to_token = {idx: token for idx, token in enumerate(self.vocab)}
        
        # The RL agent outputs one token at a time (action space is discrete over the vocabulary).
        self.action_space = spaces.Discrete(self.vocab_size)
        # The observation is the entire program (as a sequence of token IDs) padded to a fixed length.
        self.observation_space = spaces.Box(low=0, high=self.vocab_size - 1, shape=(self.max_length,), dtype=np.int32)
        
        # Load matrices from file. We expect at least 3000 randomly generated 4x4 matrices.
        self.memory = self.load_matrices(self.matrix_file)
        # For simulation, we use the first three matrices: A, B, and C.
        self.A = self.memory[0]
        self.B = self.memory[1]
        self.C = self.memory[2]
        self.ground_truth = self.compute_ground_truth()
        
        self.reset()
    
    def load_matrices(self, filename):
        """
        Load 3000 4×4 matrices from a file.
        If the file does not exist, generate random matrices and save them.
        Each line should contain 16 whitespace-separated float values.
        """
        matrices = []
        if not os.path.exists(filename):
            matrices = np.random.randn(3000, 4, 4).astype(np.float32)
            np.savetxt(filename, matrices.reshape(3000, 16))
        else:
            with open(filename, "r") as f:
                for line in f:
                    values = list(map(float, line.strip().split()))
                    if len(values) == 16:
                        matrices.append(np.array(values).reshape((4, 4)).astype(np.float32))
                    if len(matrices) >= 3000:
                        break
            matrices = np.array(matrices)
        return matrices
    
    def compute_ground_truth(self):
        """
        Compute the ground truth for the MAC operation:
          result = C; for each iteration: result = A * B + result.
        """
        result = self.C.copy()
        for _ in range(self.required_mac_ops):
            result = np.matmul(self.A, self.B) + result
        return result
    
    def reset(self):
        self.program_tokens = []  # List of token IDs generated so far.
        self.start_time = time.time()  # Start timing when the program begins.
        return self.get_observation()
    
    def step(self, action):
        # Append the chosen token (if the program hasn't exceeded max_length).
        if len(self.program_tokens) < self.max_length:
            self.program_tokens.append(action)
        else:
            return self.get_observation(), -1000, True, {"error": "Exceeded max program length"}
        
        token = self.id_to_token[action]
        if token == "END":
            # When END is emitted, parse and simulate the program.
            valid, cost, error, output = self.parse_and_simulate_program()
            self.end_time = time.time()
            exec_time = self.end_time - self.start_time  # Simulated total time.
            # If the program is valid and output matches ground truth, reward is negative cost.
            if valid and output is not None and np.allclose(output, self.ground_truth, atol=1e-4):
                reward = -(cost + exec_time)
            else:
                reward = -1000  # Heavy penalty if invalid.
            return self.get_observation(), reward, True, {"exec_time": exec_time, "error": error}
        else:
            return self.get_observation(), 0, False, {}
    
    def get_observation(self):
        """
        Returns the current program as a fixed-length sequence of token IDs (padded with "PAD").
        """
        obs = self.program_tokens.copy()
        pad_id = self.token_to_id["PAD"]
        while len(obs) < self.max_length:
            obs.append(pad_id)
        return np.array(obs, dtype=np.int32)
    
    def parse_and_simulate_program(self):
        """
        Parse the token sequence into instructions (splitting on ";" tokens) and simulate execution.
        The END token is treated as a terminator and is ignored in the simulation.
        Allowed instructions include:
          - LDG: Load a matrix from memory into a register.
          - MUL: Multiply matrices from two registers, writing the result into a destination register.
          - FADD: Add matrices from two registers.
          - STG: Store a register’s contents into the destination memory D.
        """
        # Convert token IDs back to tokens and remove "PAD" tokens.
        tokens = [self.id_to_token[idx] for idx in self.program_tokens if self.id_to_token[idx] != "PAD"]
        # Remove any standalone END tokens from the token stream.
        tokens = [t for t in tokens if t != "END"]
        # Split instructions by the ";" separator.
        instructions = []
        current_instr = []
        for t in tokens:
            if t == ";":
                if current_instr:
                    instructions.append(current_instr)
                    current_instr = []
            else:
                current_instr.append(t)
        if current_instr:
            instructions.append(current_instr)
        
        registers = {}
        total_cost = 0.0
        mac_count = 0
        error_message = ""
        output_matrix = None
        
        cost_map = {"LDG": 1, "MUL": 2, "FADD": 2, "STG": 1}
        
        for instr in instructions:
            if not instr:
                continue
            opcode = instr[0]
            # Skip any END tokens if they accidentally appear in an instruction.
            if opcode == "END":
                continue
            if opcode == "LDG":
                if len(instr) != 6:
                    error_message = f"LDG instruction malformed: {instr}"
                    return False, total_cost, error_message, None
                reg = instr[1]
                if instr[2] != "mem" or instr[3] != "[" or instr[5] != "]":
                    error_message = f"LDG syntax error: {instr}"
                    return False, total_cost, error_message, None
                mem_index = instr[4]
                if mem_index not in ["0", "1", "2"]:
                    error_message = f"LDG invalid memory index: {mem_index}"
                    return False, total_cost, error_message, None
                if mem_index == "0":
                    registers[reg] = self.A.copy()
                elif mem_index == "1":
                    registers[reg] = self.B.copy()
                elif mem_index == "2":
                    registers[reg] = self.C.copy()
                total_cost += cost_map["LDG"]
            elif opcode == "MUL":
                if len(instr) != 6:
                    error_message = f"MUL instruction malformed: {instr}"
                    return False, total_cost, error_message, None
                reg_src1 = instr[1]
                reg_src2 = instr[3]
                reg_dest = instr[5]
                if reg_src1 not in registers or reg_src2 not in registers:
                    error_message = f"MUL uses undefined registers: {instr}"
                    return False, total_cost, error_message, None
                try:
                    registers[reg_dest] = np.matmul(registers[reg_src1], registers[reg_src2])
                except Exception as e:
                    error_message = f"MUL failed: {e}"
                    return False, total_cost, error_message, None
                total_cost += cost_map["MUL"]
            elif opcode == "FADD":
                if len(instr) != 6:
                    error_message = f"FADD instruction malformed: {instr}"
                    return False, total_cost, error_message, None
                reg_src1 = instr[1]
                reg_src2 = instr[3]
                reg_dest = instr[5]
                if reg_src1 not in registers or reg_src2 not in registers:
                    error_message = f"FADD uses undefined registers: {instr}"
                    return False, total_cost, error_message, None
                try:
                    registers[reg_dest] = registers[reg_src1] + registers[reg_src2]
                except Exception as e:
                    error_message = f"FADD failed: {e}"
                    return False, total_cost, error_message, None
                total_cost += cost_map["FADD"]
                mac_count += 1
            elif opcode == "STG":
                if len(instr) != 4:
                    error_message = f"STG instruction malformed: {instr}"
                    return False, total_cost, error_message, None
                reg_src = instr[1]
                if instr[2] != "->" or instr[3] != "D":
                    error_message = f"STG syntax error: {instr}"
                    return False, total_cost, error_message, None
                if reg_src not in registers:
                    error_message = f"STG uses undefined register: {instr}"
                    return False, total_cost, error_message, None
                output_matrix = registers[reg_src].copy()
                total_cost += cost_map["STG"]
            else:
                error_message = f"Unknown opcode: {opcode}"
                return False, total_cost, error_message, None
        
        if mac_count < self.required_mac_ops:
            error_message = f"Only {mac_count} MAC operations performed; expected {self.required_mac_ops}"
            return False, total_cost, error_message, None
        
        if output_matrix is None:
            error_message = "No STG instruction executed"
            return False, total_cost, error_message, None
        
        return True, total_cost, "Program valid", output_matrix
    
    def render(self, mode="human"):
        prog_tokens = [self.id_to_token[idx] for idx in self.program_tokens]
        # Exclude PAD tokens for readability.
        prog_str = " ".join([tok for tok in prog_tokens if tok != "PAD"])
        print("Current Program:")
        print(prog_str)

# -------------------------------
# Example usage:
if __name__ == "__main__":
    # For demonstration, we use a small number of MAC operations (e.g., 3) to keep the program short.
    env = FlexibleSassEnvDynamicRegisters(required_mac_ops=3, max_length=500, matrix_file="matrices.txt")
    obs = env.reset()
    done = False
    total_reward = 0
    
    # Construct a valid program manually (the RL agent would learn to generate this).
    program_tokens = []
    
    def add_instr(instr_tokens):
        # Convert tokens to IDs and append a ";" separator.
        for t in instr_tokens:
            program_tokens.append(env.token_to_id[t])
        program_tokens.append(env.token_to_id[";"])
    
    # Sample program:
    # Load operations (using registers from the dynamic vocabulary):
    add_instr(["LDG", "R0", "mem", "[", "0", "]"])  # Load matrix A into R0
    add_instr(["LDG", "R1", "mem", "[", "1", "]"])  # Load matrix B into R1
    add_instr(["LDG", "R2", "mem", "[", "2", "]"])  # Load matrix C into R2
    
    # MAC operations: each MAC consists of a MUL followed by an FADD.
    for _ in range(env.required_mac_ops):
        add_instr(["MUL", "R0", ",", "R1", "->", "R3"])  # Multiply A and B, store in R3
        add_instr(["FADD", "R2", ",", "R3", "->", "R2"])  # Add accumulator in R2 with R3
    
    # Store final result.
    add_instr(["STG", "R2", "->", "D"])
    # End program.
    program_tokens.append(env.token_to_id["END"])
    program_tokens.append(env.token_to_id[";"])
    
    # Feed tokens into the environment one by one.
    for token_id in program_tokens:
        obs, reward, done, info = env.step(token_id)
        if done:
            break
    env.render()
    print("Reward:", reward)
    print("Info:", info)


Max registers per block on this GPU: 65536
Current Program:
LDG R0 mem [ 0 ] ; LDG R1 mem [ 1 ] ; LDG R2 mem [ 2 ] ; MUL R0 , R1 -> R3 ; FADD R2 , R3 -> R2 ; MUL R0 , R1 -> R3 ; FADD R2 , R3 -> R2 ; MUL R0 , R1 -> R3 ; FADD R2 , R3 -> R2 ; STG R2 -> D ; END
Reward: -16.003533840179443
Info: {'exec_time': 0.0035338401794433594, 'error': 'Program valid'}


In [2]:
import gym
import numpy as np
import os
import time
from gym import spaces
import subprocess
import pycuda.driver as cuda
import pycuda.autoinit  # auto-initialize CUDA driver
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

###############################################################################
# 1) DYNAMIC GRAMMAR CREATION (Example with permissive multi-load)
###############################################################################

def make_grammar(num_registers=100, num_mats_each=1000):
    """
    Build a grammar with:
      - Registers: R0..R(num_registers-1)
      - MemoryRefs: [<A0_addr>]..[<A{num_mats_each-1}_addr>],
                     [<B0_addr>]..[<B{num_mats_each-1}_addr>],
                     [<C0_addr>]..[<C{num_mats_each-1}_addr>]
    Grammar:
      Program         -> InstructionList END
      InstructionList -> InstructionList Instruction | Instruction
      Instruction     -> Load | MAC | Store
      Load            -> "LDG" Register "," MemoryRef ";"
      MAC             -> "fma.rn.f32" Register "," Register "," Register ";"
      Store           -> "STG" Register "," MemoryRef ";"
      END             -> "END"
    """
    grammar = {}
    grammar["Program"] = [["InstructionList", "END"]]
    grammar["InstructionList"] = [["InstructionList", "Instruction"], ["Instruction"]]
    grammar["Instruction"] = [["Load"], ["MAC"], ["Store"]]
    grammar["Load"] = [["LDG", "Register", ",", "MemoryRef", ";"]]
    grammar["MAC"] = [["fma.rn.f32", "Register", ",", "Register", ",", "Register", ";"]]
    grammar["Store"] = [["STG", "Register", ",", "MemoryRef", ";"]]
    grammar["END"] = [["END"]]

    # Build Register productions: R0 .. R{num_registers-1}
    reg_prods = []
    for r in range(num_registers):
        reg_prods.append([f"R{r}"])
    grammar["Register"] = reg_prods

    # Build MemoryRef productions: for A, B, and C placeholders.
    mem_prods = []
    for i in range(num_mats_each):
        mem_prods.append([f"[<A{i}_addr>]"])
    for i in range(num_mats_each):
        mem_prods.append([f"[<B{i}_addr>]"])
    for i in range(num_mats_each):
        mem_prods.append([f"[<C{i}_addr>]"])
    grammar["MemoryRef"] = mem_prods

    return grammar

###############################################################################
# 2) PARSING UTILITIES
###############################################################################

def find_leftmost_nonterminal(symbols, grammar):
    for i, sym in enumerate(symbols):
        if sym in grammar:
            return i
    return -1

def expansions_of(nonterminal, grammar):
    return grammar[nonterminal]

###############################################################################
# 3) Build Vocabulary from Grammar Terminals
###############################################################################

def build_vocab(grammar):
    """Return sorted list of all terminal tokens in the grammar."""
    terminals = set()
    for key, expansions in grammar.items():
        for expansion in expansions:
            for token in expansion:
                if token not in grammar:  # if not a nonterminal, it's a terminal
                    terminals.add(token)
    vocab = sorted(list(terminals))
    return vocab

###############################################################################
# 4) SASS => PTX, Compilation, and GPU Run (Simplified)
###############################################################################

def sass_to_ptx(sass_str):
    """
    Convert a SASS-like program into a PTX kernel.
    - Replaces placeholders: e.g., <A0_addr> => A0_ptr.
    - Converts opcodes: LDG -> ld.global.f32, STG -> st.global.f32.
    - Wraps instructions in a PTX kernel header/footer.
    """
    import re
    opcode_map = {"LDG": "ld.global.f32", "STG": "st.global.f32"}
    lines = sass_str.strip().split("\n")
    converted = []
    for line in lines:
        l = line.strip()
        # Replace placeholders using regex
        l = re.sub(r"<(A|B|C)(\d+)_addr>", lambda m: f"{m.group(1)}{m.group(2)}_ptr", l)
        # Replace opcodes if found at start
        for k, v in opcode_map.items():
            if l.startswith(k):
                l = v + l[len(k):]
        if not l.endswith(";"):
            l += ";"
        converted.append("    " + l)
    header = """.version 7.0
.target sm_75
.address_size 64

.visible .entry matrix_mmac(
    .param .u64 A0_ptr, .param .u64 B0_ptr, .param .u64 C0_ptr, .param .u64 D0_ptr
)
{
"""
    footer = """
    ret;
}
"""
    return header + "\n".join(converted) + footer

def compile_ptx(ptx_str, ptx_file="agent_kernel.ptx", cubin_file="agent_kernel.cubin"):
    with open(ptx_file, "w") as f:
        f.write(ptx_str)
    cmd = ["nvcc", "-arch=sm_75", "-cubin", ptx_file, "-o", cubin_file]
    r = subprocess.run(cmd, capture_output=True)
    if r.returncode != 0:
        return False, r.stderr.decode()
    return True, ""

def run_cubin(cubin_file):
    """
    Load the compiled cubin and run the kernel.
    For demonstration, we assume a kernel that processes a full 4x4 matrix.
    In practice, you'd pass pointers for full matrices.
    Here we allocate device memory for a single 4x4 matrix for A0, B0, C0 and output D0.
    """
    try:
        mod = cuda.module_from_file(cubin_file)
        kernel_func = mod.get_function("matrix_mmac")
    except Exception as e:
        return None, False, 0.0

    # Create dummy 4x4 matrices for A0, B0, C0 (we expect the ground truth is A0*B0 + C0)
    # For testing, we use the same matrices as provided by the environment.
    # In a real system, these would come from the loaded 3000 matrices.
    A0 = np.random.rand(4,4).astype(np.float32)
    B0 = np.random.rand(4,4).astype(np.float32)
    C0 = np.random.rand(4,4).astype(np.float32)
    D0 = np.zeros((4,4), dtype=np.float32)

    A0_flat = A0.flatten()
    B0_flat = B0.flatten()
    C0_flat = C0.flatten()
    D0_flat = D0.flatten()

    A0_gpu = cuda.mem_alloc(A0_flat.nbytes)
    B0_gpu = cuda.mem_alloc(B0_flat.nbytes)
    C0_gpu = cuda.mem_alloc(C0_flat.nbytes)
    D0_gpu = cuda.mem_alloc(D0_flat.nbytes)

    cuda.memcpy_htod(A0_gpu, A0_flat)
    cuda.memcpy_htod(B0_gpu, B0_flat)
    cuda.memcpy_htod(C0_gpu, C0_flat)

    start = time.time()
    kernel_func(A0_gpu, B0_gpu, C0_gpu, D0_gpu, block=(16,1,1), grid=(1,1))
    cuda.Context.synchronize()
    exec_time = time.time() - start

    cuda.memcpy_dtoh(D0_flat, D0_gpu)
    D0_result = D0_flat.reshape((4,4))
    return D0_result, True, exec_time

###############################################################################
# 5) LOADING 3000 (4x4) MATRICES (1000 each for A, B, C)
###############################################################################

def load_big_matrices(n=1000, fname="big_mats.txt"):
    """
    Load 3000 4x4 matrices from file (if available) or generate them.
    Returns three arrays: A_mats, B_mats, C_mats, each of shape (n, 4,4)
    and ground truth computed as A0*B0 + C0.
    """
    total = 3 * n
    if not os.path.exists(fname):
        mats = np.random.rand(total, 4,4).astype(np.float32)
        with open(fname, "w") as f:
            for mat in mats:
                f.write(" ".join(map(str, mat.flatten())) + "\n")
    else:
        mats = []
        with open(fname, "r") as f:
            for line in f:
                vals = list(map(float, line.strip().split()))
                if len(vals) == 16:
                    mats.append(np.array(vals).reshape((4,4)).astype(np.float32))
                if len(mats) >= total:
                    break
        mats = np.array(mats)
    A_mats = mats[0:n]
    B_mats = mats[n:2*n]
    C_mats = mats[2*n:3*n]
    gt = np.matmul(A_mats[0], B_mats[0]) + C_mats[0]
    return A_mats, B_mats, C_mats, gt

###############################################################################
# 6) GRAMMAR-BASED ENVIRONMENT WITH SEMANTIC CHECKS & PARTIAL REWARDS
###############################################################################

class GrammarMultiLoadEnv(gym.Env):
    """
    Grammar-based environment where the agent selects grammar production rules.
    - The grammar is permissive: loads, MAC, and stores may appear in any order.
    - The observation is a fixed-length vector of token indices representing the current parse.
    - Semantic checks are done on each completed instruction (partial reward).
    - At termination, the SASS program is converted to PTX, compiled, run on the GPU,
      and a final reward is assigned based on correctness and runtime.
    """
    def __init__(self, num_registers=100, n_mats_each=1000, max_expansions=30, max_obs_len=500, bigmat_file="big_mats.txt"):
        super().__init__()
        self.grammar = make_grammar(num_registers, n_mats_each)
        self.symbols = ["Program"]  # initial parse state
        self.max_expansions = max_expansions
        self.n_steps = 0
        self.done_flag = False

        # Build action space: flatten all (nonterminal, expansion_index) pairs.
        self.actions_list = []
        self.nt_index = {}
        self._build_action_space()

        # Build a vocabulary of terminal tokens from the grammar.
        self.vocab = build_vocab(self.grammar)
        self.token_to_idx = {token: i for i, token in enumerate(self.vocab)}
        self.idx_to_token = {i: token for i, token in enumerate(self.vocab)}

        # Define observation space as a fixed-length vector of token indices.
        self.max_obs_len = max_obs_len
        self.observation_space = spaces.Box(low=0, high=len(self.vocab)-1, shape=(self.max_obs_len,), dtype=np.int32)
        self.action_space = spaces.Discrete(len(self.actions_list))

        # Load big matrices and ground truth.
        self.A_mats, self.B_mats, self.C_mats, self.ground_truth = load_big_matrices(n_mats_each, bigmat_file)

        # For semantic checks, track register contents (which registers are loaded).
        self.reg_contents = {}
        for r in range(num_registers):
            self.reg_contents[f"R{r}"] = False

        self.reset()

    def _build_action_space(self):
        idx_count = 0
        for nt, expansions in self.grammar.items():
            start_idx = idx_count
            for i, _ in enumerate(expansions):
                self.actions_list.append((nt, i))
                idx_count += 1
            self.nt_index[nt] = (start_idx, idx_count)

    def _get_obs(self):
        # Convert the current parse (self.symbols) into a sequence of token indices.
        # If a symbol is nonterminal, assign a special index (we use 0).
        obs_tokens = []
        for sym in self.symbols:
            if sym in self.grammar:
                # nonterminal: assign 0 (or you can define a special token)
                obs_tokens.append(0)
            else:
                obs_tokens.append(self.token_to_idx.get(sym, 0))
        # Pad to max_obs_len.
        if len(obs_tokens) < self.max_obs_len:
            obs_tokens.extend([0] * (self.max_obs_len - len(obs_tokens)))
        else:
            obs_tokens = obs_tokens[:self.max_obs_len]
        return np.array(obs_tokens, dtype=np.int32)

    def reset(self):
        self.symbols = ["Program"]
        self.n_steps = 0
        self.done_flag = False
        for r in self.reg_contents:
            self.reg_contents[r] = False
        return self._get_obs()

    def possible_actions_for(self, nt):
        if nt not in self.nt_index:
            return []
        start, end = self.nt_index[nt]
        return list(range(start, end))

    def step(self, action):
        if self.done_flag:
            return self._get_obs(), 0.0, True, {}
        # Find leftmost nonterminal
        idx = find_leftmost_nonterminal(self.symbols, self.grammar)
        if idx < 0:
            return self._finalize_parse()
        current_nt = self.symbols[idx]
        valid_actions = self.possible_actions_for(current_nt)
        if action not in valid_actions:
            self.done_flag = True
            return self._get_obs(), -100.0, True, {"error": "InvalidExpansion"}
        # Apply chosen expansion.
        nt, exp_idx = self.actions_list[action]
        expansion = self.grammar[nt][exp_idx]
        self.symbols = self.symbols[:idx] + expansion + self.symbols[idx+1:]
        self.n_steps += 1
        # Partial semantic reward for last completed instruction.
        partial_r = self._semantic_check_last_instruction()
        if self.n_steps >= self.max_expansions:
            return self._finalize_parse()[0], partial_r + self._finalize_parse()[1], True, {}
        return self._get_obs(), partial_r, False, {}

    def _semantic_check_last_instruction(self):
        """
        Check the last complete instruction (ending with ';') for semantic validity.
        Return a small positive reward if valid, negative if not.
        For example:
          - For a Load instruction, mark the target register as loaded (+0.1).
          - For a MAC, check if the source registers are loaded (if yes, +0.2; else -0.2)
          - For a Store, check if the register is loaded (+0.1)
        """
        # Try to reconstruct instructions from self.symbols.
        # We'll simply join symbols and split by ";".
        full_prog = " ".join(self.symbols)
        lines = full_prog.split(";")
        if not lines:
            return 0.0
        last_line = lines[-2] if len(lines) >=2 else lines[-1]  # second last if last is empty
        last_line = last_line.strip()
        reward = 0.0
        if last_line.startswith("LDG"):
            parts = last_line.split()
            if len(parts) >= 2:
                reg = parts[1].replace(",", "")
                if reg in self.reg_contents:
                    self.reg_contents[reg] = True
                    reward += 0.1
        elif last_line.startswith("fma.rn.f32"):
            parts = last_line.split()
            if len(parts) >= 4:
                dest = parts[1].replace(",", "")
                src1 = parts[2].replace(",", "")
                src2 = parts[3].replace(",", "")
                if self.reg_contents.get(src1, False) and self.reg_contents.get(src2, False):
                    if dest in self.reg_contents:
                        self.reg_contents[dest] = True
                    reward += 0.2
                else:
                    reward -= 0.2
        elif last_line.startswith("STG"):
            parts = last_line.split()
            if len(parts) >= 2:
                reg = parts[1].replace(",", "")
                if self.reg_contents.get(reg, False):
                    reward += 0.1
        return reward

    def _finalize_parse(self):
        self.done_flag = True
        if find_leftmost_nonterminal(self.symbols, self.grammar) >= 0:
            return self._get_obs(), -500.0, True, {"error": "IncompleteParse"}
        # Build final SASS code (join terminals that end with ';' or are 'END')
        lines = []
        instr = []
        for sym in self.symbols:
            instr.append(sym)
            if sym.endswith(";") or sym == "END":
                lines.append(" ".join(instr))
                instr = []
        sass_code = "\n".join(lines)
        ptx_str = sass_to_ptx(sass_code)
        ok, err = compile_ptx_to_cubin(ptx_str)
        if not ok:
            return self._get_obs(), -500.0, True, {"error": "CompileFail", "details": err}
        t0 = time.time()
        Dmat, kernel_ok, kernel_time = run_cubin("agent_kernel.cubin")
        total_time = time.time() - t0
        if (not kernel_ok) or (Dmat is None):
            return self._get_obs(), -500.0, True, {"error": "KernelFail"}
        correct = np.allclose(Dmat, self.ground_truth, atol=1e-3)
        if not correct:
            return self._get_obs(), -500.0, True, {"error": "IncorrectResult"}
        final_reward = -total_time
        return self._get_obs(), final_reward, True, {
            "exec_time": total_time,
            "kernel_time": kernel_time,
            "correct": True
        }

###############################################################################
# 7) TRAINING LOOP WITH PPO
###############################################################################

if __name__ == "__main__":
    def make_env():
        return GrammarMultiLoadEnv(num_registers=100, n_mats_each=1000, max_expansions=30, max_obs_len=500, bigmat_file="big_mats.txt")
    env = DummyVecEnv([make_env])
    model = PPO("MlpPolicy", env, verbose=1)
    model.learn(total_timesteps=100000)
    obs = env.reset()
    for _ in range(3):
        done = [False]
        total_reward = 0.0
        while not done[0]:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            total_reward += reward[0]
        print("Episode finished. Total Reward:", total_reward, "Info:", info[0])
        obs = env.reset()




Using cuda device




-----------------------------
| time/              |      |
|    fps             | 704  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 540           |
|    iterations           | 2             |
|    time_elapsed         | 7             |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 6.7346264e-08 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -8.04         |
|    explained_variance   | 0             |
|    learning_rate        | 0.0003        |
|    loss                 | 3.92e+03      |
|    n_updates            | 10            |
|    policy_gradient_loss | -4.43e-05     |
|    value_loss           | 9.2e+03       |
------------------------------------------

In [3]:
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

# --- Assume the following have been defined or imported:
# - GrammarMultiLoadEnv: your grammar-based environment
# - sass_to_ptx(sass_str): function converting a SASS program (as a string) to PTX code.
#
# For this example, we assume these are available in the current scope.
#
# For example:
# from my_rl_env_module import GrammarMultiLoadEnv, sass_to_ptx

# Create the environment (using parameters from our production-level example)
def make_env():
    # Adjust parameters as needed; here we use 100 registers, 1000 matrices per type,
    # a maximum of 50 expansions, and an observation length of 500 tokens.
    return GrammarMultiLoadEnv(num_registers=100, n_mats_each=1000, max_expansions=50, max_obs_len=500, bigmat_file="big_mats.txt")

env = DummyVecEnv([make_env])

# Load the trained model (assume it was saved as "ppo_model.zip")

# Run one generation episode with the trained model
obs = env.reset()
done = False
while not done:
    # Use deterministic actions for generation
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, done, info = env.step(action)

# At this point, the environment’s underlying instance (the first one) has completed its parse.
# Retrieve the final parse (list of symbols) from the underlying environment.
final_parse = env.envs[0].symbols

# Convert the final parse into a SASS program.
# We assume that the SASS program is built by concatenating tokens into instructions,
# breaking at tokens that end with ";" or that equal "END".
sass_lines = []
current_instr = []
for token in final_parse:
    current_instr.append(token)
    if token.endswith(";") or token == "END":
        sass_lines.append(" ".join(current_instr))
        current_instr = []
sass_program = "\n".join(sass_lines)

print("Generated SASS program:")
print(sass_program)

# Convert the SASS program into a valid PTX program.
ptx_program = sass_to_ptx(sass_program)

print("\nGenerated PTX program:")
print(ptx_program)

# Optionally, write the PTX program to a file.
with open("generated_agent.ptx", "w") as f:
    f.write(ptx_program)
print("\nPTX program written to 'generated_agent.ptx'")


Generated SASS program:


Generated PTX program:
.version 7.0
.target sm_75
.address_size 64

.visible .entry matrix_mmac(
    .param .u64 A0_ptr, .param .u64 B0_ptr, .param .u64 C0_ptr, .param .u64 D0_ptr
)
{
    ;
    ret;
}


PTX program written to 'generated_agent.ptx'




In [4]:
import gym
import numpy as np
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

###############################################################################
# 1) SIMPLE GRAMMAR CREATION: Only a MAC6 instruction followed by END.
###############################################################################

def make_grammar():
    """
    Defines a grammar where a valid program is:
      Program -> MAC6 END
      MAC6    -> "MAC6" Register "," Register "," Register "," Register "," Register "," Register ";" 
      END     -> "END"
      Register-> one of R0, R1, R2, R3, R4, R5
    """
    grammar = {}
    grammar["Program"] = [["MAC6", "END"]]
    grammar["MAC6"] = [[
        "MAC6", "Register", ",", "Register", ",", 
        "Register", ",", "Register", ",", "Register", ",", "Register", ";"
    ]]
    grammar["END"] = [["END"]]
    grammar["Register"] = [[f"R{i}"] for i in range(6)]
    return grammar

###############################################################################
# 2) PARSING UTILITIES
###############################################################################

def find_leftmost_nonterminal(symbols, grammar):
    for i, sym in enumerate(symbols):
        if sym in grammar:
            return i
    return -1

###############################################################################
# 3) Build Vocabulary from Grammar Terminals
###############################################################################

def build_vocab(grammar):
    """Return sorted list of all terminal tokens in the grammar."""
    terminals = set()
    for key, expansions in grammar.items():
        for expansion in expansions:
            for token in expansion:
                if token not in grammar:  # terminal token
                    terminals.add(token)
    return sorted(list(terminals))

###############################################################################
# 4) SIMPLE GRAMMAR-BASED ENVIRONMENT
###############################################################################

class SimpleMAC6Env(gym.Env):
    """
    A simplified grammar-based environment. The only valid program is:
        MAC6 R?, R?, R?, R?, R?, R? ; END
    where each R? is one of R0..R5.
    
    The agent selects production rules to expand the leftmost nonterminal.
    A final reward is given if the program exactly matches the valid format.
    """
    def __init__(self, max_expansions=10, max_obs_len=30):
        super().__init__()
        self.grammar = make_grammar()
        self.symbols = ["Program"]  # start symbol
        self.max_expansions = max_expansions
        self.n_steps = 0
        self.done_flag = False
        
        # Build action space: flatten (nonterminal, expansion index) pairs.
        self.actions_list = []
        self.nt_index = {}
        self._build_action_space()
        
        # Build vocabulary and token-index mappings.
        self.vocab = build_vocab(self.grammar)
        self.token_to_idx = {token: i for i, token in enumerate(self.vocab)}
        self.idx_to_token = {i: token for i, token in enumerate(self.vocab)}
        
        # Observation: a fixed-length vector of token indices.
        self.max_obs_len = max_obs_len
        self.observation_space = spaces.Box(low=0, high=len(self.vocab)-1,
                                            shape=(self.max_obs_len,), dtype=np.int32)
        self.action_space = spaces.Discrete(len(self.actions_list))
        
        self.reset()

    def _build_action_space(self):
        idx_count = 0
        for nt, expansions in self.grammar.items():
            start_idx = idx_count
            for i, _ in enumerate(expansions):
                self.actions_list.append((nt, i))
                idx_count += 1
            self.nt_index[nt] = (start_idx, idx_count)

    def _get_obs(self):
        # Convert the current parse (list of symbols) into token indices.
        obs_tokens = []
        for sym in self.symbols:
            # Use a special index for nonterminals (here we use 0, which corresponds to
            # the first terminal in the sorted vocab; in practice you might reserve a unique index).
            if sym in self.grammar:
                obs_tokens.append(0)
            else:
                obs_tokens.append(self.token_to_idx.get(sym, 0))
        # Pad or trim the observation.
        if len(obs_tokens) < self.max_obs_len:
            obs_tokens.extend([0] * (self.max_obs_len - len(obs_tokens)))
        else:
            obs_tokens = obs_tokens[:self.max_obs_len]
        return np.array(obs_tokens, dtype=np.int32)

    def reset(self):
        self.symbols = ["Program"]
        self.n_steps = 0
        self.done_flag = False
        return self._get_obs()

    def possible_actions_for(self, nt):
        if nt not in self.nt_index:
            return []
        start, end = self.nt_index[nt]
        return list(range(start, end))

    def step(self, action):
        if self.done_flag:
            return self._get_obs(), 0.0, True, {}
        # Find leftmost nonterminal.
        idx = find_leftmost_nonterminal(self.symbols, self.grammar)
        if idx < 0:
            # Already complete.
            return self._finalize_parse()
        current_nt = self.symbols[idx]
        valid_actions = self.possible_actions_for(current_nt)
        if action not in valid_actions:
            self.done_flag = True
            return self._get_obs(), -100.0, True, {"error": "InvalidExpansion"}
        # Apply the expansion.
        nt, exp_idx = self.actions_list[action]
        expansion = self.grammar[nt][exp_idx]
        self.symbols = self.symbols[:idx] + expansion + self.symbols[idx+1:]
        self.n_steps += 1
        if self.n_steps >= self.max_expansions:
            return self._finalize_parse()
        # Check if the program is complete.
        if find_leftmost_nonterminal(self.symbols, self.grammar) < 0:
            return self._finalize_parse()
        return self._get_obs(), 0.0, False, {}

    def _finalize_parse(self):
        self.done_flag = True
        # The valid code is expected to be: MAC6 R?, R?, R?, R?, R?, R? ; END
        expected_len = 1 + 6*2 + 1  # "MAC6" + six registers (each register preceded by a comma) + "END"
        # For simplicity, we check that the sequence starts with "MAC6" and ends with "END"
        if self.symbols[0] != "MAC6" or self.symbols[-1] != "END":
            return self._get_obs(), -50.0, True, {"error": "InvalidProgramFormat"}
        # Further, check that exactly one MAC6 production was used:
        # The expected sequence is:
        #   ["MAC6", Register, ",", Register, ",", Register, ",", Register, ",", Register, ",", Register, ";", "END"]
        if len(self.symbols) != 14:
            return self._get_obs(), -50.0, True, {"error": "WrongNumberOfTokens"}
        # Check commas and semicolon positions.
        if (self.symbols[2] != "," or self.symbols[4] != "," or 
            self.symbols[6] != "," or self.symbols[8] != "," or 
            self.symbols[10] != "," or self.symbols[12] != ";"):
            return self._get_obs(), -50.0, True, {"error": "PunctuationError"}
        # If we reach here, the program is well formed.
        # Provide a positive reward (for example, -runtime as a reward, here we simply use +10).
        return self._get_obs(), 10.0, True, {"info": "ValidMAC6Program"}

###############################################################################
# 5) TRAINING LOOP WITH PPO
###############################################################################

if __name__ == "__main__":
    def make_env():
        return SimpleMAC6Env(max_expansions=10, max_obs_len=30)
    env = DummyVecEnv([make_env])
    model = PPO("MlpPolicy", env, verbose=1)
    # Train for a reduced number of timesteps for demonstration.
    model.learn(total_timesteps=20000)
    
    # Test a few episodes.
    for ep in range(3):
        obs = env.reset()
        done = [False]
        total_reward = 0.0
        while not done[0]:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, info = env.step(action)
            total_reward += reward[0]
        print("Episode finished. Total Reward:", total_reward, "Info:", info[0])


Using cuda device
-----------------------------
| time/              |      |
|    fps             | 868  |
|    iterations      | 1    |
|    time_elapsed    | 2    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 585          |
|    iterations           | 2            |
|    time_elapsed         | 6            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0020162982 |
|    clip_fraction        | 4.88e-05     |
|    clip_range           | 0.2          |
|    entropy_loss         | -2.2         |
|    explained_variance   | -0.000834    |
|    learning_rate        | 0.0003       |
|    loss                 | 3.86e+03     |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0241      |
|    value_loss           | 9.07e+03     |
-----------------------------------------