In [1]:
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms

fashion_mnist_train = FashionMNIST("./data/FashionMNIST", train=True,
                                   download=True, transform=transforms.ToTensor())
fashion_mnist_test = FashionMNIST("./data/FashionMNIST", train=False,
                                   download=True, transform=transforms.ToTensor())

batch_size = 128
train_loader = DataLoader(fashion_mnist_train, batch_size=batch_size,
                          shuffle=True, num_workers=14)
test_loader = DataLoader(fashion_mnist_test, batch_size=batch_size,
                         shuffle=True, num_workers=14)

In [3]:
class FlattenLayer(nn.Module):
    def forward(self, x):
        return torch.flatten(x, start_dim=1)

conv_net = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(32),
    nn.Dropout2d(0.25),
    nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5),
    nn.MaxPool2d(2),
    nn.ReLU(),
    nn.BatchNorm2d(64),
    nn.Dropout2d(0.25),
    FlattenLayer()
)

test_input = torch.ones(1, 1, 28, 28)
conv_output_size = conv_net(test_input).size()[-1]

mlp = nn.Sequential(
    nn.Linear(conv_output_size, 200),
    nn.ReLU(),
    nn.BatchNorm1d(200),
    nn.Dropout(0.25),
    nn.Linear(200, 10)
)

net = nn.Sequential(
    conv_net,
    mlp
)

In [6]:
def eval_net(net, data_loader, device="cpu"):
    net.eval()
    ys = []
    ypreds = []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        with torch.no_grad():
            _, y_pred = net(x).max(1)
        ys.append(y)
        ypreds.append(y_pred)
    ys = torch.cat(ys)
    ypreds = torch.cat(ypreds)
    acc = (ys == ypreds).float().sum() / len(ys)
    return acc.item()

train_losses = []
train_acc = []
val_acc = []

def train_net(net, train_loader, test_loader, optimizer_cls=optim.Adam,
              loss_fn=nn.CrossEntropyLoss(), n_iter=10, device="cpu", writer=None):
    optimizer = optimizer_cls(net.parameters())
    for epoch in range(n_iter):
        running_loss = 0.0
        net.train()
        n = 0
        n_acc = 0
        for i, (xx, yy) in tqdm.tqdm(enumerate(train_loader),
                                     total = len(train_loader)):
            xx = xx.to(device)
            yy = yy.to(device)
            h = net(xx)
            loss = loss_fn(h, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            n += len(xx)
            y_pred = h.argmax(1)
            n_acc += (y_pred == yy).float().sum().item()
        train_losses.append(running_loss / i)
        train_acc.append(n_acc / n)
        val_acc.append(eval_net(net, test_loader, device))
        print("epoch: {}\ttrain_loss: {:.3f}\ttrain_acc: {:.3f}\tval_acc: {:.3f}".format(
             epoch, train_losses[-1], train_acc[-1], val_acc[-1]), flush=True)
        if writer is not None:
            writer.add_scalar("train_loss", train_losses[-1], epoch)
            writer.add_scalars("accuracy", {"train":train_acc[-1],
                                            "validation":val_acc[-1]}, epoch)

In [7]:
from tensorboardX import SummaryWriter

writer = SummaryWriter("./tmp/cnn")

net.to("cuda:0")
train_net(net, train_loader, test_loader, n_iter=20,
          device="cuda:0", writer=writer)

100%|██████████| 469/469 [00:01<00:00, 265.74it/s]


epoch: 0	train_loss: 0.321	train_acc: 0.883	val_acc: 0.893


100%|██████████| 469/469 [00:01<00:00, 266.28it/s]


epoch: 1	train_loss: 0.281	train_acc: 0.896	val_acc: 0.900


100%|██████████| 469/469 [00:01<00:00, 262.95it/s]


epoch: 2	train_loss: 0.260	train_acc: 0.904	val_acc: 0.896


100%|██████████| 469/469 [00:01<00:00, 260.27it/s]


epoch: 3	train_loss: 0.241	train_acc: 0.911	val_acc: 0.908


100%|██████████| 469/469 [00:01<00:00, 263.28it/s]


epoch: 4	train_loss: 0.232	train_acc: 0.914	val_acc: 0.906


100%|██████████| 469/469 [00:01<00:00, 273.02it/s]


epoch: 5	train_loss: 0.221	train_acc: 0.918	val_acc: 0.910


100%|██████████| 469/469 [00:01<00:00, 278.34it/s]


epoch: 6	train_loss: 0.211	train_acc: 0.921	val_acc: 0.914


100%|██████████| 469/469 [00:01<00:00, 266.47it/s]


epoch: 7	train_loss: 0.202	train_acc: 0.923	val_acc: 0.916


100%|██████████| 469/469 [00:01<00:00, 271.09it/s]


epoch: 8	train_loss: 0.196	train_acc: 0.927	val_acc: 0.918


100%|██████████| 469/469 [00:01<00:00, 269.89it/s]


epoch: 9	train_loss: 0.188	train_acc: 0.929	val_acc: 0.916


100%|██████████| 469/469 [00:01<00:00, 273.19it/s]


epoch: 10	train_loss: 0.183	train_acc: 0.932	val_acc: 0.917


100%|██████████| 469/469 [00:01<00:00, 235.06it/s]


epoch: 11	train_loss: 0.179	train_acc: 0.933	val_acc: 0.916


100%|██████████| 469/469 [00:01<00:00, 262.79it/s]


epoch: 12	train_loss: 0.173	train_acc: 0.935	val_acc: 0.921


100%|██████████| 469/469 [00:01<00:00, 276.67it/s]


epoch: 13	train_loss: 0.168	train_acc: 0.937	val_acc: 0.919


100%|██████████| 469/469 [00:01<00:00, 273.84it/s]


epoch: 14	train_loss: 0.166	train_acc: 0.937	val_acc: 0.921


100%|██████████| 469/469 [00:01<00:00, 267.47it/s]


epoch: 15	train_loss: 0.161	train_acc: 0.939	val_acc: 0.920


100%|██████████| 469/469 [00:01<00:00, 274.48it/s]


epoch: 16	train_loss: 0.153	train_acc: 0.943	val_acc: 0.919


100%|██████████| 469/469 [00:01<00:00, 268.42it/s]


epoch: 17	train_loss: 0.153	train_acc: 0.942	val_acc: 0.918


100%|██████████| 469/469 [00:01<00:00, 260.79it/s]


epoch: 18	train_loss: 0.151	train_acc: 0.944	val_acc: 0.919


100%|██████████| 469/469 [00:01<00:00, 265.33it/s]


epoch: 19	train_loss: 0.149	train_acc: 0.944	val_acc: 0.921
