In [None]:
!pip install wandb --upgrade

In [1]:
import wandb
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

In [3]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [2]:
config = dict(
    # start_epoch=1,
    num_epochs=2,
    num_classes=10,
    batch_size=64,
    img_size=224,
    lr=3e-4,
    dataset='cifar10',
    architecture='resnet18'
)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn = nn.CrossEntropyLoss()
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)
    return acc

In [4]:
def prepare_data(config):
    T = transforms.Compose(
        [
        transforms.Resize((config.img_size, config.img_size)),
        transforms.ToTensor()
        ]
    )
    train_data = datasets.CIFAR10("data/", train=True, download=True, transform=T)
    val_data = datasets.CIFAR10("data/", train=False, download=True, transform=T)
    train_loader = DataLoader(train_data, batch_size=config.batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_data, batch_size=config.batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

In [5]:
def prepare_model(config):
    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, config.num_classes)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    return model, optimizer

In [6]:
def loop(model, loader, is_train, epoch, optimizer=None):
    model.train(is_train)
    losses = []
    accs = []

    if is_train:
        split = 'train'
    else:
        split = ' val '
        
    pbar = tqdm(loader, total=len(loader))
    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)

        with torch.set_grad_enabled(is_train):
            preds = model(x)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)
            losses.append(loss.item())
            accs.append(acc.item())

        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_loss = np.mean(losses)
        mean_acc = np.mean(accs)
        pbar.set_description(f'{split}: epoch={epoch}, loss={mean_loss:.4f}, acc={mean_acc:.4f}')
    return mean_loss, mean_acc

In [7]:
def fit(config, model, train_loader, val_loader, optimizer):
    wandb.watch(model, loss_fn, log='all', log_freq=20)
    # for epoch in range(config.start_epoch, config.num_epochs):
    for epoch in range(config.num_epochs):
        train_loss, train_acc = loop(model, train_loader, True, epoch, optimizer)
        val_loss, val_acc = loop(model, val_loader, False, epoch)
        wandb.log({'epoch': epoch,'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc})
    return model

In [8]:
def run_pipeline(config):
    with wandb.init(project='test0', config=config):
        config = wandb.config

        train_loader, val_loader = prepare_data(config)
        print('got the data!!')
        model, optimizer = prepare_model(config)
        print('got the model!!')
        model = fit(config, model, train_loader, val_loader, optimizer)
        print('training done!!')

        dummy_inputs = torch.zeros(1, 3, config.img_size, config.img_size).to(device)
        torch.onnx.export(model, dummy_inputs, 'model.onnx')
        wandb.save('model.onnx')
        print('model saved!!')

In [9]:
run_pipeline(config)

[34m[1mwandb[0m: Currently logged in as: [33mzer0sh0t[0m (use `wandb login --relogin` to force relogin)


Files already downloaded and verified
Files already downloaded and verified
got the data!!


  0%|          | 0/782 [00:00<?, ?it/s]

got the model!!


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
train: epoch=0, loss=1.2141, acc=0.5621: 100%|██████████| 782/782 [02:42<00:00,  4.81it/s]
 val : epoch=0, loss=1.2846, acc=0.5639: 100%|██████████| 157/157 [00:17<00:00,  9.13it/s]
train: epoch=1, loss=0.7517, acc=0.7352: 100%|██████████| 782/782 [02:42<00:00,  4.81it/s]
 val : epoch=1, loss=0.8480, acc=0.6997: 100%|██████████| 157/157 [00:17<00:00,  9.18it/s]


training done!!
model saved!!


VBox(children=(Label(value=' 42.65MB of 42.65MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
epoch,1.0
train_loss,0.75167
train_acc,0.73519
val_loss,0.84798
val_acc,0.69974
_runtime,367.0
_timestamp,1624722039.0
_step,1.0


0,1
epoch,▁█
train_loss,█▁
train_acc,▁█
val_loss,█▁
val_acc,▁█
_runtime,▁█
_timestamp,▁█
_step,▁█
