In [1]:
import torch
import torchvision
import torchvision.transforms as transforms 

torch.set_printoptions(linewidth=120)

In [2]:
 train_set = torchvision.datasets.FashionMNIST(    # Extract
     root = '.data/FashionMNIST',
     train = True,
     download = True,
     transform = transforms.Compose([             # Transform
         transforms.ToTensor()
     ])
 )

In [3]:
import torch.nn as nn 
import torch.nn.functional as F

class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4 , out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
        # (1) input layer
        t = t
        
        # (2) hiddden conv layer
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        # (3) hidden conv layer
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        # (4) hidden linear layer
        t = t.reshape(-1,12*4*4)
        t = self.fc1(t)
        t = F.relu(t)
        
        # (5) hidden linear layer
        t = self.fc2(t)
        t = F.relu(t)
        
        # (6) output layer
        t = self.out(t)
        # t = F.softmax(t,dim=1) (since loss is cross_entropy which automatically uses softmax)
        
        return t

In [4]:
network = Network()

train_loader = torch.utils.data.DataLoader(train_set, batch_size = 100) #Load
all_preds = torch.tensor([])
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)

for batch in train_loader:
    images, labels = batch
    preds = network(images)
    loss = F.cross_entropy(preds, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    all_preds = torch.cat((all_preds,preds),dim=0)
    

In [5]:
all_preds.shape

torch.Size([60000, 10])

In [6]:
all_preds.argmax(dim=1)

tensor([6, 6, 6,  ..., 3, 0, 5])

In [7]:
 stacked = torch.stack( (train_set.targets,all_preds.argmax(dim=1)),dim=1 )

In [8]:
print(stacked)

tensor([[9, 6],
        [0, 6],
        [0, 6],
        ...,
        [3, 3],
        [0, 0],
        [5, 5]])


In [9]:
cnt = torch.zeros((10,10),dtype=torch.int64)
for i in stacked:
    x,y = i.tolist()
    cnt[x][y] += 1 

In [10]:
print(cnt)

tensor([[4740,  116,  146,  401,   43,   17,  448,    1,   87,    1],
        [  16, 5548,   17,  323,   33,    6,   51,    0,    6,    0],
        [  92,   59, 4010,   80, 1018,   24,  601,    1,  110,    5],
        [ 287,  357,   81, 4805,  268,   10,  173,    2,   16,    1],
        [  46,   93,  760,  308, 4186,   11,  550,    0,   44,    2],
        [   8,    5,    2,   10,    1, 5363,   21,  385,   60,  145],
        [1345,   82, 1344,  263, 1039,   31, 1765,    1,  127,    3],
        [   8,   14,    0,    0,    0,  435,   21, 5113,   14,  395],
        [  45,   49,  102,   37,   44,  120,   99,   18, 5451,   35],
        [   6,   29,    2,    5,    0,  161,   22,  293,   14, 5468]])


In [11]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(train_set.targets, all_preds.argmax(dim=1))
cm

array([[4740,  116,  146,  401,   43,   17,  448,    1,   87,    1],
       [  16, 5548,   17,  323,   33,    6,   51,    0,    6,    0],
       [  92,   59, 4010,   80, 1018,   24,  601,    1,  110,    5],
       [ 287,  357,   81, 4805,  268,   10,  173,    2,   16,    1],
       [  46,   93,  760,  308, 4186,   11,  550,    0,   44,    2],
       [   8,    5,    2,   10,    1, 5363,   21,  385,   60,  145],
       [1345,   82, 1344,  263, 1039,   31, 1765,    1,  127,    3],
       [   8,   14,    0,    0,    0,  435,   21, 5113,   14,  395],
       [  45,   49,  102,   37,   44,  120,   99,   18, 5451,   35],
       [   6,   29,    2,    5,    0,  161,   22,  293,   14, 5468]])