In [3]:

# Sample Python code to analyze
sample_code = """
def e2e_eval(gt_dir, res_dir, ignore_blank=False):
    print('start testing...')
    iou_thresh = 0.5
    val_names = os.listdir(gt_dir)
    num_gt_chars = 0
    gt_count = 0
    dt_count = 0
    hit = 0
    ed_sum = 0
    for i, val_name in enumerate(val_names):
        with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
            gt_lines = [o.strip() for o in f.readlines()]
        gts = []
        ignore_masks = []
        for line in gt_lines:
            parts = line.strip().split('\t')
            # ignore illegal data
            if len(parts) < 9:
                continue
            assert (len(parts) < 11)
            if len(parts) == 9:
                gts.append(parts[:8] + [''])
            else:
                gts.append(parts[:8] + [parts[-1]])
            ignore_masks.append(parts[8])
        val_path = os.path.join(res_dir, val_name)
        if not os.path.exists(val_path):
            dt_lines = []
        else:
            with open(val_path, encoding='utf-8') as f:
                dt_lines = [o.strip() for o in f.readlines()]
        dts = []
        for line in dt_lines:
            # print(line)
            parts = line.strip().split("\t")
            assert (len(parts) < 10), "line error: {}".format(line)
            if len(parts) == 8:
                dts.append(parts + [''])
            else:
                dts.append(parts)
        dt_match = [False] * len(dts)
        gt_match = [False] * len(gts)
        all_ious = defaultdict(tuple)
        for index_gt, gt in enumerate(gts):
            gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
            gt_poly = polygon_from_str(gt_coors)
            for index_dt, dt in enumerate(dts):
                dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
                dt_poly = polygon_from_str(dt_coors)
                iou = polygon_iou(dt_poly, gt_poly)
                if iou >= iou_thresh:
                    all_ious[(index_gt, index_dt)] = iou
        sorted_ious = sorted(
            all_ious.items(), key=operator.itemgetter(1), reverse=True)
        sorted_gt_dt_pairs = [item[0] for item in sorted_ious]

        # matched gt and dt
        for gt_dt_pair in sorted_gt_dt_pairs:
            index_gt, index_dt = gt_dt_pair
            if gt_match[index_gt] == False and dt_match[index_dt] == False:
                gt_match[index_gt] = True
                dt_match[index_dt] = True
                if ignore_blank:
                    gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
                    dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
                else:
                    gt_str = strQ2B(gts[index_gt][8])
                    dt_str = strQ2B(dts[index_dt][8])
                if ignore_masks[index_gt] == '0':
                    ed_sum += ed(gt_str, dt_str)
                    num_gt_chars += len(gt_str)
                    if gt_str == dt_str:
                        hit += 1
                    gt_count += 1
                    dt_count += 1

        # unmatched dt
        for tindex, dt_match_flag in enumerate(dt_match):
            if dt_match_flag == False:
                dt_str = dts[tindex][8]
                gt_str = ''
                ed_sum += ed(dt_str, gt_str)
                dt_count += 1

        # unmatched gt
        for tindex, gt_match_flag in enumerate(gt_match):
            if gt_match_flag == False and ignore_masks[tindex] == '0':
                dt_str = ''
                gt_str = gts[tindex][8]
                ed_sum += ed(gt_str, dt_str)
                num_gt_chars += len(gt_str)
                gt_count += 1

    eps = 1e-9
    print('hit, dt_count, gt_count', hit, dt_count, gt_count)
    precision = hit / (dt_count + eps)
    recall = hit / (gt_count + eps)
    fmeasure = 2.0 * precision * recall / (precision + recall + eps)
    avg_edit_dist_img = ed_sum / len(val_names)
    avg_edit_dist_field = ed_sum / (gt_count + eps)
    character_acc = 1 - ed_sum / (num_gt_chars + eps)
    print('character_acc: %.2f' % (character_acc * 100) + "%")
    print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
    print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
    print('precision: %.2f' % (precision * 100) + "%")
    print('recall: %.2f' % (recall * 100) + "%")
    print('fmeasure: %.2f' % (fmeasure * 100) + "%")
"""

In [9]:
import ast

class ScopeHeatmapGenerator(ast.NodeVisitor):
    def __init__(self, code):
        self.code = code
        self.scopes = []  # A list of sets, where each set represents variables in a scope
        self.active_vars_per_line = [0] * (len(code.split('\n')) + 1)  # Initialize the count of active variables per line

    def enter_scope(self):
        if not self.scopes:
            self.scopes.append(set())  # Create the first global scope
        else:
            # New scope inherits all variables from the previous scope
            self.scopes.append(self.scopes[-1].copy())

    def exit_scope(self):
        self.scopes.pop()  # Remove the top scope when exiting

    def visit_FunctionDef(self, node):
        self.enter_scope()
        for arg in node.args.args:
            self.scopes[-1].add(arg.arg)
        self.generic_visit(node)
        self.exit_scope()

    def visit_AsyncFunctionDef(self, node):
        self.enter_scope()
        for arg in node.args.args:
            self.scopes[-1].add(arg.arg)
        self.generic_visit(node)
        self.exit_scope()

    def visit_For(self, node):
        self.enter_scope()
        if isinstance(node.target, ast.Name):
            self.scopes[-1].add(node.target.id)
        self.generic_visit(node)
        self.exit_scope()

    def visit_Assign(self, node):
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.scopes[-1].add(target.id)
        self.generic_visit(node)

    def generic_visit(self, node):
        """Called if no explicit visitor function exists for a node."""
        if hasattr(node, 'lineno'):
            # Count the number of variables in scope at this line
            self.active_vars_per_line[node.lineno] = len(self.scopes[-1])
        super().generic_visit(node)

    def analyze(self):
        """Analyze the given code to generate the scope heatmap."""
        # Parse the code into an AST
        tree = ast.parse(self.code)
        # Start with the global scope
        self.enter_scope()
        # Visit all nodes
        self.visit(tree)
        # Exit the global scope
        self.exit_scope()
        
        # Normalize the heatmap data to remove any lines beyond the code length
        code_length = len(self.code.split('\n'))
        self.active_vars_per_line = self.active_vars_per_line[:code_length + 1]

        return self.active_vars_per_line

# Initialize the heatmap generator with the sample code
heatmap_generator = ScopeHeatmapGenerator(sample_code)

# Generate the heatmap
heatmap_data = heatmap_generator.analyze()
heatmap_data

[0,
 0,
 3,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 10,
 10,
 11,
 12,
 13,
 14,
 15,
 0,
 15,
 15,
 15,
 15,
 15,
 0,
 15,
 15,
 14,
 14,
 15,
 0,
 15,
 15,
 16,
 17,
 0,
 18,
 18,
 18,
 18,
 0,
 18,
 17,
 18,
 19,
 19,
 20,
 21,
 21,
 22,
 23,
 24,
 24,
 24,
 20,
 20,
 21,
 0,
 0,
 22,
 22,
 22,
 22,
 22,
 22,
 23,
 24,
 0,
 24,
 24,
 24,
 24,
 24,
 24,
 24,
 24,
 24,
 0,
 0,
 21,
 21,
 22,
 23,
 23,
 23,
 0,
 0,
 21,
 21,
 22,
 23,
 23,
 23,
 23,
 0,
 11,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 17,
 17,
 17,
 17,
 17,
 17,
 0]

In [11]:
def visualize_heatmap(code, heatmap_data):
    lines = code.split('\n')
    max_heat = max(heatmap_data)
    # Normalize heat levels to a scale of 0-10
    normalized_heatmap = [int((heat / max_heat) * 10) if max_heat > 0 else 0 for heat in heatmap_data]

    # Generate visualization for each line
    for i, line in enumerate(lines, start=1):
        heat_bar = '█' * normalized_heatmap[i - 1]
        print(f"{str(i).rjust(2)}: {line[:80].ljust(80)} | {heat_bar}")

# Visualize the heatmap alongside the code
visualize_heatmap(sample_code, heatmap_data)

 1:                                                                                  | 
 2: def e2e_eval(gt_dir, res_dir, ignore_blank=False):                               | 
 3:     print('start testing...')                                                    | █
 4:     iou_thresh = 0.5                                                             | █
 5:     val_names = os.listdir(gt_dir)                                               | █
 6:     num_gt_chars = 0                                                             | ██
 7:     gt_count = 0                                                                 | ██
 8:     dt_count = 0                                                                 | ██
 9:     hit = 0                                                                      | ███
10:     ed_sum = 0                                                                   | ███
11:     for i, val_name in enumerate(val_names):                                     | ████
12:         w