In [1]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers

from torch.utils.data import DataLoader
from model_util import SequenceDataset, DistilBertData

import numpy as np
from sklearn.metrics import accuracy_score

import seaborn as sns
import matplotlib.pyplot as plt

EPOCHS = 20
MAX_STRING_LENGTH = 29
batch_size=8
SAVE_SPACE = True # deletes data structures not needed from dataset

PROBLEM = "04.03.TLT.4.1.6" #"04.03.TLT.2.1.2"
RESULTS_PATH = "../trained_models/"

ACCEPTOR_NAME = "distilbert_problem_{}.pk".format(PROBLEM) # how to save model
DATASET_CONTAINER_PATH = "dataset_problem_{}.pk".format(PROBLEM) # where to save dataset metadata 

TRAIN_DATA_PATH = "../data/abbadingo/Mid/{}_Train.txt.dat".format(PROBLEM)
TEST_DATA_PATH = "../data/abbadingo/Mid/{}_TestSR.txt.dat".format(PROBLEM)

  from .autonotebook import tqdm as notebook_tqdm


## Prepare data

In [2]:
train_dataset = SequenceDataset(TRAIN_DATA_PATH, maxlen=MAX_STRING_LENGTH)
#dataset.initialize(DATASET_CONTAINER_PATH)
train_dataset.encode_sequences()
train_dataset.save_state(os.path.join(RESULTS_PATH, DATASET_CONTAINER_PATH))

Alphabet size:  4
Sequences loaded. Some examples: 
[['a', 'a', 'a', 'a', 'a', 'a', 'b', 'a', 'b', 'b', 'b', 'd', 'c', 'c', 'b', 'b', 'b', 'b', 'a', 'a'], ['a', 'a', 'a', 'a', 'a', 'b', 'a', 'b', 'b', 'd', 'd', 'c', 'd', 'c', 'c', 'b', 'd', 'b', 'd', 'c'], ['a', 'a', 'a', 'a', 'a', 'c', 'c', 'a', 'b', 'a', 'c', 'c', 'a', 'a', 'b', 'd', 'a', 'd', 'd', 'b']]
The symbol dictionary: {'a': 0, 'b': 1, 'd': 2, 'c': 3}


## Define model

In [3]:
def make_dict(**kwargs):
    return kwargs

init_dict = make_dict(
    vocab_size=train_dataset.alphabet_size+3,
    max_position_embeddings=train_dataset.maxlen+2,
    sinusoidal_pos_embds=True,
    n_layers=2,
    n_heads=4,
    dim=train_dataset.alphabet_size*3,
    hidden_dim=train_dataset.alphabet_size*2,
    activation="gelu",
    dropout=0.1,
    attention_dropout=0.1,
    seq_classif_dropout=0.2,
    pad_token_id=train_dataset.PAD
)

model=transformers.DistilBertForSequenceClassification(transformers.DistilBertConfig(**init_dict))
print(model)

Custom message: Loading distilbert
DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(7, 12, padding_idx=6)
      (position_embeddings): Embedding(33, 12)
      (LayerNorm): LayerNorm((12,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-1): 2 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=12, out_features=12, bias=True)
            (k_lin): Linear(in_features=12, out_features=12, bias=True)
            (v_lin): Linear(in_features=12, out_features=12, bias=True)
            (out_lin): Linear(in_features=12, out_features=12, bias=True)
          )
          (sa_layer_norm): LayerNorm((12,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1

## Prepare training

In [4]:
if SAVE_SPACE:
    del train_dataset.one_hot_seq
    del train_dataset.one_hot_seq_sr
    del train_dataset.ordinal_seq_sr

In [5]:
train_input_ids = torch.cat(list(torch.unsqueeze(x, 0) for x in train_dataset.ordinal_seq))
train_labels = train_dataset.labels

In [6]:
def construct_attn_mask(lengths, maxlen):
    """
    Lengths is a list. For each sequence in input_ids it gives the length
    """
    res = torch.ones((len(lengths), maxlen))
    for i, l in enumerate(lengths):
        res[i, l:] = 0
    return res

train_attn_mask = construct_attn_mask(train_dataset.sequence_lengths, train_dataset.maxlen)

In [7]:
def get_forward_dict(x, y, mask, output_attentions=False):
    forward_dict = make_dict(
        input_ids=x, # the training data?
        labels=y, # the training labels
        attention_mask=mask, # TODO: we can do this to improve the models I suppose
        head_mask=None,
        output_attentions=output_attentions,
        output_hidden_states=False,
        return_dict=True,
    )
    return forward_dict

In [8]:
train_data = DistilBertData(train_input_ids, train_labels, train_attn_mask)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

## Model training

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)# for whole sequence: lr=0.00001
#loss_fn = nn.BCELoss()
loss_fn = nn.CrossEntropyLoss()

running_loss = 0.
last_loss = 0.
divisor = 0.

for i in range(1, EPOCHS+1):
    print("Epoch: ", i)
    for j, (x_batch, y_batch, mask_batch) in enumerate(train_dataloader):
        optimizer.zero_grad()
        model_input = get_forward_dict(x_batch, y_batch, mask_batch)
        outputs = model(**model_input)
        logits_before_softmax = outputs.logits
        #loss = loss_fn(F.softmax(logits_before_softmax, dim=1), torch.argmax(y_batch, dim=1))
        loss = loss_fn(logits_before_softmax, torch.argmax(y_batch, dim=1))
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        divisor += float( list(x_batch.size())[0] )
        if j % 100 == 0:
            last_loss = running_loss / divisor # loss per batch
            print('  batch {} loss: {}'.format(j, last_loss))
            running_loss = 0.
            divisor = 0.

Epoch:  1
  batch 0 loss: 0.0866336077451706
  batch 100 loss: 0.08664443984627723
  batch 200 loss: 0.08663883507251739
  batch 300 loss: 0.08665703825652599
  batch 400 loss: 0.08663988970220089
  batch 500 loss: 0.08663864806294441
  batch 600 loss: 0.0866508735716343


In [None]:
torch.save(model, os.path.join("..", "trained_models", "model_{}.pk".format(PROBLEM)))

### Checking the output

In [None]:
mapping = {v: k for k, v in train_dataset.symbol_dict.items()}
mapping[train_dataset.SOS] = "<SOS>"
mapping[train_dataset.EOS] = "<EOS>"
mapping[train_dataset.PAD] = "-"

mapping

In [None]:
def convert_ids_to_chars(mapping, tensor_1d, eos_symbol: int = None):
    """
    Converts the input_ids tensor to a list representing the original input.
    Mapping is dict mapping int to str/char.

    tensor_1d: The sequence as provided to the model.
    """

    res = list()
    eos_idx = None
    try:
        tensor_1d = tensor_1d.detach().numpy()
    except:
        tensor_1d = tensor_1d.numpy()
    for i, x in enumerate(tensor_1d):
        if eos_symbol is not None and x==eos_symbol:
            eos_idx = i
        res.append(mapping[x])
    return res, eos_idx

def map_sequences(mapping, x, eos_symbol: int=None):
    """
    x is two dimensional tensor (bsize, len_of_sequence)
    """
    for xx in list(x.detach()):
        s, eos_idx = convert_ids_to_chars(mapping, xx, eos_symbol)
        print("Mapped sequence: {}".format(" ".join(s)))

In [None]:
with torch.no_grad():
    for j, (x_batch, y_batch, mask_batch) in enumerate(train_dataloader):
        model_input = get_forward_dict(x_batch, y_batch, mask_batch)
        outputs = model(**model_input)

        if j==100:
            print(y_batch, "\n", F.softmax(outputs.logits, dim=1), "\n", x_batch.detach().numpy())
            map_sequences(mapping=mapping, x=x_batch, eos_symbol=train_dataset.EOS)

In [None]:
train_dataset.label_dict

## Model testing

### 1. Training accuracy

In [None]:
train_attn_mask.dtype, train_input_ids.dtype

In [None]:
with torch.no_grad():
    train_res = model(**get_forward_dict(train_input_ids, None, train_attn_mask)).logits
train_res = np.array(torch.argmax(train_res, dim=1))
train_res.shape, np.unique(train_res, return_counts=True)

In [None]:
accuracy_score(train_labels, train_res)

### 2. Test accuracy

In [None]:
test_dataset = SequenceDataset(TEST_DATA_PATH, maxlen=MAX_STRING_LENGTH)
test_dataset.initialize(os.path.join(RESULTS_PATH, DATASET_CONTAINER_PATH))
test_dataset.encode_sequences()

In [None]:
if SAVE_SPACE:
    del test_dataset.one_hot_seq
    del test_dataset.one_hot_seq_sr
    del test_dataset.ordinal_seq_sr

In [None]:
test_input_ids = torch.cat(list(torch.unsqueeze(x, 0) for x in test_dataset.ordinal_seq))
test_labels = test_dataset.labels

test_attn_mask = construct_attn_mask(test_dataset.sequence_lengths, test_dataset.maxlen)

with torch.no_grad():
    test_res = model(**get_forward_dict(test_input_ids, None, test_attn_mask)).logits
test_res = np.array(torch.argmax(test_res, dim=1))
test_res.shape, np.unique(test_res, return_counts=True)

In [None]:
accuracy_score(test_labels, test_res)

## Try a couple of sequences on your own

### Do sequences and just look at the predictions

In [None]:
symbol_dict = {k: v for k, v in train_dataset.symbol_dict.items()}
symbol_dict["<SOS>"] = train_dataset.SOS
symbol_dict["<EOS>"] = train_dataset.EOS
symbol_dict["<PAD>"] = train_dataset.PAD

def encode_sequences(sequences: list, symbol_dict: dict, maxlen: int):
    """
    Encodes the sequences and returns a tensor. sequences is list of list.
    Shape of result: (len(sequences, maxlen+2)), with +2 for SOS and EOS

    => maxlen must be maximum length without SOS and EOS!!!
    """
    res = torch.ones((len(sequences), maxlen+2), dtype=torch.int64) * symbol_dict["<PAD>"]
    lengths = list()
    for i, seq in enumerate(sequences):
        lengths.append(len(seq)+2) # plus 2 for SOS, EOS
        res[i, 0] = symbol_dict["<SOS>"]
        for j, symbol in enumerate(seq):
            res[i, j+1] = symbol_dict[symbol]
        res[i, j+2] = symbol_dict["<EOS>"]
    return res, lengths
symbol_dict

In [None]:
int_to_char_map = {v: k for k, v in symbol_dict.items()}

def map_to_chars(sequences, int_to_char_map):
    for i, seq in enumerate(sequences):
        for j, s in enumerate(seq):
            sequences[i][j] = int_to_char_map[s]
    return sequences

In [None]:
sequences = [
    [0],
    [0, 0],
    [0, 0, 1],
    [1], 
    [1, 0],
    [1, 0, 0]
]

encoded_s = map_to_chars(sequences, int_to_char_map)
x_input, lengths = encode_sequences(encoded_s, symbol_dict, MAX_STRING_LENGTH)
x_mask = construct_attn_mask(lengths, train_dataset.maxlen)

with torch.no_grad():
    outputs = model(**get_forward_dict(x_input, None, x_mask, output_attentions=True))
encoded_s, F.softmax(outputs.logits, dim=1), train_dataset.label_dict

### Now let's look at the attention-matrix. Do we see something?

In [None]:
attentions = outputs.attentions[0].detach().numpy()
attentions.shape

In [None]:
import copy

def label_attention_matrix(attentions, encoded_sequences):
    attn_res = None
    label_res = list()
    for i, seq in enumerate(encoded_sequences):
        if attn_res is None:
            attn_res = np.transpose(attentions[i, :, :len(seq)], (1, 0, -1))
        else:
            attn_res = np.vstack((attn_res, np.transpose(attentions[i, :, :len(seq)], (1, 0, -1)) ))
        label_res.extend(seq)
    return attn_res, label_res

attn_matrix, attn_labels = label_attention_matrix(attentions, encoded_s)
attn_matrix.shape, len(attn_labels)

import copy

class Stack():
    def __init__(self):
        self.list = list()
    
    def push(self, x):
        self.list.append(x)

    def pop(self):
        x = self.list[-1]
        del self.list[-1]
        return x

    def __str__(self):
        return " ".join(self.list)

test = Stack()
test.push("a")
test.push("b")
print(test)
print(test.pop())
print(test)

def get_sequences_dfs(alphabet: list, maxlen: int):
    """
    Returns all lists until maxlen.
    alphabet is list, maxlen is int
    """
    return 
    # TODO: if I want an extensive coverage
    
    res = list()
    tracker = [alphabet[0]] * maxlen
    print("Starting with ", tracker)

    alph_to_idx = {s, i for i, s in enumerate(alphabet)}
    idx_to_alph = {v: k for k, v in alph_to_idx.items()}
    while True:
        res.append(copy.copy(tracker))
        break

## Plot the attention outputs

In [None]:
output = model(**get_forward_dict(train_input_ids, None, train_attn_mask, output_attentions=True))
np_output = output.attentions[0].detach().numpy()
np_output.shape, np.argmax(np_output, axis=-1)#np_output[1, 1]

In [None]:
train_input_ids[0]

In [None]:
def plot_heatmaps(sequences, idx, attn_output, dataset, mapping):
    """
    Plots the heatmaps given by the output of the model.

    sequence: The sequence given as input to the model.
    idx: The index within the output.
    output: As returned by the distilbert model.
    dataset: The sequence-dataset.
    mapping: The mapping to get the sequence as we'd like to have it.
    """

    attn = attn_output[idx]
    num_heads = attn.shape[0]
    print("Number of heads: ", num_heads)

    sequence_list, eos_idx = convert_ids_to_chars(mapping=mapping, tensor_1d=sequences[idx], eos_symbol=dataset.EOS)
    for i in range(1, num_heads+1):
        plt.figure(figsize=(7, 5))
        sns.heatmap(attn[i-1, :eos_idx+1, :eos_idx+1], vmin=0, vmax=1)
        plt.plot()

plot_heatmaps(train_input_ids, 0, np_output, train_dataset, mapping)

## All attention outputs

In [None]:
train_res = model(**get_forward_dict(train_input_ids, None, train_attn_mask, output_attentions=True))
train_predictions = np.argmax( train_res.logits.detach().numpy(), axis=-1)
train_predictions.shape

In [None]:
def filter_attention(sequences, attentions, lengths, mapping, highlight_func):
    """
    This function gives the attention at the relevant indices. It stacks all attentions together,
    and returns them along with a mask for the attentions of interest and the unmapped symbols of 
    those.

    In the mask, uninteresting fields need a zero. Interesting fields can be numbered by ascending integers.

    IMPORTANT: If you want to do e.g. only at TP, ..., then you'll have to pre-filter that. sequences, attentions, 
    lengths are assumed to have same lengths in first dimension.

    sequences: np.array
    attentions: np.array
    lengths: list with the lengths.
    mapping: The casual mapping, from int to char
    highlight_func: A function that takes in a sequences and returns the indices of the symbols of interest.
    """
    attn_stack = None
    symbols = list()
    field_mask = list()
    
    for attn, seq, l in zip(attentions, sequences, lengths):
        if len(attn.shape) > 2:            
            # we have multiple heads
            attn = np.mean(attn, axis=0) # (seq_length, seq_length)
        attn = attn[:l]
        attn_stack = attn if attn_stack is None else np.vstack((attn_stack, attn))

        seq = list(seq[:l])
        
        #convert_ids_to_chars(mapping, torch.LongTensor(seq), eos_symbol=None) # for debugging
        
        symbols.extend(seq)
        field_mask.extend(highlight_func(seq, mapping, PROBLEM))
    return attn_stack, symbols, field_mask

In [None]:
from highlighter_functions.tlt import highlight_tlt

attn, symbols, mask = filter_attention(
                                        train_input_ids.detach().numpy(), 
                                        train_res.attentions[0].detach().numpy(),
                                        list(np.array(train_dataset.sequence_lengths)),
                                        mapping,
                                        highlight_tlt
                                        )

In [None]:
idxs_of_interest = np.where(np.array(mask)!=0)[0]

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=3)
attn_transformed = pca.fit_transform(attn[idxs_of_interest])

sns.scatterplot(x=attn_transformed[:, 0], y=attn_transformed[:, 1], hue=np.array(mask)[idxs_of_interest])

plt.savefig("attn_{}.png".format(PROBLEM))

In [None]:
sns.scatterplot(x=attn_transformed[:, 0], y=attn_transformed[:, 2], hue=np.array(mask)[idxs_of_interest])

In [None]:
sns.scatterplot(x=attn_transformed[:, 1], y=attn_transformed[:, 2], hue=np.array(mask)[idxs_of_interest])

## Attention outputs of TP

### Get the relevant fields

In [None]:
test_res = model(**get_forward_dict(test_input_ids, None, test_attn_mask))
test_predictions = np.argmax( test_res.logits.detach().numpy(), axis=-1)
test_predictions.shape

In [None]:
# find the positive label
train_dataset.label_dict

In [None]:
POSITIVE_LABEL = train_dataset.label_dict["1"]
NEGATIVE_LABEL = train_dataset.label_dict["0"]

positive_prediction_idxs = np.where(test_predictions==POSITIVE_LABEL)[0]
positive_prediction_idxs

In [None]:
positive_label_idxs = np.where(np.array(test_labels)==POSITIVE_LABEL)[0]
positive_label_idxs

In [None]:
_, _ = convert_ids_to_chars(mapping, test_input_ids[0], eos_symbol=train_dataset.PAD)

In [None]:
TP_indices = set(list(positive_prediction_idxs)).intersection(set(list(positive_label_idxs)))
len(TP_indices)

In [None]:
tp_res = model(**get_forward_dict(test_input_ids[list(TP_indices)], None, test_attn_mask[list(TP_indices)], output_attentions=True))
tp_res.attentions[0].detach().numpy().shape

In [None]:
tp_attn, tp_symbols, tp_mask = filter_attention(
                                                               test_input_ids[list(TP_indices)].detach().numpy(), 
                                                               tp_res.attentions[0].detach().numpy(),
                                                               list(np.array(test_dataset.sequence_lengths)[list(TP_indices)]),
                                                               mapping,
                                                               highlight_tlt
                                                               )

In [None]:
tp_attn.shape, len(tp_symbols), len(tp_mask)

### Get information about that stuff

In [None]:
idxs_of_interest = np.where(np.array(tp_mask)!=0)[0]

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
attn_transformed = pca.fit_transform(tp_attn[idxs_of_interest])

sns.scatterplot(x=attn_transformed[:, 0], y=attn_transformed[:, 1], hue=np.array(tp_mask)[idxs_of_interest])

In [None]:
zero_idx = np.where(np.array(tp_mask)==0)[0]
one_idx = np.where(np.array(tp_mask)==1)[0]
two_idx = np.where(np.array(tp_mask)==2)[0]

sns.scatterplot(x=attn_transformed[zero_idx, 0], y=attn_transformed[zero_idx, 1], hue=np.array(tp_mask)[zero_idx], palette="mako")
sns.scatterplot(x=attn_transformed[two_idx, 0], y=attn_transformed[two_idx, 1], hue=np.array(tp_mask)[two_idx])
sns.scatterplot(x=attn_transformed[one_idx, 0], y=attn_transformed[one_idx, 1], hue=np.array(tp_mask)[one_idx], palette="rocket")
#plt.savefig("rep.png")
plt.show()

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2)
attn_transformed = tsne.fit_transform(tp_attn[idxs_of_interest])

sns.scatterplot(x=attn_transformed[:, 0], y=attn_transformed[:, 1], hue=np.array(tp_mask)[idxs_of_interest])

## Test how well a classifier can separate the point clouds

### 1. Try a nearest neighbor

In [None]:
from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors()
nn = nn.fit(tp_attn)

In [None]:
distances, neigh_idxs = nn.kneighbors(tp_attn)
distances.shape, neigh_idxs.shape

In [None]:
distances[:5], neigh_idxs[:5]

In [None]:
test = neigh_idxs[one_idx, 1]
np.array(tp_mask)[test]

## Attn outputs of TN

In [None]:
negative_prediction_idxs = np.where(test_predictions==NEGATIVE_LABEL)[0]
negative_label_idxs = np.where(np.array(test_labels)==NEGATIVE_LABEL)[0]

negative_prediction_idxs, negative_label_idxs

In [None]:
TN_indices = set(list(negative_prediction_idxs)).intersection(set(list(negative_label_idxs)))

tn_res = model(**get_forward_dict(test_input_ids[list(TN_indices)], None, test_attn_mask[list(TN_indices)], output_attentions=True))
tn_res.attentions[0].detach().numpy().shape

len(TN_indices)

In [None]:
def first_char_highlighter_func(seq, mapping):
    """
    Only highlights first char of the sequence.

    seq: np.array or list
    """
    res = [0] * len(seq)
    res[1] = 1
    return res

tn_attn, tn_symbols, tn_mask = filter_attention(
                                                               test_input_ids[list(TN_indices)].detach().numpy(), 
                                                               tn_res.attentions[0].detach().numpy(),
                                                               list(np.array(test_dataset.sequence_lengths)[list(TN_indices)]),
                                                               mapping,
                                                               first_char_highlighter_func
                                                               )

In [None]:
pca = PCA(n_components=2)
tp_attn_transformed = pca.fit_transform(tp_attn)
tn_attn_transformed = pca.fit_transform(tn_attn)

tp_one_idx = np.where(np.array(tp_mask)==1)[0]
tn_one_idx = np.where(np.array(tn_mask)==1)[0]

sns.scatterplot(x=tn_attn_transformed[tn_one_idx, 0], y=tn_attn_transformed[tn_one_idx, 1], hue=np.array(tn_mask)[tn_one_idx], palette="rocket")
sns.scatterplot(x=tp_attn_transformed[tp_one_idx, 0], y=tp_attn_transformed[tp_one_idx, 1], hue=np.array(tp_mask)[tp_one_idx], palette="mako")
#plt.savefig("rep.png")
plt.show()

In [None]:
tn_a_idxs = np.where(np.logical_and(np.array(tn_symbols)==train_dataset.symbol_dict["a"], np.array(tn_mask) == 1))[0]
tn_no_a_idxs = np.where(np.logical_and(np.array(tn_symbols)!=train_dataset.symbol_dict["a"], np.array(tn_mask) == 1))[0]

In [None]:
sns.scatterplot(x=tn_attn_transformed[tn_no_a_idxs, 0], y=tn_attn_transformed[tn_no_a_idxs, 1], hue=np.array([1]*len(tn_no_a_idxs)), palette="mako")
sns.scatterplot(x=tp_attn_transformed[tp_one_idx, 0], y=tp_attn_transformed[tp_one_idx, 1], hue=np.array([2]*len(tp_one_idx)), palette="husl")
sns.scatterplot(x=tn_attn_transformed[tn_a_idxs, 0], y=tn_attn_transformed[tn_a_idxs, 1], hue=np.array([0]*len(tn_a_idxs)), palette="rocket")
plt.savefig("rep.png")
plt.show()

In [None]:
np.array([0]*len(tn_a_idxs))

In [None]:
tn_a_idxs

In [None]:
tn_mask == 1