```
// Copyright 2020 Twitter, Inc.
// SPDX-License-Identifier: Apache-2.0
```

# Multilingual Alignment of mBERT via Translation Pair Prediction

This notebook can be used to add the Translation Pair Pre-training task to multilingual models. We demonstrate our usecase using mBERT.


## Setup libraries


In [1]:
%pip install transformers==3.5.1 datasets==1.1.2 torch==1.4.0 seqeval==1.2.2 gensim==3.8.1

Collecting torch==1.4.0
  Downloading torch-1.4.0-cp37-cp37m-manylinux1_x86_64.whl (753.4 MB)
[K     |████████████████████████████████| 753.4 MB 7.1 kB/s 
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.9.0+cu102
    Uninstalling torch-1.9.0+cu102:
      Successfully uninstalled torch-1.9.0+cu102
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.10.0+cu102 requires torch==1.9.0, but you have torch 1.4.0 which is incompatible.
torchtext 0.10.0 requires torch==1.9.0, but you have torch 1.4.0 which is incompatible.[0m
Successfully installed torch-1.4.0


## Define parameters


In [None]:
EN_ONLY_MODEL = False
WIKIDATA_MODEL = True

# Options for WIKIDATA_PREFIX: en_ja, en_ar, en_hi, en_hi_ja_ar_equal en_hi_ja_ar
WIKIDATA_PREFIX = "en_ja"

HOMEDIR = "./"
DATADIR = f"{HOMEDIR}/tatoeba"
pre_trained_model_path = "bert-base-multilingual-uncased"


## Setup Helpers


In [2]:
import json
import random
import re
from dataclasses import dataclass
from itertools import combinations
from pathlib import Path
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

import numpy as np
import torch
from sklearn.metrics import classification_report
from torch.nn import BCEWithLogitsLoss, CosineEmbeddingLoss, CrossEntropyLoss, MSELoss
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import IterableDataset
from tqdm import tqdm, trange
from transformers import (
    BertForNextSentencePrediction,
    BertTokenizerFast,
    DataCollatorWithPadding,
    EvalPrediction,
    RobertaTokenizerFast,
    Trainer,
    TrainingArguments,
)
from transformers.modeling_bert import BertModel, BertOnlyNSPHead, BertPreTrainedModel
from transformers.modeling_outputs import NextSentencePredictorOutput
from transformers.tokenization_utils_base import (
    BatchEncoding,
    PaddingStrategy,
    PreTrainedTokenizerBase,
)


In [3]:
class BertTwoTowerHead(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.projection = torch.nn.Linear(config.hidden_size, config.hidden_size)

    def forward(self, pooled_output):
        projection = self.projection(pooled_output)
        return projection


class BertTwoTowerLoss(torch.nn.Module):
    def __init__(self, loss_type, margin=0.0):
        super().__init__()
        if loss_type == "bce":
            self.loss_compute = self._bce_loss()
        elif loss_type == "cosine":
            self.loss_compute = self._cosine_loss(margin)
        else:
            raise NotImplementedError(f"loss_type={loss_type} not implemented")

    def _bce_loss(self):
        loss_fn = BCEWithLogitsLoss()

        def _loss_compute(x1, x2, labels):
            seq_relationship_scores = (x1 * x2).sum(axis=-1).view(-1)
            labels = labels.view(-1) * 1.0
            loss = loss_fn(seq_relationship_scores, labels)
            return loss

        return _loss_compute

    def _cosine_loss(self, margin):
        loss_fn = CosineEmbeddingLoss(margin=margin)

        def _loss_compute(x1, x2, labels):
            x1 = x1.view(-1, x1.shape[-1])
            x2 = x2.view(-1, x2.shape[-1])
            labels = labels.view(-1)
            neg_label = torch.tensor(-1, device=labels.device)
            labels = torch.where(labels > 0, labels, neg_label)  # Cosine requires this
            loss = loss_fn(x1, x2, labels)
            return loss

        return _loss_compute

    def forward(self, x1, x2, labels):
        loss = self.loss_compute(x1, x2, labels)
        return loss


class BertForTwoTowerPrediction(BertPreTrainedModel):
    def __init__(self, config, loss_type="bce", margin=0.0):
        super().__init__(config)

        self.bert = BertModel(config)
        self.cls = BertTwoTowerHead(config)
        self.loss_fn = BertTwoTowerLoss(loss_type, margin)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Returns:

        Example::

            >>> from transformers import BertTokenizer, BertForNextSentencePrediction
            >>> import torch

            >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')

            >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
            >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
            >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')

            >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
            >>> logits = outputs.logits
            >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
        """

        if "next_sentence_label" in kwargs:
            warnings.warn(
                "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
                FutureWarning,
            )
            labels = kwargs.pop("next_sentence_label")

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = self.cls(outputs[1])
        pooled_output_pair = pooled_output.reshape(-1, 2, *pooled_output.shape[1:])
        # Dot product
        x1 = pooled_output_pair[:, 0]
        x2 = pooled_output_pair[:, 1]
        seq_relationship_scores = (x1 * x2).sum(axis=-1)

        # Only take the label of the first pair, cast to float
        labels = labels.reshape(-1, 2, *labels.shape[1:])[:, 0]  # * 1.0

        next_sentence_loss = None
        if labels is not None:
            next_sentence_loss = self.loss_fn(x1, x2, labels)

        if not return_dict:
            output = (seq_relationship_scores,) + outputs[2:]
            return (
                ((next_sentence_loss,) + output)
                if next_sentence_loss is not None
                else output
            )

        return NextSentencePredictorOutput(
            loss=next_sentence_loss,
            logits=seq_relationship_scores,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [4]:
URL_REGEX = re.compile(r"http[s]?://[^ ]+")


def clean_text(text):
    text = text.replace("\n", " [LF] ")
    return URL_REGEX.sub("[URL]", text)


def read_file(file_path):
    with open(file_path) as fp:
        for line in fp:
            line = line.strip()
            if not line:
                continue
            line = json.loads(line)
            num_samples = 5
            line = [clean_text(t) for t in line["unique_label_desc"]]
            line = np.random.permutation(line)[:num_samples].tolist()
            n_tweets = len(line)
            if n_tweets < 1:
                continue
            yield line, n_tweets


def get_tweet_pairs(tweet_list):
    np.random.shuffle(tweet_list)
    for t1, t2 in combinations(tweet_list, 2):
        yield [t1, t2]


def get_all_pairs(url_tweets, n_url_tweets, num_negatives=1):
    """
    file_data: list of list of tweets related to a URL
    """
    n_urls = len(url_tweets)
    n_url_tweets = np.array(n_url_tweets)
    max_n_tweets = max(n_url_tweets)
    for i, tweet_list in enumerate(url_tweets):
        n_tweets = n_url_tweets[i]
        n_pairs = n_tweets * (n_tweets - 1) // 2
        neg_url_indexes = np.concatenate(
            [np.arange(i), np.arange(i + 1, n_urls)]
        )  # Skip this URL
        neg_url_samples = np.random.choice(neg_url_indexes, (n_pairs, num_negatives))
        neg_url_positions = np.random.randint(
            2, size=(n_pairs, num_negatives)
        )  # Positions
        neg_tweet_indexes = np.random.randint(
            max_n_tweets, size=(n_pairs, num_negatives)
        )
        neg_tweet_indexes = neg_tweet_indexes % n_url_tweets[neg_url_samples]
        for j, pair in enumerate(get_tweet_pairs(tweet_list)):
            yield pair, [
                1,
                1,
            ]  # Pos data <-- This needs to be made 0 for compatibility with API
            for nj in range(num_negatives):
                neg_pair_position = neg_url_positions[j][nj]
                neg_url_idx = neg_url_samples[j][nj]
                neg_tweet_idx = neg_tweet_indexes[j][nj]  # % n_url_tweets[neg_url_idx]
                neg_tweet = url_tweets[neg_url_idx][neg_tweet_idx]
                neg_pair = (
                    [pair[0], neg_tweet]
                    if neg_pair_position == 1
                    else [neg_tweet, pair[1]]
                )
                yield neg_pair, [
                    0,
                    0,
                ]  # Neg data <-- This needs to be made 1 for compatibility with API


def read_sentencepair_data(file_paths):
    for file_path in file_paths:
        url_tweets, n_url_tweets = zip(*read_file(file_path))
        yield from get_all_pairs(url_tweets, n_url_tweets)


In [5]:
label2id = {False: 0, True: 1}

id2label = {v: k for k, v in label2id.items()}


class SentencePairDataset(IterableDataset):
    def __init__(self, file_paths, tokenizer, test_mode=False):
        self.file_paths = file_paths
        self.tokenizer = tokenizer
        self.test_mode = test_mode
        self._setup()

    def _setup(self):
        dataset_length = 0
        for file_path in tqdm(self.file_paths):
            url_tweets, n_url_tweets = zip(*read_file(file_path))
            n_url_tweets = np.array(n_url_tweets)
            n_pairs = (
                n_url_tweets * (n_url_tweets - 1)
            ).sum()  # Don't divide by 2 as for each positive there is a negative
            dataset_length += n_pairs
        self.dataset_length = dataset_length

    def prepare_tokenized_examples(self, batch_size=100000):
        chunk_pairs = []
        chunk_labels = []

        def _process_chunk(chunk_pairs, chunk_labels):
            encodings = self.tokenizer(chunk_pairs, max_length=512, truncation=True)
            for i, label in enumerate(chunk_labels):
                item = {key: torch.tensor(val[i]) for key, val in encodings.items()}
                item["label"] = torch.tensor(label)
                yield item

        if self.test_mode:
            np.random.seed(1337)
        for i, (pair, label) in enumerate(read_sentencepair_data(self.file_paths)):
            if self.test_mode:
                if i > 10000:
                    break
            if len(chunk_pairs) > batch_size:
                yield from _process_chunk(chunk_pairs, chunk_labels)
                chunk_labels = []
                chunk_pairs = []
            chunk_pairs.extend(pair)
            chunk_labels.extend(label)
        if chunk_pairs:
            yield from _process_chunk(chunk_pairs, chunk_labels)

    def __iter__(self):
        for item in self.prepare_tokenized_examples(batch_size=100000):
            yield item

    def __len__(self):
        return self.dataset_length


def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)
    result = classification_report(p.label_ids, preds)
    return result


# Setup Datasets


In [6]:
file_paths = list(Path(DATADIR).expanduser().glob(f"./{WIKIDATA_PREFIX}*.json"))
model_dir = pre_trained_model_path

tokenizer = BertTokenizerFast.from_pretrained(
    str(model_dir), max_len=512, truncation=True, padding=True
)

if WIKIDATA_MODEL:
    train_file_paths = file_paths
    val_file_paths = []
    test_file_paths = []
else:
    train_file_paths = file_paths[:-2]
    val_file_paths = file_paths[-2:-1]  # Single file
    test_file_paths = file_paths[-1:]  # Single file

# langs = {"en"}
# Set test_mode to false to run the model on full data
test_mode = False
train_dataset = SentencePairDataset(train_file_paths, tokenizer, test_mode=test_mode)
val_dataset = SentencePairDataset(val_file_paths, tokenizer, test_mode=test_mode)


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=871891.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1715180.0, style=ProgressStyle(descript…

100%|██████████| 1/1 [00:00<00:00, 104.22it/s]
0it [00:00, ?it/s]







In [7]:
data = []
for i, (pair, label) in enumerate(read_sentencepair_data(file_paths[:1])):
    print(i, label, pair)
    data.append(pair)
    if i > 10:
        break

print("***" * 10)
for i, pair in enumerate(data):
    print(i, pair)


0 [1, 1] ['きみにちょっとしたものをもってきたよ。', 'I brought you a little something.']
1 [0, 0] ['きみにちょっとしたものをもってきたよ。', '何をしているの。']
2 [1, 1] ['何かしてみましょう。', "Let's try something."]
3 [0, 0] ['何かしてみましょう。', '何をしていますか。']
4 [1, 1] ['I have to go to sleep.', '私は眠らなければなりません。']
5 [0, 0] ['What are you doing?', '私は眠らなければなりません。']
6 [1, 1] ['I have to go to sleep.', 'そろそろ寝なくちゃ。']
7 [0, 0] ['君はどうするの。', 'そろそろ寝なくちゃ。']
8 [1, 1] ['私は眠らなければなりません。', 'I have to sleep.']
9 [0, 0] ['何してるの？', 'I have to sleep.']
10 [1, 1] ['What are you doing?', '何してるの？']
11 [0, 0] ['私は眠らなければなりません。', '何してるの？']
******************************
0 ['きみにちょっとしたものをもってきたよ。', 'I brought you a little something.']
1 ['きみにちょっとしたものをもってきたよ。', '何をしているの。']
2 ['何かしてみましょう。', "Let's try something."]
3 ['何かしてみましょう。', '何をしていますか。']
4 ['I have to go to sleep.', '私は眠らなければなりません。']
5 ['What are you doing?', '私は眠らなければなりません。']
6 ['I have to go to sleep.', 'そろそろ寝なくちゃ。']
7 ['君はどうするの。', 'そろそろ寝なくちゃ。']
8 ['私は眠らなければなりません。', 'I have to sleep.']
9 ['何してるの？', 'I have to sleep.'

# Setup Model


In [8]:
loss_type = "bce"
loss_margin = 0  # 0, 0.5, -0.5, -1
model_prefix = f"{loss_type}_{loss_margin}" if loss_type == "cosine" else f"{loss_type}"
extra_model_prefix = "from_tt_wm"  # "from_tt_wd" # None #"from_tt" # "from_wm"
if extra_model_prefix:
    model_prefix = f"{model_prefix}_{extra_model_prefix}"

multi_type = "wiki"  # "curated" # "wikimat" # "wiki" # "tatoeba"
model_prefix = f"{WIKIDATA_PREFIX}_{model_prefix}"
finetuned_model_dir = str(
    Path(f"{HOMEDIR}/multi_{multi_type}_2t_{model_prefix}_model").expanduser()
)
logging_dir = str(
    Path(f"{HOMEDIR}/multi_{multi_type}_2t_{model_prefix}_logs").expanduser()
)

data_collator = DataCollatorWithPadding(tokenizer, padding=True, max_length=514)

num_epochs = 1
if test_mode:
    num_epochs = 1000
elif WIKIDATA_MODEL and WIKIDATA_PREFIX.startswith("en_"):
    num_epochs = 3

training_args = TrainingArguments(
    output_dir=str(finetuned_model_dir),  # output directory
    num_train_epochs=num_epochs,  # total number of training epochs
    per_device_train_batch_size=32,  # batch size per device during training # Ensure this is multiple of 4
    warmup_steps=500,  # number of warmup steps for learning rate scheduler
    logging_dir=str(logging_dir),  # directory for storing logs
    logging_steps=10,
    save_total_limit=2,
    prediction_loss_only=True,
    learning_rate=1e-6,
    label_names=[id2label[i] for i in range(len(id2label))],
)


In [9]:
def train_model():
    model = BertForTwoTowerPrediction.from_pretrained(
        str(model_dir),
        num_labels=len(label2id),
        id2label=id2label,
        label2id=label2id,
        loss_type=loss_type,
        margin=loss_margin,
    )
    # Patch model token_type_embedding as it is of wrong size during pre-training
    # model.config.type_vocab_size += 1 # Increase it for next sentence prediction task
    # token_type_embeddings = torch.nn.Embedding(model.config.type_vocab_size, model.config.hidden_size)
    # token_type_embeddings.weight[0] = model.bert.embeddings.token_type_embeddings.weight[0]
    # model.bert.embeddings.token_type_embeddings = token_type_embeddings # Re-assign new embedding
    # Ready to train
    trainer = Trainer(
        model=model,  # the instantiated 🤗 Transformers model to be trained
        args=training_args,  # training arguments, defined above
        train_dataset=train_dataset,  # training dataset
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    trainer.save_model(finetuned_model_dir)
    tokenizer.save_pretrained(finetuned_model_dir)
    return model, trainer


In [10]:
%%time
model, trainer = train_model()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=672271273.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForTwoTowerPrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTwoTowerPrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTwoTowerPrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTwoTowerPrediction were not initialized from the model checkpoint at be

Step,Training Loss


CPU times: user 26.1 s, sys: 6.86 s, total: 32.9 s
Wall time: 35.1 s


In [11]:
# ! rm -rf multi_wiki_2t_en_ja_bce_from_tt_wm_*
