In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120)

In [20]:
def get_num_correct(preds,labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [21]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
    self.fc1 = nn.Linear(in_features=12*4*4,out_features=120)
    self.fc2 = nn.Linear(in_features=120,out_features=60)
    self.out = nn.Linear(in_features=60,out_features=10)

  def forward(self,t):
    t=self.conv1(t)
    t=F.relu(t)
    t=F.max_pool2d(t,kernel_size=2,stride=2)

    t=self.conv2(t)
    t=F.relu(t)
    t=F.max_pool2d(t,kernel_size=2,stride=2)

    t=t.reshape(-1,12*4*4)

    t=self.fc1(t)
    t=F.relu(t)

    t=self.fc2(t)
    t=F.relu(t)

    t=self.out(t)

    return t

In [22]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [23]:
mynetwork = Network()

In [24]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=100)
batch = next(iter(train_loader))
images,labels = batch

In [25]:
preds = mynetwork(images)
loss = F.cross_entropy(preds,labels)
loss.item()

2.316528797149658

In [26]:
print(mynetwork.conv1.weight.grad)

None


Now lets apply back propagation

In [27]:
loss.backward()

In [28]:
mynetwork.conv1.weight.grad.shape

torch.Size([6, 1, 5, 5])

Lets update the weights now

In [29]:
optimizer = optim.Adam(mynetwork.parameters(),lr=0.01)

The network parameters are just the network weights

In [30]:
get_num_correct(preds,labels)

4

This is just a guess prediction where 5 out of 100 was found to be correct

In [31]:
optimizer.step()

In [32]:
preds = mynetwork(images)
loss = F.cross_entropy(preds,labels)
loss.item()

2.2883553504943848

The loss has been lowered

In [33]:
get_num_correct(preds,labels)

15

Number of correct prediction also increased

Lets now loop over training process to cover entire batch

In [40]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set,batch_size=100)

optimizer = optim.Adam(network.parameters(),lr=0.01)


total_loss = 0
total_correct = 0

for batch in train_loader:
  images,labels = batch

  preds = network(images)
  loss = F.cross_entropy(preds,labels)

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  total_loss += loss.item()
  total_correct += get_num_correct(preds,labels)

print("epoch:",0,"total_correct:",total_correct,"loss:",total_loss)  

epoch: 0 total_correct: 47559 loss: 327.6630711257458


In [41]:
accuracy = (total_correct/len(train_set))*100
print(accuracy)

79.265


We got an accuracy of 79% for 1st epoch

Now lets train for multiple epochs

In [42]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set,batch_size=100)

optimizer = optim.Adam(network.parameters(),lr=0.01)

for epoch in range(5):
  total_loss = 0
  total_correct = 0

  for batch in train_loader:
    images,labels = batch

    preds = network(images)
    loss = F.cross_entropy(preds,labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    total_correct += get_num_correct(preds,labels)

  print("epoch:",epoch,"total_correct:",total_correct,"loss:",total_loss)  


epoch: 0 total_correct: 47714 loss: 325.41009621322155
epoch: 1 total_correct: 51590 loss: 226.20172734558582
epoch: 2 total_correct: 52205 loss: 207.66187973320484
epoch: 3 total_correct: 52531 loss: 200.37564292550087
epoch: 4 total_correct: 52729 loss: 193.84976820647717


In [44]:
accuracy = (total_correct/len(train_set))*100
print(f"Accuracy:{accuracy}")

Accuracy:87.88166666666667


We got a really good accuracy over just 5 epochs