In [None]:
# adapted from bertviz tutorial (https://github.com/jessevig/bertviz)
class ShapeChecker():
  def __init__(self):
    self.shapes = {}

  def __call__(self, tensor, names, broadcast=False):
    if not torch.is_grad_enabled():
      return

    parsed = einops.parse_shape(tensor, names)

    for name, new_dim in parsed.items():
      old_dim = self.shapes.get(name, None)

      if (broadcast and new_dim == 1):
        continue

      if old_dim is None:
        self.shapes[name] = new_dim
        continue

      if new_dim != old_dim:
        raise ValueError(f"Shape mismatch for dimension: '{name}'\n"
                         f"    found: {new_dim}\n"
                         f"    expected: {old_dim}\")

In [None]:
attention_layer = CrossAttention(UNITS)

embed = nn.Embedding(target_text_processor.vocabulary_size(),
                     UNITS, padding_idx=0)
ex_tar_embed = embed(torch.tensor(ex_tar_in))

result, _ = attention_layer(ex_tar_embed, ex_context)

print(f'Input Sequence, shape (batch, s, units): {ex_context.shape}')
print(f'Target Sequence, shape (batch, t, units): {ex_tar_embed.shape}')
print(f'Attention Result, shape (batch, t, units): {result.shape}')
print(f'Attention Weights, shape (batch, t, s):    {attention_layer.last_attention_weights.shape}')

In [None]:
attention_weights = attention_layer.last_attention_weights
mask=(ex_context_tok != 0).numpy()

plt.subplot(1, 2, 1)
plt.pcolormesh(mask*attention_weights[:, 0, :])
plt.title('Attention Weights')

plt.subplot(1, 2, 2)
plt.pcolormesh(mask)
plt.title('Mask');


In [None]:
# adapted from bertviz tutorial (https://github.com/jessevig/bertviz)
import torch

def att_map(self,
            texts, *,
            max_length=50,
            temperature=0.0):
    context = self.encoder.convert_input(texts)
    batch_size = texts.shape[0]

    tokens = []
    attention_weights = []
    next_token, done, state = self.decoder.get_initial_state(context)

    for _ in range(max_length):
        next_token, done, state = self.decoder.get_next_token(
            context, next_token, done, state, temperature)

        tokens.append(next_token)
        attention_weights.append(self.decoder.last_attention_weights)

        if done.all():
            break

    tokens = torch.cat(tokens, dim=-1)
    self.last_attention_weights = torch.cat(attention_weights, dim=1)

    result = self.decoder.tokens_to_text(tokens)
    return result





In [None]:
# adapted from https://github.com/jessevig/bertviz 

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

def plot_attention(self, text, **kwargs):
  assert isinstance(text, str)
  output = self.translate([text], **kwargs)
  output = output[0].detach().cpu().numpy().decode()

  attention = self.last_attention_weights[0]

  context = torch.lower_and_split_punct(text)
  context = context.numpy().decode().split()

  output = torch.lower_and_split_punct(output)
  output = output.numpy().decode().split()[1:]

  fig = plt.figure(figsize=(10, 10))
  ax = fig.add_subplot(1, 1, 1)

  ax.matshow(attention, vmin=0.0)

  fontdict = {'fontsize': 14}

  ax.set_xticklabels([''] + context, fontdict=fontdict, rotation=90)
  ax.set_yticklabels([''] + output, fontdict=fontdict)

  ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
  ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

  ax.set_xlabel('Input text')
  ax.set_ylabel('Output text')