diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index e064a966e9..301b68595d 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -2,12 +2,13 @@ from pydantic import BaseModel +from dspy.clients.base_lm import BaseLM +from dspy.clients.lm import LM from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature from dspy.utils.callback import with_callbacks -from dspy.clients.lm import LM class Predict(Module, Parameter): @@ -80,7 +81,7 @@ def forward(self, **kwargs): # Get the right LM to use. lm = kwargs.pop("lm", self.lm) or dspy.settings.lm - assert isinstance(lm, dspy.LM), "No LM is loaded." + assert isinstance(lm, BaseLM), "No LM is loaded." # If temperature is 0.0 but its n > 1, set temperature to 0.7. temperature = config.get("temperature")