In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120)

In [3]:
def get_num_correct(preds,labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
    self.fc1 = nn.Linear(in_features=12*4*4,out_features=120)
    self.fc2 = nn.Linear(in_features=120,out_features=60)
    self.out = nn.Linear(in_features=60,out_features=10)

  def forward(self,t):
    t=self.conv1(t)
    t=F.relu(t)
    t=F.max_pool2d(t,kernel_size=2,stride=2)

    t=self.conv2(t)
    t=F.relu(t)
    t=F.max_pool2d(t,kernel_size=2,stride=2)

    t=t.reshape(-1,12*4*4)

    t=self.fc1(t)
    t=F.relu(t)

    t=self.fc2(t)
    t=F.relu(t)

    t=self.out(t)

    return t

In [5]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/FashionMNIST/raw



In [6]:
mynetwork = Network()

In [7]:
train_loader = torch.utils.data.DataLoader(train_set,batch_size=100)
batch = next(iter(train_loader))
images,labels = batch

In [8]:
preds = mynetwork(images)
loss = F.cross_entropy(preds,labels)
loss.item()

2.3083107471466064

In [9]:
print(mynetwork.conv1.weight.grad)

None


Now lets apply back propagation

In [10]:
loss.backward()

In [12]:
mynetwork.conv1.weight.grad.shape

torch.Size([6, 1, 5, 5])

Lets update the weights now

In [13]:
optimizer = optim.Adam(mynetwork.parameters(),lr=0.01)

The network parameters are just the network weights

In [15]:
get_num_correct(preds,labels)

5

This is just a guess prediction where 5 out of 100 was found to be correct

In [16]:
optimizer.step()

In [17]:
preds = mynetwork(images)
loss = F.cross_entropy(preds,labels)
loss.item()

2.266301155090332

The loss has been lowered

In [18]:
get_num_correct(preds,labels)

16

Number of correct prediction also increased