Skip to content

Commit

Permalink
pad fusion tests (#4570)
Browse files Browse the repository at this point in the history
* what breaks

* Revert "what breaks"

This reverts commit e79f679.

* simplest case

* one unsafe op

* expand+pad, shrink+pad

* safe case

* refactor
  • Loading branch information
Qazalin committed May 14, 2024
1 parent 7afca52 commit 355e1c1
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# NOTE: this has overlap with external_test_opt.py

import unittest
import numpy as np
from typing import List, Optional, Union
from tinygrad.engine.realize import run_schedule
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, flatten
Expand Down Expand Up @@ -775,5 +777,35 @@ def test_partial_fuse4(self):
f = (b - d).sum() - e
check_schedule([c, d, e, f], 3)

def test_pad_reduce_safe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
b = Tensor.rand(3, 4, 5).realize()
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())

def test_pad_reduce_usafe(self):
Tensor.manual_seed(0)
a = Tensor.rand(3, 4, 5).realize()
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())

def test_shrink_pad_safe(self):
a = Tensor.ones((3, )).contiguous().realize()
b = Tensor.ones((3, )).contiguous().realize()
out = (a + b).shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
np.testing.assert_equal(out.numpy(), [2, 0])

# TODO: should not shuffle unsafe pad ops through any pads, even if buffer is shrunk overall (#3437)
def test_shrink_pad_unsafe(self):
a = Tensor.ones((3, )).contiguous().realize()
out = a.exp2().shrink(((0, 1),)).pad(((0, 1),)).contiguous()
run_schedule(check_schedule(out, 1))
with self.assertRaises(AssertionError):
np.testing.assert_equal(out.numpy(), [2, 0])

if __name__ == '__main__':
unittest.main(verbosity=2)

0 comments on commit 355e1c1

Please sign in to comment.