In [1]:
%pip install darts

Collecting darts
  Downloading darts-0.30.0-py3-none-any.whl (917 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m917.3/917.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting xgboost>=1.6.0
  Downloading xgboost-2.1.1-py3-none-manylinux_2_28_x86_64.whl (153.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m153.9/153.9 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting holidays>=0.11.1
  Downloading holidays-0.55-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m55.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-lightning>=1.5.0
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m51.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xarray>=0.17.0
  Downloading xarray-2024.7.0-py3-none-any.whl (1.2 MB)
[2K     [90m━━━━━━━━━━

In [2]:
import digitalhub as dh
import pandas as pd
import os

  from pandas.core.computation.check import NUMEXPR_INSTALLED
  from pandas.core import (


In [3]:
PROJECT = "demo-ml"
project = dh.get_or_create_project(PROJECT)

In [36]:
%%writefile "train-model.py"


from digitalhub_runtime_python import handler

import pandas as pd
import numpy as np

from darts import TimeSeries
from darts.datasets import AirPassengersDataset
from darts.models import NBEATSModel
from darts.metrics import mape, smape, mae

from zipfile import ZipFile

@handler()
def train_model(project):
    series = AirPassengersDataset().load()
    train, val = series[:-36], series[-36:]

    model = NBEATSModel(
        input_chunk_length=24,
        output_chunk_length=12,
        n_epochs=200,
        random_state=0
    )
    model.fit(train)
    pred = model.predict(n=36)

    model.save("predictor_model.pt")
    with ZipFile("predictor_model.pt.zip", "w") as z:
        z.write("predictor_model.pt")
        z.write("predictor_model.pt.ckpt")
    metrics = {
        "mape": mape(series, pred),
        "smape": smape(series, pred),
        "mae": mae(series, pred)
    }
    
    project.log_model(
        name="darts_model", 
        kind="model", 
        source="predictor_model.pt.zip", 
        algorithm="darts.models.NBEATSModel",
        framework="darts",
        metrics=metrics
    )

Overwriting train-model.py


In [37]:
train_fn = project.new_function(
     name="train-darts",
     kind="python",
     python_version="PYTHON3_9",
     source={"source": "train-model.py", "handler": "train_model"},
     requirements=["darts==0.30.0"])

In [20]:
build_run = train_fn.run(action="build", local_execution=False)

In [38]:
train_run = train_fn.run(action="job", local_execution=False)

2024-08-26 13:34:37,303 - INFO - Validating task.
2024-08-26 13:34:37,304 - INFO - Validating run.
2024-08-26 13:34:37,304 - INFO - Starting task.
2024-08-26 13:34:37,305 - INFO - Configuring execution.
2024-08-26 13:34:37,306 - INFO - Composing function arguments.
2024-08-26 13:34:37,389 - INFO - Executing run.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name            | Type             | Params | Mode 
-------------------------------------------------------------
0 | criterion       | MSELoss          | 0      | train
1 | train_criterion | MSELoss          | 0      | train
2 | val_criterion   | MSELoss          | 0      | train
3 | train_metrics   | MetricCollection | 0      | train
4 | val_metrics     | MetricCollection | 0      | train
5 | stacks          | ModuleList       | 6.2 M  | train
-------------------------------------------------------------
6.2 M     Trainable params
1.4 K     Non-trainable params


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Predicting: |          | 0/? [00:00<?, ?it/s]

2024-08-26 13:35:26,481 - INFO - Task completed, returning run status.


In [39]:
model = project.get_model("darts_model")
model.spec.path

's3://datalake/demo-ml/model/darts_model/b1b98698-9a39-4b03-b8fc-74ef041b7189/predictor_model.pt.zip'

In [94]:
%%writefile "serve_darts_model.py"

from darts.models import NBEATSModel
from zipfile import ZipFile
from darts import TimeSeries
import json
import pandas as pd

def init(context):
    model_name = "darts_model"

    model = context.project.get_model(model_name)
    path = model.download()[0]
    local_path_model = "extracted_model/"

    with ZipFile(path, 'r') as zip_ref:
        zip_ref.extractall(local_path_model)
    
    input_chunk_length = 24
    output_chunk_length = 12
    name_model_local = local_path_model +"predictor_model.pt"
    mm = NBEATSModel(
            input_chunk_length,
            output_chunk_length
    ).load(name_model_local)

    setattr(context, "model", mm)

def serve(context, event):

    if isinstance(event.body, bytes):
        body = json.loads(event.body)
    else:
        body = event.body
    context.logger.info(f"Received event: {body}")
    inference_input = body["inference_input"]
    
    pdf = pd.DataFrame(inference_input)
    pdf['date'] = pd.to_datetime(pdf['date'], unit='ms')

    ts = TimeSeries.from_dataframe(
        pdf,
        time_col="date",
        value_cols="value"
    )
    
    output_chunk_length = 12
    result = context.model.predict(n=output_chunk_length*2, series=ts)
    # Convert the result to a pandas DataFrame, reset the index, and convert to a list
    jsonstr = result.pd_dataframe().reset_index().to_json(orient='records')
    return json.loads(jsonstr)

Overwriting serve_darts_model.py


In [95]:
func = project.new_function(name="serve_darts_model",
                            kind="python",
                            python_version="PYTHON3_9",
                            base_image = "python:3.9",
                            source={
                                 "source": "serve_darts_model.py",
                                 "handler": "serve",
                                 "init_function": "init"})

In [96]:
run_build_model_serve = func.run(action="build", instructions=["pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu","pip3 install darts==0.30.0"])

In [97]:
run_serve = func.run(action="serve")

In [98]:
import json 
from datetime import datetime

series = AirPassengersDataset().load()
val = series[-24:]
json_value = json.loads(val.to_json())

data = map(lambda x, y: {"value": x[0], "date": datetime.timestamp(datetime.strptime(y, "%Y-%m-%dT%H:%M:%S.%f"))*1000}, json_value["data"], json_value["index"])
inference_input = list(data)
inference_input

[{'value': 360.0, 'date': -347155200000.0},
 {'value': 342.0, 'date': -344476800000.0},
 {'value': 406.0, 'date': -342057600000.0},
 {'value': 396.0, 'date': -339379200000.0},
 {'value': 420.0, 'date': -336787200000.0},
 {'value': 472.0, 'date': -334108800000.0},
 {'value': 548.0, 'date': -331516800000.0},
 {'value': 559.0, 'date': -328838400000.0},
 {'value': 463.0, 'date': -326160000000.0},
 {'value': 407.0, 'date': -323568000000.0},
 {'value': 362.0, 'date': -320889600000.0},
 {'value': 405.0, 'date': -318297600000.0},
 {'value': 417.0, 'date': -315619200000.0},
 {'value': 391.0, 'date': -312940800000.0},
 {'value': 419.0, 'date': -310435200000.0},
 {'value': 461.0, 'date': -307756800000.0},
 {'value': 472.0, 'date': -305164800000.0},
 {'value': 535.0, 'date': -302486400000.0},
 {'value': 622.0, 'date': -299894400000.0},
 {'value': 606.0, 'date': -297216000000.0},
 {'value': 508.0, 'date': -294537600000.0},
 {'value': 461.0, 'date': -291945600000.0},
 {'value': 390.0, 'date': -28926

In [99]:
import requests

SERVICE_URL = run_serve.refresh().status.to_dict()["service"]["url"]

with requests.post(f'http://{SERVICE_URL}', json={"inference_input":inference_input}) as r:
    res = r.json()
print(res)


[{'date': -283996800000, 'value': 448.4713515232}, {'date': -281318400000, 'value': 416.9437679985}, {'date': -278899200000, 'value': 488.1964291872}, {'date': -276220800000, 'value': 498.348927999}, {'date': -273628800000, 'value': 509.5799418856}, {'date': -270950400000, 'value': 614.4077203825}, {'date': -268358400000, 'value': 703.8526003896}, {'date': -265680000000, 'value': 691.8522541339}, {'date': -263001600000, 'value': 606.4765462227}, {'date': -260409600000, 'value': 532.2036568194}, {'date': -257731200000, 'value': 454.02867762}, {'date': -255139200000, 'value': 497.8013444336}, {'date': -252460800000, 'value': 524.1175493405}, {'date': -249782400000, 'value': 490.7081232278}, {'date': -247363200000, 'value': 572.8023540875}, {'date': -244684800000, 'value': 584.0529239754}, {'date': -242092800000, 'value': 595.3217350383}, {'date': -239414400000, 'value': 707.6626039279}, {'date': -236822400000, 'value': 812.7016954389}, {'date': -234144000000, 'value': 805.2495930826}, {'