<a href="https://colab.research.google.com/github/tomdyer10/wine_expert/blob/master/predictions_with_justification_%26_context.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Add token specific interpretation to our wine region classifier and give additional rationale to our predictions.

In [0]:
from fastai.text import *
import pandas as pd 
import numpy as np

Mount drive and load databunch + language model databunch

In [0]:
data_lm = load_data('drive/My Drive/wine_reviews', file='data/region_clas_data_lm')
data_clas = load_data('drive/My Drive/wine_reviews', file='data/region_clas_data_clas')

Load learner and load state

In [0]:
learn = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5)
learn.load('region_classifier/fifth')

# Redefine FastAI TextClassificationInterp Class

Redefining this module here in order to make a few changes to display positive **and** negative influence.

Original fast ai code found here - https://github.com/fastai/fastai/blob/master/fastai/text/interpret.py

In [0]:
import matplotlib.cm as cm

def value2rgba(x:float, cmap:Callable=cm.RdYlGn, alpha_mult:float=1.0)->Tuple:
    "Convert a value `x` from 0 to 1 (inclusive) to an RGBA tuple according to `cmap` times transparency `alpha_mult`."
    c = cmap(x)
    rgb = (np.array(c[:-1]) * 255).astype(int)
    a = c[-1] * alpha_mult
    return tuple(rgb.tolist() + [a])

def piece_attn_html(pieces:List[str], attns:List[float], sep:str=' ', **kwargs)->str:
    html_code,spans = ['<span style="font-family: monospace;">'], []
    for p, a in zip(pieces, attns):
        p = html.escape(p)
        c = str(value2rgba(a, alpha_mult=0.5, **kwargs))
        spans.append(f'<span title="{a:.3f}" style="background-color: rgba{c};">{p}</span>')
    html_code.append(sep.join(spans))
    html_code.append('</span>')
    return ''.join(html_code)

def show_piece_attn(*args, **kwargs):
    from IPython.display import display, HTML
    display(HTML(piece_attn_html(*args, **kwargs)))

def _eval_dropouts(mod):
        module_name =  mod.__class__.__name__
        if 'Dropout' in module_name or 'BatchNorm' in module_name: mod.training = False
        for module in mod.children(): _eval_dropouts(module)

class TextClassificationInterpretation(ClassificationInterpretation):
    """Provides an interpretation of classification based on input sensitivity.
    This was designed for AWD-LSTM only for the moment, because Transformer already has its own attentional model.
    """

    def __init__(self, learn: Learner, preds: Tensor, y_true: Tensor, losses: Tensor, ds_type: DatasetType = DatasetType.Valid):
        super().__init__(learn,preds,y_true,losses,ds_type)
        self.model = learn.model

    @classmethod
    def from_learner(cls, learn: Learner,  ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None):
        "Gets preds, y_true, losses to construct base class from a learner"
        return cls(learn, *learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True, ordered=True))

    def intrinsic_attention(self, text:str, class_id:int=None):
        """Calculate the intrinsic attention of the input w.r.t to an output `class_id`, or the classification given by the model if `None`.
        For reference, see the Sequential Jacobian session at https://www.cs.toronto.edu/~graves/preprint.pdf
        """
        self.model.train()
        _eval_dropouts(self.model)
        self.model.zero_grad()
        self.model.reset()
        ids = self.data.one_item(text)[0]
        emb = self.model[0].module.encoder(ids).detach().requires_grad_(True)
        lstm_output = self.model[0].module(emb, from_embeddings=True)
        self.model.eval()
        cl = self.model[1](lstm_output + (torch.zeros_like(ids).byte(),))[0].softmax(dim=-1)
        if class_id is None: class_id = cl.argmax()
        cl[0][class_id].backward()
        # removing abs() to also include negative influences
        # attn = emb.grad.squeeze().abs().sum(dim=-1)
        attn = emb.grad.squeeze().sum(dim=-1)
        attn /= attn.max()
        tokens = self.data.single_ds.reconstruct(ids[0].cpu())
        return tokens, attn

    def html_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->str:
        text, attn = self.intrinsic_attention(text, class_id)
        return piece_attn_html(text.text.split(), to_np(attn), **kwargs)

    def show_intrinsic_attention(self, text:str, class_id:int=None, **kwargs)->None:
        text, attn = self.intrinsic_attention(text, class_id)
        show_piece_attn(text.text.split(), to_np(attn), **kwargs)

    def categorical_top_n(self, k:int, text, class_id:int):
        text, attn = self.intrinsic_attention(text, class_id)
        print(attn)
        top_n = [x for _,x in sorted(zip(attn,text.text.split()), reverse=True)][:k]
        scores = [y for y,_ in sorted(zip(attn,text.text.split()), reverse=True)][:k]
        return top_n, scores

    def show_top_losses(self, k:int, max_len:int=70)->None:
        """
        Create a tabulation showing the first `k` texts in top_losses along with their prediction, actual,loss, and probability of
        actual class. `max_len` is the maximum number of tokens displayed.
        """
        from IPython.display import display, HTML
        items = []
        tl_val,tl_idx = self.top_losses()
        for i,idx in enumerate(tl_idx):
            if k <= 0: break
            k -= 1
            tx,cl = self.data.dl(self.ds_type).dataset[idx]
            cl = cl.data
            classes = self.data.classes
            txt = ' '.join(tx.text.split(' ')[:max_len]) if max_len is not None else tx.text
            tmp = [txt, f'{classes[self.pred_class[idx]]}', f'{classes[cl]}', f'{self.losses[idx]:.2f}',
                   f'{self.preds[idx][cl]:.2f}']
            items.append(tmp)
        items = np.array(items)
        names = ['Text', 'Prediction', 'Actual', 'Loss', 'Probability']
        df = pd.DataFrame({n:items[:,i] for i,n in enumerate(names)}, columns=names)
        with pd.option_context('display.max_colwidth', pd_max_colwidth()):
            display(HTML(df.to_html(index=False)))





Init interpretation learner

In [0]:
txt_ci = TextClassificationInterpretation.from_learner(learn)

Checking data classes for interpretations

In [0]:
data_clas.classes

['France - Alsace',
 'France - Beaujolais',
 'France - Bordeaux',
 'France - Burgundy',
 'France - Champagne',
 'France - France Other',
 'France - Languedoc-Roussillon',
 'France - Loire Valley',
 'France - Provence',
 'France - Rhône Valley',
 'France - Southwest France',
 'Italy - Central Italy',
 'Italy - Italy Other',
 'Italy - Lombardy',
 'Italy - Northeastern Italy',
 'Italy - Piedmont',
 'Italy - Sicily & Sardinia',
 'Italy - Southern Italy',
 'Italy - Tuscany',
 'Italy - Veneto',
 'US - California',
 'US - Idaho',
 'US - Michigan',
 'US - New York',
 'US - Oregon',
 'US - Virginia',
 'US - Washington']

Show intrinsic attention of test text from our dataset, with respect to category 20 - 'US - California'

In [0]:
text = 'his very fine xxmaj cabernet wants a little time in the cellar . xxmaj right now , it s tight in tannins , with some acidic bitterness in the finish . xxmaj the flavors are of black currants and smoky new oak . xxmaj the xxmaj morisoli xxmaj vineyard has been home to very good , ageable bottlings from the likes of xxmaj sequoia xxmaj grove and xxmaj'
txt_ci.show_intrinsic_attention(text, 20)



# Predict with Justification

The goal of this code is to act as a helper function that sits on top of the `predict()` function and delivers additional insight.

Effectively, we want this to say "I think this, because..."

In [0]:
class justify():
  def __init__(self, learner, txt_ci, classes):
    self.learner = learner
    self.txt_ci = txt_ci
    self.classes = classes

  def top_k_tokens(self, text, attns, k=5, desc=True):
    return [(x, y) for y,x in sorted(zip(attns,text.text.split()), reverse=desc)][:k]

  def predict(self, input, k=5, desc=True):
    class_pred, pred_idx, pred_conf = self.learner.predict(input)
    tokens, attns = self.txt_ci.intrinsic_attention(input, pred_idx)
    top_k = self.top_k_tokens(tokens, attns, k, desc)
    return (class_pred, pred_conf[pred_idx], top_k)

  #give the k nearest categories to actual predictio
  def nearest_cat(self, input, k_nearest=1):
    class_pred, pred_idx, pred_conf = self.learner.predict(input)
    nearest = [x for _,x in sorted(zip(pred_conf, self.classes), reverse=True)][1:(1 + k_nearest)]
    if len(nearest) > 1:
      nearest_c = []
      for n in nearest:
        tokens, attns = self.txt_ci.intrinsic_attention(input, self.classes.index(n))
        top_k = self.top_k_tokens(tokens, attns, k=5, desc=True)
        nearest_c.append((n, pred_conf[self.classes.index(n)], top_k))
      return nearest_c
    else:
      tokens, attns = self.txt_ci.intrinsic_attention(input, self.classes.index(nearest[0]))
      top_k = self.top_k_tokens(tokens, attns, k=5, desc=True)
      return (nearest, pred_conf[self.classes.index(nearest[0])], top_k)
    

learn.justify = justify(learn, txt_ci, data_clas.classes)

Demonstrating normal learner prediction output - not very interpretable!

In [0]:
learn.predict(text)

(Category US - California,
 tensor(20),
 tensor([1.6489e-06, 3.1518e-07, 2.2980e-04, 2.0212e-06, 5.1722e-07, 5.7369e-06,
         1.6678e-07, 3.7181e-06, 8.7425e-07, 6.0157e-07, 1.6345e-05, 1.3376e-06,
         1.2453e-07, 1.5212e-07, 2.0112e-06, 2.6189e-07, 3.0338e-07, 8.4101e-08,
         3.5306e-06, 8.7766e-07, 9.9970e-01, 1.4472e-06, 2.5810e-07, 4.0509e-06,
         7.0285e-07, 6.3190e-06, 2.0093e-05]))

show intrinsic attention of demo text with respect to category 0 (France - Alsace)

In [0]:
txt_ci.show_intrinsic_attention(text, 20)



Run justify prediction from custom code, giving the top 5 tokens contributing to the decision.

In [0]:
learn.justify.predict(text)



(Category US - California,
 tensor(0.9997),
 [('ageable', tensor(1., device='cuda:0')),
  ('bottlings', tensor(0.8707, device='cuda:0')),
  ('wants', tensor(0.8477, device='cuda:0')),
  ('currants', tensor(0.7486, device='cuda:0')),
  ('fine', tensor(0.6192, device='cuda:0'))])

This is immediately making the prediction more interpretable and is allowing the user to learn a little about Claifornian wine, suggesting currant flavours are indicative of this region. 

Next, give the closest 3 categories to the one chosen and display the 5 top tokens for each of those classes.

In [0]:
learn.justify.nearest_cat(text, 3)



[('France - Bordeaux',
  tensor(0.0002),
  [('morisoli', tensor(1., device='cuda:0')),
   ('vineyard', tensor(0.2879, device='cuda:0')),
   ('xxmaj', tensor(0.1686, device='cuda:0')),
   ('likes', tensor(0.1164, device='cuda:0')),
   ('oak', tensor(0.0696, device='cuda:0'))]),
 ('US - Washington',
  tensor(2.0093e-05),
  [('morisoli', tensor(1., device='cuda:0')),
   ('vineyard', tensor(0.1704, device='cuda:0')),
   ('currants', tensor(0.1175, device='cuda:0')),
   ('xxmaj', tensor(0.1147, device='cuda:0')),
   ('acidic', tensor(0.1033, device='cuda:0'))]),
 ('France - Southwest France',
  tensor(1.6345e-05),
  [('morisoli', tensor(1., device='cuda:0')),
   ('vineyard', tensor(0.4456, device='cuda:0')),
   ('grove', tensor(0.4151, device='cuda:0')),
   ('xxmaj', tensor(0.3054, device='cuda:0')),
   ('cabernet', tensor(0.1486, device='cuda:0'))])]

Interestingly, Morisoli is highly activating for these classes. This might flag an error in our model (which is another useful consequence of interpretability) or might also suggest aa similaarity which is often referenced in wine descriptions. 

# Context in RNNs

It is clear to see that the above interpretation methods add significant value to these prediction models, allowing jsutification, error flagging and the potentail for education. 

However we also see that single tokens/words are not always that instructive and might bely the full context with which that word is used.

The below code investigates this further,particularly looking at how substituting various words throughout the description influences intrinsic attribution of predictions.

In [0]:
data_clas.show_batch(10)

text,target
"xxbos xxmaj this very fine xxmaj cabernet wants a little time in the cellar . xxmaj right now , it 's tight in tannins , with some acidic bitterness in the finish . xxmaj the flavors are of black currants and smoky new oak . xxmaj the xxmaj morisoli xxmaj vineyard has been home to very good , ageable bottlings from the likes of xxmaj sequoia xxmaj grove and xxmaj",US - California
"xxbos a terrific example of xxmaj corbières terroir at an unbeatable price . xxmaj the wine , a traditional xxmaj rhône - style blend of xxmaj syrah , xxmaj grenache and xxmaj carignan , is so loaded with aromas and flavors of the surrounding garrigue you might think that they xxunk in a bouquet garni during maturation . xxmaj sage , rosemary , menthol , thyme and bay leaves all",France - Languedoc-Roussillon
"xxbos a massive wine , decadent and splendid , and a worthy followup to the near - perfect 2008 . xxmaj densely packed in fruit , it explodes in blackberries , cassis , dark chocolate , minerals and sweet , toasty new oak . xxmaj that richness is lifted to brilliance by the tannic structure , which is world class . xxmaj so deliciously approachable , it 's drinkable now",US - California
"xxbos xxmaj hat xxmaj trick is the best of the best of xxmaj morgan 's estate vineyard , which is in the chilliest northwestern part of the xxmaj highlands . xxmaj acidity stars , giving the wine a brilliant crispness that 's so clean and fine . xxmaj barrel fermented in one - third new xxmaj french oak , the wine is incredibly rich and leesy . xxmaj the terroir",US - California
"xxbos xxmaj this is a new single - vineyard bottling for xxmaj goldeneye , grown in the cooler xxmaj deep xxmaj end western part of the valley . xxmaj it 's a feral kind of xxmaj pinot . xxmaj not for it the tame fruit of warmer climates . xxmaj this one brims with wild berries : cherries , raspberries , something animal and leathery , and mossy tastes of",US - California
"xxbos xxmaj columbia xxmaj crest makes a limited number of single vineyard reserves . xxmaj this xxmaj zinfandel comes from what is emerging as the best site for that grape in xxmaj washington . xxmaj bright , tart , and full of sappy raspberry and red fruits , this young , ageworthy xxmaj zin has great penetration and punch . xxmaj in style , it is closest to classic xxmaj",US - Washington
"xxbos xxmaj you do n't have to age this wine — xxmaj jarvis did it for you , holding it back more than five years before release , which is a very expensive proposition for a winery . xxmaj it 's a splendid xxmaj cabernet , with the tannins soft and velvety . xxmaj the blackberry and black currant fruit flavors are as fresh and vibrant as the day the",US - California
"xxbos xxmaj this is almost pure xxmaj cab , with the addition of just 2 % xxmaj malbec . xxmaj solidly in the firm , muscular , polished style of winemaker xxmaj holly xxmaj turner , this is a substantial effort that will need to be decanted if you are going to drink it any time soon . xxmaj black cherry , cassis and blue plum are swathed in milk",US - Washington
"xxbos a line up of xxunk 's 2010 and 2011 xxmaj pinot xxmaj noirs make a bold statement for xxmaj finger xxmaj lakes xxmaj pinot , but the regular label 2011 is drinking particularly well now . xxmaj lifted violet and berry notes are intoxicating on this spicy , deftly balanced wine . xxmaj rich , ripe black - cherry and raspberry flavors blend into a liquid silk on the",US - New York
"xxbos xxmaj this xxmaj cabernet is a blend of estate and purchased fruit . xxmaj in a successful year , it rivals xxmaj stag 's xxmaj leap 's best wines , including xxmaj cask 23 , and 2008 was a very successful year . xxmaj blended with 2 % xxmaj merlot , and aged in lots of new xxmaj french oak , the wine is , in a word ,",US - California


In [0]:
test_text = data_clas.train_ds[0][0]

In [0]:
learn.predict(test_text)

(Category Italy - Sicily & Sardinia,
 tensor(16),
 tensor([6.8727e-07, 5.2921e-07, 1.8685e-06, 2.1200e-06, 3.8360e-06, 3.6119e-06,
         4.8331e-06, 2.3117e-06, 6.2219e-07, 1.6509e-06, 1.5259e-06, 5.8269e-02,
         1.3068e-02, 6.9895e-02, 1.3347e-01, 3.2327e-02, 4.2442e-01, 1.6917e-01,
         8.7599e-02, 7.9983e-03, 3.5779e-03, 1.2926e-05, 2.6908e-05, 3.6262e-05,
         9.7286e-06, 6.1159e-05, 4.6936e-05]))

show intrinsic attribution with respect the predicted class.

In [0]:
txt_ci.show_intrinsic_attention(test_text, 16)



In [0]:
learn.justify.predict(test_text)



(Category Italy - Sicily & Sardinia,
 tensor(0.4244),
 [('brimstone', tensor(1., device='cuda:0')),
  ('broom', tensor(0.5403, device='cuda:0')),
  (',', tensor(0.3044, device='cuda:0')),
  ('tropical', tensor(0.2672, device='cuda:0')),
  (',', tensor(0.2104, device='cuda:0'))])

selection of commas as top tokens suggests that it is more to do with the sentence they are a part of than the tokens themselves.

**Note from later** - this appears to be symptomatic of effective overfitting in interpretation, I will be investigating using code from this paper https://arxiv.org/pdf/1910.13294.pdf to solve thisin future.

**Testing Influence**

swap each top 5 words for xxpad token to test true root of influence

In [0]:
text_swapped = 'xxbos xxmaj aromas include xxpad fruit xxpad xxpad xxpad xxpad and dried herb . xxmaj the palate is n\'t overly expressive , offering unripened apple , citrus and dried sage alongside brisk acidity .'

In [0]:
learn.justify.predict(text_swapped)



(Category Italy - Northeastern Italy,
 tensor(0.4318),
 [('unripened', tensor(1., device='cuda:0')),
  ('herb', tensor(0.6547, device='cuda:0')),
  ('expressive', tensor(0.5916, device='cuda:0')),
  ('sage', tensor(0.5082, device='cuda:0')),
  ('brisk', tensor(0.4526, device='cuda:0'))])

This changes the prediction.

In [0]:
learn.justify.nearest_cat(text_swapped, 3)



[('US - California',
  tensor(0.1457),
  [('palate', tensor(1., device='cuda:0')),
   ('citrus', tensor(0.9526, device='cuda:0')),
   ('acidity', tensor(0.7178, device='cuda:0')),
   ('xxmaj', tensor(0.6662, device='cuda:0')),
   ('include', tensor(0.6642, device='cuda:0'))]),
 ('Italy - Veneto',
  tensor(0.0822),
  [('xxpad', tensor(1., device='cuda:0')),
   ('overly', tensor(0.5603, device='cuda:0')),
   ('xxwrep', tensor(0.3942, device='cuda:0')),
   ('xxpad', tensor(0.3844, device='cuda:0')),
   ('4', tensor(0.3140, device='cuda:0'))]),
 ('Italy - Piedmont',
  tensor(0.0796),
  [('xxpad', tensor(1., device='cuda:0')),
   ('alongside', tensor(0.5470, device='cuda:0')),
   ('brisk', tensor(0.3425, device='cuda:0')),
   ('xxwrep', tensor(0.3101, device='cuda:0')),
   ('citrus', tensor(0.3065, device='cuda:0'))])]

Changing single words does change predictions above but context is clearly important.

# Apples and Oranges

With original swapped text below, prediction is Northeastern Italy

In [0]:
text_swapped = 'xxbos xxmaj aromas include xxpad fruit xxpad xxpad xxpad xxpad and dried herb . xxmaj the palate is n\'t overly expressive , offering unripened apple , citrus and dried sage alongside brisk acidity .'

learn.justify.predict(text_swapped)



(Category Italy - Northeastern Italy,
 tensor(0.4318),
 [('unripened', tensor(1., device='cuda:0')),
  ('herb', tensor(0.6547, device='cuda:0')),
  ('expressive', tensor(0.5916, device='cuda:0')),
  ('sage', tensor(0.5082, device='cuda:0')),
  ('brisk', tensor(0.4526, device='cuda:0'))])

'unripened' is the most expressive token according to this. However, that doesn't seem to be that descriptive as different unripened fruits will have very different flavours.



In [0]:
text_swapped_orange = 'xxbos xxmaj aromas include xxpad fruit xxpad xxpad xxpad xxpad and dried herb . xxmaj the palate is n\'t overly expressive , offering unripened oranges , citrus and dried sage alongside brisk acidity .'
learn.justify.predict(text_swapped_orange)



(Category Italy - Piedmont,
 tensor(0.2259),
 [('xxpad', tensor(1., device='cuda:0')),
  ('alongside', tensor(0.8027, device='cuda:0')),
  ('brisk', tensor(0.5092, device='cuda:0')),
  ('sage', tensor(0.4898, device='cuda:0')),
  ('xxwrep', tensor(0.3120, device='cuda:0'))])

This has changed the prediction, despite Apple not showing up as a key word before.

Conclusion - current RNN interpretation methods will show the contribution of a single word, but it is the context of that word which is most important.