In [119]:
import flatbuffers
import numpy as np
import FlatBufTaskGraph.TaskGraph

In [120]:
buf = []
with open("/home/weiyangw/Documents/taskgraph.fbuf", "rb") as f:
    buf = f.read()

In [121]:
tg = FlatBufTaskGraph.TaskGraph.TaskGraph.GetRootAs(buf, 0)

In [122]:
def get_all_routes(tg):
    result = {}
    for i in range(tg.RoutesLength()):
        r = tg.Routes(i)
        src, dst = r.Fromnode(), r.Tonode()
        for j in range(r.PathsLength()):
            p = r.Paths(j)
            path = [p.Hopnode(k) for k in range(p.HopnodeLength())]
            result[(src,dst)] = path
    return result
        

In [123]:
DEVICE_COMP_GPU = 0
DEVICE_COMP_CPU = 1
DEVICE_MEM_SYSTEM = 2
DEVICE_MEM_Z_COPY = 3
DEVICE_MEM_GPU_FB = 4
DEVICE_COMM_MEMBUS_COMM = 5
DEVICE_COMM_UPI_IN_COMM = 6
DEVICE_COMM_UPI_OUT_COMM = 7
DEVICE_COMM_NIC_IN_COMM = 8
DEVICE_COMM_NIC_OUT_COMM = 9
DEVICE_COMM_PCI_TO_HOST_COMM = 10
DEVICE_COMM_PCI_TO_DEV_COMM = 11
DEVICE_COMM_NVLINK_COMM = 12
DEVICE_COMM_NW_COMM = 13
DEVICE_COMM_NW_NOMINAL = 14
def get_dev(tg):
    result = {}
    for i in range(tg.DevicesLength()):
        d = tg.Devices(i)
        result[d.Deviceid()] = (d.Nodeid(), d.Type(), d.Deviceproperty(), d.Bandwidth())
    return result

def get_nwnominal(tg, devices):
    nnodes = tg.Nnode()
    nominal_comms = {}
    for devid, dev in devices.items():
        if dev[1] == DEVICE_COMM_NW_NOMINAL:
            src, dst = dev[2] // nnodes, dev[2] % nnodes
            nominal_comms[devid] = (src, dst)
    return nominal_comms

In [124]:
TASK_FORWARD = 0
TASK_BACKWARD = 1
TASK_COMM = 2
TASK_UPDATE = 3
TASK_BARRIER = 4
TASK_NOMINAL_COMM = 5
TASK_ALLREDUCE = 6
def get_tasks(tg):
    result = {}
    for i in range(tg.TasksLength()):
        t = tg.Tasks(i)
        nexttasks = [t.Nexttasks(j) for j in range(t.NexttasksLength())]
        result[t.Taskid()] = (t.Deviceid(), t.Runtime(), t.Xfersize(), t.Type(), nexttasks)
    return result

In [125]:
def get_ring(tg):
    rings = {}
    for i in range(tg.RingsLength()):
        r = tg.Rings(i)
        rsz = r.Ringsz()
        rings[rsz] = [list(r.Ringpaths(j).JumpsAsNumpy()) for j in range(r.RingpathsLength())]
    return rings

In [126]:
def get_logical_tm(tasks, rings, nom_devs, tg):
    nnodes = tg.Nnode()
    result = {} # (src, dst) -> traffic
    tot = 0
    totr = 0
    for tid, t in tasks.items():
        if t[3] == TASK_NOMINAL_COMM:
            ndev = nom_devs[t[0]]
            # print(ndev, t[2])
            if ndev in result:
                result[ndev] += t[2]
            else:
                result[ndev] = t[2]
            tot += t[2]
        elif t[3] == TASK_ALLREDUCE:
            ringsz = len(t[-1])
            if ringsz == 1: 
                continue
            nsplit = len(rings[ringsz])
            totr += 2 * (len(t[-1]) - 1) * t[2]
            xfersize = 2 * (len(t[-1]) - 1) * t[2] / nsplit / len(t[-1])
            curr_node = t[-1][0]
            for rdesc in rings[ringsz]:
                total_hop = sum(rdesc)
                for j in range(len(t[-1])):
                    ndev = (curr_node, (curr_node + total_hop) % nnodes)
                    if ndev in result:
                        result[ndev] += xfersize
                    else:
                        result[ndev] = xfersize
                    curr_node = (curr_node + total_hop) % nnodes
    print(totr, tot)
    return result
            

In [127]:
def get_physical_tm(tasks, rings, nom_devs, routes, tg):
    nnodes = tg.Nnode()
    result = {}
    for tid, t in tasks.items():
        if t[3] == TASK_NOMINAL_COMM:
            ndev = nom_devs[t[0]]
            path = routes[ndev]
            for j in range(len(path) - 1):
                npdev = (path[j], path[(j + 1)])
                # print(npdev)
                if npdev in result:
                    result[npdev] += t[2]
                    # print(t, npdev, result[npdev])
                else:
                    result[npdev] = t[2]
        elif t[3] == TASK_ALLREDUCE:
            ringsz = len(t[-1])
            if ringsz == 1: 
                continue
            nsplit = len(rings[ringsz])
            xfersize = 2 * (len(t[-1]) - 1) * t[2] / nsplit / len(t[-1])
            # print(t, nsplit, xfersize)
            curr_node = t[-1][0]
            for rdesc in rings[ringsz]:
                # print(rdesc)print(t[2])
                for j in range(len(t[-1])):
                    for l in rdesc:
                        npdev = (curr_node, (curr_node + l) % nnodes)
                        # print(l, npdev)
                        if npdev in result:
                            result[npdev] += xfersize
                        else:
                            result[npdev] = xfersize
                        curr_node = (curr_node + l) % nnodes
    return result

In [128]:
def get_hop_to_traffic(tasks, rings, nom_devs, routes, tg):
    nnodes = tg.Nnode()
    result = {}
    for tid, t in tasks.items():
        if t[3] == TASK_NOMINAL_COMM:
            ndev = nom_devs[t[0]]
            path = routes[ndev]
            pathlen = len(path) - 1
            if pathlen in result:
                result[pathlen] += t[2]
            else:
                result[pathlen] = t[2]
        elif t[3] == TASK_ALLREDUCE:
            ringsz = len(t[-1])
            if ringsz == 1: 
                continue
            nsplit = len(rings[ringsz])
            xfersize = 2 * (len(t[-1]) - 1) * t[2] / nsplit / len(t[-1])
            # print(xfersize)
            for rdesc in rings[ringsz]:
                hoplen = len(rdesc)
                for j in range(len(t[-1])):
                    if hoplen in result:
                        result[hoplen] += xfersize
                    else:
                        result[hoplen] = xfersize
    return result

In [129]:
tasks = get_tasks(tg)
rings = get_ring(tg)
devs = get_dev(tg)
ndevs = get_nwnominal(tg, devs)
routes = get_all_routes(tg)
r = get_hop_to_traffic(tasks, rings, ndevs, routes, tg)
a = sum([k * v for k, v in r.items() if k != 1]) / r[1]
print(a)

0.2889866765906471


In [130]:
tasks = get_tasks(tg)
rings = get_ring(tg)
devs = get_dev(tg)
ndevs = get_nwnominal(tg, devs)
routes = get_all_routes(tg)
# ptm = get_physical_tm(tasks, rings, ndevs,routes , tg)
ltm = get_logical_tm(tasks, rings, ndevs , tg)
# print(ptm)
# print(max([v for v in get_physical_tm(tasks, rings, ndevs,routes , tg).values() if v != 0]), min([v for v in get_physical_tm(tasks, rings, ndevs,routes , tg).values() if v != 0]))
# print(max([v for v in get_logical_tm(tasks, rings, ndevs, tg).values() if v != 0]), min([v for v in get_logical_tm(tasks, rings, ndevs , tg).values() if v != 0]))

345255592952 34091302912


In [252]:
tasks = get_tasks(tg)
rings = get_ring(tg)
devs = get_dev(tg)
ndevs = get_nwnominal(tg, devs)
routes = get_all_routes(tg)
get_hop_to_traffic(tasks, rings, ndevs, routes, tg)

{1: 353845527544.0,
 5: 6442450944,
 4: 36507222016,
 3: 53687091200,
 2: 31138512896}

In [194]:
routes

{(0, 1): [0, 1],
 (0, 2): [0, 1, 2],
 (0, 3): [0, 1, 2, 3],
 (0, 4): [0, 1, 2, 3, 4],
 (0, 5): [0, 1, 2, 3, 4, 5],
 (0, 6): [0, 1, 2, 3, 4, 5, 6],
 (0, 7): [0, 1, 2, 3, 4, 5, 6, 7],
 (0, 8): [0, 53, 106, 31, 84, 9, 8],
 (0, 9): [0, 53, 106, 31, 84, 9],
 (0, 10): [0, 1, 54, 107, 32, 85, 10],
 (0, 11): [0, 1, 54, 107, 32, 85, 10, 11],
 (0, 12): [0, 75, 22, 97, 44, 119, 118, 65, 12],
 (0, 13): [0, 75, 22, 97, 44, 119, 66, 13],
 (0, 14): [0, 75, 22, 97, 44, 119, 66, 13, 14],
 (0, 15): [0, 127, 126, 125, 124, 123, 122, 121, 68, 15],
 (0, 16): [0, 127, 126, 125, 124, 123, 70, 17, 16],
 (0, 17): [0, 127, 126, 125, 124, 123, 70, 17],
 (0, 18): [0, 127, 126, 125, 124, 71, 18],
 (0, 19): [0, 127, 74, 73, 20, 19],
 (0, 20): [0, 127, 74, 73, 20],
 (0, 21): [0, 127, 74, 21],
 (0, 22): [0, 75, 22],
 (0, 23): [0, 75, 76, 23],
 (0, 24): [0, 75, 76, 77, 24],
 (0, 25): [0, 75, 76, 77, 78, 25],
 (0, 26): [0, 75, 76, 77, 78, 79, 26],
 (0, 27): [0, 127, 126, 125, 50, 103, 102, 27],
 (0, 28): [0, 127, 126, 