# wandb

## wandb我最爱的炼丹伴侣操作指南

> [wandb我最爱的炼丹伴侣操作指南](https://www.bilibili.com/video/BV17A41167WX/?share_source=copy_web&vd_source=724ca2fcd803a56b1646d6d28e65b820)

### Intro

wandb全称weights&bias，大号tensorboard，优势：

- log存储在云端，便分享不丢失。
- 可以存代码，数据集和模型的版本，随时复现  (`wandb.Artifact`)
- 交互式表格，进行case分析 (`wandb.Table`)
- 可以自动化模型调参 (`wandb.sweep`)
  - 最重要的一点
  - 高效优雅

官方介绍：

- Experiments
- Artifacts
- Tables
- Sweeps
- Reports

对常用的framework都做了集成。

总体而言，四个核心功能：

- 实验跟踪
- 版本管理
- case分析
- 超参调优

### Experience

##### 0. 注册wandb

In [1]:
import os
os.environ["WANDB_API_KEY"] = "xxxxxxxxxxxxxxxxx"
os.environ["WANDB_NOTEBOOK_NAME"] = "wandb"
os.environ['WANDB_DISABLE_CODE'] = 'true'

import wandb
wandb.login()

[34m[1mwandb[0m: Network error (SSLError), entering retry loop.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


True

#### 1.实验跟踪

wandb提供了类似TensorBoard的实验跟踪能力，主要包括：

- 模型配置超参数的记录
- 模型训练过程中loss，metric等各种指标的记录和可视化
- 图像的可视化（wandb.Image）
- 其他各种Media（wandb.Video, wandb.Audio, wandb.Html, 3D点云等）

In [2]:
import os, PIL
import numpy as np
from torch.utils.data import DataLoader, Dataset
import torch
from torch import nn
import torchvision
from torchvision import transforms
import datetime
from argparse import Namespace

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = Namespace(
    project_name = 'wandb_demo',

    batch_size = 512,
    hidden_layer_width = 64,
    dropout_p = 0.1,

    lr = 1e-4,
    optim_type = 'Adam',

    epochs = 15,
    ckpt_path = 'checkpoint.pt'
)

In [3]:
def create_dataloaders(config):
    transform = transforms.Compose([transforms.ToTensor()])
    ds_train = torchvision.datasets.MNIST(root="./mnist/",train=True,download=True,transform=transform)
    ds_val = torchvision.datasets.MNIST(root="./mnist/",train=False,download=True,transform=transform)

    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))
    dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,
                                            num_workers=2,drop_last=True)
    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, 
                                          num_workers=2,drop_last=True)
    return dl_train,dl_val

In [4]:
def create_net(config):
    net = nn.Sequential()
    net.add_module("conv1",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))
    net.add_module("pool1",nn.MaxPool2d(kernel_size = 2,stride = 2)) 
    net.add_module("conv2",nn.Conv2d(in_channels=config.hidden_layer_width,
                                     out_channels=config.hidden_layer_width,kernel_size = 5))
    net.add_module("pool2",nn.MaxPool2d(kernel_size = 2,stride = 2))
    net.add_module("dropout",nn.Dropout2d(p = config.dropout_p))
    net.add_module("adaptive_pool",nn.AdaptiveMaxPool2d((1,1)))
    net.add_module("flatten",nn.Flatten())
    net.add_module("linear1",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))
    net.add_module("relu",nn.ReLU())
    net.add_module("linear2",nn.Linear(config.hidden_layer_width,10))
    net.to(device)
    return net

In [5]:
def train_epoch(model,dl_train,optimizer):
    model.train()
    for step, batch in enumerate(dl_train):
        features,labels = batch
        features,labels = features.to(device),labels.to(device)

        preds = model(features)
        loss = nn.CrossEntropyLoss()(preds,labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
    return model

In [6]:
def eval_epoch(model,dl_val):
    model.eval()
    accurate = 0
    num_elems = 0
    for batch in dl_val:
        features,labels = batch
        features,labels = features.to(device),labels.to(device)
        with torch.no_grad():
            preds = model(features)
        predictions = preds.argmax(dim=-1)
        accurate_preds =  (predictions==labels)
        num_elems += accurate_preds.shape[0]
        accurate += accurate_preds.long().sum()

    val_acc = accurate.item() / num_elems
    return val_acc

In [7]:
def train(config = config):
    dl_train, dl_val = create_dataloaders(config)
    model = create_net(config); 
    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)
    #======================================================================
    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True, mode="offline")
    model.run_id = wandb.run.id
    #======================================================================
    model.best_metric = -1.0
    for epoch in range(1,config.epochs+1):
        model = train_epoch(model,dl_train,optimizer)
        val_acc = eval_epoch(model,dl_val)
        if val_acc>model.best_metric:
            model.best_metric = val_acc
            torch.save(model.state_dict(),config.ckpt_path)   
        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print(f"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%")
        #======================================================================
        wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})
        #======================================================================        
    #======================================================================
    wandb.finish()
    #======================================================================
    return model  

In [8]:
model = train(config) ##3,2,1 点火🔥🔥

epoch【1】@2024-07-12 15:30:46 --> val_acc= 29.52%
epoch【2】@2024-07-12 15:30:49 --> val_acc= 29.91%
epoch【3】@2024-07-12 15:30:52 --> val_acc= 37.94%
epoch【4】@2024-07-12 15:30:55 --> val_acc= 50.54%
epoch【5】@2024-07-12 15:30:58 --> val_acc= 59.88%
epoch【6】@2024-07-12 15:31:01 --> val_acc= 68.98%
epoch【7】@2024-07-12 15:31:03 --> val_acc= 76.30%
epoch【8】@2024-07-12 15:31:06 --> val_acc= 79.64%
epoch【9】@2024-07-12 15:31:09 --> val_acc= 82.93%
epoch【10】@2024-07-12 15:31:12 --> val_acc= 85.29%
epoch【11】@2024-07-12 15:31:15 --> val_acc= 86.87%
epoch【12】@2024-07-12 15:31:18 --> val_acc= 88.03%
epoch【13】@2024-07-12 15:31:21 --> val_acc= 88.80%
epoch【14】@2024-07-12 15:31:24 --> val_acc= 89.71%
epoch【15】@2024-07-12 15:31:27 --> val_acc= 90.42%


0,1
best_val_acc,▁▁▂▃▄▆▆▇▇▇█████
epoch,▁▁▂▃▃▃▄▅▅▅▆▇▇▇█
val_acc,▁▁▂▃▄▆▆▇▇▇█████

0,1
best_val_acc,0.90419
epoch,15.0
val_acc,0.90419
