diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index 88ebfebc6..b2ef4963a 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -221,7 +221,9 @@ class _CFGuidance(LogitsProcessor): References ---------- - [1]: TODO + [1]: Sanchez, Guillaume, et al. + "Stay on topic with Classifier-Free Guidance." + arXiv preprint arXiv:2306.17806 (2023). """ @@ -265,14 +267,14 @@ class _CacheModelWrapper: Set use_caching=False to disable it, e.g. for debugging. """ - def __init__(self, model, tokenizer, use_caching=True, use_cfg=False): + def __init__(self, model, tokenizer, use_caching=True, cfg_gamma=None): self.model = model self.tokenizer = tokenizer self.use_caching = use_caching self.cache = {} self._total_calls = 0 self._uncached_calls = 0 - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma def clear(self): self.cache.clear() @@ -333,17 +335,20 @@ def generate_logits(self, *, label_id, **kwargs): return recorded_logits self._uncached_calls += 1 # mainly for debugging - guidance = _CFGuidance( - model=self.model, - label_ids=label_id, - tokenizer=self.tokenizer, - ) + recorder = _LogitsRecorder( label_ids=label_id, tokenizer=self.tokenizer, ) processors = [recorder] - if self.use_cfg: + + if self.cfg_gamma is not None: + guidance = _CFGuidance( + model=self.model, + label_ids=label_id, + tokenizer=self.tokenizer, + gamma=self.cfg_gamma, + ) processors.insert(0, guidance) self.model.generate( @@ -465,7 +470,7 @@ def _fit(self, X, y, **fit_params): classes = [str(c) for c in self.classes_] self.label_ids_ = self.tokenizer_(classes)['input_ids'] self.cached_model_ = _CacheModelWrapper( - self.model_, self.tokenizer_, use_caching=self.use_caching, use_cfg=self.use_cfg, + self.model_, self.tokenizer_, use_caching=self.use_caching, cfg_gamma=self.cfg_gamma, ) return self @@ -799,7 +804,7 @@ def __init__( error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, - use_cfg=False, + cfg_gamma=None, ): self.model_name = model_name self.model = model @@ -810,7 +815,7 @@ def __init__( self.error_low_prob = error_low_prob self.threshold_low_prob = threshold_low_prob self.use_caching = use_caching - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma def check_prompt(self, prompt): """Check if the prompt is well formed. @@ -1030,7 +1035,7 @@ def __init__( error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, - use_cfg=False, + cfg_gamma=None, random_state=None, ): self.model_name = model_name @@ -1043,7 +1048,7 @@ def __init__( self.error_low_prob = error_low_prob self.threshold_low_prob = threshold_low_prob self.use_caching = use_caching - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma self.random_state = random_state def check_prompt(self, prompt):