In [30]:
import torch
from torch import nn

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [31]:
training_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor())

test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor())

In [32]:
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [39]:
class FashionNet(nn.Module):
    
    def __init__(self,):
        super(FashionNet, self).__init__()
        self.flatten = nn.Flatten()
        
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(), 
            # nn.Linear(512, 512),
            # nn.ReLU(),
            # nn.Linear(512, 512),
            # nn.ReLU(),
            nn.Linear(512, 10)
        )
        
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [40]:
model = FashionNet().to(device)
print(model)

FashionNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [41]:
loss_function = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0)

In [42]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # back-propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [43]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [45]:
epochs = 50
from tqdm import trange

for t in trange(epochs):
    print(f'Epoch {t+1} \n----------------------------------------')
    train(train_dataloader, model, loss_function, optim)
    test(test_dataloader, model, loss_function)
    
print('Done \n')

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1 
----------------------------------------
loss: 1.183741  [    0/60000]
loss: 1.128579  [12800/60000]
loss: 1.128815  [25600/60000]
loss: 1.124442  [38400/60000]
loss: 1.148059  [51200/60000]


  2%|▏         | 1/50 [00:05<04:29,  5.50s/it]

Test Error: 
 Accuracy: 66.4%, Avg loss: 1.120983 

Epoch 2 
----------------------------------------
loss: 1.082904  [    0/60000]
loss: 1.031069  [12800/60000]
loss: 1.039182  [25600/60000]
loss: 1.043858  [38400/60000]
loss: 1.071384  [51200/60000]


  4%|▍         | 2/50 [00:11<04:27,  5.57s/it]

Test Error: 
 Accuracy: 66.9%, Avg loss: 1.049277 

Epoch 3 
----------------------------------------
loss: 1.005514  [    0/60000]
loss: 0.955862  [12800/60000]
loss: 0.971036  [25600/60000]
loss: 0.982234  [38400/60000]
loss: 1.011732  [51200/60000]


  6%|▌         | 3/50 [00:16<04:24,  5.63s/it]

Test Error: 
 Accuracy: 67.6%, Avg loss: 0.993711 

Epoch 4 
----------------------------------------
loss: 0.944298  [    0/60000]
loss: 0.896305  [12800/60000]
loss: 0.917664  [25600/60000]
loss: 0.933834  [38400/60000]
loss: 0.964290  [51200/60000]


  8%|▊         | 4/50 [00:22<04:18,  5.63s/it]

Test Error: 
 Accuracy: 68.1%, Avg loss: 0.949466 

Epoch 5 
----------------------------------------
loss: 0.894641  [    0/60000]
loss: 0.847925  [12800/60000]
loss: 0.874737  [25600/60000]
loss: 0.894911  [38400/60000]
loss: 0.925666  [51200/60000]


 10%|█         | 5/50 [00:28<04:14,  5.65s/it]

Test Error: 
 Accuracy: 68.8%, Avg loss: 0.913336 

Epoch 6 
----------------------------------------
loss: 0.853299  [    0/60000]
loss: 0.807723  [12800/60000]
loss: 0.839261  [25600/60000]
loss: 0.862766  [38400/60000]
loss: 0.893596  [51200/60000]


 12%|█▏        | 6/50 [00:33<04:08,  5.64s/it]

Test Error: 
 Accuracy: 69.4%, Avg loss: 0.883148 

Epoch 7 
----------------------------------------
loss: 0.818138  [    0/60000]
loss: 0.773692  [12800/60000]
loss: 0.809151  [25600/60000]
loss: 0.835665  [38400/60000]
loss: 0.866394  [51200/60000]


 14%|█▍        | 7/50 [00:39<04:03,  5.65s/it]

Test Error: 
 Accuracy: 70.2%, Avg loss: 0.857447 

Epoch 8 
----------------------------------------
loss: 0.787772  [    0/60000]
loss: 0.744402  [12800/60000]
loss: 0.783132  [25600/60000]
loss: 0.812409  [38400/60000]
loss: 0.843048  [51200/60000]


 16%|█▌        | 8/50 [00:45<03:58,  5.68s/it]

Test Error: 
 Accuracy: 70.9%, Avg loss: 0.835181 

Epoch 9 
----------------------------------------
loss: 0.761172  [    0/60000]
loss: 0.718858  [12800/60000]
loss: 0.760294  [25600/60000]
loss: 0.792115  [38400/60000]
loss: 0.822740  [51200/60000]


 18%|█▊        | 9/50 [00:50<03:51,  5.66s/it]

Test Error: 
 Accuracy: 71.7%, Avg loss: 0.815622 

Epoch 10 
----------------------------------------
loss: 0.737568  [    0/60000]
loss: 0.696300  [12800/60000]
loss: 0.740028  [25600/60000]
loss: 0.774147  [38400/60000]
loss: 0.804842  [51200/60000]


 20%|██        | 10/50 [00:56<03:46,  5.67s/it]

Test Error: 
 Accuracy: 72.2%, Avg loss: 0.798209 

Epoch 11 
----------------------------------------
loss: 0.716447  [    0/60000]
loss: 0.676226  [12800/60000]
loss: 0.721771  [25600/60000]
loss: 0.758073  [38400/60000]
loss: 0.788915  [51200/60000]


 22%|██▏       | 11/50 [01:02<03:41,  5.69s/it]

Test Error: 
 Accuracy: 72.8%, Avg loss: 0.782536 

Epoch 12 
----------------------------------------
loss: 0.697349  [    0/60000]
loss: 0.658190  [12800/60000]
loss: 0.705241  [25600/60000]
loss: 0.743488  [38400/60000]
loss: 0.774619  [51200/60000]


 24%|██▍       | 12/50 [01:07<03:36,  5.69s/it]

Test Error: 
 Accuracy: 73.3%, Avg loss: 0.768308 

Epoch 13 
----------------------------------------
loss: 0.679928  [    0/60000]
loss: 0.641863  [12800/60000]
loss: 0.690167  [25600/60000]
loss: 0.730231  [38400/60000]
loss: 0.761706  [51200/60000]


 26%|██▌       | 13/50 [01:13<03:29,  5.67s/it]

Test Error: 
 Accuracy: 73.8%, Avg loss: 0.755282 

Epoch 14 
----------------------------------------
loss: 0.663983  [    0/60000]
loss: 0.626984  [12800/60000]
loss: 0.676314  [25600/60000]
loss: 0.718077  [38400/60000]
loss: 0.749913  [51200/60000]


 28%|██▊       | 14/50 [01:19<03:23,  5.65s/it]

Test Error: 
 Accuracy: 74.2%, Avg loss: 0.743289 

Epoch 15 
----------------------------------------
loss: 0.649255  [    0/60000]
loss: 0.613395  [12800/60000]
loss: 0.663519  [25600/60000]
loss: 0.706847  [38400/60000]
loss: 0.739072  [51200/60000]


 30%|███       | 15/50 [01:24<03:18,  5.67s/it]

Test Error: 
 Accuracy: 74.6%, Avg loss: 0.732177 

Epoch 16 
----------------------------------------
loss: 0.635635  [    0/60000]
loss: 0.600890  [12800/60000]
loss: 0.651660  [25600/60000]
loss: 0.696406  [38400/60000]
loss: 0.729168  [51200/60000]


 32%|███▏      | 16/50 [01:30<03:13,  5.68s/it]

Test Error: 
 Accuracy: 74.8%, Avg loss: 0.721840 

Epoch 17 
----------------------------------------
loss: 0.622919  [    0/60000]
loss: 0.589410  [12800/60000]
loss: 0.640616  [25600/60000]
loss: 0.686674  [38400/60000]
loss: 0.720017  [51200/60000]


 34%|███▍      | 17/50 [01:36<03:07,  5.69s/it]

Test Error: 
 Accuracy: 75.2%, Avg loss: 0.712177 

Epoch 18 
----------------------------------------
loss: 0.611100  [    0/60000]
loss: 0.578784  [12800/60000]
loss: 0.630308  [25600/60000]
loss: 0.677582  [38400/60000]
loss: 0.711479  [51200/60000]


 36%|███▌      | 18/50 [01:41<03:01,  5.68s/it]

Test Error: 
 Accuracy: 75.6%, Avg loss: 0.703130 

Epoch 19 
----------------------------------------
loss: 0.600069  [    0/60000]
loss: 0.568924  [12800/60000]
loss: 0.620718  [25600/60000]
loss: 0.669046  [38400/60000]
loss: 0.703580  [51200/60000]


 38%|███▊      | 19/50 [01:47<02:57,  5.71s/it]

Test Error: 
 Accuracy: 75.9%, Avg loss: 0.694627 

Epoch 20 
----------------------------------------
loss: 0.589698  [    0/60000]
loss: 0.559772  [12800/60000]
loss: 0.611688  [25600/60000]
loss: 0.661002  [38400/60000]
loss: 0.696264  [51200/60000]


 40%|████      | 20/50 [01:53<02:50,  5.69s/it]

Test Error: 
 Accuracy: 76.1%, Avg loss: 0.686618 

Epoch 21 
----------------------------------------
loss: 0.579993  [    0/60000]
loss: 0.551267  [12800/60000]
loss: 0.603212  [25600/60000]
loss: 0.653450  [38400/60000]
loss: 0.689453  [51200/60000]


 42%|████▏     | 21/50 [01:59<02:45,  5.71s/it]

Test Error: 
 Accuracy: 76.4%, Avg loss: 0.679056 

Epoch 22 
----------------------------------------
loss: 0.570802  [    0/60000]
loss: 0.543300  [12800/60000]
loss: 0.595279  [25600/60000]
loss: 0.646323  [38400/60000]
loss: 0.683043  [51200/60000]


 44%|████▍     | 22/50 [02:04<02:39,  5.71s/it]

Test Error: 
 Accuracy: 76.7%, Avg loss: 0.671905 

Epoch 23 
----------------------------------------
loss: 0.562182  [    0/60000]
loss: 0.535847  [12800/60000]
loss: 0.587701  [25600/60000]
loss: 0.639607  [38400/60000]
loss: 0.677092  [51200/60000]


 46%|████▌     | 23/50 [02:10<02:33,  5.69s/it]

Test Error: 
 Accuracy: 76.8%, Avg loss: 0.665135 

Epoch 24 
----------------------------------------
loss: 0.554029  [    0/60000]
loss: 0.528898  [12800/60000]
loss: 0.580640  [25600/60000]
loss: 0.633268  [38400/60000]
loss: 0.671518  [51200/60000]


 48%|████▊     | 24/50 [02:16<02:28,  5.72s/it]

Test Error: 
 Accuracy: 77.1%, Avg loss: 0.658717 

Epoch 25 
----------------------------------------
loss: 0.546314  [    0/60000]
loss: 0.522359  [12800/60000]
loss: 0.573910  [25600/60000]
loss: 0.627254  [38400/60000]
loss: 0.666267  [51200/60000]


 50%|█████     | 25/50 [02:22<02:23,  5.73s/it]

Test Error: 
 Accuracy: 77.3%, Avg loss: 0.652620 

Epoch 26 
----------------------------------------
loss: 0.539024  [    0/60000]
loss: 0.516205  [12800/60000]
loss: 0.567556  [25600/60000]
loss: 0.621489  [38400/60000]
loss: 0.661412  [51200/60000]


 52%|█████▏    | 26/50 [02:28<02:22,  5.93s/it]

Test Error: 
 Accuracy: 77.6%, Avg loss: 0.646817 

Epoch 27 
----------------------------------------
loss: 0.532057  [    0/60000]
loss: 0.510402  [12800/60000]
loss: 0.561530  [25600/60000]
loss: 0.616086  [38400/60000]
loss: 0.656795  [51200/60000]


 54%|█████▍    | 27/50 [02:34<02:18,  6.01s/it]

Test Error: 
 Accuracy: 77.8%, Avg loss: 0.641298 

Epoch 28 
----------------------------------------
loss: 0.525452  [    0/60000]
loss: 0.504925  [12800/60000]
loss: 0.555814  [25600/60000]
loss: 0.610937  [38400/60000]
loss: 0.652457  [51200/60000]


 56%|█████▌    | 28/50 [02:40<02:12,  6.03s/it]

Test Error: 
 Accuracy: 78.0%, Avg loss: 0.636039 

Epoch 29 
----------------------------------------
loss: 0.519121  [    0/60000]
loss: 0.499751  [12800/60000]
loss: 0.550360  [25600/60000]
loss: 0.606018  [38400/60000]
loss: 0.648390  [51200/60000]


 58%|█████▊    | 29/50 [02:46<02:06,  6.03s/it]

Test Error: 
 Accuracy: 78.2%, Avg loss: 0.631024 

Epoch 30 
----------------------------------------
loss: 0.513107  [    0/60000]
loss: 0.494838  [12800/60000]
loss: 0.545163  [25600/60000]
loss: 0.601389  [38400/60000]
loss: 0.644576  [51200/60000]


 60%|██████    | 30/50 [02:52<02:00,  6.04s/it]

Test Error: 
 Accuracy: 78.4%, Avg loss: 0.626232 

Epoch 31 
----------------------------------------
loss: 0.507317  [    0/60000]
loss: 0.490173  [12800/60000]
loss: 0.540207  [25600/60000]
loss: 0.596934  [38400/60000]
loss: 0.640963  [51200/60000]


 62%|██████▏   | 31/50 [02:58<01:54,  6.04s/it]

Test Error: 
 Accuracy: 78.6%, Avg loss: 0.621661 

Epoch 32 
----------------------------------------
loss: 0.501824  [    0/60000]
loss: 0.485737  [12800/60000]
loss: 0.535462  [25600/60000]
loss: 0.592678  [38400/60000]
loss: 0.637545  [51200/60000]


 64%|██████▍   | 32/50 [03:04<01:46,  5.91s/it]

Test Error: 
 Accuracy: 78.6%, Avg loss: 0.617282 

Epoch 33 
----------------------------------------
loss: 0.496500  [    0/60000]
loss: 0.481538  [12800/60000]
loss: 0.530895  [25600/60000]
loss: 0.588574  [38400/60000]
loss: 0.634326  [51200/60000]


 66%|██████▌   | 33/50 [03:10<01:40,  5.94s/it]

Test Error: 
 Accuracy: 78.8%, Avg loss: 0.613092 

Epoch 34 
----------------------------------------
loss: 0.491438  [    0/60000]
loss: 0.477537  [12800/60000]
loss: 0.526575  [25600/60000]
loss: 0.584738  [38400/60000]
loss: 0.631276  [51200/60000]


 68%|██████▊   | 34/50 [03:16<01:33,  5.86s/it]

Test Error: 
 Accuracy: 79.0%, Avg loss: 0.609074 

Epoch 35 
----------------------------------------
loss: 0.486567  [    0/60000]
loss: 0.473718  [12800/60000]
loss: 0.522362  [25600/60000]
loss: 0.580999  [38400/60000]
loss: 0.628386  [51200/60000]


 70%|███████   | 35/50 [03:22<01:28,  5.93s/it]

Test Error: 
 Accuracy: 79.1%, Avg loss: 0.605219 

Epoch 36 
----------------------------------------
loss: 0.481842  [    0/60000]
loss: 0.470079  [12800/60000]
loss: 0.518361  [25600/60000]
loss: 0.577449  [38400/60000]
loss: 0.625600  [51200/60000]


 72%|███████▏  | 36/50 [03:28<01:23,  5.98s/it]

Test Error: 
 Accuracy: 79.1%, Avg loss: 0.601521 

Epoch 37 
----------------------------------------
loss: 0.477341  [    0/60000]
loss: 0.466558  [12800/60000]
loss: 0.514487  [25600/60000]
loss: 0.574008  [38400/60000]
loss: 0.622958  [51200/60000]


 74%|███████▍  | 37/50 [03:34<01:19,  6.10s/it]

Test Error: 
 Accuracy: 79.3%, Avg loss: 0.597971 

Epoch 38 
----------------------------------------
loss: 0.472958  [    0/60000]
loss: 0.463225  [12800/60000]
loss: 0.510799  [25600/60000]
loss: 0.570748  [38400/60000]
loss: 0.620463  [51200/60000]


 76%|███████▌  | 38/50 [03:40<01:12,  6.07s/it]

Test Error: 
 Accuracy: 79.4%, Avg loss: 0.594555 

Epoch 39 
----------------------------------------
loss: 0.468733  [    0/60000]
loss: 0.460023  [12800/60000]
loss: 0.507178  [25600/60000]
loss: 0.567614  [38400/60000]
loss: 0.618070  [51200/60000]


 78%|███████▊  | 39/50 [03:46<01:06,  6.08s/it]

Test Error: 
 Accuracy: 79.6%, Avg loss: 0.591265 

Epoch 40 
----------------------------------------
loss: 0.464691  [    0/60000]
loss: 0.456931  [12800/60000]
loss: 0.503670  [25600/60000]
loss: 0.564625  [38400/60000]
loss: 0.615690  [51200/60000]


 80%|████████  | 40/50 [03:52<01:00,  6.10s/it]

Test Error: 
 Accuracy: 79.8%, Avg loss: 0.588104 

Epoch 41 
----------------------------------------
loss: 0.460751  [    0/60000]
loss: 0.453982  [12800/60000]
loss: 0.500319  [25600/60000]
loss: 0.561700  [38400/60000]
loss: 0.613521  [51200/60000]


 82%|████████▏ | 41/50 [03:59<00:55,  6.12s/it]

Test Error: 
 Accuracy: 79.9%, Avg loss: 0.585054 

Epoch 42 
----------------------------------------
loss: 0.456958  [    0/60000]
loss: 0.451126  [12800/60000]
loss: 0.497053  [25600/60000]
loss: 0.558915  [38400/60000]
loss: 0.611409  [51200/60000]


 84%|████████▍ | 42/50 [04:05<00:48,  6.11s/it]

Test Error: 
 Accuracy: 79.9%, Avg loss: 0.582116 

Epoch 43 
----------------------------------------
loss: 0.453265  [    0/60000]
loss: 0.448393  [12800/60000]
loss: 0.493936  [25600/60000]
loss: 0.556257  [38400/60000]
loss: 0.609321  [51200/60000]


 86%|████████▌ | 43/50 [04:11<00:42,  6.05s/it]

Test Error: 
 Accuracy: 80.0%, Avg loss: 0.579273 

Epoch 44 
----------------------------------------
loss: 0.449711  [    0/60000]
loss: 0.445751  [12800/60000]
loss: 0.490850  [25600/60000]
loss: 0.553628  [38400/60000]
loss: 0.607338  [51200/60000]


 88%|████████▊ | 44/50 [04:16<00:35,  5.94s/it]

Test Error: 
 Accuracy: 80.1%, Avg loss: 0.576538 

Epoch 45 
----------------------------------------
loss: 0.446259  [    0/60000]
loss: 0.443219  [12800/60000]
loss: 0.487910  [25600/60000]
loss: 0.551146  [38400/60000]
loss: 0.605399  [51200/60000]


 90%|█████████ | 45/50 [04:22<00:29,  5.81s/it]

Test Error: 
 Accuracy: 80.2%, Avg loss: 0.573892 

Epoch 46 
----------------------------------------
loss: 0.442895  [    0/60000]
loss: 0.440761  [12800/60000]
loss: 0.485076  [25600/60000]
loss: 0.548724  [38400/60000]
loss: 0.603525  [51200/60000]


 92%|█████████▏| 46/50 [04:27<00:23,  5.77s/it]

Test Error: 
 Accuracy: 80.3%, Avg loss: 0.571332 

Epoch 47 
----------------------------------------
loss: 0.439651  [    0/60000]
loss: 0.438403  [12800/60000]
loss: 0.482283  [25600/60000]
loss: 0.546348  [38400/60000]
loss: 0.601686  [51200/60000]


 94%|█████████▍| 47/50 [04:33<00:17,  5.77s/it]

Test Error: 
 Accuracy: 80.4%, Avg loss: 0.568856 

Epoch 48 
----------------------------------------
loss: 0.436491  [    0/60000]
loss: 0.436122  [12800/60000]
loss: 0.479581  [25600/60000]
loss: 0.544110  [38400/60000]
loss: 0.600008  [51200/60000]


 96%|█████████▌| 48/50 [04:39<00:11,  5.75s/it]

Test Error: 
 Accuracy: 80.5%, Avg loss: 0.566456 

Epoch 49 
----------------------------------------
loss: 0.433421  [    0/60000]
loss: 0.433886  [12800/60000]
loss: 0.477018  [25600/60000]
loss: 0.541974  [38400/60000]
loss: 0.598304  [51200/60000]


 98%|█████████▊| 49/50 [04:45<00:05,  5.70s/it]

Test Error: 
 Accuracy: 80.6%, Avg loss: 0.564135 

Epoch 50 
----------------------------------------
loss: 0.430422  [    0/60000]
loss: 0.431723  [12800/60000]
loss: 0.474453  [25600/60000]
loss: 0.539845  [38400/60000]
loss: 0.596634  [51200/60000]


100%|██████████| 50/50 [04:50<00:00,  5.81s/it]

Test Error: 
 Accuracy: 80.6%, Avg loss: 0.561885 

Done 




