In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)

In [32]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [33]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        #(1) input Layer
        t = t
        
        #(2) hidden conv Layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        #(3) hidden conv Layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t, kernel_size=1, stride=2)
        
        #(4) hidden Liner Layer
        t = t.reshape(-1,12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        
        #(5) hidden Linear Layer
        t = self.fc2(t)
        t = F.relu(t)
        
        #(6) output Layer
        t = self.out(t)
        #t = F.softmax(t, dim=1)
        
        return t

In [34]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1f5e2509588>

In [35]:
net = Network()

In [36]:
data_loader = torch.utils.data.DataLoader(
    train_set
    ,batch_size=10
)

In [37]:
batch = next(iter(data_loader))

In [38]:
images, labels = batch

In [39]:
images.shape

torch.Size([10, 1, 28, 28])

In [40]:
labels.shape

torch.Size([10])

In [41]:
preds = net(images)

In [13]:
preds.shape

torch.Size([10, 10])

In [14]:
preds

tensor([[ 0.1026,  0.0799, -0.1155,  0.0405, -0.1070,  0.0330,  0.0585,  0.0857, -0.1004, -0.0275],
        [ 0.1028,  0.0854, -0.1217,  0.0402, -0.1043,  0.0289,  0.0584,  0.0929, -0.1027, -0.0301],
        [ 0.1005,  0.0856, -0.1157,  0.0418, -0.1049,  0.0308,  0.0587,  0.0927, -0.0996, -0.0223],
        [ 0.1018,  0.0842, -0.1163,  0.0399, -0.1050,  0.0301,  0.0586,  0.0910, -0.0997, -0.0242],
        [ 0.1045,  0.0842, -0.1186,  0.0425, -0.0962,  0.0247,  0.0564,  0.0921, -0.1013, -0.0326],
        [ 0.1028,  0.0896, -0.1169,  0.0389, -0.1008,  0.0286,  0.0540,  0.0888, -0.0986, -0.0266],
        [ 0.1028,  0.0800, -0.1067,  0.0420, -0.1171,  0.0374,  0.0580,  0.0795, -0.1107, -0.0229],
        [ 0.1010,  0.0866, -0.1200,  0.0419, -0.0977,  0.0256,  0.0551,  0.0900, -0.0989, -0.0304],
        [ 0.1067,  0.0878, -0.1120,  0.0379, -0.1107,  0.0321,  0.0519,  0.0885, -0.0995, -0.0196],
        [ 0.1015,  0.0915, -0.1125,  0.0371, -0.1086,  0.0274,  0.0489,  0.0909, -0.0946, -0.0227]])

In [17]:
preds.argmax(dim=1)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [18]:
labels

tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5])

In [21]:
preds.argmax(dim=1).eq(labels).sum()

tensor(3)

In [28]:
def get_num_correct(preds,labels):
    return preds.argmax(dim=1).eq(labels).sum()

In [29]:
get_num_correct(preds,labels)

tensor(3)