In [1]:
import cxgnncomp as cxgc
import numpy as np
import torch
from os import path

dset = "friendster"
infeat = 384
num_device = 4

total_num_node = int(
    open(path.join(f"../../../../data/{dset}/processed", "num_nodes.txt")).readline())
total_num_node = (total_num_node + num_device - 1) // num_device * num_device
assert total_num_node % 4 == 0

batches = []

for i in range(num_device):
    feat, ptr, idx, b = cxgc.prepare_graph(
        dset=dset,
        feat_len=infeat,
        num_head=1,
        num_seeds=1000,
        is_full_graph=0,
        need_edge_index=False,
        device=i)
    feat = None
    batches.append(cxgc.Batch(x=None, ptr=ptr, idx=idx, num_node_in_layer=b["num_node_in_layer"]))
    batches[-1].ptrs = [batches[-1].ptr.to(dev) for dev in range(num_device)]
    batches[-1].idxs = [batches[-1].idx.to(dev) for dev in range(num_device)]
    batches[-1].sub_to_fulls = [b["sub_to_full"].to(dev) for dev in range(num_device)]
local_feats = [torch.randn(total_num_node, infeat // num_device, device=i) for i in range(num_device)]


  from .autonotebook import tqdm as notebook_tqdm


Tuner lazy True
Loading cache table
dataset prepared torch.Size([157659]) torch.Size([1595972]) torch.Size([1425008, 384]) tensor([   1000,   11755,  157658, 1425008])
dataset prepared torch.Size([162129]) torch.Size([1640952]) torch.Size([1461853, 384]) tensor([   1000,   12107,  162128, 1461853])
dataset prepared torch.Size([151527]) torch.Size([1531929]) torch.Size([1372459, 384]) tensor([   1000,   11375,  151526, 1372459])
dataset prepared torch.Size([146311]) torch.Size([1478264]) torch.Size([1324591, 384]) tensor([   1000,   11028,  146310, 1324591])


In [2]:
import cxgnncomp_backend

class Trainer():
    def __init__(self, num_device, total_num_node, infeat, local_feats, type="ddp"):
        self.num_device = num_device
        self.total_num_node = total_num_node
        self.infeat = infeat
        self.type = type
        if type == "ddp":
            self.local_starts = [i * total_num_node // num_device for i in range(num_device)]
            self.local_ends = [(i + 1) * total_num_node // num_device for i in range(num_device)]
            self.local_feats = [item.reshape(-1, infeat) for item in local_feats]
            # self.local_feats = [torch.randn(self.local_ends[i] - self.local_starts[i], infeat, device=i) for i in range(num_device)]
            self.convs = []
        elif type in ["tp", "opt"]:
            # self.local_feats = [torch.randn(total_num_node // num_device, infeat, device=i) for i in range(num_device)]
            self.local_starts = [i * infeat // num_device for i in range(num_device)]
            self.local_ends = [(i + 1) * infeat // num_device for i in range(num_device)]
            # self.local_feats = [torch.randn(self.total_num_node, self.local_ends[i] - self.local_starts[i], device=i) for i in range(num_device)]
            self.local_feats = [item.reshape(-1, infeat // num_device) for item in local_feats]
            # self.weights = [torch.randn(infeat, infeat, device=i) for i in range(num_device)]
    
    def set_hidden(self, hidden):
        self.hidden = hidden
        if self.type in ["ddp", "opt"]:
            self.weights = [torch.randn(self.infeat, hidden, device=i) for i in range(self.num_device)]
        elif self.type == "tp":
            self.weights = [torch.randn(self.infeat // self.num_device, hidden, device=i) for i in range(self.num_device)]

    def generate_x(self, batches):
        if self.type == "ddp":
            self.generate_x_ddp(batches)
        elif self.type == "tp":
            self.generate_x_tp(batches)
        elif self.type == "opt":
            self.generate_x_opt(batches)
        else:
            assert False

    def generate_x_ddp(self, batches):
        for i, batch in enumerate(batches):
            feats = []
            for j in range(self.num_device):
                # j -> i
                torch.cuda.set_device(j)
                sub_to_full = batch.sub_to_fulls[j]
                needed = sub_to_full[torch.logical_and(sub_to_full >= self.local_starts[j], sub_to_full < self.local_ends[j])]
                feats.append(self.local_feats[j][needed - self.local_starts[j]].to(i))
            torch.cuda.set_device(i)
            batch.x = torch.cat(feats, dim=0)

        for i, batch in enumerate(batches):
            torch.cuda.set_device(i)
            if self.infeat > self.hidden:
                x = torch.mm(batch.x, self.weights[i])
                batch.x = cxgnncomp_backend.sage_sum_forward(
                    x,
                    batch.ptrs[i],
                    batch.idxs[i],
                    batch.num_node_in_layer[-2]
                )
            else:
                x = cxgnncomp_backend.sage_sum_forward(
                    batch.x,
                    batch.ptrs[i],
                    batch.idxs[i],
                    batch.num_node_in_layer[-2]
                )
                batch.x = torch.mm(x, self.weights[i])

        

    def generate_x_tp(self, batches):
        outputs = []
        for tar_it in range(self.num_device):
            arr_node_feat = []
            for dev_it in range(self.num_device):
                torch.cuda.set_device(dev_it)
                out = torch.index_select(
                    self.local_feats[dev_it],
                    dim=0,
                    index=batches[tar_it].sub_to_fulls[dev_it]
                )
                out = cxgnncomp_backend.sage_sum_forward(
                    out,
                    batches[tar_it].ptrs[dev_it],
                    batches[tar_it].idxs[dev_it],
                    batches[tar_it].num_node_in_layer[-2]
                )
                out = torch.mm(out, self.weights[dev_it])
                # arr_node_feat[tar_it][dev_it] = out
                arr_node_feat.append(out)
            batches[tar_it].x = arr_node_feat[tar_it]
            for dev_it in range(self.num_device):
                if dev_it == tar_it:
                    continue
                batches[tar_it].x += arr_node_feat[dev_it].to(tar_it)

    def generate_x_opt(self, batches):
        for tar_it in range(self.num_device):
            arr_node_feat = []
            for dev_it in range(self.num_device):
                torch.cuda.set_device(dev_it)
                out = torch.index_select(
                    self.local_feats[dev_it],
                    dim=0,
                    index=batches[tar_it].sub_to_fulls[dev_it]
                )
                out = cxgnncomp_backend.sage_sum_forward(
                    out,
                    batches[tar_it].ptrs[dev_it],
                    batches[tar_it].idxs[dev_it],
                    batches[tar_it].num_node_in_layer[-2]
                )
                # out = torch.mm(out, self.weights[dev_it])
                arr_node_feat.append(out)
            # batches[tar_it].x = arr_node_feat[tar_it]
            collect_feat = [] 
            for dev_it in range(self.num_device):
                collect_feat.append(arr_node_feat[dev_it].to(tar_it))
                # batches[tar_it].x += arr_node_feat[dev_it].to(tar_it)
            batches[tar_it].x = torch.cat(collect_feat, dim=1)
            batches[tar_it].x = torch.mm(batches[tar_it].x, self.weights[tar_it])

In [3]:

for type in ["opt", "ddp", "tp"]:
    hidden = 32
    print(type)
    trainer = Trainer(num_device, total_num_node, infeat, local_feats, type=type)
    # trainer.generate_x(batches)
    # print("dgl\tp3\tNollie")
    while hidden <= 1024:
        trainer.set_hidden(hidden)
        t = cxgc.prof("dgl", "ddp", lambda: trainer.generate_x(batches), display=False)[0]
        print(t)
        if hidden == 1024:
            print(batches[0].x)
        # P3:
        # dgl = num_src * infeat * (num_device - 1) / num_device
        # p3 = num_dst * hidden * (num_device - 1)
        # our = num_dst * min(hidden * (num_device - 1), infeat * (num_device - 1) / num_device) 
        # print(f"{dgl}\t{p3}\t{our}\t{hidden}")
        hidden *= 2

opt
(37.94892883300781, 37.94432067871094, 37.95353698730469)
(38.2740478515625, 38.265445709228516, 38.282649993896484)
(39.10041427612305, 39.08628463745117, 39.11454391479492)
(40.94464111328125, 40.83773422241211, 41.051544189453125)
(45.428733825683594, 45.319374084472656, 45.5380973815918)
(58.4007682800293, 58.4007682800293, 58.4007682800293)
tensor([[ -36.2889,   61.6108,   32.8701,  ...,   -1.5449,  -41.2632,
           -8.6654],
        [  18.5453,  -53.3738,   33.6478,  ...,  -70.9697,   46.8874,
          -68.3772],
        [ -17.5429,  -15.9698,  -30.8200,  ...,  -37.9601,  -12.1896,
           -5.4655],
        ...,
        [ 116.5622,  -53.3625,  -10.4726,  ...,   29.2781,   65.5850,
          -49.5062],
        [ 122.8694, -110.2708, -115.7789,  ...,   53.3824,   58.5736,
           12.2467],
        [  33.4942,   16.0401,   14.8758,  ...,   10.3647,    3.9851,
          142.4099]], device='cuda:0')
ddp
(270.7476501464844, 270.7476501464844, 270.7476501464844)
(272.6369