In [18]:
import pylab as pl
import torch
import torch.nn as nn
import torch.nn.functional as F


class LadderNet(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(LadderNet, self).__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes

        # Encoder layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        # Decoder layers
        self.deconv4 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(32)
        self.deconv1 = nn.ConvTranspose2d(32, 1, kernel_size=4, padding=1)

        # Classification layer
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x_labeled, x_unlabeled):
        if(x_labeled is not None):

            # Labeled input
            h = F.relu(self.bn1(self.conv1(x_labeled)))
            h = F.max_pool2d(h, 2, 2)
            out1 = h
            h = F.relu(self.bn2(self.conv2(h)))
            h = F.max_pool2d(h, 2, 2)
            out2 = h
            h = F.relu(self.bn3(self.conv3(h)))
            h = F.max_pool2d(h, 2, 2)
            out3 = h
            encoder_output = F.relu(self.bn4(self.conv4(h)))

        else:

            # Unlabeled input
            u = F.relu(self.bn1(self.conv1(x_unlabeled)))
            u = F.max_pool2d(u, 2, 2)
            out1 = u
            u = F.relu(self.bn2(self.conv2(u)))
            u = F.max_pool2d(u, 2, 2)
            out2 = u
            u = F.relu(self.bn3(self.conv3(u)))
            u = F.max_pool2d(u, 2, 2)
            out3 = u
            encoder_output = F.relu(self.bn4(self.conv4(u)))

        encoder_output = F.avg_pool2d(encoder_output, encoder_output.size()[2:])
        decoder_input  = encoder_output
        encoder_output = encoder_output.view(-1, 256)

        # Add noise to the hidden representations
        noise = torch.randn_like(encoder_output)
        distorted_encoder_output = encoder_output + noise

        # Decoder
        u = F.relu(self.bn5(self.deconv4(distorted_encoder_output.view(-1, 256, 1, 1))))
        u = F.interpolate(u, scale_factor=3, mode='nearest')
        dout3 = u
        u = F.relu(self.bn6(self.deconv3(u)))
        u = F.interpolate(u, size=(7, 7), mode='nearest')
        dout2 = u
        u = F.relu(self.bn7(self.deconv2(u)))
        u = F.interpolate(u, size=(14, 14), mode='nearest')
        dout1 = u
        u = F.interpolate(self.deconv1(u), size=(28, 28), mode='nearest')
        x_reconstructed = torch.sigmoid(u)

        # Classification
        output = self.fc(encoder_output) if x_labeled is not None else None

        return output, x_reconstructed, out1, dout1, out2, dout2, out3, dout3

    def labeled_loss(self, output_labeled, target_labeled):
        criterion = nn.CrossEntropyLoss()
        return criterion(output_labeled, target_labeled)

    def reconstruction_loss(self, x_reconstructed, x_unlabeled, out1, dout1, out2, dout2, out3, dout3):
        mse_loss = nn.MSELoss()
        mse = mse_loss(x_reconstructed, x_unlabeled) + 0.25 * (mse_loss(out1, dout1) + mse_loss(out2, dout2) + mse_loss(out3, dout3))

        return mse

import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the MNIST dataset
train_set = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Divide the training set into labeled and unlabeled sets
n_labeled = int(0.1 * len(train_set))
n_unlabeled = len(train_set) - n_labeled
train_labeled_set, train_unlabeled_set = torch.utils.data.random_split(train_set, [n_labeled, n_unlabeled])

# Create data loaders for the labeled and unlabeled sets
batch_size = 128

from torch.utils.data import Dataset


class SemisupervisedDataset(Dataset):
    def __init__(self, labeled_dataset, unlabeled_dataset):
        self.labeled_dataset = labeled_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.unlabeled_size = len(unlabeled_dataset)
        self.labeled_size = len(labeled_dataset)
        print("Dataset created")

    def __len__(self):
        return self.labeled_size

    def __getitem__(self, index):
        labeled_data, labeled_target = self.labeled_dataset[index % self.labeled_size]
        unlabeled_data, _ = self.unlabeled_dataset[index]

        return labeled_data, labeled_target, unlabeled_data



from torch.utils.data import Dataset


class SemisupervisedDataset(Dataset):
    def __init__(self, labeled_dataset, unlabeled_dataset):
        self.labeled_dataset = labeled_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.unlabeled_size = len(unlabeled_dataset)
        self.labeled_size = len(labeled_dataset)
        print("Dataset created")

    def __len__(self):
        return self.labeled_size

    def __getitem__(self, index):
        labeled_data, labeled_target = self.labeled_dataset[index % self.labeled_size]
        unlabeled_data, _ = self.unlabeled_dataset[index]

        return labeled_data, labeled_target, unlabeled_data

dataset = SemisupervisedDataset(train_labeled_set, train_unlabeled_set)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)

# Check if a GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from torch import optim

# Define the model
model = LadderNet(input_shape=(1, 28, 28), num_classes=10)

# Move the model to the device
model.to(device)

# Define the loss functions
labeled_loss_fn = model.labeled_loss
reconstruction_loss_fn = model.reconstruction_loss

num_epochs = 10

# Define the optimizer
optimizer = optim.Adam(model.parameters())


from tqdm import tqdm

# Train the model
for epoch in tqdm(range(num_epochs)):
    model.train()
    for data_labeled, target_labeled, data_unlabeled in tqdm(dataloader):
        # Train the model on the labeled data
        data_labeled = data_labeled.to(device)
        target_labeled = target_labeled.to(device)
        data_unlabeled = data_unlabeled.to(device)

        optimizer.zero_grad()
        output, x_reconstructed, out1, dout1, out2, dout2, out3, dout3 = model(data_labeled, None)
        loss_labeled = labeled_loss_fn(output, target_labeled) + reconstruction_loss_fn(data_labeled, x_reconstructed, out1, dout1, out2, dout2, out3, dout3)
        loss_labeled.backward()
        optimizer.step()

        # Train the model on the unlabeled data
        _, x_reconstructed, out1, dout1, out2, dout2, out3, dout3 = model(None, data_unlabeled)
        x_reconstructed = x_reconstructed.view(-1, 784)
        data_unlabeled = data_unlabeled.view(-1, 784)
        loss_unlabeled = reconstruction_loss_fn(data_unlabeled, x_reconstructed, out1, dout1, out2, dout2, out3, dout3)
        loss_unlabeled.backward()
        optimizer.step()

    # Test the model
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            output, _, _, _, _, _, _, _ = model(data, None)
            test_loss += labeled_loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Epoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(epoch, test_loss, correct,
                                                                                       len(test_loader.dataset),
                                                                                       accuracy))

Dataset created


  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 11.75it/s][A
  9%|▊         | 4/47 [00:00<00:03, 11.15it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 11.26it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 11.52it/s][A
 21%|██▏       | 10/47 [00:00<00:03, 11.89it/s][A
 26%|██▌       | 12/47 [00:01<00:02, 12.18it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 11.92it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.13it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.27it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 11.86it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 11.83it/s][A
 51%|█████     | 24/47 [00:02<00:01, 12.03it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 12.06it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 12.19it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 12.51it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 12.49it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.73it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.91it/s][A
 81%|

Epoch: 0 Test set: Average loss: 0.0050, Accuracy: 7709/10000 (77%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 13.73it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.98it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 13.37it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.64it/s][A
 21%|██▏       | 10/47 [00:00<00:03, 12.29it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.06it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.13it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.31it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 11.78it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 11.98it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 11.92it/s][A
 51%|█████     | 24/47 [00:01<00:01, 11.62it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 11.69it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.92it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 12.15it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 12.25it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.32it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.46it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.47i

Epoch: 1 Test set: Average loss: 0.0048, Accuracy: 7834/10000 (78%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 12.71it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.58it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.92it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.86it/s][A
 21%|██▏       | 10/47 [00:00<00:03, 12.30it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 11.83it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 11.69it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 11.64it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 11.99it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.02it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 12.02it/s][A
 51%|█████     | 24/47 [00:01<00:01, 11.89it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 11.35it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.35it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 11.65it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 11.75it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.41it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.44it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.36i

Epoch: 2 Test set: Average loss: 0.0039, Accuracy: 8180/10000 (82%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 13.33it/s][A
  9%|▊         | 4/47 [00:00<00:03, 13.33it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.87it/s][A
 17%|█▋        | 8/47 [00:00<00:02, 13.05it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.56it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.58it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.56it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.05it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 11.87it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 11.76it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 11.95it/s][A
 51%|█████     | 24/47 [00:01<00:01, 12.05it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 12.05it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.63it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 11.52it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 11.51it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 11.93it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.15it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.49i

Epoch: 3 Test set: Average loss: 0.0036, Accuracy: 8376/10000 (84%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 12.45it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.95it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.74it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.62it/s][A
 21%|██▏       | 10/47 [00:00<00:03, 12.06it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.21it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.30it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.26it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.07it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 11.86it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 11.51it/s][A
 51%|█████     | 24/47 [00:02<00:02, 11.48it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 11.45it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.26it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 11.18it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 11.26it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 11.23it/s][A
 77%|███████▋  | 36/47 [00:03<00:00, 11.52it/s][A
 81%|████████  | 38/47 [00:03<00:00, 11.82i

Epoch: 4 Test set: Average loss: 0.0042, Accuracy: 8125/10000 (81%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 13.33it/s][A
  9%|▊         | 4/47 [00:00<00:03, 13.12it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 13.22it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.92it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.76it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.77it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.68it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.88it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.53it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.28it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 11.99it/s][A
 51%|█████     | 24/47 [00:01<00:01, 11.94it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 11.89it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.82it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 11.69it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 11.30it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 11.27it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 11.48it/s][A
 81%|████████  | 38/47 [00:03<00:00, 11.50i

Epoch: 5 Test set: Average loss: 0.0036, Accuracy: 8511/10000 (85%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 11.76it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.19it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.47it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.86it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.73it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.43it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.20it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.30it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.36it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.40it/s][A
 47%|████▋     | 22/47 [00:01<00:02, 12.43it/s][A
 51%|█████     | 24/47 [00:01<00:01, 11.91it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 11.80it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 11.49it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 11.39it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 11.23it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 11.54it/s][A
 77%|███████▋  | 36/47 [00:03<00:00, 11.70it/s][A
 81%|████████  | 38/47 [00:03<00:00, 11.48i

Epoch: 6 Test set: Average loss: 0.0035, Accuracy: 8470/10000 (85%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 11.76it/s][A
  9%|▊         | 4/47 [00:00<00:03, 11.70it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.39it/s][A
 17%|█▋        | 8/47 [00:00<00:03, 12.44it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.67it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.61it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.84it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.38it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.42it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.69it/s][A
 47%|████▋     | 22/47 [00:01<00:01, 12.86it/s][A
 51%|█████     | 24/47 [00:01<00:01, 13.00it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 12.85it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 12.56it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 12.42it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 12.42it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.13it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.19it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.16i

Epoch: 7 Test set: Average loss: 0.0041, Accuracy: 8366/10000 (84%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 12.05it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.77it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 13.02it/s][A
 17%|█▋        | 8/47 [00:00<00:02, 13.06it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.85it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.73it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.73it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.66it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.61it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.52it/s][A
 47%|████▋     | 22/47 [00:01<00:01, 12.76it/s][A
 51%|█████     | 24/47 [00:01<00:01, 12.44it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 12.92it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 12.57it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 12.79it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 12.61it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.55it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.65it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.49i

Epoch: 8 Test set: Average loss: 0.0034, Accuracy: 8576/10000 (86%)



  0%|          | 0/47 [00:00<?, ?it/s][A
  4%|▍         | 2/47 [00:00<00:03, 13.33it/s][A
  9%|▊         | 4/47 [00:00<00:03, 12.36it/s][A
 13%|█▎        | 6/47 [00:00<00:03, 12.54it/s][A
 17%|█▋        | 8/47 [00:00<00:02, 13.10it/s][A
 21%|██▏       | 10/47 [00:00<00:02, 12.88it/s][A
 26%|██▌       | 12/47 [00:00<00:02, 12.75it/s][A
 30%|██▉       | 14/47 [00:01<00:02, 12.52it/s][A
 34%|███▍      | 16/47 [00:01<00:02, 12.77it/s][A
 38%|███▊      | 18/47 [00:01<00:02, 12.68it/s][A
 43%|████▎     | 20/47 [00:01<00:02, 12.84it/s][A
 47%|████▋     | 22/47 [00:01<00:01, 12.74it/s][A
 51%|█████     | 24/47 [00:01<00:01, 12.66it/s][A
 55%|█████▌    | 26/47 [00:02<00:01, 12.71it/s][A
 60%|█████▉    | 28/47 [00:02<00:01, 12.41it/s][A
 64%|██████▍   | 30/47 [00:02<00:01, 12.67it/s][A
 68%|██████▊   | 32/47 [00:02<00:01, 12.66it/s][A
 72%|███████▏  | 34/47 [00:02<00:01, 12.61it/s][A
 77%|███████▋  | 36/47 [00:02<00:00, 12.51it/s][A
 81%|████████  | 38/47 [00:03<00:00, 12.39i

Epoch: 9 Test set: Average loss: 0.0033, Accuracy: 8628/10000 (86%)



