Skip to content

Commit

Permalink
change Tensor.stack to method
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed May 24, 2024
1 parent ba116ff commit dd9dd67
Show file tree
Hide file tree
Showing 15 changed files with 40 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __call__(self, x:Tensor, start_pos:Variable, mask:Optional[Tensor]) -> Tenso
self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, self.n_heads, self.head_dim, dtype=x.dtype).contiguous().realize()

# update the cache
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack([xk, xv])).realize()
self.cache_kv.shrink((None, None,(start_pos,start_pos+seqlen),None,None)).assign(Tensor.stack(xk, xv)).realize()

if start_pos > 0:
keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None))
Expand Down
4 changes: 2 additions & 2 deletions examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def NF4Linear(block_size):
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0,
]
CODE = Tensor.stack([Tensor(c) for c in _CODE])
CODE = Tensor.stack(*[Tensor(c) for c in _CODE])
class _NF4Linear:
def __init__(self, in_features, out_features, bias=False):
assert not bias, "bias not supported"
Expand All @@ -103,7 +103,7 @@ def __init__(self, in_features, out_features, bias=False):
def __call__(self, x: Tensor) -> Tensor:
high_bits = self.weight
low_bits = (self.weight * 2 ** 4).contiguous()
unpacked = Tensor.stack([high_bits, low_bits], dim=-1).div(2 ** 4, upcast=False)
unpacked = Tensor.stack(high_bits, low_bits, dim=-1).div(2 ** 4, upcast=False)
unscaled = CODE[unpacked].to(x.device).reshape(-1, block_size) * self.scale
return x.linear(unscaled.reshape(self.out_features, self.in_features).T)

Expand Down
2 changes: 1 addition & 1 deletion examples/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def selective_scan_ref(
if i == u.shape[2] - 1:
last_state = x
ys.append(y)
y = Tensor.stack(ys, dim=2) # (batch dim L)
y = Tensor.stack(*ys, dim=2) # (batch dim L)
out = y if D is None else y + u * D.reshape((-1, 1))
if z is not None:
out = out * z.silu()
Expand Down
2 changes: 1 addition & 1 deletion examples/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def forward_single_image(self, masks, boxes):
for mask, box in zip(masks, boxes.bbox)
]
if len(res) > 0:
res = torch.stack(res, dim=0)[:, None]
res = torch.stack(*res, dim=0)[:, None]
else:
res = masks.new_empty((0, 1, masks.shape[-2], masks.shape[-1]))
return Tensor(res.numpy())
Expand Down
2 changes: 1 addition & 1 deletion examples/so_vits_svc.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def repeat_expand_2d_left(content, target_len): # content : [h, t]
if i >= temp[current_pos+1]:
current_pos += 1
cols.append(content[:, current_pos])
return Tensor.stack(cols).transpose(0, 1)
return Tensor.stack(*cols).transpose(0, 1)

def load_fairseq_cfg(checkpoint_path):
assert Path(checkpoint_path).is_file()
Expand Down
4 changes: 2 additions & 2 deletions examples/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def preprocess(im, imgsz=640, model_stride=32, model_pt=True):
same_shapes = all(x.shape == im[0].shape for x in im)
auto = same_shapes and model_pt
im = Tensor([compute_transform(x, new_shape=imgsz, auto=auto, stride=model_stride) for x in im])
im = Tensor.stack(im) if im.shape[0] > 1 else im
im = Tensor.stack(*im) if im.shape[0] > 1 else im
im = im[..., ::-1].permute(0, 3, 1, 2) # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
Expand Down Expand Up @@ -180,7 +180,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
sx = sx.reshape(1, -1).repeat([h, 1]).reshape(-1)
sy = sy.reshape(-1, 1).repeat([1, w]).reshape(-1)

anchor_points.append(Tensor.stack((sx, sy), -1).reshape(-1, 2))
anchor_points.append(Tensor.stack(sx, sy, dim=-1).reshape(-1, 2))
stride_tensor.append(Tensor.full((h * w), stride))
anchor_points = anchor_points[0].cat(anchor_points[1], anchor_points[2])
stride_tensor = stride_tensor[0].cat(stride_tensor[1], stride_tensor[2]).unsqueeze(1)
Expand Down
2 changes: 1 addition & 1 deletion extra/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tenso
start_logits = start_logits.reshape(-1, 1)
end_logits = end_logits.reshape(-1, 1)

return Tensor.stack([start_logits, end_logits])
return Tensor.stack(start_logits, end_logits)

class BertForMLPerf:
def __init__(self, hidden_size:int, intermediate_size:int, max_position_embeddings:int, num_attention_heads:int, num_hidden_layers:int, type_vocab_size:int, vocab_size:int, attention_probs_dropout_prob:float, hidden_dropout_prob:float) -> None:
Expand Down
4 changes: 2 additions & 2 deletions extra/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
return Tensor.stack([freqs.cos().half(), freqs.sin().half()], dim=-1).reshape(1, end, 1, dim//2, 2)
return Tensor.stack(freqs.cos().half(), freqs.sin().half(), dim=-1).reshape(1, end, 1, dim//2, 2)

# (a+i*b) * (c+i*d) = (ac-bd) + i*(ad+bc)
def complex_mult(A, c, d):
Expand Down Expand Up @@ -72,7 +72,7 @@ def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, ma

# update the cache
assert xk.dtype == xv.dtype == self.cache_kv.dtype, f"{xk.dtype=}, {xv.dtype=}, {self.cache_kv.dtype=}"
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack([xk, xv])).realize()
self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None, None)).assign(Tensor.stack(xk, xv)).realize()

keys = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xk
values = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None, None)) if start_pos > 0 else xv
Expand Down
14 changes: 7 additions & 7 deletions extra/models/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_strides(shape):
# with keys as integer array for all axes
def tensor_getitem(tensor, *keys):
# something about ints is broken with gpu, cuda
flat_keys = Tensor.stack([key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
flat_keys = Tensor.stack(*[key.expand((sum(keys)).shape).reshape(-1) for key in keys], dim=1).cast(dtypes.int32)
strides = get_strides(tensor.shape)
idxs = (flat_keys * strides).sum(1)
gatherer = npgather if USE_NP_GATHER else _gather
Expand Down Expand Up @@ -255,7 +255,7 @@ def clip_to_image(self, remove_empty=True):
bb2 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 1]
bb3 = self.bbox.clip(min_=0, max_=self.size[0] - TO_REMOVE)[:, 2]
bb4 = self.bbox.clip(min_=0, max_=self.size[1] - TO_REMOVE)[:, 3]
self.bbox = Tensor.stack((bb1, bb2, bb3, bb4), dim=1)
self.bbox = Tensor.stack(bb1, bb2, bb3, bb4, dim=1)
if remove_empty:
box = self.bbox
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
Expand Down Expand Up @@ -394,7 +394,7 @@ def grid_anchors(self, grid_sizes):
shift_y, shift_x = meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = Tensor.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
shifts = Tensor.stack(shift_x, shift_y, shift_x, shift_y, dim=1)

anchors.append(
(shifts.reshape(-1, 1, 4) + base_anchors.reshape(1, -1, 4)).reshape(-1, 4)
Expand Down Expand Up @@ -525,7 +525,7 @@ def encode(self, reference_boxes, proposals):
targets_dw = ww * Tensor.log(gt_widths / ex_widths)
targets_dh = wh * Tensor.log(gt_heights / ex_heights)

targets = Tensor.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
targets = Tensor.stack(targets_dx, targets_dy, targets_dw, targets_dh, dim=1)
return targets

def decode(self, rel_codes, boxes):
Expand Down Expand Up @@ -556,7 +556,7 @@ def decode(self, rel_codes, boxes):
y = pred_ctr_y - 0.5 * pred_h
w = pred_ctr_x + 0.5 * pred_w - 1
h = pred_ctr_y + 0.5 * pred_h - 1
pred_boxes = Tensor.stack([x, y, w, h]).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
pred_boxes = Tensor.stack(x, y, w, h).permute(1,2,0).reshape(rel_codes.shape[0], rel_codes.shape[1])
return pred_boxes


Expand Down Expand Up @@ -632,8 +632,8 @@ def forward_for_single_feature_map(self, anchors, objectness, box_regression):
box_regression_list.append(tensor_gather(box_regression[batch_idx], topk_idx[batch_idx]))
concat_anchors_list.append(tensor_gather(concat_anchors[batch_idx], topk_idx[batch_idx]))

box_regression = Tensor.stack(box_regression_list)
concat_anchors = Tensor.stack(concat_anchors_list)
box_regression = Tensor.stack(*box_regression_list)
concat_anchors = Tensor.stack(*concat_anchors_list)

proposals = self.box_coder.decode(
box_regression.reshape(-1, 4), concat_anchors.reshape(-1, 4)
Expand Down
2 changes: 1 addition & 1 deletion extra/models/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def do_step(self, x, hc):
new_hc = [x]
for i, cell in enumerate(self.cells):
new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
return Tensor.stack(new_hc[1:]).realize()
return Tensor.stack(*new_hc[1:]).realize()


class StackTime:
Expand Down
2 changes: 1 addition & 1 deletion extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def AffineGrid(theta: Tensor, size: Tensor, align_corners=0):
if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable]
elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable]
else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable]
original_grid = Tensor.stack(stackable, dim=len(data_sz))
original_grid = Tensor.stack(*stackable, dim=len(data_sz))
if original_grid.ndim == 3:
N, dim_2d, dim_homo = theta.shape
assert dim_2d == 2 and dim_homo == 3
Expand Down
14 changes: 7 additions & 7 deletions test/imported/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ def test_index(self):

# indexing with step
reference = consec((10, 10, 10))
numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack([reference[1], reference[3]], 0))
numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack([reference[1], reference[3], reference[5]], 0))
numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack([reference[1], reference[5]], 0))
numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack([reference[2:4, 1], reference[2:4, 3]], 1))
numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
numpy_testing_assert_equal_helper(reference[1:5:2], Tensor.stack(reference[1], reference[3], dim=0))
numpy_testing_assert_equal_helper(reference[1:6:2], Tensor.stack(reference[1], reference[3], reference[5], dim=0))
numpy_testing_assert_equal_helper(reference[1:9:4], Tensor.stack(reference[1], reference[5], dim=0))
numpy_testing_assert_equal_helper(reference[2:4, 1:5:2], Tensor.stack(reference[2:4, 1], reference[2:4, 3], dim=1))
numpy_testing_assert_equal_helper(reference[3, 1:6:2], Tensor.stack(reference[3, 1], reference[3, 3], reference[3, 5], dim=0))
numpy_testing_assert_equal_helper(reference[None, 2, 1:9:4], Tensor.stack(reference[2, 1], reference[2, 5], dim=0).unsqueeze(0))
numpy_testing_assert_equal_helper(reference[:, 2, 1:6:2], Tensor.stack(reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5], dim=1))

lst = [list(range(i, i+10)) for i in range(0, 100, 10)]
tensor = Tensor(lst)
Expand Down
4 changes: 2 additions & 2 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def test_upcast_multireduce_nested_local_upcast(self):

def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
r = Tensor.stack(a, b)

k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
Expand Down Expand Up @@ -890,7 +890,7 @@ def test_masked_upcast(self):
assert k.upcasted == 1 and k.full_shape[-1] == 7

def test_masked_upcast_wino(self):
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])

s = create_schedule([monster.lazydata])[-1]
k = Linearizer(*s.ast)
Expand Down
12 changes: 6 additions & 6 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,9 +1062,9 @@ def test_pad_slice(self):
lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:])

def test_stack_slice(self):
helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:])
helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack([x for i in range(3)])[0,0])
helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3])
helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack(*[x for i in range(3)])[0,:])
helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack(*[x for i in range(3)])[0,0])
helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack(*[x for i in range(4)])[3])

def test_transpose(self):
helper_test_op([(3,3)], lambda x: x.T)
Expand Down Expand Up @@ -1554,13 +1554,13 @@ def test_multicat(self):

def test_stack(self):
for dim in range(-1, 3):
helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack([x, y, z], dim))
helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack(x, y, z, dim=dim))

with self.assertRaises(IndexError):
Tensor.stack([Tensor.randn(45, 65, 3)], dim=77)
Tensor.stack(Tensor.randn(45, 65, 3), dim=77)

a = Tensor(3.14)
np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy())
np.testing.assert_allclose(Tensor.stack(a, a).numpy(), Tensor([3.14, 3.14]).numpy())

def test_repeat(self):
x = Tensor.randn(4, 6, 3)
Expand Down
12 changes: 5 additions & 7 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,22 +977,20 @@ def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
return functools.reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])

@staticmethod
def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor:
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates a sequence of tensors along a new dimension.
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(Tensor.stack([t0, t1, t2], dim=0).numpy())
print(t0.stack(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.stack([t0, t1, t2], dim=1).numpy())
print(t0.stack([t1, t2], dim=1).numpy())
```
"""
unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors]
# checks for shapes and number of dimensions delegated to cat
return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim)
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)

def repeat(self, repeats, *args) -> Tensor:
"""
Expand Down

0 comments on commit dd9dd67

Please sign in to comment.