In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from pathlib import Path
import os
if Path.cwd().name == "test_scripts":
    os.chdir(Path.cwd().parent)

In [8]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx

import os
import shutil
from pathlib import Path
import pickle
from typing import Optional
import subprocess


# from rich import print
from rich.pretty import pprint as pp

In [9]:
from src.analysis.parallelism_v4 import update_graph
from src.analysis.latency.engine import DFG, StreamIOType

In [10]:
model_fp = Path("./src/test_files_256/new_grad_lastdim_small.pkl")

with open(model_fp, "rb") as f:
    G_data = pickle.load(f)

print(G_data)

DiGraph with 108 nodes and 131 edges


In [11]:
model_dfg_fp = model_fp.with_stem(f"{model_fp.stem}_dfg")
try:
    with open(model_dfg_fp, "rb") as f:
        print(f"Loading cached DFG from {model_dfg_fp.absolute()}")
        dfg: DFG = pickle.load(f)
except FileNotFoundError:
    dfg = update_graph(G_data)
    with open(model_dfg_fp, "wb") as f:
        print(f"Saving DFG to {model_dfg_fp.absolute()}")
        pickle.dump(dfg, f)

Updating nodes: 100%|██████████| 108/108 [00:51<00:00,  2.10it/s]


Loading cached DFG from /export/hdd/scratch/rsarkar30/INR/inr-dsp/src/test_files_256/new_grad_lastdim_small_dfg.pkl


In [12]:
assert not dfg.has_cycle(), "Original DFG should not have cycles"

In [13]:
assert dfg.with_depths({stream.name: 2 for stream in dfg.nodes.streams}).has_cycle(), "New DFG should have a cycle"

In [24]:
IGNORED_STREAM_SUFFIXES = ("__out_stream", "__temp_stream")
depths = {stream.name: 2 for stream in dfg.nodes.streams if not stream.name.endswith(IGNORED_STREAM_SUFFIXES)}
while True:
    participants =  dfg.with_depths(depths).get_cycle_participants()
    if not participants:
        break
    participants = set(node.stream for node in participants if node.stream.name in depths)
    print(f"Deadlock found involving {len(participants)} stream(s) ({len(participants) / len(depths):.1%} of total)")
    for participant in participants:
        depths[participant.name] *= 2
    print("    Increased participant depths; new maximum depth:", max(depths.values()))
print("Found a configuration with no deadlock!")
print("     Maximum depth:", max(depths.values()))
print("     Sum of depths:", sum(depths.values()))
print("    Total increase:", sum(depths.values()) - (2 * len(depths)))

Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 4
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 8
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 16
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 32
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 64
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 128
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 256
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 512
Deadlock found involving 76 stream(s) (52.4% of total)
    Increased participant depths; new maximum depth: 102

In [35]:
reduced = 0
attempted = 0
for i, (name, depth) in enumerate(depths.items()):
    if depth > 2:
        attempted += 1
        test_depths = {**depths, name: 2}
        if not dfg.with_depths(test_depths).has_cycle():
            reduced += 1
            print(f"({i + 1}/{len(depths)}) Stream {name} can be safely reduced to depth 16384. ({reduced}/{attempted} reduced so far)")
            depths = test_depths
        else:
            print(f"({i + 1}/{len(depths)}) Stream {name} cannot be safely reduced to depth 16384. ({reduced}/{attempted} reduced so far)")
    # else:
    #     print(f"({i + 1}/{len(depths)}) Stream {name} is already depth <=16384. ({reduced}/{attempted} reduced so far)")

(9/145) Stream fn_Cos_2__in_0_stream cannot be safely reduced to depth 16384. (0/1 reduced so far)
(15/145) Stream fn_Cos_10__in_0_stream cannot be safely reduced to depth 16384. (0/2 reduced so far)
(19/145) Stream fn_Mul_23__in_1_stream cannot be safely reduced to depth 16384. (0/3 reduced so far)
(23/145) Stream fn_Cos_9__in_0_stream cannot be safely reduced to depth 16384. (0/4 reduced so far)
(25/145) Stream fn_Cos_1__in_0_stream cannot be safely reduced to depth 16384. (0/5 reduced so far)
(33/145) Stream fn_Mul_9__in_1_stream cannot be safely reduced to depth 16384. (0/6 reduced so far)
(37/145) Stream fn_Mul_7__in_1_stream cannot be safely reduced to depth 16384. (0/7 reduced so far)
(38/145) Stream fn_Mul_17__in_1_stream cannot be safely reduced to depth 16384. (0/8 reduced so far)
(44/145) Stream fn_Mul_25__in_1_stream cannot be safely reduced to depth 16384. (0/9 reduced so far)
(55/145) Stream fn_Cos_11__in_0_stream cannot be safely reduced to depth 16384. (0/10 reduced so 

In [36]:
{k: v for k, v in depths.items() if v > 2}

{'fn_Cos_2__in_0_stream': 32768,
 'fn_Cos_10__in_0_stream': 32768,
 'fn_Mul_23__in_1_stream': 32768,
 'fn_Cos_9__in_0_stream': 32768,
 'fn_Cos_1__in_0_stream': 32768,
 'fn_Mul_9__in_1_stream': 32768,
 'fn_Mul_7__in_1_stream': 32768,
 'fn_Mul_17__in_1_stream': 32768,
 'fn_Mul_25__in_1_stream': 32768,
 'fn_Cos_11__in_0_stream': 32768,
 'fn_Mul_5__in_1_stream': 32768,
 'fn_Cos_6__in_0_stream': 32768,
 'fn_Cos_5__in_0_stream': 32768,
 'fn_Mul_15__in_1_stream': 32768,
 'fn_Mul_21__in_1_stream': 32768,
 'fn_Cos_3__in_0_stream': 32768,
 'fn_Mul_13__in_1_stream': 32768,
 'fn_Cos_7__in_0_stream': 32768}

In [9]:
best_so_far = dfg.with_depths({'fn_Cos_2__in_0_stream': 32768,
 'fn_Cos_10__in_0_stream': 32768,
 'fn_Mul_23__in_1_stream': 32768,
 'fn_Cos_9__in_0_stream': 32768,
 'fn_Cos_1__in_0_stream': 32768,
 'fn_Mul_9__in_1_stream': 32768,
 'fn_Mul_7__in_1_stream': 32768,
 'fn_Mul_17__in_1_stream': 32768,
 'fn_Mul_25__in_1_stream': 32768,
 'fn_Cos_11__in_0_stream': 32768,
 'fn_Mul_5__in_1_stream': 32768,
 'fn_Cos_6__in_0_stream': 32768,
 'fn_Cos_5__in_0_stream': 32768,
 'fn_Mul_15__in_1_stream': 32768,
 'fn_Mul_21__in_1_stream': 32768,
 'fn_Cos_3__in_0_stream': 32768,
 'fn_Mul_13__in_1_stream': 32768,
 'fn_Cos_7__in_0_stream': 32768})

In [None]:
from datetime import datetime
print(f"Starting at {datetime.now()!s}")
print("Latency:", best_so_far.get_latency(), "cycles")
print(f"Finished at {datetime.now()!s}")

Starting at 2023-04-10 01:06:22.736887


In [10]:
# [(n, np.prod(x["shape"])) for n, x in G_data.nodes.items()]

In [15]:
# for i in range(0, 1024 * 1024, 1024):
#     print(f"Trying depth {i}... ", end="")
#     if dfg.with_depths({stream.name: i for stream in dfg.nodes.streams}).has_cycle():
#         print("cycle found")
#     else:
#         print("no cycle found")
#         break

# Results: no cycle found at depth 22528

Trying depth 0... cycle found
Trying depth 1024... cycle found
Trying depth 2048... cycle found
Trying depth 3072... cycle found
Trying depth 4096... cycle found
Trying depth 5120... cycle found
Trying depth 6144... cycle found
Trying depth 7168... cycle found
Trying depth 8192... cycle found
Trying depth 9216... cycle found
Trying depth 10240... cycle found
Trying depth 11264... cycle found
Trying depth 12288... cycle found
Trying depth 13312... cycle found
Trying depth 14336... cycle found
Trying depth 15360... cycle found
Trying depth 16384... cycle found
Trying depth 17408... cycle found
Trying depth 18432... cycle found
Trying depth 19456... cycle found
Trying depth 20480... cycle found
Trying depth 21504... cycle found
Trying depth 22528... no cycle found


In [22]:
# for i in range(0, 1024 * 1024, 1024):
#     print(f"Trying depth {i}... ", end="")
#     if dfg.with_depths({stream.name: (i if not stream.name.endswith(("__out_stream", "__temp_stream")) else 2) for stream in dfg.nodes.streams}).has_cycle():
#         print("cycle found")
#     else:
#         print("no cycle found")
#         break

# Results: no cycle found at depth 32768

Trying depth 0... cycle found
Trying depth 1024... cycle found
Trying depth 2048... cycle found
Trying depth 3072... cycle found
Trying depth 4096... cycle found
Trying depth 5120... cycle found
Trying depth 6144... cycle found
Trying depth 7168... cycle found
Trying depth 8192... cycle found
Trying depth 9216... cycle found
Trying depth 10240... cycle found
Trying depth 11264... cycle found
Trying depth 12288... cycle found
Trying depth 13312... cycle found
Trying depth 14336... cycle found
Trying depth 15360... cycle found
Trying depth 16384... cycle found
Trying depth 17408... cycle found
Trying depth 18432... cycle found
Trying depth 19456... cycle found
Trying depth 20480... cycle found
Trying depth 21504... cycle found
Trying depth 22528... cycle found
Trying depth 23552... cycle found
Trying depth 24576... cycle found
Trying depth 25600... cycle found
Trying depth 26624... cycle found
Trying depth 27648... cycle found
Trying depth 28672... cycle found
Trying depth 29696... cycle

In [27]:
# for i in range(32765, 32767):
#     print(f"Trying depth {i}... ", end="")
#     if dfg.with_depths({stream.name: (i if not stream.name.endswith(("__out_stream", "__temp_stream")) else 2) for stream in dfg.nodes.streams}).has_cycle():
#         print("cycle found")
#     else:
#         print("no cycle found")
#         break

# Results:
# Trying depth 32765... cycle found
# Trying depth 32766... no cycle found

Trying depth 32765... cycle found
Trying depth 32766... no cycle found


In [None]:
# {stream.name: (i if not stream.name.endswith(("__out_stream", "__temp_stream")) else 2) for stream in dfg.nodes.streams}

In [13]:
# lo = 1024
# hi = 524800
# while hi > lo:
#     mid = (hi + lo) // 2
#     print(f"Trying depth {mid}..", end="")
#     dfg_modified = dfg.with_depths({stream.name: (mid if not stream.name.endswith(("__out_stream", "__temp_stream")) else 2) for stream in dfg.nodes.streams})
#     print(". ", end="")
#     if dfg_modified.has_cycle():
#         print("cycle found")
#         lo = mid + 1
#     else:
#         print("no cycle found")
#         hi = mid

# Results (on bs=4096):
# Trying depth 262912... cycle found
# Trying depth 393856... cycle found
# Trying depth 459328... cycle found
# Trying depth 492064... cycle found
# Trying depth 508432... cycle found
# Trying depth 516616... cycle found
# Trying depth 520708... cycle found
# Trying depth 522754... cycle found
# Trying depth 523777... cycle found
# Trying depth 524289... no cycle found
# Trying depth 524033... cycle found
# Trying depth 524161... cycle found
# Trying depth 524225... cycle found
# Trying depth 524257... cycle found
# Trying depth 524273... cycle found
# Trying depth 524281... cycle found
# Trying depth 524285... cycle found
# Trying depth 524287... no cycle found
# Trying depth 524286... no cycle found

Trying depth 262912... cycle found
Trying depth 393856... cycle found
Trying depth 459328... cycle found
Trying depth 492064... cycle found
Trying depth 508432... cycle found
Trying depth 516616... cycle found
Trying depth 520708... cycle found
Trying depth 522754... cycle found
Trying depth 523777... cycle found
Trying depth 524289... no cycle found
Trying depth 524033... cycle found
Trying depth 524161... cycle found
Trying depth 524225... cycle found
Trying depth 524257... cycle found
Trying depth 524273... cycle found
Trying depth 524281... cycle found
Trying depth 524285... cycle found
Trying depth 524287... no cycle found
Trying depth 524286... no cycle found


In [11]:
from scipy.sparse.csgraph import connected_components
# dfg2 = dfg.with_depths({stream.name: 2 for stream in dfg.nodes.streams})
cc = connected_components(dfg.graph, connection="strong")

[autoreload of src.analysis.latency.engine failed: Traceback (most recent call last):
  File "/usr/scratch/rsarkar30/miniconda/envs/INSPNet/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 257, in check
    superreload(m, reload, self.old_objects)
  File "/usr/scratch/rsarkar30/miniconda/envs/INSPNet/lib/python3.10/site-packages/IPython/extensions/autoreload.py", line 455, in superreload
    module = reload(module)
  File "/usr/scratch/rsarkar30/miniconda/envs/INSPNet/lib/python3.10/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 619, in _exec
  File "<frozen importlib._bootstrap_external>", line 879, in exec_module
  File "<frozen importlib._bootstrap_external>", line 1017, in get_code
  File "<frozen importlib._bootstrap_external>", line 947, in source_to_code
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/export/hdd/scratch/rsarkar30/INR/inr-dsp/sr

In [18]:
np.intc

numpy.int32

In [12]:
unique_component_labels, component_label_counts = np.unique(component_labels, return_counts=True)
cycles = unique_component_labels[component_label_counts > 1]
cycles

array([29776], dtype=int32)

In [15]:
node_ids, = np.nonzero(component_labels == cycles[0])

In [18]:
sum(component_labels == cycles[0])

18873402

In [None]:
# from src.analysis.latency.engine import DFGNodeTable, DFG_ROOT, DFGNode
# from typing import List, Union
# import numpy.typing as npt
# def lookup_many_reverse(self: DFGNodeTable, node_ids: npt.NDArray[np.int64]):
#     nodes: List[Union[DFGNode, None]] = [None] * len(node_ids)
#     root_idxs, = np.nonzero(node_ids == 0)
#     for i in root_idxs:
#         nodes[i] = DFG_ROOT
#     for stream, lo in self.forward_table.items():
#         hi = lo + stream.num_writes * 2
#         mask = (node_ids >= lo) & (node_ids < hi)
#         idxs, = np.nonzero(mask)