In [4]:
import torch

class Module:
    def __init__(self):
        self._parameters = {}
        self.training = True

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def backward(self, *args, **kwargs):
        raise NotImplementedError

    def parameters(self):
        for param in self._parameters.values():
            yield param

    def train(self):
        self.training = True

    def eval(self):
        self.training = False

In [5]:
class BatchNorm(Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum

        self.gamma = torch.ones(num_features, requires_grad=True)
        self.beta = torch.zeros(num_features, requires_grad=True)
        self._parameters['gamma'] = self.gamma
        self._parameters['beta'] = self.beta

        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
        
        self.cache = {}

    def forward(self, x):
        dims = [0]
        if x.dim() == 4:
            dims.extend([2, 3])
        
        shape = [1, -1] + [1] * (x.dim() - 2)
        gamma = self.gamma.view(*shape)
        beta = self.beta.view(*shape)

        if self.training:
            batch_mean = x.mean(dim=dims, keepdim=True)
            batch_var_biased = x.var(dim=dims, unbiased=False, keepdim=True)

            x_hat = (x - batch_mean) / torch.sqrt(batch_var_biased + self.eps)

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean.squeeze()
                
                n = x.numel() / x.shape[1]
                batch_var_unbiased = batch_var_biased * (n / (n - 1))
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var_unbiased.squeeze()

            self.cache = {
                'x': x,
                'x_hat': x_hat,
                'gamma': gamma,
                'batch_mean': batch_mean,
                'batch_var': batch_var_biased,
                'dims': dims
            }
        else:
            mean = self.running_mean.view(*shape)
            var = self.running_var.view(*shape)
            x_hat = (x - mean) / torch.sqrt(var + self.eps)

        out = gamma * x_hat + beta
        return out

    def backward(self, grad_output):
        x, x_hat, gamma, batch_mean, batch_var, dims = \
            self.cache['x'], self.cache['x_hat'], self.cache['gamma'], \
            self.cache['batch_mean'], self.cache['batch_var'], self.cache['dims']

        N = x.shape[0]
        if x.dim() == 4:
            N *= x.shape[2] * x.shape[3]

        self.beta.grad = grad_output.sum(dim=dims)
        self.gamma.grad = (grad_output * x_hat).sum(dim=dims)
        
        grad_x_hat = grad_output * gamma
        
        inv_std = 1. / torch.sqrt(batch_var + self.eps)
        grad_var = -0.5 * (grad_x_hat * (x - batch_mean) * (inv_std**3)).sum(dim=dims, keepdim=True)
        
        grad_mean = -1. * (grad_x_hat * inv_std).sum(dim=dims, keepdim=True) + \
                    grad_var * (-2./N) * (x - batch_mean).sum(dim=dims, keepdim=True)

        grad_x = grad_x_hat * inv_std + \
                 grad_var * (2./N) * (x - batch_mean) + \
                 grad_mean / N
                 
        return grad_x

In [6]:
N, C = 8, 3
num_features = C
eps = 1e-5
momentum = 0.1

my_bn = BatchNorm(num_features, eps=eps, momentum=momentum)
torch_bn = torch.nn.BatchNorm1d(num_features, eps=eps, momentum=momentum)

with torch.no_grad():
    my_bn.gamma.copy_(torch_bn.weight)
    my_bn.beta.copy_(torch_bn.bias)
    my_bn.running_mean.copy_(torch_bn.running_mean)
    my_bn.running_var.copy_(torch_bn.running_var)
    
x = torch.randn(N, C)
x.requires_grad_(True)
my_x = x.clone().detach().requires_grad_(True)

print("--- Проверка в режиме TRAIN ---")

my_bn.train()
torch_bn.train()
my_out = my_bn(my_x)
torch_out = torch_bn(x)

print(f"Выходы forward совпадают: {torch.allclose(my_out, torch_out)}")
print(f"running_mean совпадают: {torch.allclose(my_bn.running_mean, torch_bn.running_mean)}")
print(f"running_var совпадают: {torch.allclose(my_bn.running_var, torch_bn.running_var)}")

grad_output = torch.randn_like(x)
torch_out.backward(grad_output)
my_grad_x = my_bn.backward(grad_output.clone())

print(f"Градиенты по входу dX совпадают: {torch.allclose(my_grad_x, x.grad)}")
print(f"Градиенты по gamma/weight совпадают: {torch.allclose(my_bn.gamma.grad, torch_bn.weight.grad)}")
print(f"Градиенты по beta/bias совпадают: {torch.allclose(my_bn.beta.grad, torch_bn.bias.grad)}")
print("-" * 20)

print("\n--- Проверка в режиме EVAL ---")

my_bn.eval()
torch_bn.eval()

x_eval = torch.randn(N, C)
my_out_eval = my_bn(x_eval)
torch_out_eval = torch_bn(x_eval)

print(f"Выходы forward в режиме eval совпадают: {torch.allclose(my_out_eval, torch_out_eval)}")
print("-" * 20)

--- Проверка в режиме TRAIN ---
Выходы forward совпадают: True
running_mean совпадают: True
running_var совпадают: True
Градиенты по входу dX совпадают: True
Градиенты по gamma/weight совпадают: True
Градиенты по beta/bias совпадают: True
--------------------

--- Проверка в режиме EVAL ---
Выходы forward в режиме eval совпадают: True
--------------------


Реализация linear

In [None]:
import math

class Linear(Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        bound = 1 / math.sqrt(in_features) if in_features > 0 else 0
        self.W = torch.empty(out_features, in_features).uniform_(-bound, bound)
        self.b = torch.empty(out_features).uniform_(-bound, bound)
        
        self.W.requires_grad = True
        self.b.requires_grad = True
        
        self._parameters['W'] = self.W
        self._parameters['b'] = self.b
        
        self.cache = {}

    def forward(self, x):
        self.cache['x'] = x
        output = x @ self.W.T + self.b
        return output

    def backward(self, grad_output):
        x = self.cache['x']
        
        self.W.grad = grad_output.T @ x
        self.b.grad = grad_output.sum(dim=0)
        
        grad_x = grad_output @ self.W
        
        return grad_x

In [None]:
N, D_in, D_out = 16, 100, 10

my_linear = Linear(D_in, D_out)
torch_linear = torch.nn.Linear(D_in, D_out)

with torch.no_grad():
    my_linear.W.copy_(torch_linear.weight)
    my_linear.b.copy_(torch_linear.bias)

x = torch.randn(N, D_in)
x.requires_grad_(True)
my_x = x.clone().detach().requires_grad_(True)

my_out = my_linear(my_x)
torch_out = torch_linear(x)

grad_output = torch.randn(N, D_out)
torch_out.backward(grad_output)
my_grad_x = my_linear.backward(grad_output.clone())

print("--- Проверка слоя Linear ---")
print(f"Выходы forward совпадают: {torch.allclose(my_out, torch_out)}")
print(f"Градиенты по входу dX совпадают: {torch.allclose(my_grad_x, x.grad)}")
print(f"Градиенты по весам dW совпадают: {torch.allclose(my_linear.W.grad, torch_linear.weight.grad)}")
print(f"Градиенты по смещению db совпадают: {torch.allclose(my_linear.b.grad, torch_linear.bias.grad)}")
print("-" * 20)