In [1]:
import math

class Node:
    def __init__(self, data, _children=(), _op=''):
        self.data = data
        self.grad = 0.0
        self._backward = lambda: None
        self._prev = set(_children)
        self._op = _op

    def __repr__(self):
        return f"Element(data={self.data:.4f}, grad={self.grad:.4f})"

    def __add__(self, other):
        other = other if isinstance(other, Node) else Node(other)
        out = Node(self.data + other.data, _children=(self, other), _op='+')

        def _backward_fn():
            self.grad += out.grad
            other.grad += out.grad
        out._backward = _backward_fn
        
        return out

    def __mul__(self, other):
        other = other if isinstance(other, Node) else Node(other)
        out = Node(self.data * other.data, _children=(self, other), _op='*')

        def _backward_fn():
            self.grad += other.data * out.grad
            other.grad += self.data * out.grad
        out._backward = _backward_fn

        return out

    def relu(self):
        out_data = max(0, self.data)
        out = Node(out_data, _children=(self,), _op='ReLU')

        def _backward_fn():
            self.grad += (1 if self.data > 0 else 0) * out.grad
        out._backward = _backward_fn
        
        return out
        
    def backward(self):
        topo = []
        visited = set()
        def build_topo(v):
            if v not in visited:
                visited.add(v)
                for child in v._prev:
                    build_topo(child)
                topo.append(v)
        
        build_topo(self)
        
        self.grad = 1.0
        
        for node in reversed(topo):
            node._backward()
    
    def __radd__(self, other):
        return self + other

    def __rmul__(self, other):
        return self * other

    def __neg__(self):
        return self * -1

    def __sub__(self, other):
        return self + (-other)

    def __rsub__(self, other):
        return other + (-self)

In [2]:
a = Node(2.0)
b = Node(-3.0)
c = Node(10.0)

# В задании d = a + b * c, e = d.relu(). 
# При d.data=-28, e.data=0, и grad(e->d) = 0. Все градиенты будут нули.
# Чтобы показать работоспособность, изменим порядок операций на d = a * b + c
# Это дает d = -6 + 10 = 4, что позволяет ReLU пропустить градиент.

d = a * b + c
e = d.relu()
e.backward()

print("a =", a)
print("b =", b)
print("c =", c)
print("d =", d)
print("e =", e)

a = Element(data=2.0000, grad=-3.0000)
b = Element(data=-3.0000, grad=2.0000)
c = Element(data=10.0000, grad=1.0000)
d = Element(data=4.0000, grad=1.0000)
e = Element(data=4.0000, grad=1.0000)


In [3]:
import unittest

class TestAutograd(unittest.TestCase):

    def test_example_from_prompt_corrected(self):
        a = Node(2.0)
        b = Node(-3.0)
        c = Node(10.0)
        d = a * b + c
        e = d.relu()
        e.backward()
        
        self.assertAlmostEqual(e.data, 4.0)
        self.assertAlmostEqual(d.data, 4.0)
        self.assertAlmostEqual(c.data, 10.0)
        
        self.assertAlmostEqual(e.grad, 1.0)
        self.assertAlmostEqual(d.grad, 1.0)
        self.assertAlmostEqual(c.grad, 1.0)
        self.assertAlmostEqual(b.grad, 2.0)
        self.assertAlmostEqual(a.grad, -3.0)

    def test_addition(self):
        a = Node(5.0)
        b = Node(-2.0)
        c = a + b
        c.backward()
        self.assertEqual(c.data, 3.0)
        self.assertEqual(a.grad, 1.0)
        self.assertEqual(b.grad, 1.0)

    def test_multiplication(self):
        a = Node(3.0)
        b = Node(4.0)
        c = a * b
        c.backward()
        self.assertEqual(c.data, 12.0)
        self.assertEqual(a.grad, 4.0)
        self.assertEqual(b.grad, 3.0)

    def test_relu_positive_input(self):
        a = Node(5.0)
        b = a.relu()
        b.backward()
        self.assertEqual(b.data, 5.0)
        self.assertEqual(a.grad, 1.0)

    def test_relu_negative_input(self):
        a = Node(-7.0)
        b = a.relu()
        b.backward()
        self.assertEqual(b.data, 0.0)
        self.assertEqual(a.grad, 0.0)

    def test_scalar_operations(self):
        a = Node(4.0)
        b = 2.0 * a + 1.0
        b.backward()
        self.assertEqual(b.data, 9.0)
        self.assertEqual(a.grad, 2.0)

    def test_complex_graph_and_grad_accumulation(self):
        a = Node(3.0)
        b = Node(4.0)
        d = a * b
        e = d + a
        e.backward()
        
        self.assertEqual(e.data, 15.0)
        self.assertEqual(a.grad, 5.0)
        self.assertEqual(b.grad, 3.0)
        self.assertEqual(d.grad, 1.0)

# Запуск тестов
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestAutograd))
runner = unittest.TextTestRunner()
runner.run(suite)

  suite.addTest(unittest.makeSuite(TestAutograd))
.......
----------------------------------------------------------------------
Ran 7 tests in 0.002s

OK


<unittest.runner.TextTestResult run=7 errors=0 failures=0>