Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Initial draft implementation of CFG for LLM #996

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
136 changes: 117 additions & 19 deletions skorch/llm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _extend_inputs(inputs, extra):

class _LogitsRecorder(LogitsProcessor):
"""Helper class to record logits and force the given label token ids"""
def __init__(self, label_ids, tokenizer):
def __init__(self, token_ids, tokenizer):
self.recorded_scores = []
self.label_ids = label_ids
self.token_ids = token_ids
self.tokenizer = tokenizer

def __call__(self, input_ids, scores):
Expand All @@ -180,11 +180,87 @@ def __call__(self, input_ids, scores):
# therefore there is no device mismatch and we save a bit of GPU memory
self.recorded_scores.append(scores[0].clone().cpu())
mask = torch.ones(scores.size(), dtype=torch.bool)
mask[0, self.label_ids[idx]] = False
mask[0, self.token_ids[idx]] = False
scores[mask] = -float('inf')
return scores


class _CFGuidance(LogitsProcessor):
"""Helper class to implement Classifier Free Guidance [1]
to guide the model sampling in a direction that takes
the prompt more into account than the generated output
without the prompt.

Mathematically this is implemented by the following
transformation of the log-probabilities:

.. math::

\text{log} \hat{\textbf{P}}_\theta(w|c) \propto
\text{log} \textbf{P}_\theta(w_i|w_{j < i}) +
\gamma (
\text{log} \textbf{P}_\theta(w_i|w_{j<i}, c) -
\text{log} \textbf{P}_\theta(w_i|w_{j<i})
)

In essence, the generated logits without context (prompting)
are discounted from the logits with context. This weighted
discount (gamma defaults to 1.5) is added to the contextless
logits to steer sampling.

In consequence, CFG needs to sample twice from the model,
once for with-context and once for without-contet logits.

The standard way of sampling would be to determine
``log P(w_i | w_{j < i}, c)`` - this is what we already have
normally. Therefore, this logit processor will receive
these log-probabiltiies as its scores already. Now we
also need ``log P(w_i | w{j < i})`` for which we need to
run inference once again.

References
----------

[1]: Sanchez, Guillaume, et al.
"Stay on topic with Classifier-Free Guidance."
arXiv preprint arXiv:2306.17806 (2023).

"""

def __init__(self, model, tokenizer, token_ids, gamma=1.5):
self.model = model
self.tokenizer = tokenizer
self.gamma = gamma
self.token_ids = token_ids
self.token_position = 0

def __call__(self, input_ids, scores):
idx = self.token_position

P_wi_wjic = scores

model_inputs = {
'input_ids': torch.tensor(self.token_ids)[None,:].to(self.model.device),
'attention_mask': torch.tensor([1] * len(self.token_ids))[None, :].to(self.model.device)
}

model_output = self.model.generate(
**model_inputs,
max_new_tokens=len(self.token_ids),
output_scores=True,
return_dict_in_generate=True)
P_wi_wji = model_output.scores[idx]

# We assume that this logits processor is called in `generate_logits`
# which invokes the model for each new token in the given sequence
# (self.token_ids). Thus we need to keep track where we are.
self.token_position += 1

scores = P_wi_wji + self.gamma * (P_wi_wjic - P_wi_wji)

return scores


class _CacheModelWrapper:
"""Helper class that caches model generations

Expand All @@ -195,13 +271,14 @@ class _CacheModelWrapper:
Set use_caching=False to disable it, e.g. for debugging.

"""
def __init__(self, model, tokenizer, use_caching=True):
def __init__(self, model, tokenizer, use_caching=True, cfg_gamma=None):
self.model = model
self.tokenizer = tokenizer
self.use_caching = use_caching
self.cache = {}
self._total_calls = 0
self._uncached_calls = 0
self.cfg_gamma = cfg_gamma

def clear(self):
self.cache.clear()
Expand Down Expand Up @@ -237,42 +314,59 @@ def set_cache(self, kwargs, label_id, scores):
key = str(input_id)
self.cache[key] = score

def generate_logits(self, *, label_id, **kwargs):
def generate_logits(self, *, token_ids, **model_inputs):
"""Generate logits for given token ids based on a given model input.
The model is forced to only generate the logits for the given token
ids - all other options are weighted down so that they are impossible
to be sampled.
"""
self._total_calls += 1 # mainly for debugging

recorded_logits = []
logits_cached = self.get_cache(kwargs)
logits_cached = self.get_cache(model_inputs)
while logits_cached is not None:
if label_id[0] == self.tokenizer.eos_token_id:
if token_ids[0] == self.tokenizer.eos_token_id:
# don't extend with eos_token -- it is already there at the end,
# we don't need it twice
break

recorded_logits.append(logits_cached)
kwargs = _extend_inputs(kwargs, label_id[:1])
label_id = label_id[1:]
logits_cached = self.get_cache(kwargs)
model_inputs = _extend_inputs(model_inputs, token_ids[:1])
token_ids = token_ids[1:]
logits_cached = self.get_cache(model_inputs)

if not label_id:
if not token_ids:
# the whole generation was cached
return recorded_logits

if label_id[0] == self.tokenizer.pad_token_id:
if token_ids[0] == self.tokenizer.pad_token_id:
# no need to generate on pad tokens
return recorded_logits

self._uncached_calls += 1 # mainly for debugging

recorder = _LogitsRecorder(
label_ids=label_id,
token_ids=token_ids,
tokenizer=self.tokenizer,
)
processors = [recorder]

if self.cfg_gamma is not None:
guidance = _CFGuidance(
model=self.model,
tokenizer=self.tokenizer,
token_ids=token_ids,
gamma=self.cfg_gamma,
)
processors.insert(0, guidance)

self.model.generate(
logits_processor=[recorder],
logits_processor=processors,
# TODO: should this be the max len of all labels?
max_new_tokens=len(label_id),
**kwargs
max_new_tokens=len(token_ids),
**model_inputs
)
self.set_cache(kwargs, label_id, recorder.recorded_scores)
self.set_cache(model_inputs, token_ids, recorder.recorded_scores)
return recorded_logits + recorder.recorded_scores[:]


Expand Down Expand Up @@ -385,7 +479,7 @@ def _fit(self, X, y, **fit_params):
classes = [str(c) for c in self.classes_]
self.label_ids_ = self.tokenizer_(classes)['input_ids']
self.cached_model_ = _CacheModelWrapper(
self.model_, self.tokenizer_, use_caching=self.use_caching
self.model_, self.tokenizer_, use_caching=self.use_caching, cfg_gamma=self.cfg_gamma,
)
return self

Expand All @@ -404,7 +498,7 @@ def _predict_one(self, text):

probas_all_labels = []
for label_id in self.label_ids_:
logits = self.cached_model_.generate_logits(label_id=label_id, **inputs)
logits = self.cached_model_.generate_logits(token_ids=label_id, **inputs)
logits = torch.vstack(logits)
probas = torch.nn.functional.softmax(logits, dim=-1)

Expand Down Expand Up @@ -719,6 +813,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
cfg_gamma=None,
):
self.model_name = model_name
self.model = model
Expand All @@ -729,6 +824,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.cfg_gamma = cfg_gamma

def check_prompt(self, prompt):
"""Check if the prompt is well formed.
Expand Down Expand Up @@ -948,6 +1044,7 @@ def __init__(
error_low_prob='ignore',
threshold_low_prob=0.0,
use_caching=True,
cfg_gamma=None,
random_state=None,
):
self.model_name = model_name
Expand All @@ -960,6 +1057,7 @@ def __init__(
self.error_low_prob = error_low_prob
self.threshold_low_prob = threshold_low_prob
self.use_caching = use_caching
self.cfg_gamma = cfg_gamma
self.random_state = random_state

def check_prompt(self, prompt):
Expand Down
Binary file modified skorch/tests/net_cuda.pkl
Binary file not shown.