Skip to content

Commit

Permalink
Improve output
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-perseus committed Dec 2, 2023
1 parent 4bf607c commit 1611a61
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
9 changes: 2 additions & 7 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,15 @@
urllib.request.urlretrieve(url_model, model_path)
urllib.request.urlretrieve(url_scaler, scaler_path)

single_prediction = SinglePrediction(model_path, scaler_path, "data/preprocessing/raw_features_2024.csv")
single_prediction = SinglePrediction(model_path, scaler_path, "data/preprocessing/raw_features_2024.csv",
"data/metadata/metadata.json")


@app.route('/')
def hello():
return 'Hello World!'


@app.route('/metadata')
def metadata():
metadata_json = json.load(open("data/metadata/metadata.json"))
return metadata_json


@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
Expand Down
16 changes: 13 additions & 3 deletions deploy/single_prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import json
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
from model.scaler import Scaler
Expand All @@ -9,7 +10,10 @@


class SinglePrediction:
def __init__(self, model_path, scaler_path, raw_features_path):
def __init__(self, model_path, scaler_path, raw_features_path, metadata_path):
self.metadata = json.load(open(metadata_path))
self.labels_readable = [self.metadata["parking_sg"]["fields"][field]["label"] for field in parking_data_labels]
self.max_capacity = [self.metadata["parking_sg"]["fields"][field]["max_cap"] for field in parking_data_labels]
self.scaler = Scaler.load(scaler_path)
self.single_prediction_features = SinglePredictionFeatures(raw_features_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -36,11 +40,17 @@ def predict_for_date(self, date):

output_scaled_back = self.scaler.inverse_transform(pd.DataFrame(output, columns=parking_data_labels))

return [dict(zip(parking_data_labels, row)) for row in output_scaled_back.tolist()]
return {
"predictions": output_scaled_back.tolist()[0],
"labels": parking_data_labels,
"labels_readable": self.labels_readable,
"max_capacity": self.max_capacity
}


if __name__ == "__main__":
predict = SinglePrediction("model_scripted.pt", "scaler.pkl", "../data/preprocessing/raw_features_2024.csv")
predict = SinglePrediction("../model_scripted.pt", "../scaler.pkl", "../data/preprocessing/raw_features_2024.csv",
"../data/metadata/metadata.json")
print(predict.predict_for_date("2023-12-08 08:00"))
print(predict.predict_for_date("2023-12-10 18:00"))
print(predict.predict_for_date("2023-12-12 12:00"))

0 comments on commit 1611a61

Please sign in to comment.