In [224]:
%%capture
!pip install transformers datasets prettytable

In [225]:
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
from torch.utils.data import DataLoader
from tabulate import tabulate
from datasets import load_dataset
from prettytable import PrettyTable

from tqdm import tqdm
from transformers import BertTokenizer

In [226]:
dataset = load_dataset("scikit-learn/imdb", split="train")
print(dataset)

Dataset({
    features: ['review', 'sentiment'],
    num_rows: 50000
})


In [227]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')



In [228]:
def preprocessing_fn(data, tokenizer):
    tokenized_output = tokenizer(
        data["review"],
        add_special_tokens=False,
        truncation=True,
        max_length=256,
        padding=False,
        return_attention_mask=False
    )
    data["review_ids"] = tokenized_output["input_ids"]
    data["label"] = [0 if sentiment == "negative" else 1 for sentiment in data["sentiment"]]
    return data

In [229]:
n_samples = 5000
dataset = dataset.shuffle(seed=42)
dataset = dataset.select(range(n_samples))
tokenized_dataset = dataset.map(lambda x: preprocessing_fn(x, tokenizer), batched=True)
tokenized_dataset.set_format(type='torch', columns=['review_ids', 'label'])
document_train_set, document_valid_set = tokenized_dataset.train_test_split(test_size=0.2).values()

In [230]:
table = PrettyTable()
table.field_names = ["Dataset", "Size", "Number of classes"]
table.add_row(["Train", len(document_train_set), len(document_train_set.unique("label"))])
table.add_row(["Validation", len(document_valid_set), len(document_valid_set.unique("label"))])
print(table)

Flattening the indices:   0%|          | 0/4000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/1000 [00:00<?, ? examples/s]

+------------+------+-------------------+
|  Dataset   | Size | Number of classes |
+------------+------+-------------------+
|   Train    | 4000 |         2         |
| Validation | 1000 |         2         |
+------------+------+-------------------+


In [231]:
print(document_train_set.column_names)

['review', 'sentiment', 'review_ids', 'label']


In [232]:
def extract_words_contexts(document_ids, radius):
    if len(document_ids) == 0:
        return torch.empty(0), torch.empty(0, 2 * radius)
    if not isinstance(document_ids, torch.Tensor):
        document_ids = torch.tensor(document_ids)
    padding_token = -1
    padded_document_ids = torch.cat([torch.full((radius,), padding_token), document_ids, torch.full((radius,), padding_token)])
    contexts = [
        torch.cat([
            padded_document_ids[i:i+radius],
            padded_document_ids[i+radius+1:i+2*radius+1]
        ])
        for i in range(radius, len(document_ids) + radius)
    ]
    contexts = torch.stack([c if len(c) == 2 * radius else torch.cat([c, torch.full((2 * radius - len(c),), padding_token)]) for c in contexts])
    return document_ids, contexts

def flatten_dataset_to_list(dataset, radius):
    aggregated_word_ids, aggregated_contexts, aggregated_labels = [], [], []
    for document in dataset:
        word_ids, contexts = extract_words_contexts(document['review_ids'], radius)
        aggregated_word_ids.extend(word_ids.tolist())
        aggregated_contexts.extend(contexts.tolist())
        aggregated_labels.extend([document['label']] * len(word_ids))
    return aggregated_word_ids, aggregated_contexts, aggregated_labels

In [233]:
radius = 5
train_word_ids, train_contexts, train_labels = flatten_dataset_to_list(document_train_set, radius)
valid_word_ids, valid_contexts, valid_labels = flatten_dataset_to_list(document_valid_set, radius)

table = PrettyTable()
table.field_names = ["IDs", "Contexts", "Labels"]
table.add_row([len(train_word_ids), len(train_contexts), len(train_labels)])
table.add_row([len(valid_word_ids), len(valid_contexts), len(valid_labels)])
print(table)

Epoch 1/5:   0%|          | 0/25751 [26:46<?, ?batch/s]
Epoch 1/10:   0%|          | 0/25751 [25:19<?, ?batch/s]
Epoch 1/5:   0%|          | 0/25751 [14:09<?, ?batch/s]
Epoch 1/5:   2%|▏         | 517/25751 [10:07<8:14:18,  1.18s/batch, accuracy=0.5, loss=6.1]
Epoch 1/5:  10%|█         | 2587/25751 [11:41<1:44:39,  3.69batch/s, accuracy=0.502, loss=3.92]


+--------+----------+--------+
|  IDs   | Contexts | Labels |
+--------+----------+--------+
| 827575 |  827575  | 827575 |
| 203603 |  203603  | 203603 |
+--------+----------+--------+


In [234]:
class WordContextDataset(torch.utils.data.Dataset):
    def __init__(self, word_ids, contexts, labels):
        assert len(word_ids) == len(contexts) == len(labels), "Length of word_ids, contexts, and labels must be the same"
        self.word_ids = torch.tensor(word_ids, dtype=torch.long)
        self.contexts = torch.tensor(contexts, dtype=torch.long)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.word_ids)

    def __getitem__(self, idx):
        return {
            'word_id': self.word_ids[idx],
            'context': self.contexts[idx],
            'label': self.labels[idx]
        }

In [235]:
train_set = WordContextDataset(train_word_ids, train_contexts, train_labels)
valid_set = WordContextDataset(valid_word_ids, valid_contexts, valid_labels)

In [236]:
table = PrettyTable()
table.field_names = ["Dataset", "Size", "Number of classes"]
table.add_row(["Train", len(train_set), torch.unique(train_set.labels).shape[0]])
table.add_row(["Validation", len(valid_set), torch.unique(valid_set.labels).shape[0]])
print(table)

+------------+--------+-------------------+
|  Dataset   |  Size  | Number of classes |
+------------+--------+-------------------+
|   Train    | 827575 |         2         |
| Validation | 203603 |         2         |
+------------+--------+-------------------+


In [237]:
def collate_fn(batch, vocabulary_size, K, R):
    word_ids = torch.stack([item['word_id'] for item in batch])
    contexts = torch.stack([item['context'] for item in batch])
    labels = torch.stack([item['label'] for item in batch])
    negative_contexts = torch.randint(low=0, high=vocabulary_size, size=(len(batch), 2 * K * R))
    return {
        'word_id': word_ids,
        'positive_context_ids': contexts,
        'negative_context_ids': negative_contexts,
        'labels': labels
    }

In [238]:
vocabulary_size = len(tokenizer.get_vocab())
batch_size = 32
K, R = 5, 5

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, vocabulary_size, K, R))
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, collate_fn=lambda x: collate_fn(x, vocabulary_size, K, R))

In [239]:
train_loader.dataset[0], train_loader.dataset[1], train_loader.dataset[2]

({'word_id': tensor(2412),
  'context': tensor([ 2412,  2144,  1045,  3342,  1010,  2031,  3866, 24042,  1998,  3909]),
  'label': tensor(1)},
 {'word_id': tensor(2144),
  'context': tensor([ 2144,  1045,  3342,  1010,  1045,  3866, 24042,  1998,  3909,  1012]),
  'label': tensor(1)},
 {'word_id': tensor(1045),
  'context': tensor([ 1045,  3342,  1010,  1045,  2031, 24042,  1998,  3909,  1012,  1045]),
  'label': tensor(1)})

In [240]:
table = PrettyTable()
for i, batch in enumerate(train_loader):
    table.field_names = ["Batch", "Word ID", "Positive Contexts", "Negative Contexts", "Labels"]
    table.add_row([i, batch['word_id'].shape, batch['positive_context_ids'].shape, batch['negative_context_ids'].shape, batch['labels'].shape])
    if i == 2:
        break
print(table)

+-------+------------------+----------------------+----------------------+------------------+
| Batch |     Word ID      |  Positive Contexts   |  Negative Contexts   |      Labels      |
+-------+------------------+----------------------+----------------------+------------------+
|   0   | torch.Size([32]) | torch.Size([32, 10]) | torch.Size([32, 50]) | torch.Size([32]) |
|   1   | torch.Size([32]) | torch.Size([32, 10]) | torch.Size([32, 50]) | torch.Size([32]) |
|   2   | torch.Size([32]) | torch.Size([32, 10]) | torch.Size([32, 50]) | torch.Size([32]) |
+-------+------------------+----------------------+----------------------+------------------+


In [241]:
"""
3.2 Model
9. Write a model named Word2Vec which is a valid torch.nn.Module (i.e., write a class that inherits from the torch.nn. Module), and implement the Word2Vec model. It should be parametrized by the vocabulary size and the embeddings dimension. Use the module torch.nn. Embedding.
10. Train the model. The training should be parametrized by the batch size $B$, and the number of epochs $E$.
11. Validates its accuracy on the test set.
12. Write a function save_model that saves the model's embeddings in a file. The file name should be formated like:
"model_dim-<d>_radius-<R>_ratio-<K>-batch-<B>-epoch-<E>.ckpt".

13. Once you have a working code, you can launch a bigger training, using more documents, if it does not take too much time.

3.3 Classification task

In this section you will experiment with the classification task of the lab, augmented with your Word2Vec model.

Make sure this part is independant from the above part. You should not need to retrain the Word2Vec model and only load the embeddings from the file.

Use the notebook from the lab, with the dataset and the training script.
1. Write a function load_model that takes a path to a saved Word2Vec embeddings (with the previous formatting) and loads the checkpoint the embeddings directly to the ConvolutionModel (you can use either the state-of-the art model or the first small model).
2. Train the model, initialized with these emebeddings.
3. Compare the results with the model without this initialization.
4. Make a small ablation study on the influence of some parameters of the Word2Vec model on the classification task. Analyze the results.
"""

'\n3.2 Model\n9. Write a model named Word2Vec which is a valid torch.nn.Module (i.e., write a class that inherits from the torch.nn. Module), and implement the Word2Vec model. It should be parametrized by the vocabulary size and the embeddings dimension. Use the module torch.nn. Embedding.\n10. Train the model. The training should be parametrized by the batch size $B$, and the number of epochs $E$.\n11. Validates its accuracy on the test set.\n12. Write a function save_model that saves the model\'s embeddings in a file. The file name should be formated like:\n"model_dim-<d>_radius-<R>_ratio-<K>-batch-<B>-epoch-<E>.ckpt".\n\n13. Once you have a working code, you can launch a bigger training, using more documents, if it does not take too much time.\n\n3.3 Classification task\n\nIn this section you will experiment with the classification task of the lab, augmented with your Word2Vec model.\n\nMake sure this part is independant from the above part. You should not need to retrain the Word2V

In [242]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        self.target_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, word_ids, context_ids, negative_context_ids):
        target_embeds = self.target_embeddings(word_ids)
        context_embeds = self.context_embeddings(context_ids)
        negative_context_embeds = self.context_embeddings(negative_context_ids)
        positive_dot = torch.bmm(context_embeds, target_embeds.unsqueeze(2)).squeeze(2)
        positive_scores = torch.sigmoid(positive_dot)
        negative_dot = torch.bmm(negative_context_embeds, target_embeds.unsqueeze(2)).squeeze(2)
        negative_scores = torch.sigmoid(negative_dot)
        return positive_scores, negative_scores

    def compute_loss(self, positive_scores, negative_scores):
        positive_labels = torch.ones_like(positive_scores)
        negative_labels = torch.zeros_like(negative_scores)
        positive_loss = F.binary_cross_entropy(positive_scores, positive_labels)
        negative_loss = F.binary_cross_entropy(negative_scores, negative_labels)
        return positive_loss + negative_loss

In [246]:
def train_model(model, train_loader, optimizer, epochs, device):
    model.train()
    epoch_losses = []
    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}", unit='batch')
        for batch in progress_bar:
            optimizer.zero_grad()
            positive_scores, negative_scores = model(
                batch['word_id'].to(device),
                batch['positive_context_ids'].to(device),
                batch['negative_context_ids'].to(device)
            )
            loss = model.compute_loss(positive_scores, negative_scores)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss="{:.4f}".format(loss.item()))
        epoch_losses.append(total_loss / len(train_loader))
    return epoch_losses

In [247]:
def validate_model(model, valid_loader, device):
    model.eval()
    total_loss = 0
    progress_bar = tqdm(valid_loader, desc="Validation", unit='batch')
    with torch.no_grad():
        for batch in progress_bar:
            positive_scores, negative_scores = model(
                batch['word_id'].to(device),
                batch['positive_context_ids'].to(device),
                batch['negative_context_ids'].to(device)
            )
            loss = model.compute_loss(positive_scores, negative_scores)
            total_loss += loss.item()
            progress_bar.set_postfix(loss="{:.4f}".format(loss.item()))
    average_loss = total_loss / len(valid_loader)
    return average_loss

In [None]:
vocab_size = len(tokenizer.get_vocab())
model = Word2Vec(vocab_size, embedding_dim=128)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
device = torch.device("mps")
model.to(device)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=lambda x: collate_fn(x, vocab_size, K, R))
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False, collate_fn=lambda x: collate_fn(x, vocab_size, K, R))

train_losses, train_accuracies = train_model(model, train_loader, optimizer, epochs=5, device=device)
validation_loss, validation_accuracy = validate_model(model, valid_loader, device=device)

print("Training Losses:", train_losses)
print("Training Accuracies:", train_accuracies)
print("Validation Loss:", validation_loss)
print("Validation Accuracy:", validation_accuracy)

Epoch 1/5:   4%|▍         | 496/12931 [00:06<02:32, 81.51batch/s, loss=13.0808]