In [30]:
# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader, )  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
from torch.utils.tensorboard import SummaryWriter


class CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=8,
                               kernel_size=(3, 3),
                               padding=(1, 1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=8,
                               out_channels=16,
                               kernel_size=(3, 3),
                               stride=(1, 1),
                               padding=(1, 1))
        self.fc1 = nn.Linear(16 * 7 * 7, 250)
        self.fc2 = nn.Linear(250, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x2 = F.relu(self.fc1(x))
        x = self.fc2(x2)
        return [x, x2]


def checkpoint(state, filename='my_checkpoint_cnn.pth'):
    print('saving checkpoint')
    torch.save(state, filename)


# set device to cuda
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyperparameters
in_channels = 1
num_classes = 10
#learning_rate=0.001
#batch_size=64
num_epochs = 1

# Load Data
train_dataset = datasets.MNIST(root="/data",
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_dataset = datasets.MNIST(root="/data",
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)
#initialisation of model
model = CNN(in_channels=in_channels, num_classes=num_classes).to(device=device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
####tensorboard
batch_sizes = [32, 64, 128, 1024]
learning_rates = [0.1, 0.01, 0.001, 0.0001]
####progress bar

from tqdm import tqdm  ## progress bar
for batch_size in batch_sizes:
    for learning_rate in learning_rates:
        writer = SummaryWriter(f'runs/MNIST/MiniBatchSize {batch_size} LR {learning_rate}')
        step = 0
        train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
        model = CNN(in_channels=in_channels, num_classes=num_classes).to(device=device)
        model.train()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        for epoch in range(num_epochs):
            loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=False)
            losses=[]
            accuracies=[]
            for batch_idx, (data, targets) in loop:
                #         get data to cuda if possible

                data = data.to(device=device)
                targets = targets.to(device=device)
                #         forward pass
                scores, x2 = model(data)
                #         loss calculation
                loss = criterion(scores, targets)
                losses.append(loss)
                #         zeroing gradients
                optimizer.zero_grad()
                loss.backward()

                #         gradient descent
                optimizer.step()

                a, predictions = scores.max(1)
                num_correct = (predictions == targets).sum()
                runningtrainacc = float(num_correct) / float(data.shape[0])
                accuracies.append(runningtrainacc)
                writer.add_scalar('training loss', loss, global_step=step)
                writer.add_scalar('training acc', runningtrainacc, global_step=step)
                step += 1
        
                loop.set_description(f'Epoch[{epoch}/{num_epochs}]')
                loop.set_postfix(loss=loss.item(),acc=runningtrainacc)
            writer.add_hparams(
                {"lr": learning_rate, "bsize": batch_size},
                {
                    "accuracy": sum(accuracies) / len(accuracies),
                    "loss": sum(losses) / len(losses),
                },
            )

            if epoch % 2 == 0:
                check_point = {
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
            checkpoint(check_point)


def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            scores, _ = model(x)
            a, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
        print(
            f'Got{num_correct}/{num_samples} with accuracy {float(num_correct/num_samples)*100:.2f}'
        )
    model.train()


check_accuracy(train_loader, model)
check_accuracy(test_loader, model)

Epoch[0/1]:   0%|▏                                               | 9/1875 [00:00<00:20, 89.35it/s, acc=0.531, loss=1.7]

saving checkpoint


Epoch[0/1]:   0%|                                                        | 0/1875 [00:00<?, ?it/s, acc=0.125, loss=2.3]

saving checkpoint


Epoch[0/1]:   0%|                                                       | 0/1875 [00:00<?, ?it/s, acc=0.125, loss=2.28]

saving checkpoint


Epoch[0/1]:   1%|▎                                               | 7/938 [00:00<00:14, 62.64it/s, acc=0.109, loss=2.35]

saving checkpoint


Epoch[0/1]:   1%|▎                                               | 7/938 [00:00<00:13, 69.49it/s, acc=0.219, loss=1.96]

saving checkpoint


Epoch[0/1]:   0%|                                                        | 0/938 [00:00<?, ?it/s, acc=0.141, loss=2.27]

saving checkpoint


Epoch[0/1]:   0%|                                                        | 0/938 [00:00<?, ?it/s, acc=0.0781, loss=2.3]

saving checkpoint


Epoch[0/1]:   0%|                                                        | 0/469 [00:00<?, ?it/s, acc=0.188, loss=2.31]

saving checkpoint


Epoch[0/1]:   1%|▍                                               | 4/469 [00:00<00:12, 38.56it/s, acc=0.211, loss=2.08]

saving checkpoint


Epoch[0/1]:   0%|                                                        | 0/469 [00:00<?, ?it/s, acc=0.266, loss=2.25]

saving checkpoint


Epoch[0/1]:   1%|▌                                               | 5/469 [00:00<00:09, 47.76it/s, acc=0.164, loss=2.29]

saving checkpoint


Epoch[0/1]:   0%|                                                                               | 0/59 [00:00<?, ?it/s]

saving checkpoint


Epoch[0/1]:   2%|▊                                               | 1/59 [00:00<00:06,  8.87it/s, acc=0.0889, loss=2.31]

saving checkpoint


Epoch[0/1]:   0%|                                                                               | 0/59 [00:00<?, ?it/s]

saving checkpoint


Epoch[0/1]:   0%|                                                                               | 0/59 [00:00<?, ?it/s]

saving checkpoint


                                                                                                                       

saving checkpoint
Got33848/60000 with accuracy 56.41
Got5848/10000 with accuracy 58.48


In [None]:
# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset


In [5]:
train_dataset = datasets.MNIST(
    root="/data", train=True, transform=transforms.ToTensor(), download=True
)
for (x,y) in train_dataset:
    print(y)
    

5
0
4
1
9
2
1
3
1
4
3
5
3
6
1
7
2
8
6
9
4
0
9
1
1
2
4
3
2
7
3
8
6
9
0
5
6
0
7
6
1
8
7
9
3
9
8
5
9
3
3
0
7
4
9
8
0
9
4
1
4
4
6
0
4
5
6
1
0
0
1
7
1
6
3
0
2
1
1
7
9
0
2
6
7
8
3
9
0
4
6
7
4
6
8
0
7
8
3
1
5
7
1
7
1
1
6
3
0
2
9
3
1
1
0
4
9
2
0
0
2
0
2
7
1
8
6
4
1
6
3
4
5
9
1
3
3
8
5
4
7
7
4
2
8
5
8
6
7
3
4
6
1
9
9
6
0
3
7
2
8
2
9
4
4
6
4
9
7
0
9
2
9
5
1
5
9
1
2
3
2
3
5
9
1
7
6
2
8
2
2
5
0
7
4
9
7
8
3
2
1
1
8
3
6
1
0
3
1
0
0
1
7
2
7
3
0
4
6
5
2
6
4
7
1
8
9
9
3
0
7
1
0
2
0
3
5
4
6
5
8
6
3
7
5
8
0
9
1
0
3
1
2
2
3
3
6
4
7
5
0
6
2
7
9
8
5
9
2
1
1
4
4
5
6
4
1
2
5
3
9
3
9
0
5
9
6
5
7
4
1
3
4
0
4
8
0
4
3
6
8
7
6
0
9
7
5
7
2
1
1
6
8
9
4
1
5
2
2
9
0
3
9
6
7
2
0
3
5
4
3
6
5
8
9
5
4
7
4
2
7
3
4
8
9
1
9
2
8
7
9
1
8
7
4
1
3
1
1
0
2
3
9
4
9
2
1
6
8
4
7
7
4
4
9
2
5
7
2
4
4
2
1
9
7
2
8
7
6
9
2
2
3
8
1
6
5
1
1
0
2
6
4
5
8
3
1
5
1
9
2
7
4
4
4
8
1
5
8
9
5
6
7
9
9
3
7
0
9
0
6
6
2
3
9
0
7
5
4
8
0
9
4
1
2
8
7
1
2
6
1
0
3
0
1
1
8
2
0
3
9
4
0
5
0
6
1
7
7
8
1
9
2
0
5
1
2
2
7
3
5
4
9
7
1
8
3
9
6
0
3
1
1
2
6
3
5
7
6
8


9
9
5
1
3
9
0
9
4
5
9
1
1
7
7
4
4
5
4
3
1
3
8
1
7
0
6
2
0
4
8
8
1
6
8
1
2
3
6
3
2
9
0
1
9
1
1
7
2
1
5
5
4
6
5
4
7
4
7
8
5
4
3
4
0
2
2
5
2
1
2
6
7
1
3
5
1
5
4
3
4
6
3
7
0
1
0
2
4
9
7
2
9
6
7
2
8
6
8
0
7
0
9
0
6
9
3
4
1
7
7
4
9
2
6
6
3
3
3
3
1
3
0
4
2
7
6
9
0
3
5
1
6
3
6
4
8
7
7
3
3
8
5
9
8
3
3
4
0
7
1
6
2
0
3
8
7
4
8
3
9
2
0
7
3
8
4
0
7
1
8
0
0
5
1
4
2
1
3
7
4
7
5
1
6
5
7
5
8
9
9
8
5
0
1
5
9
6
9
0
7
4
1
9
0
2
9
8
0
7
5
3
9
7
7
7
1
7
3
4
6
7
3
0
2
3
0
0
9
4
0
8
6
9
1
3
7
7
2
0
8
7
3
6
4
9
7
6
6
7
6
9
3
8
2
7
3
0
8
5
9
9
4
2
8
4
8
8
7
4
1
8
0
1
8
6
7
7
7
7
5
2
9
3
2
6
3
6
1
4
1
6
4
4
5
1
5
2
0
5
4
3
8
1
7
7
4
5
9
6
2
2
3
6
6
7
4
8
1
4
2
4
4
8
1
2
5
0
1
2
9
8
3
5
2
7
0
4
0
9
2
7
9
4
3
5
6
4
0
9
1
8
2
1
3
7
4
2
5
2
6
7
7
5
8
8
9
1
0
1
1
0
2
3
3
1
4
4
5
0
6
4
7
0
8
1
9
1
0
3
1
3
2
3
3
3
4
2
5
7
6
2
7
6
8
4
9
1
8
6
0
7
5
7
6
0
6
1
0
6
8
4
0
4
2
2
3
9
7
6
9
7
4
2
7
3
1
1
9
0
1
5
7
6
1
5
0
4
0
1
4
9
1
3
7
6
5
2
7
2
1
8
3
8
3
8
3
3
1
6
6
0
9
5
7
7
4
9
3
5
0
8
2
5
5
1
2
7
6
9
0
0
8
2
9
6
4
5
3
4


4
6
8
6
9
1
1
1
4
2
0
6
5
3
5
8
4
2
0
4
7
4
6
1
0
3
1
6
7
0
0
9
6
9
8
3
7
7
9
7
8
8
6
0
0
4
8
2
1
2
7
6
7
8
1
4
3
1
1
6
0
8
0
1
7
2
8
5
9
1
3
7
8
6
4
4
7
1
2
9
3
3
6
9
9
6
2
4
2
5
4
9
6
1
5
6
5
1
1
9
3
5
3
2
7
1
8
8
2
5
5
7
7
4
9
9
8
9
2
7
1
2
3
9
0
3
1
1
2
6
3
1
4
4
5
2
6
6
7
1
8
4
9
9
0
2
1
0
2
2
3
7
4
6
5
2
6
8
7
9
8
1
9
6
0
2
1
9
2
3
4
8
6
3
7
6
8
7
9
5
3
1
8
4
0
9
7
1
1
9
0
7
7
9
5
8
5
8
6
3
9
6
0
3
1
1
0
6
0
4
8
8
3
1
4
0
3
1
1
3
0
6
9
5
5
1
3
2
4
3
9
7
3
9
7
4
6
7
9
2
2
7
4
3
5
4
7
7
2
1
1
7
6
1
4
4
9
2
4
8
9
9
4
1
1
6
2
9
2
7
1
7
3
2
2
4
9
1
4
7
3
3
8
9
2
1
2
0
2
2
1
7
2
8
8
9
6
4
5
5
1
9
6
6
7
9
2
2
1
4
3
6
7
6
5
2
9
5
3
4
8
5
7
1
5
4
7
1
0
9
7
5
4
6
8
1
8
8
3
1
7
8
6
5
9
2
9
3
4
9
8
5
4
1
1
8
0
0
6
3
6
8
0
6
1
8
2
6
3
8
4
4
6
7
7
3
8
9
0
1
1
4
2
3
3
3
4
5
5
9
6
3
7
7
8
6
9
8
0
0
1
6
2
5
3
8
5
8
6
5
7
9
4
1
0
2
7
2
2
5
9
8
4
5
2
9
9
2
8
0
4
3
5
7
8
1
8
3
0
8
3
5
5
4
4
6
3
1
3
6
4
9
6
1
5
8
1
8
7
2
8
3
4
0
5
3
7
6
2
3
3
4
6
7
6
3
7
9
7
7
7
1
6
3
0
2
1
0
0
6
1
3
6
3
5
6
3
6
1
3


6
1
7
1
8
2
9
0
2
6
5
6
1
8
6
7
4
7
3
6
9
4
9
1
0
8
9
3
7
4
1
5
6
0
4
3
3
3
6
5
2
0
2
6
0
0
0
1
8
7
2
2
5
9
9
4
8
5
6
0
5
9
7
5
0
5
0
7
4
1
7
8
4
2
3
0
2
8
2
5
4
1
1
3
3
8
4
6
7
2
7
9
7
3
2
5
0
6
9
2
8
4
5
4
6
9
4
8
3
4
8
2
2
8
8
7
3
5
8
7
9
8
5
1
8
1
0
6
5
5
4
6
7
1
1
4
3
9
1
8
7
5
9
0
0
4
9
3
1
8
7
8
3
0
8
1
0
0
3
7
9
0
1
2
6
9
4
7
3
7
2
1
1
2
8
5
6
0
4
5
5
0
0
5
1
5
5
9
5
3
6
3
5
3
0
1
1
1
2
9
3
8
4
8
5
5
6
0
7
2
8
3
9
1
0
7
1
4
2
6
3
7
4
0
5
1
6
5
7
9
9
2
0
7
1
6
2
1
3
0
4
3
7
2
9
7
4
5
7
6
6
4
3
6
4
0
0
2
9
9
7
5
1
7
9
7
3
0
8
8
4
3
7
8
3
2
0
4
9
4
9
4
1
9
1
7
4
0
2
1
0
5
6
2
2
5
1
7
1
2
1
6
1
3
7
3
2
5
4
5
7
4
5
2
2
1
9
3
4
3
5
4
8
9
9
6
7
3
0
3
0
9
7
1
5
2
1
9
1
8
7
7
6
9
6
2
6
5
8
2
2
2
2
5
7
2
7
6
4
9
0
4
2
0
4
3
2
6
3
1
3
0
3
3
1
9
1
8
6
0
9
8
5
8
5
9
3
7
0
5
8
9
3
6
9
8
6
2
0
7
1
9
2
4
3
9
4
2
5
1
6
3
7
4
8
3
9
0
0
8
1
1
2
1
3
8
4
8
5
2
6
5
7
3
8
8
9
9
0
4
1
2
2
1
3
2
4
7
5
8
6
8
7
5
8
6
9
4
3
3
4
7
1
4
4
8
6
4
6
5
7
6
3
7
3
8
2
5
1
8
9
3
4
0
7
7
3
5
4
3
9
7
1
9
1
8
0
1
1
4


8
8
6
9
6
2
8
9
3
5
9
1
6
5
4
6
9
8
1
3
2
5
6
6
3
0
8
9
8
9
2
3
7
1
0
1
2
0
8
3
3
4
6
5
6
6
9
7
2
8
3
9
9
0
9
1
5
2
6
3
5
8
6
0
6
1
8
2
6
3
2
7
6
8
8
9
7
9
4
0
9
1
9
4
9
7
2
3
2
2
9
9
9
0
4
5
8
6
6
7
8
3
8
9
8
9
6
3
5
8
5
1
7
2
1
9
7
6
7
7
5
5
9
6
7
3
4
4
7
9
9
1
4
2
2
9
8
3
5
4
8
4
0
3
0
1
8
7
1
3
4
4
1
9
9
6
8
9
7
8
9
1
4
1
0
6
6
6
3
8
6
2
6
8
6
4
3
8
2
0
1
0
1
1
5
2
5
3
4
4
1
7
2
8
9
9
8
0
3
1
3
2
5
3
0
4
1
5
1
6
9
7
6
8
1
9
3
0
4
1
1
3
1
4
2
7
1
8
2
9
3
8
5
6
2
5
1
0
7
6
8
8
7
9
9
4
7
1
8
9
8
0
6
4
7
8
5
9
4
1
7
4
7
0
1
2
7
1
9
5
8
4
4
0
6
7
3
6
9
0
6
1
5
7
6
0
3
6
1
8
6
4
5
7
4
9
4
8
1
6
5
0
3
8
1
1
1
7
1
7
6
1
8
3
3
2
3
3
1
1
7
4
2
2
0
0
0
0
2
7
7
8
8
4
1
6
8
9
3
3
0
8
6
4
7
7
2
2
1
5
4
6
4
3
2
6
2
9
3
6
8
3
7
2
3
2
1
4
9
6
5
0
8
5
6
5
3
1
4
3
3
3
0
9
0
7
2
8
8
7
7
2
3
7
5
9
6
8
0
2
0
1
6
3
8
1
1
3
5
0
0
1
7
2
5
3
7
4
9
7
7
8
3
9
3
0
6
1
9
2
3
3
5
4
4
5
8
6
8
7
6
8
9
9
6
0
8
1
5
2
6
3
6
7
8
8
1
9
4
7
1
7
9
0
2
3
9
1
3
8
8
9
8
8
4
6
0
5
2
4
6
7
7
3
3
1
0
9
2
1
0
4
2
5
4
5
2
6
8
0


8
8
5
9
2
0
9
1
7
2
3
3
1
9
3
8
2
1
8
0
5
9
8
5
2
7
1
5
0
1
1
8
5
6
6
9
4
0
6
4
1
1
9
9
3
3
8
8
8
4
4
7
2
0
1
1
5
9
6
2
9
8
9
7
5
8
4
9
3
6
0
0
2
6
4
5
0
5
0
3
8
3
6
3
3
9
2
8
3
0
6
6
4
1
3
0
8
0
4
6
0
2
2
1
4
1
9
3
8
2
8
7
2
7
4
8
1
8
5
7
6
0
7
2
9
0
6
3
8
1
1
3
1
7
2
2
5
4
7
9
6
4
7
3
1
6
6
5
2
3
6
2
1
5
2
5
2
9
5
4
3
1
9
7
4
2
2
6
3
5
0
0
6
1
9
2
5
3
6
4
3
5
9
6
2
7
2
8
8
9
8
0
7
1
1
2
1
3
3
4
5
5
1
6
7
7
0
8
0
9
1
0
1
1
9
2
6
3
4
4
0
5
3
6
0
7
6
8
7
9
1
8
1
9
3
5
1
7
4
0
3
3
6
1
3
6
8
8
8
4
9
1
8
7
3
6
5
5
3
6
8
4
8
2
6
7
5
8
6
1
6
3
4
4
0
7
4
2
2
0
9
5
4
0
0
1
3
9
8
2
7
3
3
2
4
3
3
5
3
5
9
7
7
8
9
4
2
9
1
9
1
7
3
1
9
1
2
7
3
8
1
3
6
4
3
8
5
6
4
3
8
8
8
0
9
9
2
6
1
2
5
0
1
1
1
0
5
6
5
2
3
3
5
8
8
9
6
0
0
7
9
2
1
3
7
4
9
5
6
5
3
2
1
8
4
5
6
4
7
6
3
6
7
6
0
7
5
9
4
1
1
8
3
2
5
1
2
2
0
6
7
5
0
3
4
4
5
7
0
9
6
4
4
0
1
0
2
5
5
9
3
0
8
1
8
2
3
3
8
4
1
5
6
6
1
7
6
8
4
9
6
0
9
1
9
2
4
3
3
4
0
5
8
6
8
7
0
8
6
9
7
0
3
1
7
2
2
3
4
4
3
5
2
6
6
7
1
7
9
6
5
4
4
3
1
0
9
6
5
6
6
6
0
9
7
0
0
0
9
1


KeyboardInterrupt: 

In [10]:
def load_checkpoint(checkpoint):
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
load_checkpoint(torch.load('my_checkpoint_cnn.pth'))

In [12]:
tr

tensor([[0.3400, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 3.3408],
        [1.0919, 0.0000, 2.2912,  ..., 0.0000, 1.4870, 3.7958],
        [0.8183, 0.0000, 0.0000,  ..., 0.0000, 1.7577, 0.0000],
        ...,
        [0.0000, 0.0000, 2.1585,  ..., 0.0000, 0.9813, 1.9573],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.4128, 3.3074],
        [0.0000, 0.0000, 2.0019,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReluBackward0>)

In [30]:
#pytorch data augmentation and Image transformations
#https://pytorch.org/docs/stable/torchvision/transforms.html for more tranformations
import torch 
import os
import torchvision.transforms as transforms
from torchvision.utils import save_image
from custom_dataloading import CatsAndDogDataset

In [33]:
my_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.5),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(224,224)
    transforms.ToTensor()
])

dataset = CatsAndDogDataset(csv_file=os.path.join(os.getcwd(),'catanddog','catsanddog.csv'),root_dir=os.path.join(os.getcwd(),'catanddog'),transform=my_transforms)



torch.Size([3, 374, 500])
torch.Size([3, 280, 300])
torch.Size([3, 396, 312])
torch.Size([3, 414, 500])
torch.Size([3, 375, 499])
torch.Size([3, 144, 175])
torch.Size([3, 375, 499])
torch.Size([3, 499, 327])
torch.Size([3, 199, 187])
torch.Size([3, 375, 499])
torch.Size([3, 287, 300])
torch.Size([3, 376, 499])


In [5]:
# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset

train_dataset = datasets.MNIST(
    root="/data", train=True, transform=transforms.ToTensor(), download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST(
    root="/data", train=False, transform=transforms.ToTensor(), download=True
)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True)

In [7]:
from tqdm import tqdm## progress bar

for epoch in range(5):
    loop=tqdm(enumerate(train_loader),total=len(train_loader),leave=False)
    for batch_idx,(data,targets) in loop:
        print(targets.shape)
        

  2%|█▍                                                                              | 17/938 [00:00<00:05, 167.11it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


  5%|████▎                                                                           | 50/938 [00:00<00:05, 165.31it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 10%|███████▊                                                                        | 92/938 [00:00<00:04, 181.91it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 14%|███████████                                                                    | 132/938 [00:00<00:04, 188.52it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 18%|██████████████▍                                                                | 172/938 [00:00<00:03, 192.74it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 20%|████████████████                                                               | 191/938 [00:01<00:03, 191.47it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 25%|███████████████████▍                                                           | 231/938 [00:01<00:03, 191.21it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 29%|██████████████████████▊                                                        | 271/938 [00:01<00:03, 192.75it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 33%|██████████████████████████▎                                                    | 312/938 [00:01<00:03, 196.33it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 38%|█████████████████████████████▋                                                 | 352/938 [00:01<00:03, 192.26it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 40%|███████████████████████████████▎                                               | 372/938 [00:01<00:02, 192.44it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 44%|██████████████████████████████████▋                                            | 412/938 [00:02<00:02, 192.56it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 48%|██████████████████████████████████████▏                                        | 453/938 [00:02<00:02, 196.85it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 53%|█████████████████████████████████████████▌                                     | 493/938 [00:02<00:02, 191.25it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 57%|████████████████████████████████████████████▉                                  | 533/938 [00:02<00:02, 192.06it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 61%|████████████████████████████████████████████████▎                              | 573/938 [00:02<00:01, 193.62it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 63%|█████████████████████████████████████████████████▉                             | 593/938 [00:03<00:01, 193.94it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 67%|█████████████████████████████████████████████████████▎                         | 633/938 [00:03<00:01, 195.12it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 72%|████████████████████████████████████████████████████████▊                      | 675/938 [00:03<00:01, 198.68it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


 76%|████████████████████████████████████████████████████████████▏                  | 715/938 [00:03<00:01, 196.93it/s]

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])


                                                                                                                       

torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])
torch.Size([64])




KeyboardInterrupt: 

In [24]:
datasets[1]

TypeError: 'module' object is not subscriptable