## 19단계: 변수 사용성 개선

> 이번 단계에서는 Variable 클래스를 더욱 쉽게 사용할 수 있게 해보겠습니다.

### 19.1 변수 이름 지정

앞으로 수많은 변수들이 생성될 때 이들을 구분하기 위해 변수에 이름을 붙여주자.

In [None]:
import weakref
from heapq import heappush, heappop
import numpy as np


class Variable:
    # ---------- 변수명 추가 ----------
    def __init__(self, data, name=None):
        if not isinstance(data, (type(None), np.ndarray)):
            raise TypeError(f'{type(data)}은(는) 지원하지 않습니다.')
        
        self.data = data
        self.name = name  # 변수명 추가
        self.grad = None
        self.creator = None
        self.generation = 0
    # --------------------------------

    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1
    
    def backward(self, retain_grad=False):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs_heap = []
        seen_set = set()
        
        def add_func(f):
            if f not in seen_set:
                heappush(funcs_heap, f)
                seen_set.add(f)
        
        add_func(self.creator)
        
        while funcs_heap:
            f = heappop(funcs_heap)
            gys = [output.grad for output in f.outputs]
            
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    add_func(x.creator)
            
            if not retain_grad:
                for y in f.outputs:
                    y.grad = None
    
    def cleargrad(self):
        self.grad = None

### 19.2 ndarray 인스턴스 변수

`Variable`이 직접 데이터(`Variable.data`)를 다룰 수 있도록 프록시(proxy) 패턴으로 설계한다.

In [2]:
class Variable:
    def __init__(self, data, name=None):
        if not isinstance(data, (type(None), np.ndarray)):
            raise TypeError(f'{type(data)}은(는) 지원하지 않습니다.')
        
        self.data = data
        self.name = name
        self.grad = None
        self.creator = None
        self.generation = 0

    # -------- ndarray 프록시 패턴 추가 --------
    @property
    def shape(self):
        return self.data.shape
    
    @property
    def ndim(self):
        return self.data.ndim
    
    @property
    def size(self):
        return self.data.size
    
    @property
    def dtype(self):
        return self.data.dtype
    # -----------------------------------------
    
    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1
    
    def backward(self, retain_grad=False):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs_heap = []
        seen_set = set()
        
        def add_func(f):
            if f not in seen_set:
                heappush(funcs_heap, f)
                seen_set.add(f)
        
        add_func(self.creator)
        
        while funcs_heap:
            f = heappop(funcs_heap)
            gys = [output.grad for output in f.outputs]
            
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    add_func(x.creator)
            
            if not retain_grad:
                for y in f.outputs:
                    y.grad = None
    
    def cleargrad(self):
        self.grad = None

In [3]:
x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
print(x.shape)  # @property로 x.shape() 대신 x.shape로 호출 가능함

(2, 3)


### 19.3 len 함수와 print 함수

`Variable` 클래스에 `__len__` 메서드와 `__repr__` 메서드를 추가해보자.

In [5]:
class Variable:
    def __init__(self, data, name=None):
        if not isinstance(data, (type(None), np.ndarray)):
            raise TypeError(f'{type(data)}은(는) 지원하지 않습니다.')
        
        self.data = data
        self.name = name
        self.grad = None
        self.creator = None
        self.generation = 0

    # -------- magic method 추가 --------
    def __len__(self):
        return len(self.data)
    
    def __repr__(self):
        if self.data is None:
            return 'Variable(None)'
        p = str(self.data).replace('\n', '\n' + ' '*9)
        return f'Variable({p})'
    # -----------------------------------
    
    @property
    def shape(self):
        return self.data.shape
    
    @property
    def ndim(self):
        return self.data.ndim
    
    @property
    def size(self):
        return self.data.size
    
    @property
    def dtype(self):
        return self.data.dtype
    
    def set_creator(self, func):
        self.creator = func
        self.generation = func.generation + 1
    
    def backward(self, retain_grad=False):
        if self.grad is None:
            self.grad = np.ones_like(self.data)
        
        funcs_heap = []
        seen_set = set()
        
        def add_func(f):
            if f not in seen_set:
                heappush(funcs_heap, f)
                seen_set.add(f)
        
        add_func(self.creator)
        
        while funcs_heap:
            f = heappop(funcs_heap)
            gys = [output.grad for output in f.outputs]
            
            gxs = f.backward(*gys)
            if not isinstance(gxs, tuple):
                gxs = (gxs,)
            
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad = x.grad + gx
                
                if x.creator is not None:
                    add_func(x.creator)
            
            if not retain_grad:
                for y in f.outputs:
                    y.grad = None
    
    def cleargrad(self):
        self.grad = None

In [6]:
x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
print(len(x))

2


In [7]:
x = Variable(np.array([1, 2, 3]))
print(x)

x = Variable(None)
print(x)

x = Variable(np.array([[1, 2, 3], [4, 5, 6]]))
print(x)

Variable([1 2 3])
Variable(None)
Variable([[1 2 3]
          [4 5 6]])
