Skip to content

Commit

Permalink
Use cfg_gamma instead of use_cfg boolean flag
Browse files Browse the repository at this point in the history
- Makes it possible to set gamma parameter
- Setting it to `None` disabled functionality completely
  • Loading branch information
ottonemo committed Jul 19, 2023
1 parent 26b7e7b commit 96de091
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions skorch/llm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ class _CFGuidance(LogitsProcessor):
References
----------
[1]: TODO
[1]: Sanchez, Guillaume, et al.
"Stay on topic with Classifier-Free Guidance."
arXiv preprint arXiv:2306.17806 (2023).
"""

Expand Down Expand Up @@ -265,14 +267,14 @@ class _CacheModelWrapper:
Set use_caching=False to disable it, e.g. for debugging.
"""
def __init__(self, model, tokenizer, use_caching=True, use_cfg=False):
def __init__(self, model, tokenizer, use_caching=True, cfg_gamma=None):
self.model = model
self.tokenizer = tokenizer
self.use_caching = use_caching
self.cache = {}
self._total_calls = 0
self._uncached_calls = 0
self.use_cfg = use_cfg
self.cfg_gamma = cfg_gamma

def clear(self):
self.cache.clear()
Expand Down Expand Up @@ -333,17 +335,19 @@ def generate_logits(self, *, label_id, **kwargs):
return recorded_logits

self._uncached_calls += 1 # mainly for debugging
guidance = _CFGuidance(
model=self.model,
label_ids=label_id,
tokenizer=self.tokenizer,
)

recorder = _LogitsRecorder(
label_ids=label_id,
tokenizer=self.tokenizer,
)
processors = [recorder]
if self.use_cfg:

if self.cfg_gamma is not None:
guidance = _CFGuidance(
model=self.model,
label_ids=label_id,
tokenizer=self.tokenizer,
)
processors.insert(0, guidance)

self.model.generate(
Expand Down Expand Up @@ -465,7 +469,7 @@ def _fit(self, X, y, **fit_params):
classes = [str(c) for c in self.classes_]
self.label_ids_ = self.tokenizer_(classes)['input_ids']
self.cached_model_ = _CacheModelWrapper(
self.model_, self.tokenizer_, use_caching=self.use_caching, use_cfg=self.use_cfg,
self.model_, self.tokenizer_, use_caching=self.use_caching, cfg_gamma=self.cfg_gamma,
)
return self

Expand Down Expand Up @@ -799,7 +803,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
use_cfg=False,
cfg_gamma=None,
):
self.model_name = model_name
self.model = model
Expand All @@ -810,7 +814,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.use_cfg = use_cfg
self.cfg_gamma = cfg_gamma

def check_prompt(self, prompt):
"""Check if the prompt is well formed.
Expand Down Expand Up @@ -1030,7 +1034,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
use_cfg=False,
cfg_gamma=None,
random_state=None,
):
self.model_name = model_name
Expand All @@ -1043,7 +1047,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.use_cfg = use_cfg
self.cfg_gamma = cfg_gamma
self.random_state = random_state

def check_prompt(self, prompt):
Expand Down

0 comments on commit 96de091

Please sign in to comment.