diff --git a/head_segmentation/predict_pipeline.py b/head_segmentation/predict_pipeline.py index 2db21dd..61f82c2 100644 --- a/head_segmentation/predict_pipeline.py +++ b/head_segmentation/predict_pipeline.py @@ -11,10 +11,13 @@ class HumanHeadSegmentationPipeline: def __init__( self, model_path: str = C.HEAD_SEGMENTATION_MODEL_PATH, - image_input_resolution: int = 512, ): + ckpt = torch.load(model_path, map_location=torch.device("cpu")) + self._preprocessing_pipeline = ip.PreprocessingPipeline( - nn_image_input_resolution=image_input_resolution + nn_image_input_resolution=ckpt["hyper_parameters"][ + "nn_image_input_resolution" + ] ) self._model = mdl.HeadSegmentationModel.load_from_checkpoint( ckpt_path=model_path