In [1]:
#| default_exp tests.test_model

In [1]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

from fastcore.utils import *
from tinypytorch.core import *
from tinypytorch.data import get_local_data
from tinypytorch.model import Lin, ReLU, MSE, initialize_parameters, log_softmax, nll, cross_entropy, Model

import pytest

In [2]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

#### Tensor for tests

In [3]:
#| export
A = torch.arange(start=-4, end=8, dtype=torch.float)

In [4]:
#| export
A = torch.reshape(A, (4, 3))

In [52]:
A

tensor([[-4., -3., -2.],
        [-1.,  0.,  1.],
        [ 2.,  3.,  4.],
        [ 5.,  6.,  7.]])

In [53]:
#| export
B = torch.arange(13, 25, dtype=torch.float)

In [54]:
#| export
B = torch.reshape(B, (4, 3))

In [55]:
B

tensor([[13., 14., 15.],
        [16., 17., 18.],
        [19., 20., 21.],
        [22., 23., 24.]])

#### Samples Data

In [56]:
#| export
x_train, y_train, x_valid, y_valid = get_local_data()

In [57]:
#| export
bs = 64

In [58]:
x_train.shape

torch.Size([50000, 784])

In [59]:
#| export
xb = x_train[0:bs]
yb = y_train[0:bs]

In [60]:
xb.shape

torch.Size([64, 784])

In [61]:
#| export
n_in, nh, n_out = 784, 50, 10

#### Linear Layer

In [8]:
#| export
@pytest.fixture
def model():
    return Model(n_in, nh, n_out)

In [63]:
#| export
def test_intialize_parameters_should_return_true():
    
    m = 5 # number of rows
    nh = 3 # number of hidden layers
    w1, b1, w2, b2 = initialize_parameters(m=m, nh=nh)
    assert w1.shape == (m, nh)
    assert b1.shape == (nh,)
    assert w2.shape == (nh, 1)
    assert b2.shape == (1,)

In [12]:
#| export
def test_linear_forward_pass():
    

#### Model

In [9]:
#| export
def test_model_params_shape(model):
    assert model.w1.shape == (n_in, nh)
    assert model.b1.shape == (nh,)
    assert model.w2.shape == (nh, n_out)
    assert model.b2.shape == (n_out,)

In [10]:
#| export
# TODO: fix
def test_model_forward_pass(model):
    assert model(xb).shape == nn.Linear(n_in, n_out)(xb).shape

In [11]:
#| export
def test_model_backward_pass():
    pass

#### Loss Functions

##### Mean Squared Error

In [67]:
#| export
def test_mse_forward_pass():
    
    output = MSE().forward(A, B)
    result = F.mse_loss(A, B)
    
    assert is_near_tensor(output, result) == True

In [68]:
#| export
def test_mse_backward_pass():
    pass

##### Cross-entropy Loss

In [69]:
#| export
@pytest.mark.parametrize(("test_input"), (A, xb))
def test_log_softmax_forward_pass(test_input):
    assert is_near_tensor(log_softmax(A), F.log_softmax(A, dim=1)) == True

In [70]:
log_softmax(A)

tensor([[-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076]])

In [71]:
F.log_softmax(A, dim=1)

tensor([[-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076],
        [-2.4076, -1.4076, -0.4076]])

In [72]:
log_softmax(A) == F.log_softmax(A, dim=1)

tensor([[ True,  True, False],
        [ True,  True, False],
        [ True,  True, False],
        [ True,  True, False]])

In [73]:
log_softmax(A) == log_softmax(A)

tensor([[True, True, True],
        [True, True, True],
        [True, True, True],
        [True, True, True]])

##### Negative Log Likelihood

In [86]:
#| export
@pytest.mark.parametrize(
    ("test_input", "target"),
    ((torch.tensor([1, 0]), torch.tensor([[0, 1, 2], [5, 0, 4]], dtype=torch.float)))
)
def test_nll(test_input, target):
    output = nll(test_input, target)
    result = F.nll_loss(F.test_input, target)
    assert is_near_tensor(output, result) == True

In [87]:
log_softmax(xb)

tensor([[-6.8610, -6.8610, -6.8610,  ..., -6.8610, -6.8610, -6.8610],
        [-6.8844, -6.8844, -6.8844,  ..., -6.8844, -6.8844, -6.8844],
        [-6.8022, -6.8022, -6.8022,  ..., -6.8022, -6.8022, -6.8022],
        ...,
        [-6.7952, -6.7952, -6.7952,  ..., -6.7952, -6.7952, -6.7952],
        [-6.8814, -6.8814, -6.8814,  ..., -6.8814, -6.8814, -6.8814],
        [-6.9820, -6.9820, -6.9820,  ..., -6.9820, -6.9820, -6.9820]])

In [85]:
xb.shape

torch.Size([64, 784])

In [83]:
nll(xb, yb)

inp.shape=torch.Size([64, 784])
targ.shape=torch.Size([64])


tensor(-0.)

In [78]:
sm_pred = F.log_softmax(xb, dim=1)

In [79]:
sm_pred.shape

torch.Size([64, 784])

In [80]:
nll_loss = F.nll_loss(sm_pred, yb)

In [81]:
yb.shape

torch.Size([64])

In [82]:
nll_loss.shape

torch.Size([])

In [193]:
nll_loss

tensor(6.8455)

In [160]:
input = torch.randn(3, 5)

In [127]:
target = torch.tensor([1, 0, 4])

In [128]:
output = F.nll_loss(f.log_softmax(input, dim=1), target)

In [162]:
output

tensor(2.1823)

In [186]:
f.log_softmax(xb, dim=1).shape

torch.Size([64, 784])

In [183]:
output = F.nll_loss(f.log_softmax(xb, dim=1), yb)

In [184]:
output

tensor(6.8455)

In [194]:
output = F.nll_loss(input, target)

##### Cross-entropy Loss

In [163]:
pred = torch.tensor([[0, 1, 2], [5, 0, 4]], dtype=torch.float)

In [164]:
targ = torch.tensor([2, 1])

In [166]:
result = F.cross_entropy(pred, targ)

In [167]:
result

tensor(2.8629)

In [28]:
#| export
def test_cross_entropy_backward_pass():
    pred = torch.tensor([[0, 1, 2], [5, 0, 4]], dtype=torch.float)
    targ = torch.tensor([2, 1])
    
    assert cross_entropy(pred, targ) == F.cross_entropy(pred, targ)

#### Activation Functions

##### ReLU

In [26]:
#| export
def test_relu_forward_pass():
    
    output = ReLU().forward(A)
    result = F.relu(A) - 0.5
    assert is_near_tensor(output, result) == True

In [27]:
#| export
def test_relu_backward_pass():
    pass

In [173]:
output = ReLU().forward(A)

In [174]:
output

tensor([[-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000,  0.5000],
        [ 1.5000,  2.5000,  3.5000],
        [ 4.5000,  5.5000,  6.5000]])

In [175]:
A

tensor([[-4., -3., -2.],
        [-1.,  0.,  1.],
        [ 2.,  3.,  4.],
        [ 5.,  6.,  7.]])

In [176]:
relu = torch.nn.ReLU()
result = relu(A)

In [177]:
result - 0.5

tensor([[-0.5000, -0.5000, -0.5000],
        [-0.5000, -0.5000,  0.5000],
        [ 1.5000,  2.5000,  3.5000],
        [ 4.5000,  5.5000,  6.5000]])