# Wandb tutorial

What is wandb?

![alt text](wandb_screenshot.png "Title")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchmetrics

import wandb

# Experiment tracking

One of the most useful features of wandb is the ability to manage your experiments. We will first demonstrate basic usage of wandb to track our image classification results.

In [None]:
# Set some variables
batch_size = 4
lr = 0.001
epochs = 10
optim_name = 'ADAM'
loss_fn = 'Cross entropy'
filters = 20

# Generate our run id and config dictionary
id = wandb.util.generate_id()

config = {
    "batch_size": batch_size, 
    "learning_rate": lr,
    "loss_function": loss_fn,
    "optimizer": optim_name,
    "run_id": id,
    "epochs": epochs,
    "filters": filters,
}

# Set our run name, project, and anything else we want to initialize our run with
run_name = f"simple_CIFAR10_v0"
project="ENEL_645"
note = 'Simple CIFAR10 classification wandb tutorial'
run = wandb.init(project=project, entity='natalia-dubljevic', id=id, name=run_name, config=config, notes=note)  # resume is True when resuming

#### Download and deal with our data

In [None]:
transform = transforms.ToTensor()
dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=transform)
dataset = torch.utils.data.Subset(dataset, list(range(int(len(dataset) / 10))))

trainset, valset = torch.utils.data.random_split(dataset, [0.7, 0.3])
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=1, shuffle=False)

classes = ('airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
def img_transform(img):
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    return img

# get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

fig = plt.figure()
for i, image in enumerate(images):
    plt.subplot(1, 4, i+1)
    plt.imshow(img_transform(image))
    plt.title(classes[labels[i]])
    plt.axis('off')
plt.show()


#### Define our network architecture

In [None]:
class Net(nn.Module):
    def __init__(self, filters=20) -> None:
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, filters, kernel_size=3, padding='same')
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding='same')
        self.conv3 = nn.Conv2d(filters, filters, kernel_size=3, padding='same')
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        self.fc1 = nn.Linear(filters * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))

        x = torch.flatten(x, 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [None]:
# cuda check
device='cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

### Begin the training and validation loops

In [None]:
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

for epoch in range(epochs):
    # Training loop
    train_loss = 0.0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()

        imgs, labels = data[0].to(device), data[1].to(device)
        outputs = net(imgs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= (i + 1)
    print(f'train loss: {train_loss:.6f}', flush=True)

    # Validation loop
    val_loss = 0.0
    with torch.no_grad():
        imgs = []
        preds = []
        labels = []

        # Create a table of predictions for later
        columns = ['Image', 'Prediction', 'Truth']
        for class_name in classes:
            columns.append("score_" + class_name)
        predictions_table = wandb.Table(columns=columns)

        for i, data in enumerate(val_loader):
            img, label = data[0].to(device), data[1].to(device)
            output = net(img)  # recall this is of size batch x 10
            loss = criterion(output, label)
            val_loss += loss.item()

            _, pred = torch.max(output, 1)  # returns max value and index of max value
            preds.append(pred.item())
            labels.append(label.item())

            if i in range(16):
                # Create some lists to log some images
                pred_class, label_class = classes[pred], classes[label]
                imgs.append(wandb.Image(img, caption=f"Pred: {pred_class}, Label: {label_class}"))
                
                # Add data to our predictions table
                row = [wandb.Image(img), pred_class, label_class]
                output = nn.functional.softmax(torch.squeeze(output), dim=0)
                for class_prob in output.tolist():
                    row.append(np.round(class_prob, 4))
                predictions_table.add_data(*row)

    acc = torchmetrics.functional.accuracy(torch.Tensor(preds), torch.Tensor(labels), task="multiclass", num_classes=10)

    val_loss /= (i + 1)
    print(f'val loss: {val_loss:.6f}', flush=True)
    wandb.log({"train_loss": train_loss, 
                "val_loss": val_loss,
                "accuracy": acc,
                'img': imgs, 
                'prediction_table' : predictions_table,
                "conf_mat" : wandb.plot.confusion_matrix(
                    probs=None, y_true=labels, preds=preds,
                    class_names=classes)},
                step=epoch+1)
    
run.finish()

## Hyperparameter sweeps

Hyperparameter sweeps allow as to dial in on the most optimal hyperparameters for a given experiment. These tend to be more feasible for lighter models as they can be computationally expensive to run.

![alt text](sweep_controller.png "Title")

### Define our sweep configs

In [None]:
config_defaults = {
    'loss_function': loss_fn,
    'optimizer': optim_name,
    'epochs': epochs,
    'batch_size': batch_size
}


grid_sweep = {
    'method': 'grid',  # or 'random' or 'bayes'
    'name': 'grid_sweep',
    'metric': {'goal': 'maximize', 'name': 'acc'},
    'parameters': {
        'filters': {'values': [10, 20, 30]},
        'learning_rate': {'values': [0.01, 0.001]}
    }
}


random_sweep = {
    'method': 'random',
    'name': 'random_sweep',
    'metric': {'goal': 'maximize', 'name': 'acc'},
    'parameters': {
        'learning_rate': {
           # val between exp(min) and exp(max) such that log is uniformly 
           #distributed between min and max
            'distribution': 'log_uniform', 
            'min': -8, 
            'max': -5},
        'filters': {
            # a flat distribution between 0 and 0.1
            'distribution': 'int_uniform',
            'min': 10,
            'max': 100}
   }
}


### Set up our training function for the sweep controller

In [None]:
def train(config=None):
    with wandb.init(config=config_defaults) as run:
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config

        net = Net(filters=config.filters).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=config.learning_rate)

        for epoch in range(epochs):
            train_loss = 0.0
            for i, data in enumerate(train_loader):
                optimizer.zero_grad()

                imgs, labels = data[0].to(device), data[1].to(device)
                outputs = net(imgs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()

            train_loss /= (i + 1)

            val_loss = 0.0
            with torch.no_grad():
                imgs = []
                preds = []
                labels = []

                for i, data in enumerate(val_loader):
                    img, label = data[0].to(device), data[1].to(device)
                    output = net(img)  # recall this is of size batch x 10
                    loss = criterion(output, label)
                    val_loss += loss.item()

                    _, pred = torch.max(output, 1)  # returns max value and index of max value
                    preds.append(pred)
                    labels.append(label)

                    if i in range(16):
                        pred, label = classes[pred.item()], classes[label.item()]
                        img = torch.squeeze(img.detach().cpu())
                        imgs.append(wandb.Image(img, caption=f"Pred: {pred}, Label: {label}"))

            acc = torchmetrics.functional.accuracy(torch.Tensor(preds), torch.Tensor(labels), task="multiclass", num_classes=10)

            val_loss /= (i + 1)
            wandb.log({"train_loss": train_loss, 
                        "val_loss": val_loss,
                        "accuracy": acc,
                        'img': imgs}, step=epoch+1)

### Let's try out a grid sweep

In [None]:
sweep_id = wandb.sweep(grid_sweep, project="ENEL_645", entity='natalia-dubljevic')
wandb.agent(sweep_id, train)

### We can also do a random sweep

In [None]:
sweep_id = wandb.sweep(random_sweep, project="ENEL_645", entity='natalia-dubljevic')
wandb.agent(sweep_id, train, count=6)