In [1]:
%cd ..

/home/pablo/long-transformers


In [2]:
import sys

sys.path.append('.')

In [3]:
import torch
import pytorch_lightning as pl

from src.trainers.listops_trainer import ListopsModule

  warn(
  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


## Investigating success of convolution based models in Listops

The idea of this notebook is to investigate the characteristics of the examples that our network is learning. For example, we know that around 17% of accuracy can be reached just by always predicting 0 or 9, since most of the examples with MIN and MAX as root will take those values respectively. Based on the root value, we can make the following predictions to get 36% accuracy:
* MIN -> 0
* MAX -> 9
* MED -> 5
* SM -> random

Are there any tricks of this kind that the network can learn to reach around 61% accuracy as convolution based architectures learn?

We managed to get to 50% accuracy with GMLP. Let's try to retrieve the model from the checkpoint and see which examples it is predicting correctly. The things I want to check are:

* Confusion matrix
* Distribution of correctly predicted labels vs incorrectly predicted labels.
* Distribution of correctly predicted root ops vs incorrectly predicted root ops.

In [4]:
ckpt_path = "out/listops/gmlp/lightning_logs/version_98/checkpoints/epoch=19-val_acc=0.49.ckpt"
hparams_path = "out/listops/gmlp/lightning_logs/version_98/hparams.yaml"

In [5]:
module = ListopsModule.load_from_checkpoint(ckpt_path, hparams_path=hparams_path)

In [6]:
model = module.model_with_head

In [7]:
data_module = module.data_module

In [8]:
data_module.val_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7fcc1a3b40d0>

In [9]:
# group examples in validation set by whether the model got them right or wrong
batch_size = 32
correct_examples = []
wrong_examples = []
for idx, batch in enumerate(data_module.val_dataloader()):
    # send all batch tensors to cuda:2
    cuda_batch = dict()
    for key, tensor in batch.items():
        cuda_batch[key] = batch[key].to("cuda:2")

    loss, logits, labels = module._step(cuda_batch, idx)
    correct = torch.argmax(logits, dim=1) == labels

    for i, is_correct in enumerate(correct):
        pred = torch.argmax(logits[i]).item()
        label = labels[i].item()
        if is_correct:
            correct_examples.append((i + idx * batch_size, pred, label))
        else:
            wrong_examples.append((i + idx * batch_size, pred, label))

In [10]:
correct_examples[0]

(0, 9, 9)

In [11]:
wrong_examples[0]

(1, 1, 0)

In [12]:
len(correct_examples), len(wrong_examples)

(978, 1022)

In [13]:
# print the confusion matrix
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

all_examples = correct_examples + wrong_examples
y_true = [example[2] for example in all_examples]
y_pred = [example[1] for example in all_examples]
cm = confusion_matrix(y_true, y_pred)
cm

array([[259,  22,  40,  15,   4,   2,   0,   1,   0,   0],
       [ 33,  60,  43,  22,   2,   0,   0,   0,   0,   2],
       [ 28,   1,  66,  42,   9,   1,   1,   0,   1,   1],
       [ 17,   0,  37,  75,  31,   3,   7,   0,   0,   2],
       [ 16,   2,  32,  33,  76,  17,   9,   2,   2,   5],
       [ 19,   1,  21,  20,  28,  37,  19,   1,   2,   8],
       [ 25,   2,  31,   7,  19,  23,  54,   7,   2,   5],
       [ 15,   1,  21,  12,   8,  10,  23,  42,   2,  24],
       [ 16,   1,  25,   6,   2,   4,   5,   8,  78,  26],
       [  7,   0,  28,  11,   1,   3,   8,   4,  26, 231]])

In [14]:
for idx, pred, label in correct_examples:
    sequence = data_module.val_dataset[idx]["sequence"]
    first_token = sequence.split(" ", 1)[0]
    sequence_len = len(sequence.split(" "))
    print(first_token, label, sequence_len, sequence)

[MAX 9 1744 [MAX 1 [MIN 3 [MED 6 1 ] 4 5 ] 0 [SM [MED 1 6 2 1 0 8 2 3 ] 9 [MED 0 3 7 ] 2 [SM [MIN 8 9 1 ] 3 0 2 5 [MED 7 [SM 4 [MAX 4 1 [MAX 9 4 3 ] 1 9 2 ] 3 ] ] [MED 1 [MIN 0 [MIN 7 3 9 2 6 3 2 9 [SM 9 2 [MIN 4 3 5 2 6 2 [MIN 7 2 9 3 6 6 ] [MED 3 5 ] 2 ] [MED 6 6 5 9 2 [SM 0 2 9 8 0 7 ] 6 ] 1 9 3 2 ] 9 ] 7 ] 0 [MED 1 4 2 0 1 [MAX [MED 8 2 [MIN [MIN 1 5 6 9 9 2 2 5 4 0 ] 1 [MED 2 6 4 ] 7 2 [SM 3 4 ] ] 7 ] [MIN 1 8 3 6 ] 4 6 6 ] 8 1 ] 6 [SM 1 7 5 5 5 6 9 8 4 ] 1 5 ] 8 1 1 ] ] [MAX 0 [MAX 7 [MED 9 [MIN 2 [SM 9 [MIN [MAX 2 6 8 [MED 6 7 6 7 7 6 3 ] 2 ] 6 ] ] [MIN 0 4 2 1 ] ] 4 1 ] ] [MIN 8 2 5 6 7 ] 9 ] [MED 2 8 5 0 [SM [MIN 3 8 6 [SM 7 7 [MAX 7 [SM 3 5 0 ] 3 9 ] 6 [MAX [MED [MED 3 [MED 1 3 1 4 9 2 0 5 3 0 ] 1 5 [SM 3 7 5 2 5 8 3 ] [MAX 2 2 9 5 7 8 3 3 3 0 ] 7 9 8 5 ] 8 2 ] 8 8 1 9 ] 9 4 5 ] [MAX 0 6 6 7 8 3 5 4 ] 5 9 7 ] 9 [MAX 3 [MAX [SM 7 0 [MIN [MAX 1 [SM 2 5 1 4 6 1 ] 6 5 5 ] [SM 3 3 6 [MAX 1 9 9 3 2 7 5 0 0 0 ] 8 1 4 9 6 ] ] 3 6 0 [MED 4 2 [MAX 3 8 4 3 1 3 8 ] 7 [MED 2 [MIN 2 4 9 2 

In [15]:
# for each possible first token, print the number of examples where the model got it right
first_token_counts = dict()
first_token_correct = dict()

for idx, pred, label in correct_examples:
    sequence = data_module.val_dataset[idx]["sequence"]
    first_token = sequence.split(" ", 1)[0]
    first_token_counts[first_token] = first_token_counts.get(first_token, 0) + 1
    first_token_correct[first_token] = first_token_correct.get(first_token, 0) + 1

for idx, pred, label in wrong_examples:
    sequence = data_module.val_dataset[idx]["sequence"]
    first_token = sequence.split(" ", 1)[0]
    first_token_counts[first_token] = first_token_counts.get(first_token, 0) + 1

In [16]:
for first_token, count in first_token_counts.items():
    num_correct = first_token_correct.get(first_token, 0)
    acc = num_correct / count
    print(first_token, acc)

[MAX 0.7165354330708661
[SM 0.118
[MIN 0.7169811320754716
[MED 0.41359223300970877


In [23]:
# for each possible pair (first token, label), print the number of examples where the model got it right
pair_counts = dict()
pair_correct = dict()

for idx, pred, label in correct_examples:
    sequence = data_module.val_dataset[idx]["sequence"]
    first_token = sequence.split(" ", 1)[0]
    pair = (first_token, label)
    pair_counts[pair] = pair_counts.get(pair, 0) + 1
    pair_correct[pair] = pair_correct.get(pair, 0) + 1

for idx, pred, label in wrong_examples:
    sequence = data_module.val_dataset[idx]["sequence"]
    first_token = sequence.split(" ", 1)[0]
    pair = (first_token, label)
    pair_counts[pair] = pair_counts.get(pair, 0) + 1

for pair, count in sorted(pair_counts.items()):
    num_correct = pair_correct.get(pair, 0)
    acc = num_correct / count
    print(pair, f"{num_correct:9.4f} / {count:9.4f} = {acc:.4f}")

('[MAX', 1)    0.0000 /    2.0000 = 0.0000
('[MAX', 2)    0.0000 /    1.0000 = 0.0000
('[MAX', 3)    1.0000 /    3.0000 = 0.3333
('[MAX', 4)    2.0000 /    8.0000 = 0.2500
('[MAX', 5)    6.0000 /   19.0000 = 0.3158
('[MAX', 6)   17.0000 /   27.0000 = 0.6296
('[MAX', 7)   30.0000 /   61.0000 = 0.4918
('[MAX', 8)   77.0000 /  116.0000 = 0.6638
('[MAX', 9)  231.0000 /  271.0000 = 0.8524
('[MED', 0)    1.0000 /    6.0000 = 0.1667
('[MED', 1)    3.0000 /   16.0000 = 0.1875
('[MED', 2)   14.0000 /   51.0000 = 0.2745
('[MED', 3)   46.0000 /   88.0000 = 0.5227
('[MED', 4)   69.0000 /  122.0000 = 0.5656
('[MED', 5)   31.0000 /   90.0000 = 0.3444
('[MED', 6)   36.0000 /   83.0000 = 0.4337
('[MED', 7)   12.0000 /   49.0000 = 0.2449
('[MED', 8)    1.0000 /    7.0000 = 0.1429
('[MED', 9)    0.0000 /    3.0000 = 0.0000
('[MIN', 0)  238.0000 /  276.0000 = 0.8623
('[MIN', 1)   57.0000 /  102.0000 = 0.5588
('[MIN', 2)   26.0000 /   46.0000 = 0.5652
('[MIN', 3)   15.0000 /   24.0000 = 0.6250
('[MIN', 4)

Perhaps the model is learning to solve X number of levels (at most 1 or 2) and then predicting according to the basic rule, with exceptions of the type: if there are no ones, then MIN will return 2, and so on.