In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision

import MyModel

import Utils

from tqdm import tqdm

In [2]:
EPOCHS = 10

In [3]:
transforms = torchvision.transforms.ToTensor()

In [4]:
trainDataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transforms)

In [5]:
testDataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transforms)

In [6]:
trainLoader = DataLoader(dataset=trainDataset, batch_size=4)

In [7]:
testLoader = DataLoader(dataset=testDataset, batch_size=4)

In [8]:
criterion = nn.NLLLoss()

In [9]:
device = Utils.getDevice()

In [10]:
model = MyModel.MyModel()

In [11]:
model.to(device)

MyModel(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.4)
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (logsoftmax): LogSoftmax()
)

In [12]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [14]:
for i in range(EPOCHS):

    running_acc = 0
    running_loss = 0
    model.train()
    for img, label in tqdm(trainLoader):
        img = img.to(device)
        label = label.to(device)
    
        optimizer.zero_grad()
        pred = model(img)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()
    
        pred = torch.topk(pred, 1)[1].view(1, pred.size()[0])
        running_loss += loss.item()
        running_acc += torch.sum(pred == label).item()
        
    train_loss = running_loss / len(trainLoader)
    train_acc = running_acc / len(trainLoader)
    
    print('Training Loss : ', train_loss)
    print('Training Acc : ', train_acc)
    
    test_loss = 0
    test_acc = 0
    
    model.eval()
    for img, label in tqdm(testLoader):
        img = img.to(device)
        label = label.to(device)
    
        pred = model(img)
        loss = criterion(pred, label)
    
        pred = torch.topk(pred, 1)[1].view(1, pred.size()[0])
        test_loss += loss.item()
        test_acc += torch.sum(pred == label).item()
    
    test_loss = test_loss / len(testLoader)
    test_acc = test_acc / len(testLoader)
    
    print('Test Loss : ', test_loss)
    print('Train Acc : ', test_acc)

  0%|          | 0/15000 [00:00<?, ?it/s]
  4%|▎         | 91/2500 [00:00<00:02, 896.05it/s]

Training Loss :  0.0001490251064300537
Training Acc :  0.0002


100%|██████████| 2500/2500 [00:02<00:00, 967.43it/s]

Test Loss :  2.302908491230011
Train Acc :  0.3844





AssertionError: 