Skip to content
This repository was archived by the owner on Nov 15, 2022. It is now read-only.

Commit 221f332

Browse files
authored
C++ _jit_tensorwise - first prototype of C++ implementation of tensorwise (#22)
1 parent 3d2c3a8 commit 221f332

File tree

12 files changed

+350
-80
lines changed

12 files changed

+350
-80
lines changed

benchmarks/jit_tensorwise.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import nestedtensor
3+
import utils
4+
import time
5+
6+
7+
@nestedtensor._C.jit_tensorwise()
8+
@torch.jit.script
9+
def f(i, w):
10+
return torch.conv2d(i, w)
11+
12+
def loop_f(inp1, w):
13+
for inp in inp1:
14+
torch.conv2d(inp, w)
15+
16+
17+
if __name__ == "__main__":
18+
w = torch.randn(64, 3, 9, 9).cuda()
19+
inp1 = list(torch.randn(128, 1, 3, 16, 16).cuda().unbind())
20+
inp3 = nestedtensor.as_nested_tensor(inp1)._impl
21+
# print(sum(inp.numel() for inp in inp1))
22+
# print(inp3.numel())
23+
24+
fc = nestedtensor._C.jit_tensorwise()(torch.conv2d)
25+
26+
t0 = time.time()
27+
count = 0
28+
while(time.time() - t0 < 5.0):
29+
r2 = fc(inp3, w)
30+
torch.cuda.synchronize()
31+
count += 1
32+
print("jit: " + str(count))
33+
34+
t0 = time.time()
35+
count = 0
36+
while(time.time() - t0 < 5.0):
37+
loop_f(inp1, w)
38+
torch.cuda.synchronize()
39+
count += 1
40+
print("for loop: " + str(count))
41+
42+
43+
# print(r.nested_size())
44+
45+
# na = nestedtensor._C.jit_tensorwise()(torch.mul)
46+
47+
# print("111")
48+
# out = nestedtensor.as_nested_tensor([torch.randn(1, 2)])
49+
# print(na(
50+
# nestedtensor.as_nested_tensor([torch.randn(1, 2)])._impl,
51+
# 4.0,
52+
# ))
53+
# print("222")
54+
# print('out')
55+
# print(out)
56+
57+
# nv = nestedtensor._C.jit_tensorwise()(torch.mv)
58+
# print(nv(
59+
# nestedtensor._C._ListNestedTensor([torch.randn(1, 2)]),
60+
# nestedtensor._C._ListNestedTensor([torch.randn(2)]),
61+
# ))
62+
63+
# print("333")
64+
# print(na(
65+
# torch.randn(1, 2),
66+
# torch.randn(1, 2),
67+
# ))
68+
# print("444")

benchmarks/nearest_neighbors.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from nestedtensor import torch
21
import nestedtensor
2+
import torch
33
import argparse
44
import time
55
import random
66
import pprint
77

8-
EMBED_DIM = 1024
8+
EMBED_DIM = 128
99

1010
SEED = 0
1111

@@ -60,8 +60,8 @@ def gen_algorithm_nested_mv(keys, sub_clusters):
6060
for sub_cluster in sub_clusters:
6161
new_sub_cluster = [torch.tensor(list(map(list, cluster))) for cluster in sub_cluster]
6262
new_sub_clusters.append(new_sub_cluster)
63-
nested_sub_clusters = torch.nested_tensor(sub_clusters).to_tensor(2)
64-
nested_keys = torch.nested_tensor(keys)
63+
nested_sub_clusters = nestedtensor.nested_tensor(sub_clusters).to_tensor(2)
64+
nested_keys = nestedtensor.nested_tensor(keys)
6565
def _nested_mv():
6666
return torch.mv(nested_sub_clusters, nested_keys)
6767
return _nested_mv
@@ -74,18 +74,16 @@ def gen_algorithm_nested_jit_mv(keys, sub_clusters):
7474
for cluster in sub_cluster:
7575
new_sub_cluster.append(torch.stack(cluster))
7676
new_sub_clusters.append(new_sub_cluster)
77-
nested_sub_clusters = nestedtensor._ListNestedTensor(new_sub_clusters)
78-
print("HERE")
79-
print(nested_sub_clusters.nested_size())
80-
nested_keys = nestedtensor._ListNestedTensor(keys)
81-
print(nested_keys.nested_size())
77+
nested_sub_clusters = nestedtensor.as_nested_tensor(new_sub_clusters)
78+
nested_keys = nestedtensor.as_nested_tensor(keys)
8279

80+
@nestedtensor._C.jit_tensorwise()
8381
@torch.jit.script
8482
def my_fun(x, y):
8583
return torch.mv(x, y)
8684

8785
def _nested_jit_mv():
88-
return nestedtensor._C.jit_apply_function((nested_sub_clusters, nested_keys), my_fun)
86+
return my_fun(nested_sub_clusters, nested_keys)
8987
return _nested_jit_mv
9088

9189

@@ -139,12 +137,12 @@ def benchmark_fn(fn, run_time = 15.0):
139137
gen_results_naive = gen_algorithm_naive(keys, sub_clusters)
140138
gen_results_mv = gen_algorithm_mv(keys, sub_clusters)
141139
gen_results_nested_mv = gen_algorithm_nested_mv(keys, sub_clusters)
142-
gen_results_nested_jit_mv = gen_algorithm_nested_jit_mv(keys, sub_clusters)
140+
# gen_results_nested_jit_mv = gen_algorithm_nested_jit_mv(keys, sub_clusters)
143141

144-
# print(benchmark_fn(gen_results_naive))
145-
# print(benchmark_fn(gen_results_mv))
146-
# print(benchmark_fn(gen_results_nested_mv))
147-
print(benchmark_fn(gen_results_nested_jit_mv))
142+
print(benchmark_fn(gen_results_nested_mv))
143+
print(benchmark_fn(gen_results_naive))
144+
print(benchmark_fn(gen_results_mv))
145+
# print(benchmark_fn(gen_results_nested_jit_mv))
148146
# import cProfile, pstats, io
149147
# pr = cProfile.Profile()
150148
# pr.enable()

nestedtensor/csrc/buffer_nested_tensor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ std::pair<int64_t, TensorNode> _build_structure(
7171
for (size_t i = 0; i < nested_size.degree(); i++) {
7272
std::pair<int64_t, TensorNode> result_i = _build_structure(
7373
index, buffers, nested_size.children(i), nested_stride.children(i));
74+
index = std::get<0>(result_i);
7475
result.push_back(std::get<1>(result_i));
75-
index++;
7676
}
7777
return std::pair<int64_t, TensorNode>(index, TensorNode(result));
7878
}

nestedtensor/csrc/buffer_nested_tensor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,11 @@ struct TORCH_API _BufferNestedTensor {
125125
new_size.push_back(start->degree());
126126
start = start->children_data(0);
127127
}
128-
for (size_t i = 0; i < start->payload(0).size(); i++) {
129-
new_size.push_back(start->payload(0)[i]);
128+
new_size.push_back(start->size());
129+
if (start->size() > 0) {
130+
for (size_t i = 0; i < start->payload(0).size(); i++) {
131+
new_size.push_back(start->payload(0)[i]);
132+
}
130133
}
131134
return _buffer.reshape(at::IntArrayRef(new_size));
132135
}

0 commit comments

Comments
 (0)