In [None]:
%matplotlib widget
import re
import math
import bisect
import textwrap
import numpy as np

from pathlib import Path
import mplcursors
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Patch

kv_pattern = re.compile(r";\s*([\w/.]+)\s*=\s*(['\"]?)([:,_\w\s/\-\".']+)\2")

In [None]:
KB = 1024
# TPU Specs
# local_mem_size: local memory size per npu
# local_mem_banks: number of local memory banks
# local_mem_size = local_mem_size // local_mem_banks
bm1684x_spec = {
    "local_mem_size": 256 * KB,
    "local_mem_banks": 16 # 16 banks， local_mem_size = local_mem_size // local_mem_banks
}

bm1688_spec = {
    "local_mem_size": 128 * KB,
    "local_mem_banks": 16 # 16 banks， local_mem_size = local_mem_size // local_mem_banks
}

In [None]:
# 正则化
def is_number_regex(s):
    pattern = r"^[-+]?(\d+|\d+\.\d*|\.\d+)([eE][-+]?\d+)?$"
    return re.fullmatch(pattern, s) is not None

def parse_dic(line, filter=None):
    if filter is None:
        filter = set()
    elif isinstance(filter, str):
        filter = set(filter.split(","))
    elif isinstance(filter, list):
        filter = set(filter)

    hit = True
    for f in filter:
        if f not in line:
            hit = False
            break
    if not hit:
        return None

    ret = kv_pattern.findall(line)
    dic = {}
    for k, _, v in ret:
        v = v.strip()
        if is_number_regex(v):
            dic[k] = int(v)
        elif "0x" in v:
            dic[k] = int(v, 16)
        else:
            dic[k] = v
    return dic


In [None]:
# 解析log文件
def dump_lmem_assign_result_with_setttings(log_file):
    allocation_patches = []
    allocation_settings = []
    allocation = []
    allocation_dict = {}
    allocation_keys = ["timestep_start", "timestep_end", "addr", "size", "op_type", "op_name", "lmem_type", "hold_in_lmem", "status", "allow_bank_conflict", "one_loop", "shape_secs"]
    allocation_setting = {}
    allocation_setting_keys = ["allow_bank_conflict", "shape_secs"]
    # allocation_setting_idx = 0
    with open(log_file, 'r') as f:
        for line in f:
            dic = parse_dic(line, ["; action = lmem_assign", "; tag = iteration_result"])
            if dic:
                for key in allocation_keys:
                    if key in dic:
                        allocation_dict[key] = dic[key]
                cur_setting = {}
                for key in allocation_setting_keys:
                    if key in dic:
                        cur_setting[key] = dic[key]
                if allocation_setting:
                    if allocation_setting != cur_setting:
                        allocation_settings.append(allocation_setting)
                        allocation_patches.append(allocation)
                        allocation_setting = {}
                        allocation = []
                    allocation.append(allocation_dict)
                    allocation_dict = {}
                else:
                    allocation.append(allocation_dict)
                    allocation_dict = {}
                    allocation_setting = cur_setting
        if allocation_setting:
            allocation_settings.append(allocation_setting)
            allocation_patches.append(allocation)
    for settings in allocation_settings:
        print(settings)
    return allocation_settings, allocation_patches

def update_allocation(allocation_patches):
    allocation_patches_updated = []
    for allocation in allocation_patches:
        max_addr = 0
        max_timestep = 0
        allocation_success = []
        allocation_failed = []
        allocation_updated = []
        for iter in allocation:
            if iter["status"] == "success":
                allocation_success.append(iter)
            elif iter["status"] == "failed":
                allocation_failed.append(iter)
            max_timestep = max(max_timestep, iter["timestep_end"])
        # calculate max addr
        for iter in allocation_success:
            max_addr = max(max_addr, iter["addr"] + iter["size"])
        # update failed allocation
        for iter in allocation_failed:
            if iter["status"] == "failed":
                iter["addr"] = max_addr
                max_addr += iter["size"]
        allocation_updated = allocation_success + allocation_failed
        # add "max_timestep" to allocation
        for iter in allocation_updated:
            iter["max_timestep"] = max_timestep
        allocation_patches_updated.append(allocation_updated)
    return allocation_patches_updated

def find_allocation_idx(allocation_settings, settings):
    allocation_idx = []
    for idx, allocation in enumerate(allocation_settings):
        if allocation == settings:
            allocation_idx.append(idx)
    print(f"allocation idx: {allocation_idx}")
    return allocation_idx

In [None]:
# plot memory allocation
def plot_memory_allocation(allocations, allocation_setting, chip_spec=None, figsize=(12, 8), **kwargs):
    # 自定义参数
    timestep = kwargs.get("timestep", None)
    timestep_mode = True if timestep else False
    timestep_mode_addr = 0
    print(timestep)

    get_lmem_per_timestep = kwargs.get("get_lmem_per_timestep", False)

    plt.rcParams.update({
        'path.simplify': True,
        'path.simplify_threshold': 1.0,
        'agg.path.chunksize': 10000,
    })

    main_fig = plt.figure(figsize=figsize, dpi=100)
    ax = plt.gca()
    global current_annotation
    current_annotation = None

    all_times = [t for a in allocations for t in (a["timestep_start"], a["timestep_end"])]
    total_timesteps = max(all_times) + 1 if all_times else 1

    lmem_per_timestep = [0] * total_timesteps

    colors = plt.cm.get_cmap('tab20', len(allocations))

    max_addr = max([a["addr"]+a["size"] for a in allocations]) if allocations else 0
    max_memory = bank_size = None
    if chip_spec:
        max_memory = chip_spec["local_mem_size"]
        bank_size = chip_spec["local_mem_size"] // chip_spec["local_mem_banks"]
        y_max = max(max_memory, math.ceil(max_addr / bank_size) * bank_size)
    else:
        y_max = max_addr * 1.1

    hatch_mapping = {
        "LMEM_ACTIVATION": "",
        "LMEM_WEIGHT": "/////",
        "LMEM_OPERATION": "***",
    }

    rects = []
    legend_elements = []

    for idx, allocation in enumerate(allocations):
        s = allocation["timestep_start"]
        e = allocation["timestep_end"]
        addr = allocation["addr"]
        size = allocation["size"]
        op_type = allocation["op_type"]
        op_name = allocation["op_name"]
        lmem_type = allocation["lmem_type"]
        hold_in_lmem = allocation["hold_in_lmem"]
        one_loop = allocation["one_loop"]
        status = allocation["status"]
        color = colors(idx)
        edgecolor = 'black' if status == "success" else 'red'

        if not one_loop and hold_in_lmem:
            for t in range(total_timesteps):
                lmem_per_timestep[t] += size
        else:
            if s <= e:
                for t in range(s, e+1):
                    lmem_per_timestep[t] += size
            else:
                for t in range(s, total_timesteps):
                    lmem_per_timestep[t] += size
                for t in range(0, e+1):
                    lmem_per_timestep[t] += size

        time_ranges = []
        if s <= e:
            time_ranges.append((s, e-s+1))
        else:
            time_ranges.append((s, total_timesteps-s))
            time_ranges.append((0, e+1))
        if hold_in_lmem and not one_loop and not timestep:
            time_ranges.append((0, total_timesteps))
        if timestep:
            res = []
            for t_start, t_duration in time_ranges:
                if t_start <= timestep < t_start + t_duration:
                    res = [(timestep, 1)]
                    break
            time_ranges = res
            if res:
                addr = timestep_mode_addr
                timestep_mode_addr += size

        for t_start, t_duration in time_ranges:
            rect = Rectangle(
                (t_start, addr),
                t_duration,
                size,
                facecolor=color,
                edgecolor=edgecolor,
                linewidth=0.5,
                alpha=0.7,
                picker=True,
                snap=True,
                hatch=hatch_mapping.get(lmem_type, ""),
            )

            rect.block_info = {
                "block_id": idx,
                "global_start": s,
                "global_end": e,
                "segment_start": t_start,
                "segment_duration": t_duration,
                "address": addr,
                "size": size,
                "color": color,
                "op_type": op_type,
                "op_name": op_name,
                "lmem_type": lmem_type,
                "hold_in_lmem": "True" if hold_in_lmem else "False"
            }
            ax.add_patch(rect)
            rects.append(rect)

        legend_elements.append(Patch(
            facecolor=color,
            label=f'Block {idx}: {size}B @ {addr}',
        ))

    cursor = mplcursors.cursor(
        rects,
        hover=False,
        highlight=False
    )

    @cursor.connect("add")
    def on_hover(sel):
        global current_annotation
        sel.annotation.set_visible(False)

        if current_annotation:
            try:
                current_annotation.remove()
            except (ValueError, AttributeError):
                pass
            current_annotation = None

        rect = sel.artist
        info = rect.block_info

        rgb = matplotlib.colors.to_rgb(info['color'])
        luminance = 0.299*rgb[0] + 0.587*rgb[1] + 0.114*rgb[2]
        text_color = 'black' if luminance > 0.5 else 'white'
        xtext = 15 if info['segment_start'] < total_timesteps // 2 else -200
        ytext = 15 if info['address'] < y_max // 2 else -100

        current_annotation = ax.annotate(
            text=build_tooltip(info),
            xy=(info['segment_start'], info['address']),
            xytext=(xtext, ytext),
            textcoords='offset points',
            bbox=dict(
                boxstyle="round,pad=0.3",
                facecolor=matplotlib.colors.to_rgba(info['color'], 0.9),
                edgecolor='black',
                linewidth=0.5
            ),
            arrowprops=dict(
                arrowstyle="->",
                connectionstyle="arc3,rad=0.3",
                color='black'
            ),
            fontsize=9,
            color=text_color
        )
        main_fig.canvas.draw_idle()

    def build_tooltip(info):
        return (
            f"Block ID: {info['block_id']}\n"
            f"Timestep: {info['global_start']}→{info['global_end']}\n"
            f"Address: {info['address']}-{info['address']+info['size']}\n"
            f"Size: {info['size']}B ({info['size']/1024:.1f} KB)\n"
            f"Op Type: {info['op_type']}\n"
            f"Op Name: {textwrap.fill(info['op_name'], width=50)}\n"
            f"LMEM Type: {info['lmem_type']}\n"
            f"Hold in LMEM: {info['hold_in_lmem']}"
        )

    def on_leave(event):
        if current_annotation and current_annotation.get_visible():
            current_annotation.remove()
            main_fig.canvas.draw_idle()

    main_fig.canvas.mpl_connect('axes_leave_event', on_leave)

    for ts in range(total_timesteps + 1):
        ax.axvline(ts, color='gray', linestyle='--', alpha=0.3)

    if max_memory is not None:
        ax.axhline(y=max_memory,
                  color='black',
                  linestyle='--',
                  linewidth=2)

    ax.set_ylim(0, y_max)
    ax.set_ylabel("Memory Address")

    if bank_size:
        num_banks = y_max // bank_size
        ax.set_yticks([i*bank_size for i in range(num_banks+1)])

        for bank in range(1, num_banks):
            ax.axhline(y=bank*bank_size,
                      color='gray',
                      linestyle='--',
                      linewidth=1.5,
                      alpha=0.5)

        ax2 = ax.twinx()
        ax2.set_ylim(ax.get_ylim())
        ax2.set_yticks([(i+0.5)*bank_size for i in range(num_banks)])
        ax2.set_yticklabels([f"Bank {i}" for i in range(num_banks)],
                            fontdict={
                                'fontsize': 8,
                                'color': 'gray',
                                'alpha': 0.7
                            })
        ax2.set_ylabel('Memory Banks')

    spec_text = f"SPEC: LmemSize(Byte)={max_memory or 'N/A'}, BankSize(Byte)={bank_size or 'N/A'}"

    setting_text = "\nSETTINGS: "
    for k, v in allocation_setting.items():
        setting_text += f"{k}={v}, "
    setting_text = setting_text[:-2]
    ax.set_xlabel("Timestep")
    ax.set_title(f"TPU Memory Allocation \n{spec_text}{setting_text}", pad=30)
    ax.set_xticks([i + 0.5 for i in range(total_timesteps)])
    ax.set_xticklabels([str(i) for i in range(total_timesteps)])
    ax.set_xlim(0, total_timesteps)

    if max_memory:
        legend_elements.append(Patch(
            facecolor='white',
            edgecolor='black',
            linestyle='--',
            linewidth=2,
            label=f'Max Memory ({max_memory}B)'
        ))
    for lmem_type, hatch in hatch_mapping.items():
        legend_elements.append(Patch(
            facecolor='white',
            edgecolor='black',
            linestyle='-',
            hatch=hatch,
            linewidth=0.5,
            label=f'{lmem_type}'
        ))

    ax.legend(handles=legend_elements,
             bbox_to_anchor=(-0.1, 1.2, 0, 0),
             loc='upper right',
             borderaxespad=3,
             fontsize=7,
    )

    def on_resize(event):
        ax = event.canvas.figure.axes[0]
        fig_width = event.width / 100
        legend = ax.legend(prop={'size': fig_width * 0.8})
        event.canvas.draw()

    main_fig.canvas.mpl_connect('resize_event', on_resize)

    plt.tight_layout()

    hist_fig = None
    if get_lmem_per_timestep:
        hist_fig = plt.figure(figsize=(figsize[0], figsize[1]//2))
        hist_ax= hist_fig.gca()

        timesteps = list(range(len(lmem_per_timestep)))
        bars = hist_ax.bar(timesteps, lmem_per_timestep, width=0.8,
                          edgecolor='black', alpha=0.7)

        for bar in bars:
            height = bar.get_height()
            if height > 0:
                hist_ax.text(bar.get_x() + bar.get_width()/2., height,
                            f'{height}B',
                            ha='center', va='bottom', fontsize=8)

        hist_ax.set_xlabel("Timestep")
        hist_ax.set_ylabel("Memory Usage (Bytes)")
        hist_ax.set_title("Memory Usage Per Timestep", pad=15)
        hist_ax.set_xticks(timesteps)
        hist_ax.grid(axis='y', linestyle='--', alpha=0.7)
        hist_fig.tight_layout()

    return (main_fig, hist_fig) if get_lmem_per_timestep else main_fig

def plot_memory_allocation_with_idx(allocation_patches, allocation_settings, allocation_idx, chip_spec=None, figsize=(12, 8), **kwargs):
    timestep = kwargs.get("timestep", None)
    get_lmem_per_timestep = kwargs.get("get_lmem_per_timestep", False)
    for idx in allocation_idx:
        allocations = allocation_patches[idx]
        setting = allocation_settings[idx]
        print(f"allocation idx: {idx}\nsettings: {setting}")
        if get_lmem_per_timestep:
            fig, hist_fig = plot_memory_allocation(allocations, setting, chip_spec, figsize, timestep=timestep, get_lmem_per_timestep=True)
        else:
            fig = plot_memory_allocation(allocations, setting, chip_spec, figsize, timestep=timestep)

In [None]:
log_file = "lmem_assign_result.log"
allocation_settings, allocation_patches = dump_lmem_assign_result_with_setttings(log_file)
allocation_patches_updated = update_allocation(allocation_patches)

In [None]:
# set according to info above
setting = {
    "allow_bank_conflict": False, # manually setting
    "shape_secs": "8,7,1,1,1"     # manually setting
}
allocation_idx = find_allocation_idx(allocation_settings, setting)
plot_memory_allocation_with_idx(allocation_patches_updated, allocation_settings, allocation_idx, bm1688_spec, (12, 8), get_lmem_per_timestep=True)

In [None]:
plot_memory_allocation_with_idx(allocation_patches_updated, allocation_settings, allocation_idx, bm1688_spec, (12, 8), timestep=5)