In [1]:
from pathlib import Path
import wandb
import pickle
import torch

def log_file_artifact(wandb_run, path, name, type):
    artifact = wandb.Artifact(name, type=type)
    artifact.add_file(path)
    return wandb_run.log_artifact(artifact)

exports_path = Path('./exports')
exports_path.mkdir(parents=True, exist_ok=True)

def log_model_state(wandb_run, model):
    path = exports_path / 'model-state.pt'
    torch.save(model.state_dict(), path)
    log_file_artifact(wandb_run, path, 'model-state', type='model')
    return path

def log_learner(wandb_run, learn):
    path = exports_path / 'learn.pkl'
    learn.export(path)
    log_file_artifact(wandb_run, path, 'learn', type='model')
    return path

def log_preprocessor(wandb_run, pp, name):
    path = exports_path / f'{name}.pkl'
    with open(path, 'wb') as f:
        pickle.dump(pp, f)
    log_file_artifact(wandb_run, path, name, type='preprocessor')
    return path
