# Weights and Biases Callback
> Defines a fastai Callback for specifically tracking image-to-image translation experiments in Weights and Biases.

In [1]:
#default_exp tracking.wandb

In [2]:
#export
import wandb
from fastai.vision.all import *
from fastai.callback.wandb import *
from fastai.callback.wandb import _format_metadata, _format_config
from fastai.basics import *
from fastai.vision.gan import *
from upit.models.cyclegan import *
from upit.data.unpaired import *
from upit.train.cyclegan import *
from upit.metrics import *

In [3]:
#export
class SaveModelAtEndCallback(Callback):
    def __init__(self, fname='model', with_opt=False): store_attr()
    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)
    def after_fit(self, **kwargs): self._save(f'{self.fname}')
    @property
    def name(self): return "save_model"

In [4]:
#export
def log_dataset(main_path, folder_names=None, name=None, metadata={}, description='raw dataset'):
    "Log dataset folder"
    # Check if wandb.init has been called in case datasets are logged manually
    if wandb.run is None:
        raise ValueError('You must call wandb.init() before log_dataset()')
    path = Path(main_path)
    if not path.is_dir():
        raise f'path must be a valid directory: {path}'
    name = ifnone(name, path.name)
    _format_metadata(metadata)
    artifact_dataset = wandb.Artifact(name=name, type='dataset', metadata=metadata, description=description)
    # log everything in folder_names
    if not folder_names: folder_names = [p.name for p in path.ls() if p.is_dir()]
    for p in path.ls():
        if p.is_dir():
            if p.name in folder_names and p.name != 'models': artifact_dataset.add_dir(str(p.resolve()), name=p.name)
        else: artifact_dataset.add_file(str(p.resolve()))
    wandb.run.use_artifact(artifact_dataset)

In [5]:
#export
class UPITWandbCallback(Callback):
    "Saves model topology, losses & metrics"
    remove_on_fetch,order = True,Recorder.order+1
    # Record if watch has been called previously (even in another instance)
    _wandb_watch_called = False

    def __init__(self, log="gradients", log_preds=True, log_model=True, log_dataset=False, folder_names=None, dataset_name=None, valid_dl=None, n_preds=36, seed=12345, reorder=True):
        # Check if wandb.init has been called
        if wandb.run is None:
            raise ValueError('You must call wandb.init() before WandbCallback()')
        # W&B log step
        self._wandb_step = wandb.run.step - 1  # -1 except if the run has previously logged data (incremented at each batch)
        self._wandb_epoch = 0 if not(wandb.run.step) else math.ceil(wandb.run.summary['epoch']) # continue to next epoch
        store_attr()

    def before_fit(self):
        "Call watch method to log model topology, gradients & weights"
        self.run = not hasattr(self.learn, 'lr_finder') and not hasattr(self, "gather_preds") and rank_distrib()==0
        if not self.run: return

        # Log config parameters
        log_config = self.learn.gather_args()
        _format_config(log_config)
        try:
            wandb.config.update(log_config, allow_val_change=True)
        except Exception as e:
            print(f'WandbCallback could not log config parameters -> {e}')

        if not WandbCallback._wandb_watch_called:
            WandbCallback._wandb_watch_called = True
            # Logs model topology and optionally gradients and weights
            wandb.watch(self.learn.model, log=self.log)

       
        # log dataset
        assert isinstance(self.log_dataset, (str, Path, bool)), 'log_dataset must be a path or a boolean'
        if self.log_dataset is True:
            if Path(self.dls.path) == Path('.'):
                print('WandbCallback could not retrieve the dataset path, please provide it explicitly to "log_dataset"')
                self.log_dataset = False
            else:
                self.log_dataset = self.dls.path
        
        if self.log_dataset:
            self.log_dataset = Path(self.log_dataset)
            assert self.log_dataset.is_dir(), f'log_dataset must be a valid directory: {self.log_dataset}'
            metadata = {'path relative to learner': os.path.relpath(self.log_dataset, self.learn.path)}
            if self.folder_names:
                assert isinstance(self.folder_names, list), 'folder_names must be a list of folder names as strings'
                for name in self.folder_names: assert isinstance(name, str), 'the elements of folder_names must be strings'
            log_dataset(main_path=self.log_dataset, folder_names=self.folder_names, name=self.dataset_name, metadata=metadata)


        # log model
        if self.log_model and not hasattr(self, 'save_model'):
            print('Adding SaveModelAtEndCallback()')
            self.learn.add_cb(SaveModelAtEndCallback())
            self.add_save_model = True
        else: self.add_save_model = False

        if self.log_preds:
            try:
                if not self.valid_dl:
                    if not len(self.dls.valid_ds):
                        print('Saving training set predictions')
                        #Initializes the batch watched
                        wandbRandom = random.Random(self.seed)  # For repeatability
                        self.n_preds = min(self.n_preds, len(self.dls.train_ds))
                        idxs = wandbRandom.sample(range(len(self.dls.train_ds)), self.n_preds)
                        test_items = [getattr(self.dls.train_ds.items, 'iloc', self.dls.train_ds.items)[i] for i in idxs]
                        self.preds_dl = self.dls.test_dl(test_items, with_labels=True)
                        
                else: self.preds_dl = self.valid_dl
                self.learn.add_cb(FetchPredsCallback(dl=self.preds_dl, with_input=True, with_decoded=True, reorder=self.reorder))
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to prepare a DataLoader for logging prediction samples -> {e}')

    def after_batch(self):
        "Log hyper-parameters and training loss"
        if self.training:
            self._wandb_step += 1
            self._wandb_epoch += 1/self.n_iter
            hypers = {f'{k}_{i}':v for i,h in enumerate(self.opt.hypers) for k,v in h.items()}

            wandb.log({'epoch': self._wandb_epoch, 'train_loss': float(to_detach(self.smooth_loss.clone())), 
                       'raw_loss': float(to_detach(self.loss.clone())), **hypers}, step=self._wandb_step)

    def log_predictions(self, preds):
        raise NotImplementedError("To be implemented")

    def after_epoch(self):
        "Log validation loss and custom metrics & log prediction samples"
        # Correct any epoch rounding error and overwrite value
        self._wandb_epoch = round(self._wandb_epoch)
        wandb.log({'epoch': self._wandb_epoch}, step=self._wandb_step)
        # Log sample predictions
        if self.log_preds:
            try:
                self.log_predictions(self.learn.fetch_preds.preds)
            except Exception as e:
                self.log_preds = False
                print(f'WandbCallback was not able to get prediction samples -> {e}')
        wandb.log({n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}, step=self._wandb_step)

    def after_fit(self):
        if self.log_model:
            if self.save_model.last_saved_path is None:
                print('WandbCallback could not retrieve a model to upload')
            else:
                metadata = {n:s for n,s in zip(self.recorder.metric_names, self.recorder.log) if n not in ['train_loss', 'epoch', 'time']}
                log_model(self.save_model.last_saved_path, metadata=metadata)
        self.run = True
        self.learn.remove_cb(FetchPredsCallback)
        self.learn.remove_cb(SaveModelAtEndCallback)
        wandb.log({})  # ensure sync of last step
        self._wandb_step += 1

In [6]:
#cuda
import tempfile

horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')
folders = horse2zebra.ls().sorted()
trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]
dls = get_dls(trainA_path, trainB_path, num_A=100, num_B=100, load_size=286)

#os.environ['WANDB_MODE'] = 'dryrun' # run offline
wandb.init()
cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,opt_func=partial(Adam,mom=0.5,sqr_mom=0.999),
                    metrics=[FrechetInceptionDistance()])

learn.fit_flat_lin(1,1,2e-4,cbs=[UPITWandbCallback(log_preds=True, log_model=True, log_dataset=horse2zebra, folder_names=[trainA_path.name,trainB_path.name])])
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mtmabraham[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (/home/tmabraham/.fastai/data/horse2zebra/trainA)... 

Could not gather input dimensions


Done. 0.3s
[34m[1mwandb[0m: Adding directory to artifact (/home/tmabraham/.fastai/data/horse2zebra/trainB)... Done. 0.2s


Adding SaveModelAtEndCallback()
Saving training set predictions


epoch,train_loss,id_loss_A,id_loss_B,gen_loss_A,gen_loss_B,cyc_loss_A,cyc_loss_B,D_A_loss,D_B_loss,frechet_inception_distance,time
0,9.867505,1.454771,1.509735,0.430915,0.450728,3.095023,3.188569,0.400037,0.400273,91.447801,00:18
1,8.506508,1.120648,1.158904,0.289497,0.295813,2.388297,2.494975,0.258037,0.257544,94.59166,00:17


  warn("Your generator is empty.")


WandbCallback was not able to get prediction samples -> To be implemented


VBox(children=(Label(value='107.959 MB of 107.959 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
D_A_loss,█▁
D_B_loss,█▁
cyc_loss_A,█▁
cyc_loss_B,█▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
frechet_inception_distance,▁█
gen_loss_A,█▁
gen_loss_B,█▁
id_loss_A,█▁

0,1
D_A_loss,0.25804
D_B_loss,0.25754
cyc_loss_A,2.3883
cyc_loss_B,2.49498
epoch,2.0
eps_0,1e-05
frechet_inception_distance,94.59166
gen_loss_A,0.2895
gen_loss_B,0.29581
id_loss_A,1.12065


In [7]:
b = dls.one_batch()
_,_,preds = learn.get_preds(dl=[b], with_decoded=True)

In [8]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_models.cyclegan.ipynb.
Converted 01b_models.junyanz.ipynb.
Converted 02_data.unpaired.ipynb.
Converted 03_train.cyclegan.ipynb.
Converted 04_inference.cyclegan.ipynb.
Converted 05_metrics.ipynb.
Converted 06_tracking.wandb.ipynb.
Converted 07_models.dualgan.ipynb.
Converted 08_train.dualgan.ipynb.
Converted 09_models.ganilla.ipynb.
Converted index.ipynb.
