-
Notifications
You must be signed in to change notification settings - Fork 3.6k
/
distributed_sampling_xpu.py
217 lines (171 loc) · 7.21 KB
/
distributed_sampling_xpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""Distributed GAT training, targeting XPU devices.
PVC has 2 tiles, each reports itself as a separate
device. DDP approach allows us to employ both tiles.
Additional requirements:
IPEX (intel_extension_for_pytorch)
oneCCL (oneccl_bindings_for_pytorch)
We need to import both these modules, as they extend
torch module with XPU/oneCCL related functionality.
Run with:
mpirun -np 2 python distributed_sampling_xpu.py
"""
import copy
import os
import os.path as osp
from typing import Any, Tuple, Union
import intel_extension_for_pytorch # noqa
import oneccl_bindings_for_pytorch # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch import Tensor
from torch.nn import Linear as Lin
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GATConv
from torch_geometric.profile import get_stats_summary, profileit
class GAT(torch.nn.Module):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int,
heads: int,
):
super().__init__()
self.num_layers = num_layers
self.convs = torch.nn.ModuleList()
self.convs.append(GATConv(dataset.num_features, hidden_channels,
heads))
for _ in range(num_layers - 2):
self.convs.append(
GATConv(heads * hidden_channels, hidden_channels, heads))
self.convs.append(
GATConv(heads * hidden_channels, out_channels, heads,
concat=False))
self.skips = torch.nn.ModuleList()
self.skips.append(Lin(dataset.num_features, hidden_channels * heads))
for _ in range(num_layers - 2):
self.skips.append(
Lin(hidden_channels * heads, hidden_channels * heads))
self.skips.append(Lin(hidden_channels * heads, out_channels))
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
for i, (conv, skip) in enumerate(zip(self.convs, self.skips)):
x = conv(x, edge_index) + skip(x)
if i != self.num_layers - 1:
x = F.elu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x
def inference(
self,
x_all: Tensor,
device: Union[str, torch.device],
subgraph_loader: NeighborLoader,
) -> Tensor:
pbar = tqdm(total=x_all.size(0) * self.num_layers)
pbar.set_description("Evaluating")
# Compute representations of nodes layer by layer, using *all*
# available edges. This leads to faster computation in contrast to
# immediately computing the final representations of each batch.
for i in range(self.num_layers):
xs = []
for batch in subgraph_loader:
x = x_all[batch.n_id].to(device)
edge_index = batch.edge_index.to(device)
x = self.convs[i](x, edge_index) + self.skips[i](x)
x = x[:batch.batch_size]
if i != self.num_layers - 1:
x = F.elu(x)
xs.append(x.cpu())
pbar.update(batch.batch_size)
x_all = torch.cat(xs, dim=0)
pbar.close()
return x_all
@profileit('xpu')
def train_step(model: Any, optimizer: Any, x: Tensor, edge_index: Tensor,
y: Tensor, bs: int) -> float:
optimizer.zero_grad()
out = model(x, edge_index)[:bs]
loss = F.cross_entropy(out, y[:bs].squeeze())
loss.backward()
optimizer.step()
return float(loss)
def run(rank: int, world_size: int, dataset: PygNodePropPredDataset):
device = f"xpu:{rank}"
split_idx = dataset.get_idx_split()
split_idx["train"] = (split_idx["train"].split(
split_idx["train"].size(0) // world_size, dim=0)[rank].clone())
data = dataset[0].to(device, "x", "y")
kwargs = dict(batch_size=1024, num_workers=0, pin_memory=True)
train_loader = NeighborLoader(data, input_nodes=split_idx["train"],
num_neighbors=[10, 10, 5], **kwargs)
if rank == 0:
subgraph_loader = NeighborLoader(copy.copy(data), num_neighbors=[-1],
**kwargs)
evaluator = Evaluator(name="ogbn-products")
torch.manual_seed(12345)
model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3,
heads=4).to(device)
model = DDP(model, device_ids=[device])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 21):
epoch_stats = []
model.train()
for batch in train_loader:
batch = batch.to(device)
loss, stats = train_step(model, optimizer, batch.x,
batch.edge_index, batch.y,
batch.batch_size)
epoch_stats.append(stats)
dist.barrier()
if rank == 0:
print(f"Epoch: {epoch:02d}, Loss: {loss:.4f}")
print(f"Epoch: {epoch:02d}, Rank: {rank}, "
f"Stats: {get_stats_summary(epoch_stats)}")
if rank == 0 and epoch % 5 == 0: # Evaluation on a single GPU
model.eval()
with torch.no_grad():
out = model.module.inference(data.x, device, subgraph_loader)
y_true = data.y.to(out.device)
y_pred = out.argmax(dim=-1, keepdim=True)
train_acc = evaluator.eval({
"y_true": y_true[split_idx["train"]],
"y_pred": y_pred[split_idx["train"]],
})["acc"]
val_acc = evaluator.eval({
"y_true": y_true[split_idx["valid"]],
"y_pred": y_pred[split_idx["valid"]],
})["acc"]
test_acc = evaluator.eval({
"y_true": y_true[split_idx["test"]],
"y_pred": y_pred[split_idx["test"]],
})["acc"]
print(f"Train: {train_acc:.4f}, Val: {val_acc:.4f}, "
f"Test: {test_acc:.4f}")
dist.barrier()
dist.destroy_process_group()
def get_dist_params() -> Tuple[int, int, str]:
master_addr = "127.0.0.1"
master_port = "29500"
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
mpi_rank = int(os.environ.get("PMI_RANK", -1))
mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0)
world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get(
"WORLD_SIZE", 1))
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
init_method = f"tcp://{master_addr}:{master_port}"
return rank, world_size, init_method
if __name__ == "__main__":
rank, world_size, init_method = get_dist_params()
dist.init_process_group(backend="ccl", init_method=init_method,
world_size=world_size, rank=rank)
path = osp.join(osp.dirname(osp.realpath(__file__)), "../../data",
"ogbn-products")
dataset = PygNodePropPredDataset("ogbn-products", path)
run(rank, world_size, dataset)