Skip to content

Commit

Permalink
Fix use of Falcon language model for text generation (#3592)
Browse files Browse the repository at this point in the history
* fix falcon

* change model parameters to not use additional requirements

* remove additional test dependencies
  • Loading branch information
CloseChoice committed Apr 5, 2024
1 parent 87a6cf3 commit bbbe821
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
4 changes: 4 additions & 0 deletions shap/models/_teacher_forcing.py
@@ -1,3 +1,5 @@
import inspect

import numpy as np
import scipy.special

Expand Down Expand Up @@ -282,6 +284,8 @@ def model_inference(self, inputs, output_ids):
inputs["position_ids"] = (inputs["attention_mask"].long().cumsum(-1) - 1)
inputs["position_ids"].masked_fill_(inputs["attention_mask"] == 0, 0)
# model inference
expected_parameters = list(inspect.signature(self.similarity_model.forward).parameters)
inputs = {k: v for k, v in inputs.items() if k in expected_parameters}
outputs = self.similarity_model(**inputs, return_dict=True)
logits = outputs.logits.detach().cpu().numpy().astype('float64')
elif self.similarity_model_type == "tf":
Expand Down
33 changes: 33 additions & 0 deletions tests/models/test_teacher_forcing_logits.py
Expand Up @@ -6,6 +6,39 @@
import shap


def test_falcon():
transformers = pytest.importorskip("transformers")
requests = pytest.importorskip("requests")
name = "fxmarty/really-tiny-falcon-testing"
try:
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
model = transformers.AutoModelForCausalLM.from_pretrained(
name, trust_remote_code=True, load_in_8bit=False, low_cpu_mem_usage=False
)
except requests.exceptions.RequestException:
pytest.xfail(reason="Connection error to transformers model")

model = model.eval()


s = ["I enjoy walking with my cute dog"]
gen_dict = dict(
max_new_tokens=100,
num_beams=5,
renormalize_logits=True,
no_repeat_ngram_size=8,
)


model.config.task_specific_params = dict()
model.config.task_specific_params["text-generation"] = gen_dict
shap_model = shap.models.TeacherForcing(model, tokenizer)

explainer = shap.Explainer(shap_model, tokenizer)
shap_values = explainer(s)
assert not np.isnan(np.sum(shap_values.values))


def test_method_get_teacher_forced_logits_for_encoder_decoder_model():
"""Tests if get_teacher_forced_logits() works for encoder-decoder models."""
transformers = pytest.importorskip("transformers")
Expand Down

0 comments on commit bbbe821

Please sign in to comment.