In [1]:
import numpy as np
import pandas as pd

import seaborn as sns
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split, SubsetRandomSampler


from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f48a113e990>

In [3]:
%matplotlib inline
sns.set_style('darkgrid')

In [4]:
iris = load_iris()

X = iris.data
y = iris.target

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=42)

In [6]:
train_dataset = TensorDataset(torch.Tensor(X_train), torch.Tensor(y_train).long())
test_dataset = TensorDataset(torch.Tensor(X_test), torch.Tensor(y_test).long())

In [7]:
train_dataloader = DataLoader(train_dataset, batch_size=8, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=1)

In [8]:
class IrisClassifier(nn.Module):
    def __init__(self):
        super(IrisClassifier, self).__init__()
        
        self.fc1 = nn.Linear(in_features=4, out_features=5)
        self.fc2 = nn.Linear(in_features=5, out_features=6)
        self.fc3 = nn.Linear(in_features=6, out_features=3)
        
        self.relu = nn.ReLU()
        
        self.gradient = None
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        
        x = self.fc2(x)
        x = self.relu(x)
        
        x.register_hook(self.extract_gradient)
        
        x = self.fc3(x)
        
        
        return x
    
    
    
    def extract_gradient(self, grad):
        self.gradient = grad
        print(grad.shape)
        
    def get_gradient(self):
        return self.gradient

In [9]:
model = IrisClassifier()
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

IrisClassifier(
  (fc1): Linear(in_features=4, out_features=5, bias=True)
  (fc2): Linear(in_features=5, out_features=6, bias=True)
  (fc3): Linear(in_features=6, out_features=3, bias=True)
  (relu): ReLU()
)


In [10]:
for i in model.named_parameters():
    print(i)

('fc1.weight', Parameter containing:
tensor([[-0.0037,  0.2682, -0.4115, -0.3680],
        [-0.1926,  0.1341, -0.0099,  0.3964],
        [-0.0444,  0.1323, -0.1511, -0.0983],
        [-0.4777, -0.3311, -0.2061,  0.0185],
        [ 0.1977,  0.3000, -0.3390, -0.2177]], requires_grad=True))
('fc1.bias', Parameter containing:
tensor([ 0.1816,  0.4152, -0.1029,  0.3742, -0.0806], requires_grad=True))
('fc2.weight', Parameter containing:
tensor([[ 0.0473,  0.4049, -0.4149, -0.2815, -0.1132],
        [-0.1743,  0.3864, -0.2899, -0.2059, -0.3124],
        [-0.4188, -0.2611,  0.3844,  0.1996,  0.2168],
        [ 0.0235, -0.2293,  0.0757, -0.4176, -0.3231],
        [-0.2306,  0.2822,  0.2622, -0.1983, -0.0161],
        [ 0.2860,  0.4446,  0.1775,  0.0604,  0.2999]], requires_grad=True))
('fc2.bias', Parameter containing:
tensor([-0.2633,  0.0833, -0.3467, -0.3100, -0.2310,  0.2024],
       requires_grad=True))
('fc3.weight', Parameter containing:
tensor([[ 0.1642, -0.2418,  0.1233,  0.2241, -0.0

In [12]:
print("Begin training.")

for e in tqdm(range(1, 3)):
    
    # TRAINING
    epoch_loss = 0
    
    model.train()
    for X_batch, y_batch in train_dataloader:

        optimizer.zero_grad()
        
        y_pred = model(X_batch)        
        
        loss = criterion(y_pred, y_batch)
        
        loss.backward()
        optimizer.step()      
        
        epoch_loss += loss.item()
        
                              
    
    print(f'Epoch {e+0:03}: | Train Loss: {epoch_loss/len(train_dataloader):.5f}')

Begin training.


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
Epoch 001: | Train Loss: 1.06078
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
torch.Size([8, 6])
Epoch 002: | Train Loss: 0.88591



We can print out certain information about this model. Let's see what we can do.

In [13]:
for params in model.named_parameters():
    print(params[0], "  :  ", params[1].shape)

fc1.weight   :   torch.Size([5, 4])
fc1.bias   :   torch.Size([5])
fc2.weight   :   torch.Size([6, 5])
fc2.bias   :   torch.Size([6])
fc3.weight   :   torch.Size([3, 6])
fc3.bias   :   torch.Size([3])


In [14]:
model.gradient

tensor([[ 0.0258, -0.0085, -0.0059,  0.0323,  0.0109, -0.0211],
        [-0.0258, -0.0119, -0.0003, -0.0166, -0.0071,  0.0189],
        [-0.0260, -0.0123, -0.0004, -0.0165, -0.0071,  0.0190],
        [ 0.0246, -0.0108, -0.0065,  0.0330,  0.0109, -0.0205],
        [-0.0270, -0.0133, -0.0006, -0.0168, -0.0073,  0.0197],
        [ 0.0258, -0.0095, -0.0062,  0.0332,  0.0111, -0.0212],
        [ 0.0234, -0.0123, -0.0068,  0.0329,  0.0107, -0.0197],
        [-0.0259, -0.0123, -0.0004, -0.0165, -0.0071,  0.0190]])

In [15]:
model.eval()

IrisClassifier(
  (fc1): Linear(in_features=4, out_features=5, bias=True)
  (fc2): Linear(in_features=5, out_features=6, bias=True)
  (fc3): Linear(in_features=6, out_features=3, bias=True)
  (relu): ReLU()
)

In [17]:
test_dataloader_iter = iter(test_dataloader)

In [18]:
test_sample, test_label = next(test_dataloader_iter)

In [20]:
pred = model(test_sample)
pred.shape

torch.Size([1, 3])

In [22]:
pred = F.softmax(pred, dim = 1)
pred

tensor([[0.1072, 0.3800, 0.5129]], grad_fn=<SoftmaxBackward>)

In [24]:
pred[:, 1].backward()

torch.Size([1, 6])


In [25]:
model.gradient

tensor([[-0.0836,  0.0213,  0.0171, -0.0991, -0.0347,  0.0680]])