# Transformers Interpret Multiclass Classification Example

In [1]:
!pip install transformers
!pip install transformers-interpret


Collecting transformers
  Downloading transformers-4.34.1-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.18.0-py3-none-any.whl (301 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.15,>=0.14 (from transformers)
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m44.3 MB/s[0m eta [36m0:00:00[0m
Col

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Import Industry Classification Mode
This finetuned model by @sampathkethineedi uses a distilbert base to predict the professional industry a text is referring to.

In [3]:
tokenizer = AutoTokenizer.from_pretrained("sampathkethineedi/industry-classification")
model = AutoModelForSequenceClassification.from_pretrained("sampathkethineedi/industry-classification")

Downloading (…)okenizer_config.json:   0%|          | 0.00/58.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/5.07k [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Let's explore the classes, there are 62 unique classes. Many of these are overlapping/related industries such as __Health Care Equipment__ and __Health Care Supplies__.

In [4]:
model.config.id2label

{0: 'Advertising',
 1: 'Aerospace & Defense',
 2: 'Apparel Retail',
 3: 'Apparel, Accessories & Luxury Goods',
 4: 'Application Software',
 5: 'Asset Management & Custody Banks',
 6: 'Auto Parts & Equipment',
 7: 'Biotechnology',
 8: 'Building Products',
 9: 'Casinos & Gaming',
 10: 'Commodity Chemicals',
 11: 'Communications Equipment',
 12: 'Construction & Engineering',
 13: 'Construction Machinery & Heavy Trucks',
 14: 'Consumer Finance',
 15: 'Data Processing & Outsourced Services',
 16: 'Diversified Metals & Mining',
 17: 'Diversified Support Services',
 18: 'Electric Utilities',
 19: 'Electrical Components & Equipment',
 20: 'Electronic Equipment & Instruments',
 21: 'Environmental & Facilities Services',
 22: 'Gold',
 23: 'Health Care Equipment',
 24: 'Health Care Facilities',
 25: 'Health Care Services',
 26: 'Health Care Supplies',
 27: 'Health Care Technology',
 28: 'Homebuilding',
 29: 'Hotels, Resorts & Cruise Lines',
 30: 'Human Resource & Employment Services',
 31: 'IT Co

Import __SequenceClassificationExplainer__ from transformers interpret. This class should work with most if not all language models with a sequence classification head from the transformers package.


In [5]:
from transformers_interpret import SequenceClassificationExplainer

In [6]:
sample_text = """
Stocks ended a choppy session mixed as investors digested a host of corporate earnings results and considered policymakers’ next moves to support the still virus-stricken economy.
The S&P 500 shook off earlier declines to narrowly eke out a record closing high.The Dow ended a tick below its recent record closing level."""

In [20]:
sample_text = """Rafale is a better machine when compared with F16"""
sample_text = """This movie deserve 5 stars but food served in theatre was pathetic"""


In [21]:
multiclass_explainer = SequenceClassificationExplainer(model=model, tokenizer=tokenizer)

In [22]:
# call the exlplainer
word_attributions = multiclass_explainer(text=sample_text)

In [23]:
# seems to be an appropriate prediction
multiclass_explainer.predicted_class_name

'Movies & Entertainment'

In [24]:
#True Label
model.config.id2label[multiclass_explainer.selected_index]

'Movies & Entertainment'

In [25]:
multiclass_explainer.pred_probs

tensor(0.9965)

In [26]:
# look the the raw word attributions
#word_attributions

## Visualizating Explanations
With a single call to the `visualize()` method we get a nice inline display of what inputs are causing the activations to fire that led to this prediction. **Note the alogirthm used to calcualte attributions are Layer Integreated Gradients to read more about them click [here](https://captum.ai/docs/algorithms)**

In [27]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
40.0,Movies & Entertainment (1.00),Movies & Entertainment,1.94,[CLS] this movie deserve 5 stars but food served in theatre was pathetic [SEP]
,,,,


## Explaining The Same Text For A Different Class
Lets say we think this text could also fall somewhat under the class of __Asset Management & Custody Banks__ If we want it is also possible to get an explantion/attributions for the text with that class

In [29]:
word_attributions = multiclass_explainer(sample_text, class_name="Restaurants")

In [31]:
# look the the raw word attributions
#word_attributions

In [32]:
#True Label
model.config.id2label[multiclass_explainer.selected_index]

'Restaurants'

In [33]:
# seems to be an appropriate prediction
multiclass_explainer.predicted_class_name

'Movies & Entertainment'

In [34]:
multiclass_explainer.pred_probs

tensor(5.8132e-05)

The results are close to the first visualization, a good sign that the model is generalizing well for both of these related classes

In [35]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
52.0,Movies & Entertainment (0.00),Restaurants,-0.8,[CLS] this movie deserve 5 stars but food served in theatre was pathetic [SEP]
,,,,


What if we get attributions for class that makes no sense in this context such as __Restaurants__?


In [32]:
word_attributions = multiclass_explainer(sample_text, class_name="Restaurants")

There isn't much to this prediction, it is worth nothing however that the words "choppy" had a more positive impact in this instance which seems plausible given the industry.

In [33]:
html = multiclass_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
52.0,Aerospace & Defense (0.00),Restaurants,-1.2,[CLS] raf ##ale is a better machine when compared with f1 ##6 [SEP]
,,,,


In [None]:
word_attributions

[('[CLS]', 0.0),
 ('stocks', -0.6854859058840344),
 ('ended', -0.02903205216337953),
 ('a', 0.11090246049261213),
 ('chop', 0.006652637417738755),
 ('##py', -0.20624943153114753),
 ('session', -0.07065561556653746),
 ('mixed', -0.13288257535449538),
 ('as', 0.07492890890956264),
 ('investors', -0.12119187012995146),
 ('digest', -0.12445948275050642),
 ('##ed', -0.09281772083050198),
 ('a', 0.04926221669843679),
 ('host', 0.002458038605145266),
 ('of', -0.04180982353010906),
 ('corporate', 0.09596853145329277),
 ('earnings', 0.19780804001186278),
 ('results', 0.2649820413816817),
 ('and', 0.10765607812143425),
 ('considered', 0.08802590573220752),
 ('policy', 0.11527451534122934),
 ('##makers', -0.0979763265607011),
 ('’', -0.04952380439753834),
 ('next', -0.08798391545875107),
 ('moves', 0.016330984853300038),
 ('to', -0.013984838237578378),
 ('support', -0.0746662851801564),
 ('the', 0.018439130760925827),
 ('still', -0.07228677398653408),
 ('virus', 0.07444827768777922),
 ('-', -0.06