In [20]:
from __future__ import annotations

import numpy as np

from typing import Tuple, Union, List, Optional, Callable



In [145]:
class Tensor:
    def __init__(self, data, device: str = "cpu", requires_grad: bool = False):
        self._data = data
        self._device = device
        self._requires_grad = requires_grad
        self._grad = None
        self._ctx = None

    def __repr__(self) -> str:
        return f"tensor({self._data}, shape={self.shape}, device={self._device}, ctx={self._ctx})"

    def __add__(self, other: Tensor) -> Tensor:
        return Add.apply(self, other)

    def __matmul__(self, other: Tensor) -> Tensor:
        return Matmul.apply(self, other)

    def __neg__(self) -> Tensor:
        return -1 * self

    @property
    def shape(self) -> Tuple[int, ...]:
        return self._data.shape
    
    @property
    def requires_grad(self) -> bool:
        return self._requires_grad

    def backward(self) -> None:
        """"""
        if self._ctx is None: return

        if self._grad is None:
            assert np.prod(self.shape) == 1, \
                "You are trying to do backward pass on unreduced tensor, this is not good..."

        order = []
        visited = set()
        def _topological_sort(v):
            if v not in visited:
                visited.add(v)
                if v._ctx is not None:
                    for child in v._ctx._children:
                        _topological_sort(child)
                order.append(v)

        _topological_sort(self)
        self._grad = np.ones(self.shape)
        
        for t in reversed(order):
            assert t._grad is not None, "need gradient!"
            children = t._ctx._children
            grads = t._ctx.grad_fn(t._grad)
            print("="*50)
            print(t)
            print(grads)
            if not isinstance(grads, (tuple, list)):
                grads = list(grads)
            if not isinstance(children, (tuple, list)):
                children = list(children)
            for grad, child in zip(grads, children):
                print(grad, child)
                child._grad = grad
                


In [146]:
class Function:
    _saved_tensors = None

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__}>"
    
    def save_for_backward(self, *tensors):
        self._saved_tensors = list(tensors)

    @property
    def saved_tensors(self) -> Optional[List[Tensor]]:
        return self._saved_tensors
        
    @classmethod
    def apply(cls, *tensors, **kwargs) -> Union[Tensor, List[Tensor]]:
        func = cls()
        ctx = FunctionCtx(func, *tensors)
        result = Tensor(func.forward(*[t._data for t in tensors], **kwargs), requires_grad=ctx._requires_grad)
        result._ctx = ctx
        return result


class ReLU(Function):
    def forward(self, x):
        self.save_for_backward(x)
        return np.maximum(x, 0.0)

    def backward(self, grad):
        x, = self.saved_tensors
        return grad * (x >= 0.0)


class Add(Function):
    def forward(self, a, b):
        return a + b

    def backward(self, grad):
        return grad, grad


class Matmul(Function):
    def forward(self, a, b):
        self.save_for_backward(a, b)
        return a @ b

    def backward(self, grad):
        a, b, = self.saved_tensors
        print(a.shape, b.shape, grad.shape)
        return b @ grad, a @ grad


class Mean(Function):
    def forward(self, a, dim=None, keepdims=True):
        res = a.sum(axis=dim, keepdims=keepdims)
        self.save_for_backward(a, res)
        return res * np.prod(res.shape) / np.prod(a.shape)

    def backward(self, grad):
        a, res, = self.saved_tensors
        return np.ones(a.shape) * grad * np.prod(res.shape) / np.prod(a.shape)


class FunctionCtx:
    def __init__(self, func: Function, *tensors: Tuple[Tensor]):
        self._children = list(tensors)
        self._func = func
        self._requires_grad = any(t.requires_grad for t in tensors)

    def __repr__(self) -> str:
        return self._func.__repr__()

    def __str__(self) -> str:
        return self._func.__str__()

    @property
    def grad_fn(self) -> Callable:
        return self._func.backward
    


In [147]:
a = Tensor(np.ones((8, 10)), requires_grad=False)
b = Tensor(np.random.normal(0, 0.1, size=(10, 2)), requires_grad=True)

c = Mean.apply(a @ b, dim=None, keepdims=True)
c


tensor([[-0.35406435]], shape=(1, 1), device=cpu, ctx=<Mean>)

In [148]:
c.backward()

tensor([[-0.35406435]], shape=(1, 1), device=cpu, ctx=<Mean>)
[[0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]
 [0.0625 0.0625]]
[0.0625 0.0625] tensor([[-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]
 [-0.37839131 -0.3297374 ]], shape=(8, 2), device=cpu, ctx=<Matmul>)
(8, 10) (10, 2) (2,)


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 10)

In [144]:
a._grad, b._grad, c._grad

(None, None, array([[1.]]))