diff --git a/diffcse/tool.py b/diffcse/tool.py index 12ff764..855cc36 100644 --- a/diffcse/tool.py +++ b/diffcse/tool.py @@ -37,11 +37,9 @@ def __init__(self, model_name_or_path: str, if pooler is not None: self.pooler = pooler - elif "unsup" in model_name_or_path: - logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.") - self.pooler = "cls_before_pooler" else: - self.pooler = "cls" + logger.info("Use `cls_before_pooler` for DiffCSE models. If you want to use other pooling policy, specify `pooler` argument.") + self.pooler = "cls_before_pooler" def encode(self, sentence: Union[str, List[str]], device: str = None,