In [184]:
import os
import numpy as np
import matplotlib.pyplot as plt
from flax.traverse_util import flatten_dict
from flax.core import frozen_dict
from brax.io import model
from brax.training.acme.running_statistics import RunningStatisticsState


In [185]:
def load_policy_params(param_path):
    """
    Load Brax parameters, automatically extract and return policy network weights (FrozenDict) for visualization.
    """
    raw = model.load_params(param_path)

    # If it's already a FrozenDict, return it directly
    if isinstance(raw, frozen_dict.FrozenDict):
        print("Detected FrozenDict.")
        return raw

    # If it's a plain dict, try to wrap it
    if isinstance(raw, dict):
        print("Detected plain dict, wrapping as FrozenDict.")
        return frozen_dict.freeze(raw)

    # If it's a tuple, try to find the valid weight part
    if isinstance(raw, tuple):
        print("Detected tuple, searching for valid weight params...")
        for i, part in enumerate(raw):
            if isinstance(part, frozen_dict.FrozenDict):
                print(f"Found FrozenDict at index {i}")
                return part
            elif isinstance(part, dict):
                print(f"Found dict at index {i}, attempting wrap...")
                return frozen_dict.freeze(part)
            elif isinstance(part, RunningStatisticsState):
                print(f"Skipping normalizer state at index {i}")
        raise ValueError("No valid network parameters found for visualization")
    
    raise TypeError(f"Unrecognized parameter type: {type(raw)}")


In [186]:
def visualize_weight_heatmaps(params, output_dir='viz_params', max_layers=20, cmap='coolwarm'):
    """
    Visualize Flax or dict weight parameters (2D only), with a fixed color scale (-1 to 1).
    Each layer is saved as an individual heatmap PNG.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from flax.traverse_util import flatten_dict
    from flax.core import frozen_dict

    os.makedirs(output_dir, exist_ok=True)

    if isinstance(params, dict):
        print("Wrapping dict as FrozenDict...")
        params = frozen_dict.freeze(params)

    assert hasattr(params, 'keys'), f"Invalid parameter type: {type(params)}"

    flat_params = flatten_dict(params, sep='/')
    count = 0
    vmin, vmax = -1.0, 1.0

    for name, value in flat_params.items():
        val_np = np.array(value)
        if val_np.ndim == 2:
            plt.figure(figsize=(6, 4))
            plt.imshow(val_np, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
            plt.colorbar(label='Weight Value')
            plt.title(f'Layer: {name}')
            plt.xlabel('Output Dim')
            plt.ylabel('Input Dim')
            plt.tight_layout()

            filename = name.replace('/', '_') + '.png'
            filepath = os.path.join(output_dir, filename)
            plt.savefig(filepath)
            plt.close()

            print(f"Saved heatmap: {filepath}")
            count += 1
            if count >= max_layers:
                break


In [187]:
def plot_summary_heatmap(params, output_path='summary_heatmap.png', cmap='coolwarm', max_layers=50, scale=2.0):
    """
    Generate a high-resolution summary heatmap image of all 2D weight matrices in the model.
    Each matrix's display size reflects its actual dimensions.
    Color scale is fixed to [-1, 1] for comparability.
    """
    import os
    import numpy as np
    import matplotlib.pyplot as plt
    from flax.traverse_util import flatten_dict
    from flax.core import frozen_dict

    if isinstance(params, dict):
        print("Wrapping dict as FrozenDict...")
        params = frozen_dict.freeze(params)

    flat_params = flatten_dict(params, sep='/')
    heatmaps = []
    labels = []
    shapes = []

    count = 0
    for name, value in flat_params.items():
        val_np = np.array(value)
        if val_np.ndim == 2:
            heatmaps.append(val_np)
            labels.append(f"{name} {val_np.shape}")
            shapes.append(val_np.shape)
            count += 1
            if count >= max_layers:
                break

    if not heatmaps:
        print("No 2D parameters found for summary heatmap.")
        return

    # Fixed global color scale for comparability
    vmin, vmax = -1.0, 1.0

    # Set resolution
    dpi = 100
    max_width = max(w for _, w in shapes)
    total_height = sum(h for h, _ in shapes)

    fig_width = max_width * scale / dpi
    fig_height = total_height * scale / dpi
    fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)

    current_y = 1.0
    for mat, label, (h, w) in zip(heatmaps, labels, shapes):
        rel_height = h / total_height
        ax = fig.add_axes([0.0, current_y - rel_height, 1.0, rel_height])
        ax.imshow(mat, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
        ax.set_title(label, fontsize=9 * scale)
        ax.axis('off')
        current_y -= rel_height

    fig.savefig(output_path, dpi=dpi, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved high-res summary heatmap: {output_path}")


In [188]:
from brax.io import model

raw = model.load_params('./params_good')
print("Loaded type:", type(raw))

if isinstance(raw, tuple):
    print("Tuple length:", len(raw))
    for i, part in enumerate(raw):
        print(f" - Index {i}: type={type(part)}")


Loaded type: <class 'tuple'>
Tuple length: 3
 - Index 0: type=<class 'brax.training.acme.running_statistics.RunningStatisticsState'>
 - Index 1: type=<class 'dict'>
 - Index 2: type=<class 'dict'>


In [189]:


params = model.load_params("params_good")
index1 = params[1]

print("index 1 type:", type(index1))
print("Top-level keys:", list(index1.keys()))

for key, value in index1.items():
    print(f"\nKey: {key}")
    print("  Type:", type(value))
    if hasattr(value, 'keys'):
        print("  Subkeys:", list(value.keys()))
    else:
        print("  Content:", str(value)[:300])  # Only show first 300 characters to avoid too long output


index 1 type: <class 'dict'>
Top-level keys: ['params']

Key: params
  Type: <class 'dict'>
  Subkeys: ['hidden_0', 'hidden_1', 'hidden_2', 'hidden_3', 'hidden_4']


In [190]:
param_path = 'params_bad'  # Your parameter path
params = load_policy_params(param_path)
visualize_weight_heatmaps(params, output_dir="viz_policy")
plot_summary_heatmap(params, output_path='viz_policy/summary.png')

Detected tuple, searching for valid weight params...
Skipping normalizer state at index 0
Found dict at index 1, attempting wrap...
Saved heatmap: viz_policy/params_hidden_0_kernel.png
Saved heatmap: viz_policy/params_hidden_1_kernel.png
Saved heatmap: viz_policy/params_hidden_2_kernel.png
Saved heatmap: viz_policy/params_hidden_3_kernel.png
Saved heatmap: viz_policy/params_hidden_4_kernel.png
Saved high-res summary heatmap: viz_policy/summary.png
