Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unknown error when running TCAV on a multiclass dataset #793

Closed
terraformmachine opened this issue Jul 21, 2022 · 3 comments
Closed

Unknown error when running TCAV on a multiclass dataset #793

terraformmachine opened this issue Jul 21, 2022 · 3 comments

Comments

@terraformmachine
Copy link

I'm getting an unknown error when running TCAV on a multiclass dataset.

screenshot

Steps to reproduce:

  1. Open Colab Notebook: https://colab.research.google.com/drive/1hFwjvr_JLJuR5KmJgmkOSp7VUI1iAlPN?usp=sharing
  2. Run the notebook to render the widget
  3. Search text for "she|her" and Create a Slice called "female"
  4. Go to the TCAV tab and select "female" slice
  5. Run TCAV

Unknown Error 😱

Dataset class:

class EmotionData(lit_dataset.Dataset):
  # emotion dataset:
  ## url:   https://huggingface.co/datasets/emotion
  ## text:  a string feature.
  ## label: a classification label, with possible values including: sadness (0), joy (1), love (2), anger (3), fear (4), surprise (5).

  LABELS = ['0', '1', '2', '3', '4', '5']

  def __init__(self, path):
    df = pd.read_csv(path)
    self._examples = [{
      'text': row['text'],
      'label': row['label']
    } for _, row in df.iterrows()]

  def spec(self):
    return {
      'text': lit_types.TextSegment(),
      'label': lit_types.CategoryLabel(vocab=self.LABELS),
    }

Model class:

class EmotionModel(model.Model):

    LABELS = ["0", "1", "2", "3", "4", "5"]

    def __init__(self, model_path=None, **kw):
        self._model = transformers.AutoModelForSequenceClassification.from_pretrained(
            model_path, output_hidden_states=True, output_attentions=True
        )
        self._tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

    def predict_minibatch(self, inputs):
        texts = [x["text"] for x in inputs]
        tokenized_input = self._tokenizer.batch_encode_plus(
            texts,
            return_tensors="pt",
            add_special_tokens=True,
            max_length=128,
            padding="longest",
            truncation=True,
        )

        if torch.cuda.is_available():
            self._model.cuda()
            for tensor in tokenized_input:
                tokenized_input[tensor] = tokenized_input[tensor].cuda()

        outputs = self._model(**tokenized_input)

        batched_outputs = {
            "probas": torch.nn.functional.softmax(outputs.logits, dim=-1),
            "input_ids": tokenized_input["input_ids"],
            "ntok": torch.sum(tokenized_input["attention_mask"], dim=1),
            "cls_emb": outputs.hidden_states[-1][:, 0],
        }

        scalar_pred_for_gradients = torch.max(
            batched_outputs["probas"], dim=1, keepdim=False, out=None
        )[0]

        arg_max = torch.argmax(batched_outputs["probas"], axis=-1).numpy()
        grad_classes = [
            ex.get("grad_class", arg_max[i]) for (i, ex) in enumerate(inputs)
        ]
        grad_classes = [
            self.config.labels.index(label) if isinstance(label, str) else label
            for label in grad_classes
        ]
        batched_outputs["grad_class"] = torch.tensor(grad_classes)

        batched_outputs["input_emb_grad"] = torch.autograd.grad(
            scalar_pred_for_gradients,
            outputs.hidden_states[0],
            grad_outputs=torch.ones_like(scalar_pred_for_gradients),
        )[0]

        for i, layer_attention in enumerate(outputs.attentions):
            batched_outputs[f"layer_{i}/attention"] = layer_attention

        detached_outputs = {
            k: v.cpu().detach().numpy() for k, v in batched_outputs.items()
        }

        for output in utils.unbatch_preds(detached_outputs):
            ntok = output.pop("ntok")
            output["tokens"] = self._tokenizer.convert_ids_to_tokens(
                output.pop("input_ids")[:ntok]
            )
            output["token_grad_sentence"] = output["input_emb_grad"][:ntok]

            output["cls_grad"] = output["input_emb_grad"][0]

            for key in output:
                if not re.match(r"layer_(\d+)/attention", key):
                    continue
                output[key] = output[key][:, :ntok, :ntok].transpose((0, 2, 1))
                output[key] = output[key].copy()

            yield output

    def input_spec(self):
        return {
            "text": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.LABELS, required=False),
            # "input_embs": lit_types.TokenEmbeddings(align="tokens", required=False),
            # "grad_class": lit_types.CategoryLabel(vocab=self.LABELS, required=False),
        }

    def output_spec(self):
        ret = {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(vocab=self.LABELS, parent="label"),
            "cls_emb": lit_types.Embeddings(),
            "token_grad_sentence": lit_types.TokenGradients(align="tokens"),
            "cls_grad": lit_types.Gradients(
                grad_for="cls_emb", grad_target_field_key="grad_class"
            ),
        }
        for i in range(self._model.config.num_hidden_layers):
            ret[f"layer_{i}/attention"] = lit_types.AttentionHeads(
                align_in="tokens", align_out="tokens"
            )
        return ret
@jameswex
Copy link
Collaborator

@terraformmachine thanks for the colab with the reproduction of the issue!

One issue I noticed is that when you load your dataset, your 'label' field for each example is an integer for the label class ID, but you want it to be a string from your vocab for use in LIT. In the future, we are looking to add dataset and model field validation to catch and indicate these types of issues on first launch instead of having unexpected issues when using LIT.

If you change your setting of 'label' in the dataset str(row['label'], then LIT will correctly see the ground truth labels. Then the metrics and classification results modules will display correctly, as opposed to some issues they had.

But, I don't think that actually fixes your TCAV issue. To debug that, I am currently running the TCAV interpreter directly in the colab as opposed to through the UI. That way, I can see on what line of code the TCAV failed. Here is the code I use for that:

from lit_nlp.components import tcav
from lit_nlp.lib import caching
from lit_nlp.api.dataset import IndexedDataset

# LIT under the covers wraps the dataset in IndexedDataset and the model in CachingModelWrapper, so doing that here in order to use them in the TCAV interpreter.
indexed_datasets = IndexedDataset.index_all(datasets, caching.input_hash)
cached_model = caching.CachingModelWrapper(models["distilbert-base-uncased-emotion"], "distilbert-base-uncased-emotion")

# Run TCAV
ids = ["ab7570637ea93bafe16a910cbb09b798", "f09151abc3763045e779602009fc1428", "a7796da2aca7a74702762877b9506e0b", "b76b3adaa5b65f2cf0072cdd15b38103",
       "e1c57d1a02d7f2e40225758bce7cfa9f", "be9b032d394de752b85481499168f8aa" , "ded6753ba0e8021ba490a36b082e9971",
       "125065a3ebb6fa7cba26fbe048c53532", "a133362931f8837b9164da8cd5dde98b", "7e19494ccea7be91e1aa818192bbf763", "b5152500f9b9fe595e7c183605c21132"]

config = {
    'class_to_explain': "0",
    'concept_set_ids': ids,
    'dataset_name': "emotion",
    'grad_layer': "cls_grad",
}
t = tcav.TCAV()
t.run_with_metadata(indexed_datasets["emotion"].indexed_examples, cached_model, indexed_datasets["emotion"], config=config)

@jameswex
Copy link
Collaborator

I was able to get TCAV working for your model once I also updated the grad_class output from your predict method to also return the string of the class instead of the integer index. I also uncommented-out the optional input of grad_class in the model's input spec.

Please let me know if this works for you.

@terraformmachine
Copy link
Author

@jameswex that fixed it! thank you very much for your help 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants