<a href="https://colab.research.google.com/github/timsetsfire/wandb-examples/blob/main/General_Training_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch + W&B

* Experiment Tracking 
* Artifacts (and usage)
* Tables 

In [1]:
%%capture
!pip install wandb --upgrade

In [2]:
import wandb
from wandb.beta import workflows as wb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import os
import random

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm.notebook import tqdm

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# remove slow mirror from list of MNIST mirrors
torchvision.datasets.MNIST.mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors
                                      if not mirror.startswith("http://yann.lecun.com")]

In [4]:
# Conventional and convolutional neural network
class ConvNet(nn.Module):
    def __init__(self, kernels, classes=10):
        super(ConvNet, self).__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, kernels[0], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, kernels[1], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7 * 7 * kernels[-1], classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [5]:
def make(config):
    # Make the data
    train, test = get_data(train=True), get_data(train=False)
    train_loader = make_loader(train, batch_size=config.batch_size)
    test_loader = make_loader(test, batch_size=config.batch_size)

    # Make the model
    model = ConvNet(config.kernels, config.classes).to(device)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=config.learning_rate)
    
    return model, train_loader, test_loader, criterion, optimizer

def get_data(slice=5, train=True):
    full_dataset = torchvision.datasets.MNIST(root=".",
                                              train=train, 
                                              transform=transforms.ToTensor(),
                                              download=True)
    #  equiv to slicing with [::slice] 
    sub_dataset = torch.utils.data.Subset(
      full_dataset, indices=range(0, len(full_dataset), slice))
    
    return sub_dataset


def make_loader(dataset, batch_size):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         pin_memory=True, num_workers=2)
    return loader

In [6]:
def train_log(loss, example_ct, epoch):
    # Where the magic happens
    wandb.log({"epoch": epoch, "loss": loss}, step=example_ct)
    print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}")

def train_batch(images, labels, model, optimizer, criterion):
    images, labels = images.to(device), labels.to(device)
    # Forward pass ➡
    outputs = model(images)
    loss = criterion(outputs, labels)
    # Backward pass ⬅
    optimizer.zero_grad()
    loss.backward()
    # Step with optimizer
    optimizer.step()
    return loss


def test_model(model, test_loader, log_predictions = True):
    model.eval()
    # Run the model on some test examples

    with torch.no_grad():
        correct, total = 0, 0
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            if log_predictions:
              wandb_images = []
              for image in images.numpy():
                temp = wandb.Image(image)
                wandb_images.append(temp) 
              data = {"images":wandb_images, "predicted": predicted.numpy().tolist(), "labels": labels.numpy().tolist()}
              import pandas as pd 
              df = pd.DataFrame(data)
              scores_df = pd.DataFrame( outputs.numpy().tolist(), columns = [f"p{i}" for i in range(outputs.shape[1])])
              wandb_df = df.join(scores_df)          
              wandb_df["labels"] = wandb_df["labels"].apply(lambda x: "d-" + str(x))
              wandb.log({"Predictions vs Actuals": wandb.Table(dataframe = wandb_df)})

        print(f"Accuracy of the model on the {total} " +
              f"test images: {100 * correct / total}%")
        
        wandb.log({"test_accuracy": correct / total})          


    # # Save the model in the exchangeable ONNX format
    # torch.onnx.export(model, images, "model.onnx")
    # wandb.save("model.onnx")

## Get and Log Data

Here we are going throught the exercise of getting the data.  The data will be split into 3 subsets: training, validation, and test.  Once this is complete, we will log said data to W&B.  

In [7]:
entity, project_name = None, "MNIST_EXAMPLE"

In [9]:
with wandb.init(project = project_name, job_type = "data-acquisition") as run:

  train, test = get_data(train=True), get_data(train=False)
  train, validation = torch.utils.data.random_split(train, [10000, 2000])

  torch.save(train, 'training_data.pt')
  torch.save(validation, 'validation_data.pt')
  torch.save(test, 'test_data.pt')

  train_artifact = wandb.Artifact(name = "mnist-training-data", type = "dataset")
  train_artifact.add_file("./training_data.pt")

  validation_artifact = wandb.Artifact(name = "mnist-validation-data", type = "dataset")
  validation_artifact.add_file("./validation_data.pt")

  test_artifact = wandb.Artifact(name = "mnist-test-data", type = "dataset")
  test_artifact.add_file("./test_data.pt")  
  
  run.log_artifact(train_artifact)
  run.log_artifact(validation_artifact)
  run.log_artifact(test_artifact)


[34m[1mwandb[0m: Currently logged in as: [33mtim-w[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



VBox(children=(Label(value='90.738 MB of 90.738 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

## Train Model

When training the model, we'll use the `training` and `validation` dataset to assess our model.  We won't touch the `test` dataset until later.  

In [13]:
config = dict(
    epochs=5,
    classes=10,
    kernels=[16, 32],
    batch_size=128,
    learning_rate=0.01,
    dataset="MNIST",
    architecture="CNN")

In [14]:
with wandb.init(project = project_name, job_type = "training", config = config) as run:
   
  train_artifact = run.use_artifact(f"{project_name}/mnist-training-data:latest", type = "dataset")
  validation_artifact = run.use_artifact(f"{project_name}/mnist-validation-data:latest", type = "dataset")

  train_directory = train_artifact.download(".")
  validation_directory = validation_artifact.download(".")

  train = torch.load(train_artifact.file())
  validation = torch.load(validation_artifact.file())

  train_loader = make_loader(train, batch_size= wandb.config.batch_size)
  validation_loader = make_loader(validation, batch_size= wandb.config.batch_size)

  # Make the model
  model = ConvNet(wandb.config.kernels, wandb.config.classes).to(device)

  # Make the loss and optimizer
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(
      model.parameters(), lr=wandb.config.learning_rate)
  
  wandb.watch(model, criterion, log="all", log_freq=10)

  # Run training and track with wandb
  total_batches = len(train_loader) * wandb.config.epochs
  example_ct = 0  # number of examples seen
  batch_ct = 0
  for epoch in tqdm(range(wandb.config.epochs)):
      for _, (images, labels) in enumerate(train_loader):

          loss = train_batch(images, labels, model, optimizer, criterion)
          example_ct +=  len(images)
          batch_ct += 1

          # Report metrics every 25th batch
          if ((batch_ct + 1) % 25) == 0:
              train_log(loss, example_ct, epoch)
      
      test_model(model, validation_loader, log_predictions = False)

  model_artifact = wandb.Artifact(name = "mnist-model", type = "model")
  torch.save(model.state_dict(), "model.pt")
  model_artifact.add_file("model.pt")

  run.log_artifact(model_artifact)



  0%|          | 0/5 [00:00<?, ?it/s]

Loss after 03072 examples: 0.428
Loss after 06272 examples: 0.307
Loss after 09472 examples: 0.117
Accuracy of the model on the 2000 test images: 95.0%
Loss after 12560 examples: 0.058
Loss after 15760 examples: 0.067
Loss after 18960 examples: 0.048
Accuracy of the model on the 2000 test images: 97.05%
Loss after 22048 examples: 0.101
Loss after 25248 examples: 0.038
Loss after 28448 examples: 0.051
Accuracy of the model on the 2000 test images: 96.7%
Loss after 31536 examples: 0.046
Loss after 34736 examples: 0.010
Loss after 37936 examples: 0.084
Accuracy of the model on the 2000 test images: 97.4%
Loss after 41024 examples: 0.051
Loss after 44224 examples: 0.050
Loss after 47424 examples: 0.018
Accuracy of the model on the 2000 test images: 97.8%


VBox(children=(Label(value='0.200 MB of 0.200 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▃▃▃▅▅▅▆▆▆███
loss,█▆▃▂▂▂▃▁▂▂▁▂▂▂▁
test_accuracy,▁▆▅▇█

0,1
epoch,4.0
loss,0.01828
test_accuracy,0.978


## Test Model

In [15]:
with wandb.init(project = project_name, job_type = "evaluation", config = config) as run:


  artifact = run.use_artifact(f'tim-w/{project_name}/mnist-model:v0', type='model')
  artifact_dir = artifact.download("./my-model")
 

  my_model = ConvNet(wandb.config.kernels, wandb.config.classes).to(device)
  my_model.load_state_dict(torch.load(f"{artifact_dir}/model.pt"))

  test_artifact = run.use_artifact(f"{project_name}/mnist-test-data:latest", type = "dataset")
  test_directory = validation_artifact.download(".")

  test = torch.load(f"{test_directory}/test_data.pt")

  test_loader = make_loader(test, batch_size=wandb.config.batch_size)

  test_model(my_model, test_loader, True)

  ## registering model in the registry
  # run.link_artifact(artifact, f'tim-w/model-registry/MNIST', aliases = ["latest", "production", "needs-validation"])
  # artm = wb.log_model(my_model, name = "mnist-model-v2")
  # run.link_artifact(artm, f'tim-w/model-registry/MNIST', aliases = ["latest", "production", "needs-validation"])
        

Accuracy of the model on the 2000 test images: 98.0%


VBox(children=(Label(value='1.862 MB of 1.862 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test_accuracy,▁

0,1
test_accuracy,0.98
