## Embedding
[没有思考过 Embedding，不足以谈 AI](https://zhuanlan.zhihu.com/p/643560252)

看完这篇博客应该就对embedding是什么很理解了，而项目中代码的实现也就很清晰了，对于forward就是一个查表的操作

In [None]:
class Layer:
    def __init__(self):
        self._params: set[str] = set()

    def __setattr__(self, key, value):
        if isinstance(value, (Parameter, Layer)):
            self._params.add(key)
        super(Layer, self).__setattr__(key, value)

    def __call__(self, *inputs: Union[Variable, np.ndarray]) -> Union[list[Variable], Variable]:
        outputs = self.forward(*inputs)
        if not isinstance(outputs, tuple):
            outputs = (outputs,)
        self.inputs = [weakref.ref(x) for x in inputs]
        self.outputs = [weakref.ref(y) for y in outputs]
        return outputs if len(outputs) > 1 else outputs[0]

    def forward(self, *inputs):
        raise NotImplementedError()

    def params(self):
        for name in self._params:
            obj = self.__dict__[name]
            if isinstance(obj, Layer):
                yield from obj.params()
            else:
                yield obj

    def cleargrads(self):
        for param in self.params():
            param.cleargrad()

    def to_cpu(self):
        for param in self.params():
            param.to_cpu()

    def to_gpu(self):
        for param in self.params():
            param.to_gpu()

    def _flatten_params(self, params_dict, parent_key=""):
        for name in self._params:
            obj = self.__dict__[name]
            key = parent_key + '/' + name if parent_key else name

            if isinstance(obj, Layer):
                obj._flatten_params(params_dict, key)
            else:
                params_dict[key] = obj

    def save_weights(self, path):
        self.to_cpu()
        params_dict = {}
        self._flatten_params(params_dict)
        array_dict = {key: param.data for key, param in params_dict.items()
                      if param is not None}
        try:
            np.savez_compressed(path, **array_dict)
        except (Exception, KeyboardInterrupt) as e:
            if os.path.exists(path):
                os.remove(path)
            raise

    def load_weights(self, path):
        npz = np.load(path)
        params_dict = {}
        self._flatten_params(params_dict)
        for key, param in params_dict.items():
            param.data = npz[key]
            print(f'{key} loaded')


class Linear(Layer):
    def __init__(self, out_size: int, nobias: bool = False, dtype=np.float32, in_size: int = None):
        super(Linear, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        self.dtype = dtype

        self.W = Parameter(None, name='W')
        if self.in_size is not None:
            self._init_W()

        if nobias:
            self.b = None
        else:
            self.b = Parameter(np.zeros(out_size, dtype=dtype), name='b')

    def _init_W(self):
        I, O = self.in_size, self.out_size
        W_data = np.random.randn(I, O).astype(self.dtype) * np.sqrt(1 / I)
        self.W = Parameter(W_data, name='W')

    def forward(self, x):
        if self.W.data is None:
            self.in_size = x.shape[1]
            self._init_W()
        y = linear(x, self.W, self.b)
        return y


class Embedding(Layer):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.W = Parameter(np.random.randn(in_size, out_size), name='W')

    def forward(self, x):
        y = self.W[x]
        return y

In [None]:
class GetItem(Function):
    def __init__(self, slices):
        self.slices = slices

    def forward(self, x: np.ndarray) -> tuple[np.ndarray]:
        y = x[self.slices]
        return y

    def backward(self, gy: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        x = self.inputs[0]
        f = GetItemGrad(self.slices, x.shape)
        return f(gy)


def get_item(x, slices):
    return GetItem(slices)(x)


class GetItemGrad(Function):
    def __init__(self, slices, in_shape):
        self.slices = slices
        self.in_shape = in_shape

    def forward(self, gy: np.ndarray) -> Union[tuple[np.ndarray, ...], np.ndarray]:
        xp = get_array_module(gy)
        gx = xp.zeros(self.in_shape)
        if xp is np:
            np.add.at(gx, self.slices, gy)
        else:
            xp.scatter_add(gx, self.slices, gy)
        return gx

    def backward(self, ggx: np.ndarray) -> Union[tuple[Variable, ...], Variable]:
        return get_item(ggx, self.slices)