In [None]:
import wandb

from fastai.vision.all import *

## Inference Benchmark

In [None]:
PROJECT="CamVid"
ENTITY="av-demo"
IMAGE_SHAPE = (720, 960)
SEED = 42
RUN_NAME = "inference-1"
JOB_TYPE = "inference"

MODEL_ARTIFACT_ID = 'av-demo/CamVid/baseline-train-1-saved-model:latest'

BATCH_SIZE = 4
IMAGE_RESIZE_FACTOR = 1
VALIDATION_SPLIT_PCT = 0.2
HIDDEN_DIM = 256
BACKBONE = "mobilenetv2_100"

LEARNING_RATE = 1e-3
TRAIN_EPOCHS = 1

INFERENCE_BATCH_SIZE = 8
NUM_WARMUP_ITERS = 10
NUM_INFERENCE_BENCHMARK_ITERS = 50

In [None]:
run = wandb.init(
    project=PROJECT,
    name=RUN_NAME,
    entity=ENTITY,
    job_type=JOB_TYPE,
    config={
        "model_artifact_id": MODEL_ARTIFACT_ID,
        "image_shape": IMAGE_SHAPE,
        "batch_size": BATCH_SIZE,
        "image_resize_factor": IMAGE_RESIZE_FACTOR,
        "validation_split": VALIDATION_SPLIT_PCT,
        "hidden_dims": HIDDEN_DIM,
        "backbone": BACKBONE,
        "learning_rate": LEARNING_RATE,
        "train_epochs": TRAIN_EPOCHS,
        "inference_batch_size": INFERENCE_BATCH_SIZE,
        "num_warmup_iters": NUM_WARMUP_ITERS,
        "num_inference_banchmark_iters": NUM_INFERENCE_BENCHMARK_ITERS
    }
)

In [None]:
def _get_traced(artifact):
    artifact = run.use_artifact(artifact, type='model')
    artifact_dir = Path(artifact.download())
    return list(artifact_dir.glob("*_traced.pt"))[0]

In [None]:
def benchmark_inference_time(
    model_artifact: str,
    image_shape: tuple[int, int],
    batch_size: int,
    num_warmup_iters: int,
    num_iter: int,
    seed: int,
):
    
    model_file = _get_traced(model_artifact)
    model = torch.jit.load(model_file).cuda()
    
    dummy_input = torch.randn(
        batch_size, 3, image_shape[0] // 2, image_shape[0] // 2, dtype=torch.float
    ).to("cuda")

    starter, ender = (
        torch.cuda.Event(enable_timing=True),
        torch.cuda.Event(enable_timing=True),
    )
    timings = np.zeros((num_iter, 1))

    print("Warming up GPU...")
    for _ in progress_bar(range(num_warmup_iters)):
        _ = model(dummy_input)

    print(
        f"Computing inference time over {num_iter} iterations with batches of {batch_size} images..."
    )

    with torch.inference_mode():
        for step in progress_bar(range(num_iter)):
            starter.record()
            _ = model(dummy_input)
            ender.record()
            torch.cuda.synchronize()
            timings[step] = starter.elapsed_time(ender)

    return np.sum(timings) / (num_iter * batch_size)

In [None]:
model = model.eval()
torch.cuda.empty_cache()
inference_time = benchmark_inference_time(model_artifact=MODEL_ARTIFACT_ID,
                    batch_size=INFERENCE_BATCH_SIZE,
                    image_shape=IMAGE_SHAPE,
                    num_warmup_iters=NUM_WARMUP_ITERS,
                    num_iter=NUM_INFERENCE_BENCHMARK_ITERS,
                    seed=SEED
                    )


# wandb.log({"Model_Parameters": get_model_parameters(model)})
# wandb.log({
#     "Inference_Time": 
# })

In [None]:
inference_time

In [None]:
wandb.log({"inference_time":inference_time})
wandb.finish()