---
title: "Word alignment with multilingual BERT: from subwords to lexicons"
subtitle: "Multilingual NLP -- Lab 2"
author: "Philippos Triantafyllou"
date-modified: last-modified
date-format: long
lang: en
format:
    pdf:
        pdf-engine: lualatex
        documentclass: scrartcl
        fontsize: 16pt
        papersize: A3
        toccolor: blue
        classoption: 
            - "DIV=12"
            - "parskip=relative"
            - "titlepage=false"
        code-block-border-left: MediumBlue
        code-block-bg: WhiteSmoke
        template-partials:
            - "../_pandoc/doc-class.tex"
            - "../_pandoc/toc.tex"
            - "../_pandoc/before-title.tex"
toc: true
toc-depth: 3
number-depth: 1
number-sections: true
highlight-style: github
fig-cap-location: top
embed-resources: true
---

:::{.callout-note}
## Instructions

In this lab, we will explore how to extract a bilingual lexicon (a list of word translation pairs linking a source language to a target language) directly from the internal representations of a "large" multilingual language model. Specifically, we will rely on `mBERT`, a variant of `BERT` pre-trained on Wikipedia in over 100 languages (Devlin et al. 2019).

As we discussed at (too) great length in class, one of the surprising findings about `mBERT` is its ability to perform cross-lingual transfer without any explicit alignment objective or parallel data (Pires, Schlinger et Garrette 2019; Wu et Dredze 2019). Although trained solely with a masked language modelling objective on monolingual corpora, `mBERT` appears to learn a shared semantic space across languages. Words that are translations of each other tend to occupy neighbouring regions in the representation space, even when the languages do not share scripts or subwords.

Because of this emergent alignment, it is possible to recover word-level translation pairs by comparing the contextualised embeddings produced by `mBERT` for parallel sentences. In this lab, we will experiment with three different alignment strategies:

- Direct argmax alignment, where each source word is linked to the target word with the most similar embedding.
- Competitive linking (via the Hungarian algorithm), which enforces one-to-one alignments.
- Canonical Correlation Analysis (CCA), which learns linear projections to better align the representation spaces of two languages (Cao, Kitaev et Klein 2020).

To carry out this lab, we will rely on a range of multilingual resources. First, we require parallel corpora (collections of sentence pairs that are translations of each other) so that we can compare word representations across languages. Well-known examples include the No Language Left Behind (`NLLB`) dataset (Team et al. 2022), which provides large-scale, high-quality translations across more than 200 languages. In addition, we will use bilingual lexicons, that is, precompiled lists of word translation pairs such as those available in `MUSE` (Lample et al. 2018) or `PanLex` (Kamholz, Pool et Colowick 2014). These lexicons serve both as supervision signals (for instance when applying Canonical Correlation Analysis) and as gold standards to evaluate the accuracy of the alignments we obtain.
:::

## Selecting languages

We select 5 languages that are well represented in the training of `mBERT` and 5 others that are either less represented or absent (the only language that is completely absent from `mBERT` is Kurdish). All the languages also have a bilingual lexicon from `MUSE` that we will use for the CCA.

In [1]:
# Disable useless logging

import logging
logging.basicConfig(level=logging.INFO)

from transformers import logging
logging.set_verbosity_error()

from datasets import disable_progress_bars
disable_progress_bars()

In [2]:
#| echo: true

# (language, code, link_to_MUSE)
langs = [
    ("albanian",    "sq",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-sq.0-5000.txt"),
    ("arabic",      "ar",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-ar.0-5000.txt"),
    ("french",      "fr",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-fr.0-5000.txt"),
    ("german",      "de",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-de.0-5000.txt"),
    ("greek",       "el",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-el.0-5000.txt"),
    ("hindi",       "hi",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-hi.0-5000.txt"),
    ("japanese",    "ja",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-ja.0-5000.txt"),
    ("macedonian",  "mk",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-mk.0-5000.txt"),
    ("persian",     "fa",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-fa.0-5000.txt"),
    ("thai",        "th",   "https://dl.fbaipublicfiles.com/arrival/dictionaries/en-th.0-5000.txt")
]

In [14]:
import requests
from datasets import load_dataset

def install_datasets(langs):
    corpora = {}

    # Iterate through each language definition (name, ISO code, lexicon URL)
    for lang, code, links in langs:
        print(f"Processing {lang}...")
        corpora[lang] = {}
        print(f"\tLoading pairs...")

        # Load parallel English–target-language sentence pairs from JW300
        data = load_dataset(
            "sentence-transformers/parallel-sentences-jw300",
            f"en-{code}",
            split="train"
        )
        print(f"\tLoaded JW300 for {lang}")

        corpora[lang]['pairs'] = (
            data
            # Load up to 5,000 parallel sentence pairs
            .take(5000) # type: ignore
            .map(
                # Clean zero-width Unicode chars
                lambda x: {
                    'english': x['english'].replace('\u200b', '').replace('\u200f', '').replace('\u200c', ''),
                    'non_english': x['non_english'].replace('\u200b', '').replace('\u200f', '').replace('\u200c', '')
                }
            )
            # Drop any pair where either side becomes empty after cleaning
            .filter(
                lambda x: x['english'].strip() and x['non_english'].strip()
            )
        )

        # Report how many valid parallel pairs remain after filtering
        print(f"\tKept {len(corpora[lang]['pairs'])} aligned sentence pairs")

        print(f"\tLoading lexicons...")

        # Download and decode the lexicon text file
        text = requests.get(links).content.decode("utf-8", errors="ignore")

        # Parse tab-separated lexicons if tabs are present
        if '\t' in text:
            lexicon = [
                tuple(line.split('\t', 1))
                for line in text.splitlines()
                if line.strip() and '\t' in line
            ]
        # Otherwise parse whitespace-separated lexicons (French and German)
        else:
            lexicon = [
                tuple(line.split(None, 1))
                for line in text.splitlines()
                if line.strip()
            ]

        corpora[lang]["lexicon"] = lexicon
    return corpora

In [15]:
import os
from datasets import Dataset

def save_datasets(corpora, data_folder="data"):
    os.makedirs(data_folder, exist_ok=True)
    
    for lang, content in corpora.items():
        lang_folder = os.path.join(data_folder, lang)
        os.makedirs(lang_folder, exist_ok=True)
        
        # Save pairs as jsonl
        pairs_path = os.path.join(lang_folder, "pairs.jsonl")
        content['pairs'].to_json(pairs_path)
        
        # Save lexicon as text file
        lexicon_path = os.path.join(lang_folder, "lexicon.txt")
        with open(lexicon_path, 'w', encoding='utf-8') as f:
            for src, tgt in content['lexicon']:
                f.write(f"{src}\t{tgt}\n")

In [16]:
def load_datasets(langs, data_folder="data"):
    corpora = {}
    
    for lang, _, _ in langs:
        lang_folder = os.path.join(data_folder, lang)
        corpora[lang] = {}
        
        # Load pairs
        pairs_path = os.path.join(lang_folder, "pairs.jsonl")
        corpora[lang]['pairs'] = Dataset.from_json(pairs_path)
        
        # Load lexicon
        lexicon_path = os.path.join(lang_folder, "lexicon.txt")
        with open(lexicon_path, 'r', encoding='utf-8') as f:
            lexicon = [tuple(line.strip().split('\t', 1)) for line in f if line.strip()]
        corpora[lang]['lexicon'] = lexicon
    
    return corpora

In [17]:
data = install_datasets(langs)
save_datasets(data)

Processing albanian...
	Loading pairs...
	Loaded JW300 for albanian
	Kept 4928 aligned sentence pairs
	Loading lexicons...
Processing arabic...
	Loading pairs...
	Loaded JW300 for arabic
	Kept 4844 aligned sentence pairs
	Loading lexicons...
Processing french...
	Loading pairs...
	Loaded JW300 for french
	Kept 4773 aligned sentence pairs
	Loading lexicons...
Processing german...
	Loading pairs...
	Loaded JW300 for german
	Kept 4942 aligned sentence pairs
	Loading lexicons...
Processing greek...
	Loading pairs...
	Loaded JW300 for greek
	Kept 4618 aligned sentence pairs
	Loading lexicons...
Processing hindi...
	Loading pairs...
	Loaded JW300 for hindi
	Kept 4053 aligned sentence pairs
	Loading lexicons...
Processing japanese...
	Loading pairs...
	Loaded JW300 for japanese
	Kept 4557 aligned sentence pairs
	Loading lexicons...
Processing macedonian...
	Loading pairs...
	Loaded JW300 for macedonian
	Kept 4890 aligned sentence pairs
	Loading lexicons...
Processing persian...
	Loading pairs

In [28]:
data = load_datasets(langs)

In [29]:
for lang, corpora in data.items():
    print(f"For {lang}:")
    for keys, vals in corpora.items():
        print(f"{keys} ({type(vals)})")
        print(f"{vals[:5]}")
    print()

For albanian:
pairs (<class 'datasets.arrow_dataset.Dataset'>)
{'english': ['Page Two', '1945 - 1995 What Have We Learned?', '3 - 14', 'Fifty years have passed since the end of World War II.', 'In what ways has humankind progressed?'], 'non_english': ['Faqja dy', '1945 - 1995 Çfarë kemi mësuar?', '3 - 14', 'Kanë kaluar pesëdhjetë vjet që nga mbarimi i Luftës II Botërore.', 'Në cilat fusha ka bërë progres njeriu?']}
lexicon (<class 'list'>)
[('and', 'dhe'), ('was', 'ishte'), ('for', 'për'), ('for', 'per'), ('that', 'që')]

For arabic:
pairs (<class 'datasets.arrow_dataset.Dataset'>)
{'english': ['Perhaps you have asked yourself this question, as well as others: “How Can I Say No to Premarital Sex? ”', '“ Why Do I Get So Depressed? ”', 'These are chapter titles in the new book Questions Young People Ask  — Answers That Work.', 'A youth from San Antonio, Texas, wrote: “I really enjoyed chapter 36, ‘ How Can I Control My TV Viewing Habits? ’', 'It will really help me. I am a compulsive TV 

For the selected languages, obtain the `mBERT` representations for all these sentences, aggregating subword embeddings in order to reconstruct word-level vectors.

In [30]:
#| echo: true

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-cased")
model = AutoModel.from_pretrained("google-bert/bert-base-multilingual-cased")

The pipeline is as follows:

- A function takes care to extract the embeddings for a list of sentences in a given language.
- A second function iterates over all languages in the dataset, calling the first function for each language and storing the results.
- Finally, we save the extracted embeddings to disk for later use.

Function to extract embeddings.

In [31]:
import torch

def get_word_embeddings(sentences, tokenizer, model, batch_size=16, device=None):
    # Device detection
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model.to(device)
    model.eval()

    all_results = []
    total_sentences = len(sentences)
    
    # Process in batches
    for batch_idx in range(0, total_sentences, batch_size):
        batch = sentences[batch_idx:batch_idx + batch_size]

        # Tokenize batch
        encoded = tokenizer(
            list(batch),
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512,
            add_special_tokens=True
        )

        # Move to device
        input_ids = encoded['input_ids'].to(device)
        attention_mask = encoded['attention_mask'].to(device)

        # Get embeddings
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state # [batch_size, seq_len, hidden_size]

        # Process each sentence in the batch
        for idx_in_batch, sentence in enumerate(batch):
            word_ids = encoded.word_ids(batch_index=idx_in_batch)
            sent_embeddings = embeddings[idx_in_batch]

            # Group tokens by word_id
            word_to_token_indices = {}
            for token_idx, word_id in enumerate(word_ids):
                if word_id is not None:
                    if word_id not in word_to_token_indices:
                        word_to_token_indices[word_id] = []
                    word_to_token_indices[word_id].append(token_idx)

            # Aggregate embeddings (mean)
            word_embeddings = []
            for word_id in sorted(word_to_token_indices.keys()):
                token_indices = word_to_token_indices[word_id]
                token_embs = sent_embeddings[token_indices]
                word_emb = token_embs.mean(dim=0)

                word_embeddings.append({
                    'word_id': word_id,
                    'embedding': word_emb.cpu().numpy(),
                    'num_tokens': len(token_indices)
                })

            all_results.append({
                'sentence': sentence,
                'word_embeddings': word_embeddings,
                'num_words': len(word_embeddings)
            })

        # Progress
        processed = min(batch_idx + batch_size, total_sentences)
        progress_pct = (processed / total_sentences) * 100
        print(f"\tProgress: {processed}/{total_sentences} ({progress_pct:.1f}%)", end='\r')

        # Memory management
        if device.type == 'cpu' and batch_idx % (batch_size * 20) == 0:
            import gc
            gc.collect()

        if device.type == 'cuda' and batch_idx % (batch_size * 10) == 0:
            torch.cuda.empty_cache()

    print()
    return all_results

Function to process all languages in the dataset.

In [32]:
def process_all_languages(data, tokenizer, model, batch_size=16, device=None, save_dir=None):
    import time
    
    embeddings_by_language = {}
    total_start_time = time.time()
    
    for lang_idx, (lang, corpora) in enumerate(data.items(), 1):
        print(f"\n{'='*70}")
        print(f"[{lang_idx}/{len(data)}] Processing {lang.upper()}")
        print('='*70)
        
        embeddings_by_language[lang] = {}
        
        if 'pairs' in corpora:
            english_sentences = corpora['pairs']['english']
            non_english_sentences = corpora['pairs']['non_english']
            
            # Process English
            print(f"\nProcessing English sentences ({len(english_sentences)} total)...")
            start_time = time.time()
            embeddings_by_language[lang]['english'] = get_word_embeddings(
                english_sentences, tokenizer, model, 
                batch_size=batch_size, device=device
            )
            english_time = time.time() - start_time
            print(f"\tCompleted in {english_time:.2f} seconds")
            
            # Process non-English
            print(f"\nProcessing {lang.capitalize()} sentences ({len(non_english_sentences)} total)...")
            start_time = time.time()
            embeddings_by_language[lang]['non_english'] = get_word_embeddings(
                non_english_sentences, tokenizer, model, 
                batch_size=batch_size, device=device
            )
            non_english_time = time.time() - start_time
            print(f"\tCompleted in {non_english_time:.2f} seconds")
            
            # Summary
            print(f"\n{lang.upper()} Summary:")
            print(f"\tEnglish: {len(embeddings_by_language[lang]['english'])} sentences")
            print(f"\t{lang.capitalize()}: {len(embeddings_by_language[lang]['non_english'])} sentences")
            print(f"\tTotal time: {english_time + non_english_time:.2f} seconds")
            
            # Save intermediate results
            if save_dir:
                import pickle
                import os
                os.makedirs(save_dir, exist_ok=True)
                
                filepath = os.path.join(save_dir, f'embeddings_{lang}.pkl')
                with open(filepath, 'wb') as f:
                    pickle.dump(embeddings_by_language[lang], f)
                print(f"\tSaved to {filepath}")
    
    total_time = time.time() - total_start_time
    print(f"\n{'='*70}")
    print(f"ALL LANGUAGES PROCESSED!")
    print(f"Total time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
    print('='*70)
    
    return embeddings_by_language

In [None]:
#| echo: false

embeddings_by_language = process_all_languages(
    data, 
    tokenizer, 
    model, 
    batch_size=16,
    device=None,
    save_dir='./embeddings'
)

We can see the dimensions of the extracted embeddings for each language.

In [None]:
import pprint

# Pretty print with indentation
pprint.pprint(embeddings_by_language['albanian']['english'][1], depth=None, width=80)

{'num_words': 8,
 'sentence': '1945 - 1995 What Have We Learned?',
 'word_embeddings': [{'embedding': array([-7.79106796e-01,  1.63706824e-01,  4.19476509e-01, -2.52812207e-01,
       -4.04073417e-01, -4.29445773e-01,  1.10392213e+00, -4.96197119e-02,
        3.29879969e-01, -7.99426436e-02, -3.16635936e-01,  8.22210729e-01,
       -2.11032480e-03, -3.24477315e-01,  8.07668447e-01,  2.79502958e-01,
       -1.08839631e+00, -4.56714988e-01, -9.12314296e-01,  1.26947331e+00,
       -9.69342887e-03, -5.57337642e-01, -1.14374235e-02, -1.07016131e-01,
        9.33045805e-01,  1.67202264e-01, -2.66185820e-01,  9.95863974e-02,
        1.43899890e-02,  2.61802226e-01, -8.93424898e-02, -3.80249396e-02,
       -6.14247978e-01,  1.38727844e+00, -1.74044847e-01, -4.24568981e-01,
       -1.05630890e-01, -3.16683412e-01, -2.36605078e-01, -4.89909708e-01,
       -7.32856870e-01,  5.59178703e-02,  1.60589063e+00, -3.09409171e-01,
       -1.55224696e-01, -3.66162717e-01,  4.71679449e-01, -1.10167023e-02

: 