In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from torch.nn import CrossEntropyLoss


## One-hot encoding concepts

In [2]:
one_hot_numpy = np.array([1, 0, 0])

print(F.one_hot(torch.tensor(0), num_classes=3))

# tensor([1, 0, 0])
print(F.one_hot(torch.tensor(1), num_classes=3))
# tensor([0, 1, 0])
print(F.one_hot(torch.tensor(2), num_classes=3))
# tensor([0, 0, 1])


tensor([1, 0, 0])
tensor([0, 1, 0])
tensor([0, 0, 1])


#### Cross entropy loss in Pytorch

In [3]:
score = torch.tensor([[-0.1211, 0.1059]])

one_hot_target = torch.tensor([[1, 0]])

criterion = CrossEntropyLoss()
criterion(score.double(), one_hot_target.double())

tensor(0.8131, dtype=torch.float64)

$$ Backpropagation in PyTorch


In [4]:
sample = torch.randn(1, 16)

# Create the model and run a forward pass
model = nn.Sequential(
    nn.Linear(16,8),
    nn.Linear(8,4),
    nn.Linear(4,2)
)

prediction = model (sample)

In [5]:
target = torch.tensor([0]).long()
#calculate the loss and compute the gradients
criterion = CrossEntropyLoss()
loss = criterion(prediction, target)
loss.backward()



In [6]:
#access each layer's gradients
print(model[0].weight.grad, model[0].bias.grad)
print(model[1].weight.grad, model[1].bias.grad)
print(model[2].weight.grad, model[2].bias.grad)


tensor([[-1.0321e-02,  4.3281e-03,  5.8020e-04, -4.4388e-03,  4.2538e-03,
          1.0143e-03,  5.2763e-03, -6.0700e-03,  4.4452e-03,  1.8668e-04,
         -1.6127e-03, -1.3079e-03,  6.3356e-06, -6.5301e-03, -6.0822e-03,
          1.0461e-02],
        [-7.7925e-02,  3.2677e-02,  4.3804e-03, -3.3512e-02,  3.2116e-02,
          7.6579e-03,  3.9835e-02, -4.5828e-02,  3.3561e-02,  1.4094e-03,
         -1.2176e-02, -9.8747e-03,  4.7833e-05, -4.9301e-02, -4.5920e-02,
          7.8976e-02],
        [ 1.2095e-01, -5.0718e-02, -6.7989e-03,  5.2015e-02, -4.9847e-02,
         -1.1886e-02, -6.1829e-02,  7.1131e-02, -5.2091e-02, -2.1876e-03,
          1.8899e-02,  1.5327e-02, -7.4243e-05,  7.6521e-02,  7.1274e-02,
         -1.2258e-01],
        [-5.5571e-02,  2.3303e-02,  3.1238e-03, -2.3898e-02,  2.2902e-02,
          5.4610e-03,  2.8407e-02, -3.2681e-02,  2.3933e-02,  1.0051e-03,
         -8.6830e-03, -7.0419e-03,  3.4111e-05, -3.5158e-02, -3.2747e-02,
          5.6320e-02],
        [ 9.8776e-03

### Updatign model paramenters


In [7]:
lr  = 0.001

weight = model[0].weight
weight_grad = model[0].weight.grad
weight = weight - lr * weight_grad

bias = model[0].bias
bias_grad = model[0].bias.grad
bias = bias - lr * bias_grad


In [8]:
print(weight, bias)


tensor([[-6.2337e-03, -2.2440e-01, -1.6213e-02, -2.2529e-01,  2.0591e-01,
         -2.2772e-01,  2.3175e-02, -1.8617e-01,  1.9887e-01, -3.5142e-02,
         -1.7977e-01, -2.4464e-02,  2.0432e-01,  1.6612e-01,  1.1696e-01,
          1.6165e-01],
        [-6.7822e-02,  1.9475e-01, -1.5042e-01,  1.5298e-01, -1.8437e-01,
         -1.7940e-01,  1.5854e-01, -1.2026e-01,  1.7810e-01, -2.5168e-02,
         -2.0115e-02, -1.3009e-01, -2.7404e-02,  1.6137e-01, -2.2021e-01,
         -9.5999e-02],
        [-1.7664e-01,  6.8405e-02,  6.1880e-02, -1.9997e-01,  1.6182e-03,
          1.2510e-01, -1.0986e-01,  1.2840e-01, -2.2650e-01, -2.4172e-01,
          4.7030e-02, -2.2015e-01,  1.1567e-01, -1.1529e-01, -1.2152e-01,
         -9.8257e-02],
        [-1.3138e-02,  1.5581e-01,  9.7134e-02,  3.2372e-02,  1.7983e-01,
         -1.1459e-01, -1.4688e-01,  4.1477e-02, -1.1482e-01, -1.4018e-01,
         -6.4237e-02, -8.3823e-03, -6.9378e-02, -1.2334e-01, -2.4418e-01,
         -1.8375e-01],
        [ 1.2425e-01

In [10]:
optimizer = optim.SGD(model.parameters(), lr=0.001)

In [11]:
optimizer.step()

In [13]:
print(optimizer.param_groups[0]['params'][0])

Parameter containing:
tensor([[-6.2337e-03, -2.2440e-01, -1.6213e-02, -2.2529e-01,  2.0591e-01,
         -2.2772e-01,  2.3175e-02, -1.8617e-01,  1.9887e-01, -3.5142e-02,
         -1.7977e-01, -2.4464e-02,  2.0432e-01,  1.6612e-01,  1.1696e-01,
          1.6165e-01],
        [-6.7822e-02,  1.9475e-01, -1.5042e-01,  1.5298e-01, -1.8437e-01,
         -1.7940e-01,  1.5854e-01, -1.2026e-01,  1.7810e-01, -2.5168e-02,
         -2.0115e-02, -1.3009e-01, -2.7404e-02,  1.6137e-01, -2.2021e-01,
         -9.5999e-02],
        [-1.7664e-01,  6.8405e-02,  6.1880e-02, -1.9997e-01,  1.6182e-03,
          1.2510e-01, -1.0986e-01,  1.2840e-01, -2.2650e-01, -2.4172e-01,
          4.7030e-02, -2.2015e-01,  1.1567e-01, -1.1529e-01, -1.2152e-01,
         -9.8257e-02],
        [-1.3138e-02,  1.5581e-01,  9.7134e-02,  3.2372e-02,  1.7983e-01,
         -1.1459e-01, -1.4688e-01,  4.1477e-02, -1.1482e-01, -1.4018e-01,
         -6.4237e-02, -8.3823e-03, -6.9378e-02, -1.2334e-01, -2.4418e-01,
         -1.8375e-01]