# Modèle d'attention

Le mécanisme d'attention est une sorte d'apprentissage du pooling ou plutot de la moyenne: l'idée est de pondérer chaque élément de la somme avec un poids issu de paramètres appris.

Ce TP envisage 4 formes d'attention, après avoir testé un modèle de base uniforme:
1. Modèle uniforme: il s'agit d'un pooling classique (qui sera codé comme une attention uniforme)
2. Attention simple: apprentissage d'un vecteur de pondération des éléments
3. Attention personnalisée : les mots définissent leurs fonctions d'attention
4. RNN et attention

In [None]:
from collections import namedtuple
import os
import click
from torch.utils.tensorboard import SummaryWriter
import logging
import re
from pathlib import Path
from tqdm import tqdm
import numpy as np
import time
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib

logging.basicConfig(level=logging.INFO)


## Classe de gestion des données textuelles

1. Récupération d'embedding glove
    1. Téléchargement:
    ```wget http://nlp.stanford.edu/data/glove.6B.zip```
    2. Lecture des fichiers
2. Récupération des données imdb (classification d'opinion)
3. Traitement des données

In [None]:
# recuperation des embbeding 
EMB_SIZE = 50 # 100, 200 or 300
PATH = "./data/glove/glove.6B/" # répertoire où vous avez récupéré les embeddings

vocab,embeddings = [],[]
with open(PATH+'glove.6B.{:d}d.txt'.format(EMB_SIZE),'rt') as fi:
    full_content = fi.read().strip().split('\n')
for i in range(len(full_content)):
    i_word = full_content[i].split(' ')[0]
    i_embeddings = [float(val) for val in full_content[i].split(' ')[1:]]
    vocab.append(i_word)
    embeddings.append(i_embeddings)

In [None]:
print(vocab[0]) # premier mot
print(len(embeddings[0]), embeddings[0]) # premier embedding

In [None]:
# récupération via huggingface des données imdb
from datasets import load_dataset
dataset = load_dataset('imdb')

# dataset["train"][0]
print(dataset["train"][0]['text'])
print(dataset["train"][0]['label'])

In [None]:
class FolderText(Dataset):
    """Dataset gérant la tokenization des documents à la volée"""

    def __init__(self, data, tokenizer):
        self.tokenizer = tokenizer
        self.txts   = [data[i]["text"] for i in range(len(data))]
        self.labels = [data[i]["label"] for i in range(len(data))]
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, ix):
        return self.tokenizer(self.txts[ix]), self.labels[ix]
    def get_txt(self,ix):
        # s = self.txts[ix]
        return self.txts[ix], self.labels[ix]


In [None]:
# mise en forme du dataset
WORDS = re.compile(r"\S+")

embedding_size = len(embeddings[0])
OOVID = len(vocab)
vocab.append("__OOV__")
word2id = {word: ix for ix, word in enumerate(vocab)}
embeddings = np.vstack((embeddings, np.zeros(embedding_size)))

def tokenizer(t):
    return [word2id.get(x, OOVID) for x in re.findall(WORDS, t.lower())]

logging.info("Loading embeddings")
logging.info("Get the IMDB dataset")

train_data, test_data=FolderText(dataset["train"], tokenizer), FolderText(dataset["test"], tokenizer)
id2word = dict((v, k) for k, v in word2id.items())


In [None]:
# vérification rapide du bon fonctionnemnet des éléments ci-dessus

sent = "this movie was great"
ind = tokenizer(sent)
print(ind)
print("check reconstruction :", " ".join([id2word[i] for i in ind]))

# avec un mot inconnu

sent = "this movie was qslkjgf"
ind = tokenizer(sent)
print(ind)
print("check reconstruction :", " ".join([id2word[i] for i in ind]))

# Modélisation

1. Jouer avec la fonction ```softmax```
2. Construction d'un modèle de base


Note: à chaque étape, on aurait pu/du implémenter un masque sur les séquences pour tenir compte du padding (d'où la fonction ci-dessous)

In [None]:
# def masked_softmax(x,lens=None):
#     #X : B x N
#     x = x.view(x.size(0),x.size(1))
#     if lens is None:
#         lens = torch.zeros(x.size(0),1).fill_(x.size(1))
#     mask  = torch.arange(x.size(1),device=x.device).view(1,-1) < lens.view(-1,1)
#     x[~mask] = float('-inf') # indices pas dans le mask
#     # print("MASK", lens,"\n", mask,"\n", x)
#     return x.softmax(1)

In [None]:
# explication des outils
# la fonction softmax
x = torch.tensor([[10, 1, 0.5, 0.2, 0, 0]])
print(x)
print(x.softmax(1)) # exp(x_i) / sum(exp(x_i))


In [None]:
class ModelBase(nn.Module):
    def __init__(self,embeddings,label_count):
        super().__init__()
        self.emb_layer = nn.Embedding.from_pretrained(embeddings)
        self.linear = nn.Linear(embeddings.size(1),label_count)

    def emb(self,x):
        # question triviale: juste pour s'assurer que vous avez bien défini la fonction
        #  TODO 
        

    def forward(self,x): #B x N x E
        #  TODO 
        return yhat
    

    def attention(self,x):
        # retourner une attention uniforme (ATTENTION à la longueur variable des phrases)
        #  TODO 
        return a.softmax(1) # B x N

In [None]:
#  TODO 

In [None]:


class Learner:
    """Base class for supervised learning"""

    def __init__(self, model, model_id: str):
        super().__init__()
        self.model = model
        self.optim = torch.optim.Adam(model.parameters(),lr=1e-3)
        self.model_id = model_id
        self.iteration = 0

    def run(self,train_loader, test_loader, epochs, test_iterations, device,entropy_pen=0.):
        """Run a model during `epochs` epochs"""
        writer = SummaryWriter(f"/tmp/runs/{self.model_id}")
        model = self.model.to(device)
        loss = nn.CrossEntropyLoss()
        loss_nagg = nn.CrossEntropyLoss(reduction='sum')

        model.train()
        for epoch in tqdm(range(epochs)):
            # Iterate over batches
            for x, y, lens in train_loader:
                self.optim.zero_grad()
                yhat = model(x.to(device))
                l = loss(yhat, y.to(device))
                probs = model.attention(model.emb(x.to(device))) # calcul de l'entropie
                entrop = -(probs*(probs+1e-10).log()).sum(1).mean()
                total_l = l+entropy_pen*entrop
                total_l.backward()
                self.optim.step()
                writer.add_scalar('loss/train', l, self.iteration)
                writer.add_scalar('loss/entrop',entrop,self.iteration)
                writer.add_scalar('loss/total_train',total_l,self.iteration)
                self.iteration += 1
                
                if self.iteration % test_iterations == 0:
                    model.eval()
                    with torch.no_grad():
                        lst_probs = []
                        cumloss = 0
                        cumcorrect = 0
                        count = 0
                        for x, y, lens in test_loader:
                            yhat = model(x.to(device))
                            cumloss += loss_nagg(yhat, y.to(device))
                            cumcorrect += (yhat.argmax(1) == y.to(device)).sum()
                            count += x.shape[0]
                            probs =  model.attention(model.emb(x.to(device)))
                            lst_probs.append(-(probs*(probs+1e-10).log()).sum(1))

                        writer.add_scalar(
                            'loss/test', cumloss.item() / count, self.iteration)
                        writer.add_scalar(
                            'correct/test', cumcorrect.item() / count, self.iteration)
                        
                        writer.add_histogram(f'entropy',torch.cat(lst_probs),self.iteration)
                        
                    model.train()


        

In [None]:
def collate(batch):
        """ Collate function for DataLoader """
        data = [torch.LongTensor(item[0]) for item in batch]
        lens = [len(d) for d in data]
        labels = [item[1] for item in batch]
        return torch.nn.utils.rnn.pad_sequence(data, batch_first=True,padding_value = PAD), \
                torch.LongTensor(labels), torch.Tensor(lens)


In [None]:
EPOCH = 30 # nepoch
NITER = 50 # calcul perf test
BATCH_SIZE = 128

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
PAD = word2id["__OOV__"] # variable globale pour collate
embeddings = torch.Tensor(embeddings)
model = ModelBase(embeddings, 2) # deux étiquettes
# model1 = ModelAttention1(embeddings, 2)
# model2 = ModelAttention2(embeddings, 2)
# model4 = ModelAttention4(embeddings, 2)

train_loader = DataLoader(train_data, shuffle=True,batch_size=BATCH_SIZE, collate_fn=collate)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, collate_fn=collate,shuffle=False)

learner = Learner(model, time.asctime())
learner.run(train_loader,test_loader,EPOCH,NITER,device,0.)

# Exploitation des résultats

1. Visualiser l'attention sur les textes
2. Calculer les performances et proposer des nouvelles phrases

In [None]:
def colorize(words, color_array):
    # words is a list of words
    # color_array is an array of numbers between 0 and 1 of length equal to words
    cmap = matplotlib.colormaps['Reds']
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
    return colored_string


In [None]:
from IPython.display import display, HTML

# mise en forme du texte de sortie:
sent = "this movie was great"
# att  = np.ones(len(sent.split(" "))) # np.random.rand(len(sent.split(" ")))
att  =  np.random.rand(len(sent.split(" ")))
print(att)
display(HTML(colorize(sent.split(" "),att)))


In [None]:
# essai sur une phrase 
from IPython.display import display, HTML

mod = model # selection du modèle

sent = "this movie was not so bad"
ind = tokenizer(sent)
print("check :", " ".join([id2word[i] for i in ind]))
ind = torch.tensor(ind).unsqueeze(0)
print(ind)
res = mod(ind.to(device))
att = mod.attention(mod.emb(ind.to(device)))
print(res)
print(att)

display(HTML(colorize(sent.split(" "),att.to("cpu").detach().squeeze().numpy())))



In [1]:
###  TODO )"," TODO ",\
    txt, flags=re.DOTALL))
f2.close()

### </CORRECTION> ###