-
Notifications
You must be signed in to change notification settings - Fork 5.4k
/
statsmodel_predictor.py
71 lines (60 loc) · 2.35 KB
/
statsmodel_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# fmt: off
# __statsmodelpredictor_imports_start__
import os
from typing import Optional
import numpy as np # noqa: F401
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.base.model import Results
from statsmodels.regression.linear_model import OLSResults
import ray
from ray.air import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import Predictor
# __statsmodelpredictor_imports_end__
# __statsmodelpredictor_signature_start__
class StatsmodelPredictor(Predictor):
...
# __statsmodelpredictor_signature_end__
# __statsmodelpredictor_init_start__
def __init__(self, results: Results, preprocessor: Optional[Preprocessor] = None):
self.results = results
super().__init__(preprocessor)
# __statsmodelpredictor_init_end__
# __statsmodelpredictor_predict_pandas_start__
def _predict_pandas(self, data: pd.DataFrame) -> pd.DataFrame:
predictions: pd.Series = self.results.predict(data)
return predictions.to_frame(name="predictions")
# __statsmodelpredictor_predict_pandas_end__
# __statsmodelpredictor_from_checkpoint_start__
@classmethod
def from_checkpoint(
cls,
checkpoint: Checkpoint,
filename: str,
) -> Predictor:
with checkpoint.as_directory() as directory:
path = os.path.join(directory, filename)
results = OLSResults.load(path)
return cls(results, checkpoint.get_preprocessor())
# __statsmodelpredictor_from_checkpoint_end__
# __statsmodelpredictor_model_start__
data: pd.DataFrame = sm.datasets.get_rdataset("Guerry", "HistData").data
results = smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=data).fit()
# __statsmodelpredictor_model_end__
# __statsmodelpredictor_checkpoint_start__
os.makedirs("checkpoint", exist_ok=True)
results.save("checkpoint/guerry.pickle")
checkpoint = Checkpoint.from_directory("checkpoint")
# __statsmodelpredictor_checkpoint_end__
# __statsmodelpredictor_predict_start__
predictor = BatchPredictor.from_checkpoint(
checkpoint, StatsmodelPredictor, filename="guerry.pickle"
)
# This is the same data we trained our model on. Don't do this in practice.
dataset = ray.data.from_pandas(data)
predictor.predict(dataset)
# __statsmodelpredictor_predict_end__
# fmt: on