In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.optim import SGD
%matplotlib inline
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision import datasets
data_folder = '~/data/FMNIST' # This can be any directory you
# want to download FMNIST to
fmnist = datasets.FashionMNIST(data_folder, download=True, \
train=True)
tr_images = fmnist.data
tr_targets = fmnist.targets

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /root/data/FMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /root/data/FMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /root/data/FMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /root/data/FMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/FMNIST/FashionMNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [2]:
class FMinistDataset(Dataset):
  def __init__(self, x, y):
    x = x.float()
    x = x.view(-1,28*28)
    self.x, self.y = x, y
  
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, ix):
    return self.x[ix].to(device), self.y[ix].to(device)
  


In [3]:
def get_data():
  train = FMinistDataset(tr_images, tr_targets)
  trn_dl = DataLoader(train, batch_size=32, shuffle=True)
  return trn_dl

In [4]:
def get_model():
  model = nn.Sequential(
      nn.Linear(28*28, 1000), 
      nn.ReLU(), 
      nn.Linear(1000, 10)
  )
  loss = nn.CrossEntropyLoss()
  opt = SGD(model.parameters(), lr=0.01)
  return model, loss, opt

In [5]:
def train_batch(x, y, model, opt, loss_fn):
  model.train()
  batch_loss = loss(model(x), y)
  batch_loss.backward()
  opt.step()
  opt.zero_grad()
  return batch_loss.item()


In [6]:
@torch.no_grad()
def accuracy(x, y, model):
  model.eval()
  prediction = model(x)
  max_values, argmaxes = prediction.max(-1)
  is_correct = argmaxes == y 
  return is_correct.cpu().numpy().tolist()


In [7]:
trn_dl = get_data()
model, loss, opt = get_model()
losses, accuracies = [], []
for epoch in range(10):
  print(epoch)
  epoch_losses, epoch_accuracies = [], []
  for ix, batch in enumerate(iter(trn_dl)):
    x, y = batch
    batch_loss = train_batch(x, y, model, opt, loss)
    epoch_losses.append(batch_loss)
  losses.append(np.array(epoch_losses).mean())
  for ix, batch in enumerate(iter(trn_dl)):
    x, y = batch
    is_correct = accuracy(x, y, model)
    epoch_accuracies.extend(is_correct)
  accuracies.append(np.mean(epoch_accuracies))






0


RuntimeError: ignored

In [None]:
epochs = np.arange(10)+1
plt.figure(figsize=(20,5))
plt.subplot(121)
plt.title('Loss value over increasing epochs')
plt.plot(epochs, losses, label='Training Loss')
plt.legend()
plt.subplot(122)
plt.title('Accuracy value over increasing epochs')
plt.plot(epochs, accuracies, label='Training Accuracy')
plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) \
for x in plt.gca().get_yticks()])
plt.legend()

In [None]:
device