In [1]:
import os
import random
import math
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
# import torchvision
# import torchvision.transforms as transforms
# import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import wandb

# save weights and biases api key to .env file in project directory
assert os.getenv('WANDB_API_KEY')
assert torch.cuda.is_available()

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)  # noqa: NPY002
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Device configuration
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshane-kercheval[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
x, y = fetch_openml('mnist_784', version=1, return_X_y=True, parser='auto')
x = torch.tensor(x.values, dtype=torch.float32)
y = torch.tensor(y.astype(int).values, dtype=torch.long)

# need to make this dynamic based on Fully Connected vs Convolutional
# Reshape data to have channel dimension
# MNIST images are 28x28, so we reshape them to [batch_size, 1, 28, 28]
x = x.reshape(-1, 1, 28, 28)

# 80% train; 10% validation; 10% test
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, test_size=0.5, random_state=42)

print(f"Training set  : X-{x_train.shape}, y-{y_train.shape}")
print(f"Validation set: X-{x_val.shape}, y-{y_val.shape}")
print(f"Test set      : X-{x_test.shape}, y-{y_test.shape}")

Training set  : X-torch.Size([56000, 1, 28, 28]), y-torch.Size([56000])
Validation set: X-torch.Size([7000, 1, 28, 28]), y-torch.Size([7000])
Test set      : X-torch.Size([7000, 1, 28, 28]), y-torch.Size([7000])


In [4]:
class ConvNet(nn.Module):
    """Convolutional neural network (two convolutional layers)."""

    def __init__(self, kernels: list, classes: int = 10):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, kernels[0], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, kernels[1], kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.fc = nn.Linear(7 * 7 * kernels[-1], classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass."""
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        return self.fc(out)

In [17]:
def make_loader(x: torch.tensor, y: torch.tensor, batch_size: int) -> DataLoader:
    """Make a DataLoader from a given dataset."""
    return DataLoader(
        dataset=TensorDataset(x, y),
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=2,
    )


def make(config: dict) -> tuple:
    """Make the model, data, and optimization objects."""
    # Make the data
    train_loader = make_loader(x_train, y_train, batch_size=config.batch_size)
    validation_loader = make_loader(x_val, y_val, batch_size=config.batch_size)
    test_loader = make_loader(x_test, y_test, batch_size=config.batch_size)

    # Make the model
    model = ConvNet(config.kernels, config.classes).to(DEVICE)

    # Make the loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    return (
        model,
        train_loader,
        validation_loader,
        test_loader,
        criterion,
        optimizer,
    )


def train_log(training_loss: float, validation_loss: float, example_ct: int, epoch: int) -> None:
    """Logs loss to the console and wandb."""
    # Where the magic happens
    wandb.log(
        {
            'epoch': epoch,
            'training_loss': training_loss,
            'validation_loss': validation_loss
        },
        step=example_ct,
    )
    print(
        f"Training/Validation Loss after {str(example_ct).zfill(5)} examples: ",
        f"{training_loss:.3f} | {validation_loss:.3f}",
    )


def calculate_average_loss(
        data_loader: DataLoader,
        model: nn.Module,
        loss_func: callable) -> float:
    """Calculates the average loss over a dataset."""
    running_loss = 0
    total_samples = 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)  # noqa: PLW2901
            loss = loss_func(model(x), y)
            # weighted average of the loss adjusted for the batch size
            running_loss += loss.item() * x.shape[0]
            total_samples += x.shape[0]
    return running_loss / total_samples


def train(
        model: nn.Module,
        train_loader: DataLoader,
        validation_loader: DataLoader,
        criterion: callable,
        optimizer: torch.optim.Optimizer,
        config: dict) -> None:
    """Trains the model for the number of epochs specified in the config."""
    model.train()
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, criterion, log="all", log_freq=10)

    # Run training and track with wandb
    example_ct = 0  # number of examples seen

    log_interval = 30 # i.e. every 30 batches
    total_batches = len(train_loader)
    log_interval = max(1, math.floor(total_batches / log_interval))

    for epoch in tqdm(range(config.epochs)):
        running_training_loss = 0
        total_train_samples = 0
        for batch_index, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(DEVICE), y_batch.to(DEVICE)  # noqa: PLW2901
            # ➡ Forward pass
            outputs = model(x_batch)
            loss = criterion(outputs, y_batch)
            # ⬅ Backward pass & optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            example_ct += len(x_batch)
            # weighted average of the training loss
            running_training_loss += loss.item() * x_batch.shape[0]
            total_train_samples += x_batch.shape[0]
            # Report metrics every 25th batch
            if batch_index % log_interval == 0:
                avg_training_loss = running_training_loss / total_train_samples
                running_training_loss = 0
                total_train_samples = 0
                model.eval()
                average_validation_loss = calculate_average_loss(
                    data_loader=validation_loader, model=model, loss_func=criterion,
                )
                train_log(avg_training_loss, average_validation_loss, example_ct, epoch)
                model.train()


def test(model: nn.Module, test_loader: DataLoader, criterion: callable) -> None:
    """Tests the model on the test set. Logs the accuracy to the console and to wandb."""
    model.eval()
    avg_test_loss = calculate_average_loss(data_loader=test_loader, model=model, loss_func=criterion)  # noqa
    print(f"Average Loss on test set: {avg_test_loss:.3f}")
    wandb.log({'test_loss': avg_test_loss})

    # Save the model in the exchangeable ONNX format
    x, _ = next(iter(test_loader))
    torch.onnx.export(model, x.to(DEVICE) , 'model.onnx')
    wandb.save('model.onnx')


def model_pipeline(config: dict) -> nn.Module:
    """Builds the model and runs it."""
    # tell wandb to get started
    project = config.pop('project'); assert project
    tags = config.pop('tags', None)
    notes = config.pop('notes', None)
    with wandb.init(project=project, config=config, tags=tags, notes=notes):
      config = wandb.config
      # make the model, data, and optimization problem
      model, train_loader, validation_loader, test_loader, criterion, optimizer = make(config)
      print(model)
      # and use them to train the model
      train(model, train_loader, validation_loader, criterion, optimizer, config)
      # and test its final performance
      test(model, test_loader, criterion)

    return model

In [18]:
config = {
    'project': 'pytorch-demo',
    'tags': ['pytorch', 'demo'],
    'notes': 'First run with a simple CNN',
    'epochs': 5,
    'classes': 10,
    'kernels': [16, 32],
    'batch_size': 64,
    'learning_rate': 0.005,
    'dataset': 'MNIST',
    'architecture': 'CNN',
}
# Build, train and analyze the model with the pipeline
model = model_pipeline(config)

ConvNet(
  (layer1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Linear(in_features=1568, out_features=10, bias=True)
)


  0%|          | 0/5 [00:00<?, ?it/s]

Training/Validation Loss after 00064 examples:  14.934 | 91.637
Training/Validation Loss after 01920 examples:  8.008 | 1.051
Training/Validation Loss after 03776 examples:  0.748 | 0.530
Training/Validation Loss after 05632 examples:  0.467 | 0.340
Training/Validation Loss after 07488 examples:  0.352 | 0.309
Training/Validation Loss after 09344 examples:  0.258 | 0.248
Training/Validation Loss after 11200 examples:  0.289 | 0.238
Training/Validation Loss after 13056 examples:  0.231 | 0.212
Training/Validation Loss after 14912 examples:  0.211 | 0.214
Training/Validation Loss after 16768 examples:  0.190 | 0.225
Training/Validation Loss after 18624 examples:  0.201 | 0.189
Training/Validation Loss after 20480 examples:  0.186 | 0.194
Training/Validation Loss after 22336 examples:  0.245 | 0.193
Training/Validation Loss after 24192 examples:  0.223 | 0.197
Training/Validation Loss after 26048 examples:  0.156 | 0.156
Training/Validation Loss after 27904 examples:  0.174 | 0.235
Traini

 20%|██        | 1/5 [00:10<00:41, 10.31s/it]

Training/Validation Loss after 55744 examples:  0.203 | 0.169
Training/Validation Loss after 56064 examples:  0.066 | 0.191
Training/Validation Loss after 57920 examples:  0.134 | 0.160
Training/Validation Loss after 59776 examples:  0.171 | 0.158
Training/Validation Loss after 61632 examples:  0.161 | 0.152
Training/Validation Loss after 63488 examples:  0.143 | 0.142
Training/Validation Loss after 65344 examples:  0.153 | 0.173
Training/Validation Loss after 67200 examples:  0.158 | 0.155
Training/Validation Loss after 69056 examples:  0.154 | 0.157
Training/Validation Loss after 70912 examples:  0.157 | 0.156
Training/Validation Loss after 72768 examples:  0.159 | 0.165
Training/Validation Loss after 74624 examples:  0.160 | 0.141
Training/Validation Loss after 76480 examples:  0.159 | 0.177
Training/Validation Loss after 78336 examples:  0.104 | 0.143
Training/Validation Loss after 80192 examples:  0.121 | 0.177
Training/Validation Loss after 82048 examples:  0.203 | 0.146
Training

 40%|████      | 2/5 [00:20<00:30, 10.24s/it]

Training/Validation Loss after 111744 examples:  0.134 | 0.143
Training/Validation Loss after 112064 examples:  0.002 | 0.279
Training/Validation Loss after 113920 examples:  0.131 | 0.169
Training/Validation Loss after 115776 examples:  0.138 | 0.160
Training/Validation Loss after 117632 examples:  0.145 | 0.164
Training/Validation Loss after 119488 examples:  0.144 | 0.151
Training/Validation Loss after 121344 examples:  0.161 | 0.151
Training/Validation Loss after 123200 examples:  0.123 | 0.165
Training/Validation Loss after 125056 examples:  0.185 | 0.151
Training/Validation Loss after 126912 examples:  0.165 | 0.189
Training/Validation Loss after 128768 examples:  0.107 | 0.180
Training/Validation Loss after 130624 examples:  0.140 | 0.203
Training/Validation Loss after 132480 examples:  0.139 | 0.143
Training/Validation Loss after 134336 examples:  0.103 | 0.149
Training/Validation Loss after 136192 examples:  0.122 | 0.143
Training/Validation Loss after 138048 examples:  0.110 

 60%|██████    | 3/5 [00:31<00:21, 10.51s/it]

Training/Validation Loss after 167744 examples:  0.141 | 0.134
Training/Validation Loss after 168064 examples:  0.078 | 0.124
Training/Validation Loss after 169920 examples:  0.109 | 0.120
Training/Validation Loss after 171776 examples:  0.097 | 0.125
Training/Validation Loss after 173632 examples:  0.101 | 0.130
Training/Validation Loss after 175488 examples:  0.119 | 0.134
Training/Validation Loss after 177344 examples:  0.109 | 0.135
Training/Validation Loss after 179200 examples:  0.124 | 0.141
Training/Validation Loss after 181056 examples:  0.111 | 0.125
Training/Validation Loss after 182912 examples:  0.109 | 0.120
Training/Validation Loss after 184768 examples:  0.119 | 0.157
Training/Validation Loss after 186624 examples:  0.070 | 0.114
Training/Validation Loss after 188480 examples:  0.143 | 0.160
Training/Validation Loss after 190336 examples:  0.176 | 0.150
Training/Validation Loss after 192192 examples:  0.144 | 0.133
Training/Validation Loss after 194048 examples:  0.130 

 80%|████████  | 4/5 [00:41<00:10, 10.50s/it]

Training/Validation Loss after 223744 examples:  0.110 | 0.135
Training/Validation Loss after 224064 examples:  0.197 | 0.118
Training/Validation Loss after 225920 examples:  0.110 | 0.141
Training/Validation Loss after 227776 examples:  0.113 | 0.179
Training/Validation Loss after 229632 examples:  0.097 | 0.143
Training/Validation Loss after 231488 examples:  0.130 | 0.156
Training/Validation Loss after 233344 examples:  0.134 | 0.150
Training/Validation Loss after 235200 examples:  0.109 | 0.145
Training/Validation Loss after 237056 examples:  0.122 | 0.123
Training/Validation Loss after 238912 examples:  0.130 | 0.137
Training/Validation Loss after 240768 examples:  0.124 | 0.123
Training/Validation Loss after 242624 examples:  0.106 | 0.147
Training/Validation Loss after 244480 examples:  0.123 | 0.120
Training/Validation Loss after 246336 examples:  0.089 | 0.116
Training/Validation Loss after 248192 examples:  0.119 | 0.139
Training/Validation Loss after 250048 examples:  0.085 

100%|██████████| 5/5 [00:51<00:00, 10.40s/it]

Training/Validation Loss after 279744 examples:  0.117 | 0.115





Average Loss on test set: 13.518667%




0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆████████
test_loss,▁
training_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_loss,█▃▂▂▂▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,4.0
test_loss,0.13519
training_loss,0.11668
validation_loss,0.11477
