In [None]:
class Function:
    def __call__(self, *inputs: Union[np.ndarray, Variable]) -> Union[list[Variable], Variable]:
        # 当这个函数被调用的时候，记录这个函数的输入，输出
        inputs: list[Variable] = [as_variable(x) for x in inputs]

        xs = [x.data for x in inputs] #读取传进来的输入

        ys = self.forward(*xs) #计算输出
        if not isinstance(ys, tuple):
            ys = (ys,) #将输出转换为tuple
        outputs = [Variable(as_array(y)) for y in ys]

        if Config.enable_backprop:
            self.generation = max([x.generation for x in inputs]) #确保函数的generation大于函数所有输入的generation
            for output in outputs:
                output.set_creator(self) #设置输出变量的creator
            self.outputs = [weakref.ref(output) for output in outputs]
            # outputs中储存的是对所有output的弱引用，好处是当output不再被其他对象引用时，弱引用可以被
            # 垃圾回收机制释放，避免内存泄漏
            self.inputs = inputs

        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, *xs: np.ndarray) -> tuple[np.ndarray]:
        raise NotImplementedError()

    def backward(self, gys: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        raise NotImplementedError()

In [None]:
# 各种函数具体的forward和backward实现，这里只截取了一部分特殊函数
# 可以仔细看看backward的实现
class Reshape(Function):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x: np.ndarray) -> tuple[np.ndarray]:
        self.x_shape = x.shape
        y = x.reshape(self.shape)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        return reshape(gy, self.x_shape) #就是reshape回去


def reshape(x, shape):
    if x.shape == shape:
        return as_variable(x)
    return Reshape(shape)(x)

class Transpose(Function):
    def __init__(self, axes=None):
        self.axes = axes

    def forward(self, x):
        y = x.transpose(self.axes)
        return y

    def backward(self, gy):
        if self.axes is None:
            return transpose(gy) 

        axes_len = len(self.axes)
        inv_axes = tuple(np.argsort([ax % axes_len for ax in self.axes]))
        return transpose(gy, inv_axes)

    # 原来的shape[0, 1, 2, 3]
    # 假设axes=[3, 1, 0, 2] axes_len=4
    # argsort返回一个储存了下标的数组，下标的顺序使得下标对应的值从小到大排序
    # inv_axes=[2, 1, 3, 0] # 再转置后确保梯度与前向传播时的输入数据保持一致。


def transpose(x, axes=None):
    return Transpose(axes)(x)


class BroadcastTo(Function):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        self.x_shape = x.shape
        xp = get_array_module(x)
        y = xp.broadcast_to(x, self.shape)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        gx = sum_to(gy, self.x_shape) #加回去，把多broadcast出来的大小舍去
        return gx


def broadcast_to(x, shape):
    if x.shape == shape:
        return as_variable(x)
    return BroadcastTo(shape)(x)


class SumTo(Function):
    def __init__(self, shape):
        self.shape = shape

    def forward(self, x: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        self.x_shape = x.shape
        y = raw_sum_to(x, self.shape)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        gx = broadcast_to(gy, self.x_shape) # 把原来因为加上缺失的维度补上
        return gx


def sum_to(x, shape):
    if x.shape == shape:
        return as_variable(x)
    return SumTo(shape)(x)


class MatMul(Function): #矩阵乘法
    def forward(self, x: np.ndarray, W: np.ndarray) -> tuple[np.ndarray]:
        if x.ndim <= 2 and W.ndim <= 2:
            y = x.dot(W)
        else:
            y = x @ W
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        x, W = self.inputs
        gx = matmul(gy, W.transpose(([i for i in range(W.ndim - 2)] + [-1, -2])))
        # 将W最后两个维度交换后和gy相乘
        # 按照数学定义，gx=gy·W^T,为什么只交换W的最后两个维度，是因为
        gW = matmul(x.transpose(([i for i in range(x.ndim - 2)] + [-1, -2])), gy)
        return gx, gW


def matmul(x, W):
    return MatMul()(x, W)


class MeanSquaredError(Function):
    def forward(self, x0: np.ndarray, x1: np.ndarray) -> np.ndarray:
        diff = x0 - x1
        y = (diff ** 2).sum() / len(diff)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        x0, x1 = self.inputs
        diff: Variable = x0 - x1
        gx0: Variable = gy * diff * (2. / len(diff))
        gx1: Variable = -gx0
        return gx0, gx1


def mean_squared_error(x, y):
    return MeanSquaredError()(x, y)


class Linear(Function):
    def forward(self, x: np.ndarray, W: np.ndarray, b: np.ndarray) -> tuple[np.ndarray]:
        y = x.dot(W)
        if b is not None:
            y += b

        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        x, W, b = self.inputs
        gb = None if b.data is None else sum_to(gy, b.shape)
        gx = matmul(gy, W.transpose(([i for i in range(W.ndim - 2)] + [-1, -2])))
        gW = matmul(x.transpose(([i for i in range(x.ndim - 2)] + [-1, -2])), gy)
        return gx, gW, gb


def linear(x, W, b=None):
    return Linear()(x, W, b)


class Sigmoid(Function):
    def forward(self, x: np.ndarray) -> tuple[np.ndarray]:
        # y = 1 / (1 + exp(-x))
        xp = get_array_module(x)
        y = xp.tanh(x * 0.5) * 0.5 + 0.5
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        y = self.outputs[0]()
        gx = gy * y * (1 - y)
        return gx


def sigmoid(x):
    return Sigmoid()(x)


class GetItem(Function):
    def __init__(self, slices):
        self.slices = slices

    def forward(self, x: np.ndarray) -> tuple[np.ndarray]:
        y = x[self.slices]
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        x = self.inputs[0]
        f = GetItemGrad(self.slices, x.shape)
        return f(gy)


def get_item(x, slices):
    return GetItem(slices)(x)


class GetItemGrad(Function):
    def __init__(self, slices, in_shape):
        self.slices = slices
        self.in_shape = in_shape

    def forward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        xp = get_array_module(gy)
        gx = xp.zeros(self.in_shape)
        if xp is np:
            np.add.at(gx, self.slices, gy)
        else:
            xp.scatter_add(gx, self.slices, gy)
        return gx

    def backward(self, ggx: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        return get_item(ggx, self.slices)


class Softmax(Function):
    def __init__(self, axis=1):
        self.axis = axis

    def forward(self, x: np.ndarray) -> tuple[np.ndarray]:
        xp = get_array_module(x)
        y = x - x.max(axis=self.axis, keepdims=True)
        y = xp.exp(y)
        y /= y.sum(axis=self.axis, keepdims=True)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        y = self.outputs[0]()
        gx = y * gy
        sumdx = gx.sum(axis=self.axis, keepdims=True)
        gx -= y * sumdx
        return gx


def softmax(x, axis=1):
    return Softmax(axis=axis)(x)


class Cat(Function):
    def __init__(self, axis: int = 0):
        self.axis = axis

    def forward(self, *xs: np.ndarray) -> np.ndarray:
        xp = get_array_module(xs[0])
        z = xp.concatenate(xs, axis=self.axis)
        return z

    def backward(self, gy: Variable) -> Union[tuple[Variable, ...], Variable]:
        inputs = self.inputs
        gx = []
        start_idx = 0
        for x in inputs:
            end_idx = start_idx + x.shape[self.axis]
            indices = [slice(None)] * gy.ndim
            indices[self.axis] = slice(start_idx, end_idx)
            gx.append(gy[tuple(indices)])
            start_idx = end_idx

        return tuple(gx)


def cat(inputs, axis=0):
    return Cat(axis=axis)(*inputs)


class Clip(Function):
    def __init__(self, x_min, x_max):
        self.x_min = x_min
        self.x_max = x_max

    def forward(self, x: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        xp = get_array_module(x)
        y = xp.clip(x, self.x_min, self.x_max)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        x = self.inputs[0]
        mask = (x.data >= self.x_min) * (x.data <= self.x_max)
        gx = gy * mask
        return gx


def clip(x, x_min, x_max):
    return Clip(x_min, x_max)(x)


def softmax_cross_entropy_simple(x, t):
    x, t = as_variable(x), as_variable(t)
    N = x.shape[0]

    p = softmax(x)
    p = clip(p, 1e-15, 1.0)
    log_p = log(p)
    tlog_p = log_p[np.arange(N), t.data]
    y = -1 * sum(tlog_p) / N
    return y


def accuracy(y, t):
    y, t = as_variable(y), as_variable(t)

    pred = y.data.argmax(axis=1).reshape(t.shape)
    result = (pred == t.data)
    acc = result.mean()

    return Variable(as_array(acc))


def dropout(x, dropout_ratio=0.5):
    x = as_variable(x)

    if Config.train:
        xp = get_array_module(x)
        mask = xp.random.rand(*x.shape) > dropout_ratio
        scale = xp.array(1.0 - dropout_ratio).astype(x.dtype)
        y = x * mask / scale
        return y
    else:
        return x


class Stack(Function):
    def __init__(self, axis: int = 0):
        self.axis = axis

    def forward(self, *xs: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        xp = get_array_module(xs[0])
        self.x_shape = xs[0].shape
        self.x_num = len(xs)
        y = xp.stack(xs, axis=self.axis)
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        gx = []
        for i in range(self.x_num):
            indices = [slice(None)] * gy.ndim
            indices[self.axis] = slice(i, i + 1)
            gx.append(gy[tuple(indices)].reshape(self.x_shape))
        return tuple(gx)


def stack(inputs, axis=0):
    return Stack(axis=axis)(*inputs)
