In [170]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

In [171]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [172]:
train_dataset = torchvision.datasets.FashionMNIST(root="data",
                                             train=True, 
                                             download=True,
                                             transform=transforms.ToTensor())
test_dataset = torchvision.datasets.FashionMNIST(root="data",
                                             train=False, 
                                             download=True,
                                             transform=transforms.ToTensor())

In [173]:
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=128, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=128, 
                                           shuffle=False)

In [174]:
image, label = train_dataset[0]
print (image.size())
print (label)

torch.Size([1, 28, 28])
9


In [175]:
train_features, train_labels = next(iter(train_loader))
print("Train feature shape:", train_features.shape)
print("Train label shape:", train_labels.shape)

test_features, test_labels = next(iter(test_loader))
print("Test feature shape:", test_features.shape)
print("Test label shape:", test_labels.shape)

Train feature shape: torch.Size([128, 1, 28, 28])
Train label shape: torch.Size([128])
Test feature shape: torch.Size([128, 1, 28, 28])
Test label shape: torch.Size([128])


In [176]:
import torch.nn as nn
import torch.nn.functional as F

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [177]:
import torch.optim as optim
model = NN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [178]:
from tqdm import tqdm

num_epochs = 100

for epoch in range(num_epochs):
    #print(f"Epoch [{epoch + 1}/{num_epochs}]")
    for batch_index, (data, targets) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        targets = targets.to(device)
        
        data = data.reshape(data.shape[0], -1)

        predict = model(data)
        loss = criterion(predict, targets)

        # Backward pass: compute the gradients
        optimizer.zero_grad()
        loss.backward()

        # Optimization step: update the model parameters
        optimizer.step()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

100%|██████████| 469/469 [00:06<00:00, 67.72it/s]


Epoch [1/100], Loss: 1.1106


100%|██████████| 469/469 [00:06<00:00, 74.69it/s]


Epoch [2/100], Loss: 0.9121


100%|██████████| 469/469 [00:06<00:00, 72.06it/s]


Epoch [3/100], Loss: 0.6253


100%|██████████| 469/469 [00:06<00:00, 71.57it/s]


Epoch [4/100], Loss: 0.7331


100%|██████████| 469/469 [00:06<00:00, 73.10it/s]


Epoch [5/100], Loss: 0.6223


100%|██████████| 469/469 [00:06<00:00, 69.34it/s]


Epoch [6/100], Loss: 0.5581


100%|██████████| 469/469 [00:06<00:00, 70.24it/s]


Epoch [7/100], Loss: 0.6815


100%|██████████| 469/469 [00:07<00:00, 63.81it/s]


Epoch [8/100], Loss: 0.4772


100%|██████████| 469/469 [00:06<00:00, 70.88it/s]


Epoch [9/100], Loss: 0.4754


100%|██████████| 469/469 [00:07<00:00, 66.81it/s]


Epoch [10/100], Loss: 0.4230


100%|██████████| 469/469 [00:06<00:00, 72.18it/s]


Epoch [11/100], Loss: 0.5871


100%|██████████| 469/469 [00:06<00:00, 74.52it/s]


Epoch [12/100], Loss: 0.3336


100%|██████████| 469/469 [00:06<00:00, 74.41it/s]


Epoch [13/100], Loss: 0.5146


100%|██████████| 469/469 [00:06<00:00, 75.38it/s]


Epoch [14/100], Loss: 0.4480


100%|██████████| 469/469 [00:06<00:00, 74.60it/s]


Epoch [15/100], Loss: 0.3419


100%|██████████| 469/469 [00:06<00:00, 73.31it/s]


Epoch [16/100], Loss: 0.4499


100%|██████████| 469/469 [00:06<00:00, 71.01it/s]


Epoch [17/100], Loss: 0.5553


100%|██████████| 469/469 [00:06<00:00, 70.47it/s]


Epoch [18/100], Loss: 0.4092


100%|██████████| 469/469 [00:06<00:00, 72.42it/s]


Epoch [19/100], Loss: 0.2917


100%|██████████| 469/469 [00:06<00:00, 71.00it/s]


Epoch [20/100], Loss: 0.3992


100%|██████████| 469/469 [00:06<00:00, 70.54it/s]


Epoch [21/100], Loss: 0.4102


100%|██████████| 469/469 [00:06<00:00, 75.13it/s]


Epoch [22/100], Loss: 0.4740


100%|██████████| 469/469 [00:06<00:00, 75.02it/s]


Epoch [23/100], Loss: 0.5495


100%|██████████| 469/469 [00:06<00:00, 74.33it/s]


Epoch [24/100], Loss: 0.4136


100%|██████████| 469/469 [00:06<00:00, 75.61it/s]


Epoch [25/100], Loss: 0.3782


100%|██████████| 469/469 [00:06<00:00, 73.59it/s]


Epoch [26/100], Loss: 0.4227


100%|██████████| 469/469 [00:06<00:00, 71.26it/s]


Epoch [27/100], Loss: 0.3734


100%|██████████| 469/469 [00:06<00:00, 76.14it/s]


Epoch [28/100], Loss: 0.2730


100%|██████████| 469/469 [00:06<00:00, 75.21it/s]


Epoch [29/100], Loss: 0.5531


100%|██████████| 469/469 [00:06<00:00, 75.21it/s]


Epoch [30/100], Loss: 0.4564


100%|██████████| 469/469 [00:06<00:00, 75.38it/s]


Epoch [31/100], Loss: 0.5229


100%|██████████| 469/469 [00:06<00:00, 75.42it/s]


Epoch [32/100], Loss: 0.3934


100%|██████████| 469/469 [00:06<00:00, 75.81it/s]


Epoch [33/100], Loss: 0.2477


100%|██████████| 469/469 [00:06<00:00, 75.29it/s]


Epoch [34/100], Loss: 0.3263


100%|██████████| 469/469 [00:06<00:00, 74.52it/s]


Epoch [35/100], Loss: 0.3755


100%|██████████| 469/469 [00:06<00:00, 75.14it/s]


Epoch [36/100], Loss: 0.2774


100%|██████████| 469/469 [00:06<00:00, 75.22it/s]


Epoch [37/100], Loss: 0.3740


100%|██████████| 469/469 [00:06<00:00, 75.02it/s]


Epoch [38/100], Loss: 0.2326


100%|██████████| 469/469 [00:06<00:00, 75.07it/s]


Epoch [39/100], Loss: 0.4290


100%|██████████| 469/469 [00:06<00:00, 75.88it/s]


Epoch [40/100], Loss: 0.2480


100%|██████████| 469/469 [00:06<00:00, 75.10it/s]


Epoch [41/100], Loss: 0.2992


100%|██████████| 469/469 [00:06<00:00, 75.74it/s]


Epoch [42/100], Loss: 0.2547


100%|██████████| 469/469 [00:06<00:00, 74.97it/s]


Epoch [43/100], Loss: 0.4623


100%|██████████| 469/469 [00:06<00:00, 73.25it/s]


Epoch [44/100], Loss: 0.2781


100%|██████████| 469/469 [00:06<00:00, 73.43it/s]


Epoch [45/100], Loss: 0.4407


100%|██████████| 469/469 [00:06<00:00, 73.38it/s]


Epoch [46/100], Loss: 0.3375


100%|██████████| 469/469 [00:06<00:00, 73.70it/s]


Epoch [47/100], Loss: 0.3234


100%|██████████| 469/469 [00:06<00:00, 72.94it/s]


Epoch [48/100], Loss: 0.3390


100%|██████████| 469/469 [00:06<00:00, 73.50it/s]


Epoch [49/100], Loss: 0.2595


100%|██████████| 469/469 [00:06<00:00, 72.68it/s]


Epoch [50/100], Loss: 0.4789


100%|██████████| 469/469 [00:06<00:00, 72.98it/s]


Epoch [51/100], Loss: 0.3399


100%|██████████| 469/469 [00:06<00:00, 73.32it/s]


Epoch [52/100], Loss: 0.3644


100%|██████████| 469/469 [00:06<00:00, 69.79it/s]


Epoch [53/100], Loss: 0.2899


100%|██████████| 469/469 [00:08<00:00, 57.92it/s]


Epoch [54/100], Loss: 0.3328


100%|██████████| 469/469 [00:06<00:00, 74.65it/s]


Epoch [55/100], Loss: 0.4847


100%|██████████| 469/469 [00:06<00:00, 72.63it/s]


Epoch [56/100], Loss: 0.3502


100%|██████████| 469/469 [00:06<00:00, 71.29it/s]


Epoch [57/100], Loss: 0.3167


100%|██████████| 469/469 [00:06<00:00, 71.09it/s]


Epoch [58/100], Loss: 0.2872


100%|██████████| 469/469 [00:06<00:00, 70.96it/s]


Epoch [59/100], Loss: 0.3214


100%|██████████| 469/469 [00:06<00:00, 67.93it/s]


Epoch [60/100], Loss: 0.2619


100%|██████████| 469/469 [00:06<00:00, 68.79it/s]


Epoch [61/100], Loss: 0.4041


100%|██████████| 469/469 [00:06<00:00, 70.64it/s]


Epoch [62/100], Loss: 0.2682


100%|██████████| 469/469 [00:06<00:00, 71.69it/s]


Epoch [63/100], Loss: 0.3606


100%|██████████| 469/469 [00:06<00:00, 70.70it/s]


Epoch [64/100], Loss: 0.2117


100%|██████████| 469/469 [00:06<00:00, 69.25it/s]


Epoch [65/100], Loss: 0.2863


100%|██████████| 469/469 [00:06<00:00, 69.44it/s]


Epoch [66/100], Loss: 0.2658


100%|██████████| 469/469 [00:07<00:00, 64.80it/s]


Epoch [67/100], Loss: 0.4956


100%|██████████| 469/469 [00:06<00:00, 69.16it/s]


Epoch [68/100], Loss: 0.5226


100%|██████████| 469/469 [00:07<00:00, 65.66it/s]


Epoch [69/100], Loss: 0.3266


100%|██████████| 469/469 [00:06<00:00, 67.55it/s]


Epoch [70/100], Loss: 0.2243


100%|██████████| 469/469 [00:06<00:00, 70.37it/s]


Epoch [71/100], Loss: 0.3285


100%|██████████| 469/469 [00:06<00:00, 70.83it/s]


Epoch [72/100], Loss: 0.4276


100%|██████████| 469/469 [00:06<00:00, 70.20it/s]


Epoch [73/100], Loss: 0.3637


100%|██████████| 469/469 [00:06<00:00, 70.26it/s]


Epoch [74/100], Loss: 0.4444


100%|██████████| 469/469 [00:06<00:00, 69.78it/s]


Epoch [75/100], Loss: 0.2410


100%|██████████| 469/469 [00:06<00:00, 70.43it/s]


Epoch [76/100], Loss: 0.1861


100%|██████████| 469/469 [00:06<00:00, 70.11it/s]


Epoch [77/100], Loss: 0.2963


100%|██████████| 469/469 [00:06<00:00, 70.08it/s]


Epoch [78/100], Loss: 0.3509


100%|██████████| 469/469 [00:08<00:00, 57.50it/s]


Epoch [79/100], Loss: 0.2352


100%|██████████| 469/469 [00:06<00:00, 70.50it/s]


Epoch [80/100], Loss: 0.2086


100%|██████████| 469/469 [00:06<00:00, 72.73it/s]


Epoch [81/100], Loss: 0.1737


100%|██████████| 469/469 [00:06<00:00, 70.72it/s]


Epoch [82/100], Loss: 0.3645


100%|██████████| 469/469 [00:06<00:00, 71.44it/s]


Epoch [83/100], Loss: 0.4715


100%|██████████| 469/469 [00:06<00:00, 73.06it/s]


Epoch [84/100], Loss: 0.1701


100%|██████████| 469/469 [00:06<00:00, 69.88it/s]


Epoch [85/100], Loss: 0.3392


100%|██████████| 469/469 [00:06<00:00, 70.22it/s]


Epoch [86/100], Loss: 0.3751


100%|██████████| 469/469 [00:06<00:00, 72.17it/s]


Epoch [87/100], Loss: 0.4183


100%|██████████| 469/469 [00:06<00:00, 71.46it/s]


Epoch [88/100], Loss: 0.2672


100%|██████████| 469/469 [00:06<00:00, 69.53it/s]


Epoch [89/100], Loss: 0.2993


100%|██████████| 469/469 [00:06<00:00, 74.69it/s]


Epoch [90/100], Loss: 0.3587


100%|██████████| 469/469 [00:06<00:00, 70.78it/s]


Epoch [91/100], Loss: 0.2289


100%|██████████| 469/469 [00:06<00:00, 70.94it/s]


Epoch [92/100], Loss: 0.4317


100%|██████████| 469/469 [00:06<00:00, 73.15it/s]


Epoch [93/100], Loss: 0.2990


100%|██████████| 469/469 [00:06<00:00, 69.12it/s]


Epoch [94/100], Loss: 0.1579


100%|██████████| 469/469 [00:06<00:00, 73.33it/s]


Epoch [95/100], Loss: 0.3807


100%|██████████| 469/469 [00:06<00:00, 70.25it/s]


Epoch [96/100], Loss: 0.2919


100%|██████████| 469/469 [00:06<00:00, 70.31it/s]


Epoch [97/100], Loss: 0.2087


100%|██████████| 469/469 [00:06<00:00, 71.48it/s]


Epoch [98/100], Loss: 0.2722


100%|██████████| 469/469 [00:06<00:00, 71.17it/s]


Epoch [99/100], Loss: 0.2869


100%|██████████| 469/469 [00:06<00:00, 69.77it/s]

Epoch [100/100], Loss: 0.3382





In [182]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on test set: {100 * correct / total}%')


Accuracy on test set: 89.605%
