In [None]:
!rm -rf wandb

import wandb
import os
from concurrent.futures import ProcessPoolExecutor
import random
import math

PROJECT = 'triton-ensemble-demo'

%env WANDB_PROJECT {PROJECT}
%env WANDB_ENTITY megatruong

config={
    "learning_rate": 0.01 * random.random(),
    "batch_size": 128,
    "momentum": 0.1 * random.random(),
    "dropout": 0.4 * random.random(),
    "dataset": ["hello", "world", 2],
    "model": "trained-model",
}


def add_refs(target, *sources):
    for source in sources:
        for name in source.manifest.entries:
            ref = source.get_path(name)
            art_name, art_ver = source.name.split(':v')
            namespaced_fname = f"{art_name}/{name}"
            target.add_reference(ref, namespaced_fname)
            
            
def log_component_model(component):
    with wandb.init(config=config, job_type=f"generate_{component}") as run:
        EVAL_STEPS = 1000
        # Log metrics and checkpoints at N steps
        displacement1 = random.random() * 2
        displacement2 = random.random() * 4
        for step in range(EVAL_STEPS):
            run.log({
                "acc": .1 + 0.4 * (math.log(1 + step + random.random()) + random.random() * run.config.learning_rate + random.random() + displacement1 + random.random() * run.config.momentum),
                "val_acc": .1 + 0.5 * (math.log(1 + step + random.random()) + random.random() * run.config.learning_rate - random.random() + displacement1),
                "loss": .1 + 0.08 * (3.5 - math.log(1 + step + random.random()) + random.random() * run.config.momentum + random.random() + displacement2),
                "val_loss": .1 + 0.04 * (4.5 - math.log(1 + step + random.random()) + random.random() * run.config.learning_rate - random.random() + displacement2),
            })
        art = wandb.Artifact(component, type='model')
        art.add_dir(f'ensemble_model/{component}')
        run.log_artifact(art)
        run.link_artifact(art, "model-registry/Text Detection")
        

def log_packaged_model(components):
    with wandb.init(job_type="package_components") as run:
        art = wandb.Artifact('ensemble_model', type='model')
        art.add_dir('ensemble_model')
        arts = [run.use_artifact(f"{component}:latest") for component in components]
        add_refs(art, *arts)
        run.log_artifact(art)
        run.log_code()
        

def log_component_models(components):
    for x in components:
        log_component_model(x)
    # with ProcessPoolExecutor() as executor:
    #     executor.map(log_component_model, components)
        
        
ensemble_components = ['detection_preprocessing', 'text_detection', 'detection_postprocessing', 'text_recognition', 'recognition_postprocessing']

log_component_models(ensemble_components)
# log_packaged_model(ensemble_components)
# deploy_to_triton('ensemble_model')

In [None]:
# !export WANDB_PROJECT={PROJECT} && rm -rf wandb && python log_component_model.py && rm -rf wandb && python log_packaged_model.py

In [None]:
wandb.api.api_key