From 26b7e7bafe6fdc1321c96058e0676b5b12e1b654 Mon Sep 17 00:00:00 2001 From: Marian Tietz Date: Wed, 19 Jul 2023 17:07:25 +0200 Subject: [PATCH 1/3] Initial draft implementation of CFG for LLM Based on the paper Sanchez, Guillaume, et al. "Stay on topic with Classifier-Free Guidance." arXiv preprint arXiv:2306.17806 (2023). a draft implementation of classifier free guidance. This is simply for sharing internally and might very well be completely wrong. It is debatable if we should expose such a feature as a flag to the network or make it a separate classifier instance (or a mixin). In the past we were very much against special (potentially short-lived) feature flags and it was much nicer to have this implemented as an addon/callback. We might need to do something similar here as well. --- skorch/llm/classifier.py | 90 ++++++++++++++++++++++++++++++++++++-- skorch/tests/net_cuda.pkl | Bin 23023 -> 23027 bytes 2 files changed, 87 insertions(+), 3 deletions(-) diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index 23ea8ad98..88ebfebc6 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -185,6 +185,76 @@ def __call__(self, input_ids, scores): return scores +class _CFGuidance(LogitsProcessor): + """Helper class to implement Classifier Free Guidance [1] + to guide the model sampling in a direction that takes + the prompt more into account than the generated output + without the prompt. + + Mathematically this is implemented by the following + transformation of the log-probabilities: + + .. math:: + + \text{log} \hat{\textbf{P}}_\theta(w|c) \propto + \text{log} \textbf{P}_\theta(w_i|w_{j < i}) + + \gamma ( + \text{log} \textbf{P}_\theta(w_i|w_{j4hU06XTx?CllqM2NDxBl4$he!Ph@4dxzIwU%%ha>+jpwp9=b= zAnG=i{P)|O|ve{66vzoWa^bX}57kIzVjP}V$`_uJF5^a^isH#pQKde;u z+FC5+Oq}3bz~P>>`~FU(YWu{#yE2ww$$cD2@YUUP)I&5eb#a03B4xvbJTBMsr&Z?y zA0(Fqe}>FSzMIVaGBVT|{v@38H1AavNkgHTGBD-wAH-ZjH~b?8$tH$}1@fnaWWJZ^ zrTi#&EfSJ(27Y)N&;_WdIfgpme9baBjDKocQ5!VY-Uzi3DJPYNhd5D;h@rvgm53lj z`ibwpWDg0E@NiTd6=LC08Du;y(b0ZR5V$dN7tD*x0n{holY8X)APQRIGM!cv&I zwNqlx(XEZ>SY1UQ>Tc?)?4}3;BPc3w1ijiIq`SK5Js)#>FW&dzIsgCfeV_mNlS99A zh?pSeHkODRMD8xpuhVCkRamde={dBJmUYXmica9%U zy^$*LVv{VEcN?5j`6d1#%$n?#S-AuGYIm)qe)H9Qe=_WYbrYka`f3!EUPjq4DxGna zqp*50R<{Y0G_A2t8{H<2m3bhqP544L!dt3J1uh#EARzz*St4KK`{vMD`iE;AL(?pkG#ZUcV&_{(OvMcK$WtJ)7*v{c3{DH&m=6vX_- zL>Z!wD&Bh?SE}(5r5GD>Yt+0PCs~X${rO$V+HJrYcs@8_f|l=81N~XQ9q3fP|1~oO ztF#vAp$27sLlKrb>M4g^n0fDQY{&n`MkX|v&$|8M5yGxmYl_yCT`Ii=e(l=%V*_q`|;v>%W;yuBg z!aKKsG_}O_zsb53dJNr*4f!5>=@-k{$L=;h9g(RxN81~&4Pgs4n=X(FLLLw zSE!5xu7&&}E5Vkt4Gn_bbpZKb!1Wk$@Ym% Date: Wed, 19 Jul 2023 20:44:00 +0200 Subject: [PATCH 2/3] Use `cfg_gamma` instead of `use_cfg` boolean flag - Makes it possible to set gamma parameter - Setting it to `None` disabled functionality completely --- skorch/llm/classifier.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index 88ebfebc6..b2ef4963a 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -221,7 +221,9 @@ class _CFGuidance(LogitsProcessor): References ---------- - [1]: TODO + [1]: Sanchez, Guillaume, et al. + "Stay on topic with Classifier-Free Guidance." + arXiv preprint arXiv:2306.17806 (2023). """ @@ -265,14 +267,14 @@ class _CacheModelWrapper: Set use_caching=False to disable it, e.g. for debugging. """ - def __init__(self, model, tokenizer, use_caching=True, use_cfg=False): + def __init__(self, model, tokenizer, use_caching=True, cfg_gamma=None): self.model = model self.tokenizer = tokenizer self.use_caching = use_caching self.cache = {} self._total_calls = 0 self._uncached_calls = 0 - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma def clear(self): self.cache.clear() @@ -333,17 +335,20 @@ def generate_logits(self, *, label_id, **kwargs): return recorded_logits self._uncached_calls += 1 # mainly for debugging - guidance = _CFGuidance( - model=self.model, - label_ids=label_id, - tokenizer=self.tokenizer, - ) + recorder = _LogitsRecorder( label_ids=label_id, tokenizer=self.tokenizer, ) processors = [recorder] - if self.use_cfg: + + if self.cfg_gamma is not None: + guidance = _CFGuidance( + model=self.model, + label_ids=label_id, + tokenizer=self.tokenizer, + gamma=self.cfg_gamma, + ) processors.insert(0, guidance) self.model.generate( @@ -465,7 +470,7 @@ def _fit(self, X, y, **fit_params): classes = [str(c) for c in self.classes_] self.label_ids_ = self.tokenizer_(classes)['input_ids'] self.cached_model_ = _CacheModelWrapper( - self.model_, self.tokenizer_, use_caching=self.use_caching, use_cfg=self.use_cfg, + self.model_, self.tokenizer_, use_caching=self.use_caching, cfg_gamma=self.cfg_gamma, ) return self @@ -799,7 +804,7 @@ def __init__( error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, - use_cfg=False, + cfg_gamma=None, ): self.model_name = model_name self.model = model @@ -810,7 +815,7 @@ def __init__( self.error_low_prob = error_low_prob self.threshold_low_prob = threshold_low_prob self.use_caching = use_caching - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma def check_prompt(self, prompt): """Check if the prompt is well formed. @@ -1030,7 +1035,7 @@ def __init__( error_low_prob='ignore', threshold_low_prob=0.0, use_caching=True, - use_cfg=False, + cfg_gamma=None, random_state=None, ): self.model_name = model_name @@ -1043,7 +1048,7 @@ def __init__( self.error_low_prob = error_low_prob self.threshold_low_prob = threshold_low_prob self.use_caching = use_caching - self.use_cfg = use_cfg + self.cfg_gamma = cfg_gamma self.random_state = random_state def check_prompt(self, prompt): From da062b0a05a8d4af71811ea2c58792e65f974e66 Mon Sep 17 00:00:00 2001 From: Marian Tietz Date: Wed, 26 Jul 2023 14:41:06 +0200 Subject: [PATCH 3/3] Rename label_id and kwargs - `label_id` was misleading since it is actually a list of token ids related to a label and not a scalar value. Also the general process of generating logits it not related to labels at all but rather just to tokens - `kwargs` was named to be similar to transformers `generate` convention but is meant to be passed to `generate` and is therefore, in the context of `generate_logits` a model input. This should help the reader distinguish between expected input (`token_ids`) and model input (`model_input`) --- skorch/llm/classifier.py | 69 +++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 30 deletions(-) diff --git a/skorch/llm/classifier.py b/skorch/llm/classifier.py index b2ef4963a..560fde3e0 100644 --- a/skorch/llm/classifier.py +++ b/skorch/llm/classifier.py @@ -169,9 +169,9 @@ def _extend_inputs(inputs, extra): class _LogitsRecorder(LogitsProcessor): """Helper class to record logits and force the given label token ids""" - def __init__(self, label_ids, tokenizer): + def __init__(self, token_ids, tokenizer): self.recorded_scores = [] - self.label_ids = label_ids + self.token_ids = token_ids self.tokenizer = tokenizer def __call__(self, input_ids, scores): @@ -180,7 +180,7 @@ def __call__(self, input_ids, scores): # therefore there is no device mismatch and we save a bit of GPU memory self.recorded_scores.append(scores[0].clone().cpu()) mask = torch.ones(scores.size(), dtype=torch.bool) - mask[0, self.label_ids[idx]] = False + mask[0, self.token_ids[idx]] = False scores[mask] = -float('inf') return scores @@ -227,30 +227,34 @@ class _CFGuidance(LogitsProcessor): """ - def __init__(self, model, tokenizer, label_ids, gamma=1.5): + def __init__(self, model, tokenizer, token_ids, gamma=1.5): self.model = model self.tokenizer = tokenizer self.gamma = gamma - self.label_ids = label_ids - self.recorded_scores = [] + self.token_ids = token_ids + self.token_position = 0 def __call__(self, input_ids, scores): - idx = len(self.recorded_scores) + idx = self.token_position P_wi_wjic = scores - model_input = { - 'input_ids': torch.tensor(self.label_ids)[None,:].to(self.model.device), - 'attention_mask': torch.tensor([1] * len(self.label_ids))[None, :].to(self.model.device) + model_inputs = { + 'input_ids': torch.tensor(self.token_ids)[None,:].to(self.model.device), + 'attention_mask': torch.tensor([1] * len(self.token_ids))[None, :].to(self.model.device) } - model_output = self.model.generate(**model_input, output_scores=True, return_dict_in_generate=True) - P_wi_wji = model_output.scores[idx] + model_output = self.model.generate( + **model_inputs, + max_new_tokens=len(self.token_ids), + output_scores=True, + return_dict_in_generate=True) + P_wi_wji = model_output.scores[idx] - # we pull the logits to CPU because they are not used as input, - # therefore there is no device mismatch and we save a bit of GPU memory - # TODO remove this by a counter since we're only using the position - self.recorded_scores.append(scores[0].clone().cpu()) + # We assume that this logits processor is called in `generate_logits` + # which invokes the model for each new token in the given sequence + # (self.token_ids). Thus we need to keep track where we are. + self.token_position += 1 scores = P_wi_wji + self.gamma * (P_wi_wjic - P_wi_wji) @@ -310,34 +314,39 @@ def set_cache(self, kwargs, label_id, scores): key = str(input_id) self.cache[key] = score - def generate_logits(self, *, label_id, **kwargs): + def generate_logits(self, *, token_ids, **model_inputs): + """Generate logits for given token ids based on a given model input. + The model is forced to only generate the logits for the given token + ids - all other options are weighted down so that they are impossible + to be sampled. + """ self._total_calls += 1 # mainly for debugging recorded_logits = [] - logits_cached = self.get_cache(kwargs) + logits_cached = self.get_cache(model_inputs) while logits_cached is not None: - if label_id[0] == self.tokenizer.eos_token_id: + if token_ids[0] == self.tokenizer.eos_token_id: # don't extend with eos_token -- it is already there at the end, # we don't need it twice break recorded_logits.append(logits_cached) - kwargs = _extend_inputs(kwargs, label_id[:1]) - label_id = label_id[1:] - logits_cached = self.get_cache(kwargs) + model_inputs = _extend_inputs(model_inputs, token_ids[:1]) + token_ids = token_ids[1:] + logits_cached = self.get_cache(model_inputs) - if not label_id: + if not token_ids: # the whole generation was cached return recorded_logits - if label_id[0] == self.tokenizer.pad_token_id: + if token_ids[0] == self.tokenizer.pad_token_id: # no need to generate on pad tokens return recorded_logits self._uncached_calls += 1 # mainly for debugging recorder = _LogitsRecorder( - label_ids=label_id, + token_ids=token_ids, tokenizer=self.tokenizer, ) processors = [recorder] @@ -345,8 +354,8 @@ def generate_logits(self, *, label_id, **kwargs): if self.cfg_gamma is not None: guidance = _CFGuidance( model=self.model, - label_ids=label_id, tokenizer=self.tokenizer, + token_ids=token_ids, gamma=self.cfg_gamma, ) processors.insert(0, guidance) @@ -354,10 +363,10 @@ def generate_logits(self, *, label_id, **kwargs): self.model.generate( logits_processor=processors, # TODO: should this be the max len of all labels? - max_new_tokens=len(label_id), - **kwargs + max_new_tokens=len(token_ids), + **model_inputs ) - self.set_cache(kwargs, label_id, recorder.recorded_scores) + self.set_cache(model_inputs, token_ids, recorder.recorded_scores) return recorded_logits + recorder.recorded_scores[:] @@ -489,7 +498,7 @@ def _predict_one(self, text): probas_all_labels = [] for label_id in self.label_ids_: - logits = self.cached_model_.generate_logits(label_id=label_id, **inputs) + logits = self.cached_model_.generate_logits(token_ids=label_id, **inputs) logits = torch.vstack(logits) probas = torch.nn.functional.softmax(logits, dim=-1)