# Get predictions on test

In [1]:
import ray
from ray.train.sklearn import SklearnCheckpoint, SklearnPredictor
from ray.train.batch_predictor import BatchPredictor

import pandas as pd
import numpy as np

## Configuration

In [15]:
INPUT_MODEL_PATH = '/Users/rgareev/projects/mlops-openfoodfacts/wrk/trainings/20220831-dev/model'
INPUT_DATA_PATH = '/Users/rgareev/data/openfoodfacts/wrk/20220831-dev/test.parquet'
# TODO this script should not deal with labels at all
LABEL_COLUMN = 'nova_group'
#
OUTPUT_DATA_PATH = '/Users/rgareev/projects/mlops-openfoodfacts/wrk/testings/20220831-dev/model'

## Script

In [3]:
input_ds = ray.data.read_parquet(INPUT_DATA_PATH)

2022-09-11 18:46:00,979	INFO worker.py:1509 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


In [4]:
input_ds.schema()

product_name: string
nova_group: int8
ingredients_list: list<item: string>
  child 0, item: string
code: string
-- schema metadata --
pandas: '{"index_columns": ["code"], "column_indexes": [{"name": null, "f' + 684

In [5]:
input_ds = input_ds.repartition(10)

Read: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.06it/s]
Repartition: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 16.31it/s]


In [7]:
from ray.data.context import DatasetContext

ctx = DatasetContext.get_current()
ctx.enable_tensor_extension_casting = False

In [8]:
input_ds = input_ds.drop_columns([LABEL_COLUMN])

Map_Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 12.59it/s]


In [9]:
model_checkpoint = SklearnCheckpoint.from_directory(INPUT_MODEL_PATH)
predictor = BatchPredictor(model_checkpoint, SklearnPredictor)

In [12]:
# does not work
# model_output_ds = predictor.predict(input_ds, keep_columns='code')
model_output_ds = predictor.predict(input_ds)

Map Progress (2 actors 1 pending): 100%|███████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.29it/s]


In [13]:
result_ds = ray.data.from_arrow_refs(input_ds.to_arrow_refs()).zip(
    ray.data.from_arrow_refs(model_output_ds.to_arrow_refs()))

In [14]:
result_ds.schema()

product_name: string
ingredients_list: list<item: string>
  child 0, item: string
predictions: int8
-- schema metadata --
pandas: '{"index_columns": [{"kind": "range", "name": null, "start": 0, "' + 524

In [16]:
result_ds.write_parquet(OUTPUT_DATA_PATH)

Write Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 18.11it/s]


In [17]:
!ls -alh $OUTPUT_DATA_PATH

total 17776
drwxr-xr-x  12 rgareev  staff   384B Sep 11 18:50 [1m[36m.[m[m
drwxr-xr-x   3 rgareev  staff    96B Sep 11 18:50 [1m[36m..[m[m
-rw-r--r--   1 rgareev  staff   895K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000000.parquet
-rw-r--r--   1 rgareev  staff   878K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000001.parquet
-rw-r--r--   1 rgareev  staff   886K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000002.parquet
-rw-r--r--   1 rgareev  staff   890K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000003.parquet
-rw-r--r--   1 rgareev  staff   892K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000004.parquet
-rw-r--r--   1 rgareev  staff   879K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000005.parquet
-rw-r--r--   1 rgareev  staff   889K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000006.parquet
-rw-r--r--   1 rgareev  staff   896K Sep 11 18:50 b386d41642984276b426fdccbd958a30_000007.parquet
-rw-r--r--   1 rgareev  staff   881K Sep 11 18:50 b386d41642984276b42