Skip to content

Commit

Permalink
Initial draft implementation of CFG for LLM
Browse files Browse the repository at this point in the history
Based on the paper

        Sanchez, Guillaume, et al.
        "Stay on topic with Classifier-Free Guidance."
        arXiv preprint arXiv:2306.17806 (2023).

a draft implementation of classifier free guidance.

This is simply for sharing internally and might very well be
completely wrong. It is debatable if we should expose such
a feature as a flag to the network or make it a separate
classifier instance (or a mixin). In the past we were
very much against special (potentially short-lived) feature
flags and it was much nicer to have this implemented as
an addon/callback. We might need to do something similar
here as well.
  • Loading branch information
ottonemo committed Jul 19, 2023
1 parent 549f3e6 commit 26b7e7b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 3 deletions.
90 changes: 87 additions & 3 deletions skorch/llm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,76 @@ def __call__(self, input_ids, scores):
return scores


class _CFGuidance(LogitsProcessor):
"""Helper class to implement Classifier Free Guidance [1]
to guide the model sampling in a direction that takes
the prompt more into account than the generated output
without the prompt.
Mathematically this is implemented by the following
transformation of the log-probabilities:
.. math::
\text{log} \hat{\textbf{P}}_\theta(w|c) \propto
\text{log} \textbf{P}_\theta(w_i|w_{j < i}) +
\gamma (
\text{log} \textbf{P}_\theta(w_i|w_{j<i}, c) -
\text{log} \textbf{P}_\theta(w_i|w_{j<i})
)
In essence, the generated logits without context (prompting)
are discounted from the logits with context. This weighted
discount (gamma defaults to 1.5) is added to the contextless
logits to steer sampling.
In consequence, CFG needs to sample twice from the model,
once for with-context and once for without-contet logits.
The standard way of sampling would be to determine
``log P(w_i | w_{j < i}, c)`` - this is what we already have
normally. Therefore, this logit processor will receive
these log-probabiltiies as its scores already. Now we
also need ``log P(w_i | w{j < i})`` for which we need to
run inference once again.
References
----------
[1]: TODO
"""

def __init__(self, model, tokenizer, label_ids, gamma=1.5):
self.model = model
self.tokenizer = tokenizer
self.gamma = gamma
self.label_ids = label_ids
self.recorded_scores = []

def __call__(self, input_ids, scores):
idx = len(self.recorded_scores)

P_wi_wjic = scores

model_input = {
'input_ids': torch.tensor(self.label_ids)[None,:].to(self.model.device),
'attention_mask': torch.tensor([1] * len(self.label_ids))[None, :].to(self.model.device)
}
model_output = self.model.generate(**model_input, output_scores=True, return_dict_in_generate=True)
P_wi_wji = model_output.scores[idx]


# we pull the logits to CPU because they are not used as input,
# therefore there is no device mismatch and we save a bit of GPU memory
# TODO remove this by a counter since we're only using the position
self.recorded_scores.append(scores[0].clone().cpu())

scores = P_wi_wji + self.gamma * (P_wi_wjic - P_wi_wji)

return scores


class _CacheModelWrapper:
"""Helper class that caches model generations
Expand All @@ -195,13 +265,14 @@ class _CacheModelWrapper:
Set use_caching=False to disable it, e.g. for debugging.
"""
def __init__(self, model, tokenizer, use_caching=True):
def __init__(self, model, tokenizer, use_caching=True, use_cfg=False):
self.model = model
self.tokenizer = tokenizer
self.use_caching = use_caching
self.cache = {}
self._total_calls = 0
self._uncached_calls = 0
self.use_cfg = use_cfg

def clear(self):
self.cache.clear()
Expand Down Expand Up @@ -262,12 +333,21 @@ def generate_logits(self, *, label_id, **kwargs):
return recorded_logits

self._uncached_calls += 1 # mainly for debugging
guidance = _CFGuidance(
model=self.model,
label_ids=label_id,
tokenizer=self.tokenizer,
)
recorder = _LogitsRecorder(
label_ids=label_id,
tokenizer=self.tokenizer,
)
processors = [recorder]
if self.use_cfg:
processors.insert(0, guidance)

self.model.generate(
logits_processor=[recorder],
logits_processor=processors,
# TODO: should this be the max len of all labels?
max_new_tokens=len(label_id),
**kwargs
Expand Down Expand Up @@ -385,7 +465,7 @@ def _fit(self, X, y, **fit_params):
classes = [str(c) for c in self.classes_]
self.label_ids_ = self.tokenizer_(classes)['input_ids']
self.cached_model_ = _CacheModelWrapper(
self.model_, self.tokenizer_, use_caching=self.use_caching
self.model_, self.tokenizer_, use_caching=self.use_caching, use_cfg=self.use_cfg,
)
return self

Expand Down Expand Up @@ -719,6 +799,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
use_cfg=False,
):
self.model_name = model_name
self.model = model
Expand All @@ -729,6 +810,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.use_cfg = use_cfg

def check_prompt(self, prompt):
"""Check if the prompt is well formed.
Expand Down Expand Up @@ -948,6 +1030,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
use_cfg=False,
random_state=None,
):
self.model_name = model_name
Expand All @@ -960,6 +1043,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.use_cfg = use_cfg
self.random_state = random_state

def check_prompt(self, prompt):
Expand Down
Binary file modified skorch/tests/net_cuda.pkl
Binary file not shown.

0 comments on commit 26b7e7b

Please sign in to comment.