# Weights and Biases integration

In [1]:
#| default_exp wandb

In [2]:
#| export
from torch.utils.data import Dataset
from pathlib import Path
from dataclasses import dataclass

import wandb
from fastprogress import progress_bar

In [3]:
#| export
@dataclass
class WandbConfig:
    "A minimal config for wandb"
    project: str
    entity: str = None

In [4]:
config = WandbConfig("capetorch", "capecape")

## Log a Torchvision dataset to W&B

In [13]:
#| export
class ImageDatasetLogger:
    "Log a torchvision dataset with Images as a Table"
    def __init__(self, 
                 ds: Dataset, 
                 image_mode:str="RGB", 
                 n:int=None, 
                 log_raw:bool=True,
                 log2workspace:bool=True, 
                 artifact_name:str="image_dataset",
                 table_name:str="image_dataset_table"
                ):
        
        self.ds = ds
        self.image_mode = image_mode
        self.n = n
        self.log_raw = log_raw
        self.log2workspace = log2workspace
        self.artifact_name = artifact_name
        self.table_name = table_name
        
        # create artifact
        self.ds_at = wandb.Artifact(self.artifact_name, type="data")
            
        
    def __repr__(self):
        return self.ds.__repr__()
    
    def add_raw(self):
        try:
            path = Path(getattr(self.ds, "raw_folder")).parent
            self.ds_at.add_dir(path)
        except:
            raise Exception("Error finding the folder for the dataset")
    
    def _dataset_table(self):
        "Create a wanb.Table with validation data" 
        ds_table = wandb.Table(columns = ["image", "label"])
        pbar = progress_bar(self.ds, total=len(self.ds), leave=False)
        pbar.comment = "Creating W&B Table with validation DL"
        for i, (img, lbl) in enumerate(pbar):
            if self.n is not None and i>=self.n:
                break
            ds_table.add_data(wandb.Image(img, mode=self.image_mode), lbl)
        return ds_table
    
    def add_dataset_table(self):
        "Add the dataset as a wandb.Table"
        self.ds_table = self._dataset_table()
        
        assert wandb.run is not None, "Execute this function within a wandb run"
        self.ds_at.add(self.ds_table, self.table_name)
            
    def _log(self):
        # log to W&B
        wandb.log_artifact(self.ds_at)
        
        # log table also to workspace
        if self.log2workspace:
            wandb.log({self.table_name: self.ds_table})
            
    def log_table(self):    
        # add table
        self.add_dataset_table()
        
        # log to W&B
        self._log()
        
    def log_all(self):
        # add raw folder
        if self.log_raw and self.n is None:
            self.add_raw()
        
        self.log_table()

In [14]:
import torchvision as tv
ds = tv.datasets.FashionMNIST(".", download=True, train=False)

In [17]:
wdsl = ImageDatasetLogger(ds, n=100)

In [18]:
with wandb.init(project=config.project):
    wdsl.log_all()

## Log predictions

In [25]:
api = wandb.Api()

In [29]:
at = api.artifact("capetorch/image_dataset:latest")

In [39]:
at.get("image_dataset_table")

[34m[1mwandb[0m:   101 of 101 files downloaded.  


<wandb.data_types.Table at 0x7f2422598520>

In [53]:
def not_none(o):
    if isinstance(o, (list, tuple)):
        return all(not_none(x) for x in o)
    return o is not None

In [55]:
not_none([1,2,None])

False

In [56]:
not_none(2)

True

In [None]:
class ModelPredictionsLogger:
    """
    Log model predictions on the validation datasets, it references the previously logged dataset
    """
    def __init__(self, ds, n_preds=None, ds_data_at=None, ds_table_name=None, log2workspace=True):
        self.ds = ds
        self.n_preds = n_preds
        self.ds_data_at = ds_data_at
        self.ds_table_name = ds_table_name
        self.log2workspace = log2workspace
        
        if not_none([ds_data_at, ds_table_name]):
            self.ds_table = _get_reference_table(ds_data_at, ds_table_name)
            
        
    def _get_reference_table(self):
        artifact = wandb.use_artifact(self.ds_data_at, type='data')
        return artifact.get("ds_table")
    
    def _init_preds_table(self, num_classes=10):
        "Create predictions table"
        self.preds_table = wandb.Table(columns=["image", "label", "preds"]+[f"prob_{i}" for i in range(num_classes)])
        
    def create_preds_table(self, preds, n=None):
        if self.ds_table is None:
            print("No val table reference found")
            return
        table_idxs = self.ds_table.get_index()
        
        if isinstance(preds, torch.Tensor):
            preds = preds.cpu().numpy()
                 
        for idx in progress_bar(table_idxs[:n], leave=False):
            pred = preds[idx]
            self.preds_table.add_data(
                self.ds_table.data[idx][1],
                self.ds_table.data[idx][2],
                pred.argmax(),
                *pred
            )
    
    def log(self, preds, n=None, table_name="preds_table", aliases=None):
        if len(preds.shape) == 2:
            num_classes = preds.shape[1]
        else:
            raise Error("The preds tensor must have rank 1 or 2, no more or less")
        
        # get the validation data from the reference
        self._get_reference_table()
            
        # create the Predictions Table 
        self._init_preds_table(num_classes=num_classes)
        
        # Populate the Table with the model predictions
        self.create_preds_table(preds, n=n)
        
        # Log to W&B
        assert wandb.run is not None, "Execute this inside a wandb run"
        pred_artifact = wandb.Artifact(f"run_{wandb.run.id}_preds", type="evaluation")
        pred_artifact.add(self.preds_table, table_name)
        wandb.log_artifact(pred_artifact, aliases=aliases or ["latest"])
        
        # Log the Table to the workspace
        if self.log2workspace:
            wandb.log({"preds_table":self.preds_table}) 