Skip to content

Commit

Permalink
allow setup(self, weights: str)
Browse files Browse the repository at this point in the history
Signed-off-by: technillogue <technillogue@gmail.com>
  • Loading branch information
technillogue committed Mar 5, 2024
1 parent 46d0eb3 commit 92cda9d
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/cog/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@


class BasePredictor(ABC):
def setup(self, weights: Optional[Union[CogFile, CogPath]] = None) -> None:
def setup(self, weights: Optional[Union[CogFile, CogPath, str]] = None) -> None:
"""
An optional method to prepare the model so multiple predictions run efficiently.
"""
Expand All @@ -70,7 +70,7 @@ def run_setup(predictor: BasePredictor) -> None:
predictor.setup()
return

weights: Union[io.IOBase, Path, None]
weights: Union[io.IOBase, Path, str, None]

weights_url = os.environ.get("COG_WEIGHTS")
weights_path = "weights"
Expand All @@ -85,6 +85,9 @@ def run_setup(predictor: BasePredictor) -> None:
elif weights_type == CogPath:
# TODO: So this can be a url. evil!
weights = cast(CogPath, CogPath.validate(weights_url))
# allow people to download weights themselves
elif weights_type == str:
weights = weights_url
else:
raise ValueError(
f"Predictor.setup() has an argument 'weights' of type {weights_type}, but only File and Path are supported"
Expand Down

0 comments on commit 92cda9d

Please sign in to comment.