In [1]:
from lenet import IDXLeNetDataset, LeNet5
from torch.utils.data import DataLoader
import torch
import wandb
import matplotlib.pyplot as plt
# Test the model on our own MNIST dataset
from lenet import OurMNISTDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bs = 4096

In [3]:
train_dataset = IDXLeNetDataset("data/train-images-idx3-ubyte", "data/train-labels-idx1-ubyte", padding=2) 
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

test_dataset = IDXLeNetDataset("data/t10k-images-idx3-ubyte", "data/t10k-labels-idx1-ubyte", padding=2)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=True)

our_dataset = OurMNISTDataset(padding=0)
our_loader = DataLoader(our_dataset, batch_size=1, shuffle=True)

In [4]:
model = LeNet5().to("cuda")
lr = 0.001
epochs = 40


# Cross Entropy Loss
criterion = torch.nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

wandb.init(project="lenet", config={"learning_rate": lr, "epochs": epochs, "batch_size": bs}, entity="jpossaz")

wandb.watch(model, log="all", criterion=criterion)

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to("cuda"), target.to("cuda")
        optimizer.zero_grad()
        logits, probas = model(data)
        loss = criterion(logits, target) / bs
        loss.backward()
        optimizer.step()

        wandb.log({"loss": loss.item()})

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to("cuda"), target.to("cuda")
            logits, probas = model(data)
            test_loss += criterion(logits, target).item() # sum up batch loss
            pred = probas.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataset)
    accuracy = 100. * correct / len(test_dataset)

    wandb.log({"test_loss": test_loss, "accuracy": accuracy})

    # Visualize convolutional filters
    filters = model.conv1.weight.data.cpu().numpy() # shape: (6, 1, 5, 5)
    wandb.log({"conv1": [wandb.Image(filters[i,0,:,:]) for i in range(filters.shape[0])]})

    # Now for conv2
    filters = model.conv2.weight.data.cpu().numpy()
    wandb.log({"conv2": [wandb.Image(filters[i,0,:,:]) for i in range(filters.shape[0])]})

    test_loss = 0
    correct = 0

    images = []
    with torch.no_grad():
        for i, (data, target) in enumerate(our_loader):
            
            data, target = data.to("cuda"), target.to("cuda")
            logits, probas = model(data)
            test_loss += criterion(logits, target).item() # sum up batch loss
            pred = probas.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            if i > 10:
                continue

            # Make a matpliotlib figure with the probability distribution
            # and the title as the target, and also show the image
            fig, ax = plt.subplots(1, 2, figsize=(5, 2.5))
            ax[0].imshow(data[0,0,:,:].cpu().numpy())
            ax[1].bar(range(10), probas[0,:].cpu().numpy())
            ax[1].set_xticks(range(10))
            ax[1].set_title(f"Target: {target.item()}")

            images.append(wandb.Image(fig))

            plt.close(fig)

    wandb.log({"our_mnist_predictions": images})

    test_loss /= len(our_dataset)
    accuracy = 100. * correct / len(our_dataset)

    wandb.log({"our_mnist_loss": test_loss, "our_mnist_accuracy": accuracy})

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjpossaz[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
accuracy,▁▅▆▇▇▇▇▇████████████████████████████████
loss,█▇▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
our_mnist_accuracy,▂▃▂▂▂▂▁▁▂▁▁▃▂▃▄▅▅▅▅▆▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇███▇▇
our_mnist_loss,▂▁▆▆▅▆▆█▆▇█▆▆▅▆▆▅▅▆▅▆▆▅▅▅▅▅▆▆▆▅▅▆▅▅▅▅▅▆▆
test_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,98.39
loss,1e-05
our_mnist_accuracy,34.0
our_mnist_loss,3.15889
test_loss,1e-05
