## 39단계: 합계 함수

> .

### 39.1 sum 함수의 역전파

<img src="images/그림 39-1.png" width=450/>

</br>

<img src="images/그림 39-2.png" width=450/>

</br>

<img src="images/그림 39-3.png" width=500/>


### 39.2 sum 함수 구현

In [1]:
# dezero/functions.py
from dezero.core import Function

class Sum(Function):
    def forward(self, x):
        self.x_shape = x.shape
        y = x.sum()
        return y
    
    def backward(self, gy):
        gx = broadcast_to(gy, self.x_shape)
        return gx

def sum(x):
    return Sum()(x)

In [2]:
import numpy as np
from dezero import Variable
import dezero.functions as F

x = Variable(np.array([1, 2, 3, 4, 5, 6]))
y = F.sum(x)
y.backward()
print(y)
print(x.grad)

Variable(21)
Variable([1 1 1 1 1 1])


In [3]:
x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
y = F.sum(x)
y.backward()
print(y)
print(x.grad)

Variable(21)
Variable([[1 1 1]
          [1 1 1]])


### 39.3 axis와 keepdims

In [4]:
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.sum(x, axis=0)
print(y)
print(x.shape, ' -> ', y.shape)

[5 7 9]
(2, 3)  ->  (3,)


<img src="images/그림 39-4.png" width=350/>

</br>

<img src="images/그림 39-5.png" width=500/>

In [5]:
x = np.array([[1, 2, 3], [4, 5, 6]])
y = np.sum(x, keepdims=True)
print(y)
print(y.shape)

[[21]]
(1, 1)


In [6]:
# dezero/functions.py
from dezero.core import Function
from dezero import utils

class Sum(Function):
    def __init__(self, axis, keepdims):
        self.axis = axis
        self.keepdims = keepdims
    
    def forward(self, x):
        self.x_shape = x.shape
        y = x.sum(axis=self.axis, keepdims=self.keepdims)
        return y
    
    def backward(self, gy):
        gy = utils.reshape_sum_backward(gy, self.x_shape, self.axis,
                                        self.keepdims)
        gx = broadcast_to(gy, self.x_shape)
        return gx


def sum(x, axis=None, keepdims=False):
    return Sum(axis, keepdims)(x)

In [7]:
# dezero/core.py

class Variable:
    ...
    def sum(self, axis=None, keepdims=False):
        return dezero.functions.sum(self, axis, keepdims)

In [8]:
from dezero import Variable
import dezero.functions as F

x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
y = F.sum(x, axis=0)
y.backward()
print(y)
print(x.grad)

x = Variable(np.random.randn(2, 3, 4, 5))
y = x.sum(keepdims=True)
print(y.shape)

Variable([5 7 9])
Variable([[1 1 1]
          [1 1 1]])
(1, 1, 1, 1)
