In [None]:
from google.cloud import bigquery
from google.oauth2 import service_account


class Dataset:
    def __init__(self, texts, labels):
        self.text = texts
        self.label = labels

    def __getitem__(self, key):
        return {"label": self.label[key], "text": self.text[key]}


def get_dataset():
    key_path = "/Users/yco/dev/myreddit/dbt-user-creds.json"
    credentials = service_account.Credentials.from_service_account_file(key_path)

    client = bigquery.Client(
        credentials=credentials,
        project=credentials.project_id,
    )
    query_job = client.query(f"SELECT * FROM `reddit_texts.posts_clean` ")
    return Dataset(
        [row["text"] for row in query_job], [row["subreddit"] for row in query_job]
    )

In [None]:
dataset = get_dataset()

In [None]:
import spacy

textcat_spacy = spacy.load(
    "../models/subreddit_classif/textcat_ens/2022/02/16/model-best"
)
tokenizer_spacy = spacy.tokenizer.Tokenizer(textcat_spacy.vocab)

# Run the spacy pipeline on some random text just to retrieve the classes
doc = textcat_spacy("hi")
classes = list(doc.cats.keys())

# Define a function to predict
def predict(texts):
    # convert texts to bare strings
    texts = [str(text) for text in texts]
    results = []
    for doc in textcat_spacy.pipe(texts):
        # results.append([{'label': cat, 'score': doc.cats[cat]} for cat in doc.cats])
        results.append([doc.cats[cat] for cat in classes])
    return results


# Create a function to create a transformers-like tokenizer to match shap's expectations
def tok_adapter(text, return_offsets_mapping=False):
    doc = tokenizer_spacy(text)
    out = {"input_ids": [tok.norm for tok in doc]}
    if return_offsets_mapping:
        out["offset_mapping"] = [(tok.idx, tok.idx + len(tok)) for tok in doc]
    return out

In [None]:
import shap

# Create the Shap Explainer
# - predict is the "model" function, adapted to a transformers-like model
# - masker is the masker used by shap, which relies on a transformers-like tokenizer
# - algorithm is set to permuation, which is the one used for transformers models
# - output_names are the classes (altough it is not propagated to the permutation explainer currently, which is why plots do not have the labels)
# - max_evals is set to a high number to reduce the probability of cases where the explainer fails because there are too many tokens
explainer = shap.Explainer(
    predict,
    masker=shap.maskers.Text(tok_adapter),
    algorithm="permutation",
    output_names=classes,
    max_evals=1500,
)

In [None]:
shap.plots.text(explainer(dataset[:3]["text"]))

In [None]:
import rubrix as rb
from tqdm import tqdm

records = []
for sample in tqdm(dataset[:10]["text"]):
    doc = textcat_spacy(sample)
    try:
        shap_values = explainer([sample])
    except:
        continue
    predictions = {i: doc.cats[cat] for i, cat in enumerate(classes)}
    predicted_class = max(predictions, key=lambda x: predictions[x])
    token_attributions = [
        rb.TokenAttributions(
            token=token, attributions={predicted_class: values[predicted_class]}
        )  # ignore first (CLS) and last (SEP) tokens
        for token, values in zip(shap_values[0].data, shap_values[0].values)
    ]
    records.append(
        rb.TextClassificationRecord(
            inputs=sample,
            prediction=[(classes[i], prob) for i, prob in predictions.items()],
            prediction_agent="textcat_ens",
            explanation={"text": token_attributions},
            multi_label=False,
        )
    )
rb.delete("textcat_ens_explainations")
rb.log(records, name="textcat_ens_explainations")

# token_attributions