This is a demo of sampling-based minimum Bayes risk decoding for NMT [(Eikema and Aziz, 2020)](https://www.aclweb.org/anthology/2020.coling-main.398/) for educational purposes. For a scalable implementation check [mbr_nmt](https://github.com/Roxot/mbr-nmt) and [(Eikema and Aziz, 2021)](https://arxiv.org/abs/2108.04718).

In [1]:
from demo import SampleFromNMT, SampleFromBuffer, Utility, MBR, load_de_en

Let's start with a quick recap of concepts.

# Neural machine translation

Given an input $x$, a trained NMT model predicts a conditional distribution $Y|X=x, \theta$ over all possible translations of $x$. 

The sample space $\mathcal Y$ of the model is the set of all sequences of the form $(y_1, \ldots, y_N)$ where $y_n$ belongs to a vocabulary of known target-language symbols, $N \ge 1$ is the sequence length, and $y_N$ is a special end-of-sequence (EOS) symbol.

An outcome $y \in \mathcal Y$ is assigned probability mass:

\begin{align}
P_{Y|X}(y|x,\theta) &= \prod_{n=1}^{N} \mathrm{Cat}(y_n|f(x, y_{<n}; \theta))
\end{align}

where $y_{<n}$ is the sequence of tokens before the $n$th token, and $f(\cdot; \theta)$ is a neural network architecture with parameters $\theta$.


# Maximum-a-posterior decoding 

MAP decoding picks the most probable translation.
\begin{align}
y^{\mathrm{mode}} &= \arg\max_{h \in \mathcal Y} ~ P_{Y|X}(h|x, \theta)
\end{align}

As the space $\mathcal Y$ is unbounded and the NMT model makes no Markov assumptions, this search is intractable. In practice, we approximate the decision rule searching through a beam of probable translations. 

\begin{align}
y^{\mathrm{beam}} &= \arg\max_{h \in \mathrm{beam}(x)} ~  P_{Y|X}(h|x, \theta)
\end{align}

# Minimum Bayes risk decoding

MBR decoding picks the translation that has highest expected utility:

\begin{align}
y^{\text{mbr}} &= \arg\max_{h \in \mathcal Y} ~ \mathbb E[u(Y, h; x)|x, \theta] \\
&= \arg\max_{h \in \mathcal Y} ~ \sum_{y \in \mathcal Y} u(y, h; x) P_{Y|X}(y|x, \theta)
\end{align}

where a utility function $u(y, h; x)$ quantifies the benefit of choosing $h$ as the translation of $x$, when $y$ is the correct (or preferred) translation.

There are two sources of intractability in MBR decoding. First,  just like in MAP decoding, the search space (i.e., the sample space $\mathcal Y$) is unbounded. Second, for any given candidate translation $h$, the expected utility $\mathbb E[u(Y, h; x)|x, \theta]$ is intractable to compute. 

[Eikema and Aziz (2020)](https://www.aclweb.org/anthology/2020.coling-main.398/) propose to  

1. approximate the hypothesis space by a tractable subset of hypotheses $\mathcal H(x)$ obtained by sampling from the model;
2. approximate the expected value using Monte Carlo (MC).

The decision rule becomes

\begin{align}
y^{\text{smbr}} &= \arg\max_{h \in \mathcal H(x)} ~ \frac{1}{S} \sum_{s=1}^S u(y^{(s)}, h; x)
\end{align}

where $y^{(s)} \sim Y|X=x, \theta$ is a sample from the NMT model (samples can be drawn efficiently via ancestral sampling).

For much more on sampling-based MBR see also [Eikema and Aziz (2021)](https://arxiv.org/abs/2108.04718).

# Load pre-trained models 

Load pre-processing pipelines and fairseq models. This will take a moment.

Make sure you've downloaded the models (run `bash download-data.sh`).

In [2]:
models = load_de_en(np_seed=10, torch_seed=10)

Here, `x` is German and `y` is English.

In [3]:
example_x = 'Es war noch nie leicht, ein rationales Gespräch über den Wert von Gold zu führen.'
example_y = 'It has never been easy to have a rational conversation about the value of gold.'

In [4]:
models.pipeline_x.decode(models.pipeline_x.encode(example_x))

'Es war noch nie leicht, ein rationales Gespräch über den Wert von Gold zu führen.'

In [5]:
models.pipeline_y.decode(models.pipeline_y.encode(example_y))

'It has never been easy to have a rational conversation about the value of gold.'

# Demo


## Approximate MAP decoding

In [6]:
models.x2y.beam_search(example_x)['output'][0]

'It has never been easy to talk rationally about the value of gold.'

## Approximate MBR decoding

Let's start by creating a hypothesis space. For example, we can use the unique translations found in a large sample. 

In [7]:
hyp_space = list(set(models.x2y.ancestral_sampling(example_x, num_samples=100)['output']))
len(hyp_space)

84

As source sentences get longer or away from the training domain, it's common to find very few duplicates in samples from NMT. Note that duplicates in the hypothesis space do not affect the results, they only waste some computation, that's why we keep only the unique translations. 

Now we need to choose a utility function. For most of this demo we will assign utility using ChrF (but do check the last section for an example using COMET, an NN-based metric).

In [8]:
from sacrebleu import sentence_chrf

In [9]:
class ChrF(Utility):
    
    def __call__(self, src: str, hyp: str, ref: str) -> float:
        return sentence_chrf(hyp, ref).score  # note that chrf does not make use of the source sentence

chrf = ChrF()    

In [10]:
chrf('This is cool!', 'Das ist cool!', 'Cool!')

0.40623147802266224

In [11]:
chrf.batch('This is cool!', 'Das ist cool!', ['Cool!', 'Das ist cool!'])

array([0.40623148, 1.        ])

MBR decoding requires approximating the expected utility of each candidate in the hypothesis space, the mechanism of choice is MC estimation. Thus, MBR decoding requires access to samples from the model. 

Ideally, for each hypothesis, we would draw samples completely independently:

In [12]:
nmt_sampler = SampleFromNMT(models.x2y, sample_size=30)

In [13]:
nmt_sampler(example_x)

['It has never been easy to talk rationally about the value of gold.',
 'It has never been easy to talk rationally about the value of gold.',
 'It has never been easy to discuss the value of gold rationally.',
 'It has never been easy to hold a rational conversation about the value of gold.',
 "It has never been easy to engage in a rational conversation about gold's value.",
 'It was never easy to have a rational conversation about the value of gold.',
 'It was never easy to talk rationally about the value of gold.',
 'It has never been easy to discuss the value of gold in a rational manner.',
 'It has never been easy to engage rationally in a conversation about the value of gold.',
 "It has never been easy to have a rational conversation on gold's value.",
 'It has never been easy to enter a rational conversation about the value of gold.',
 'It has never been easy to enter into a rational conversation on the value of gold.',
 'It has never been simple to talk rationally about the valu

Depending on the size of your sample space, this will quickly become too costly. A good alternative is to obtain a large collection of samples (e.g., 100 to 1000) and draw samples with replacement from this collection:

In [14]:
buffered_sampler = SampleFromBuffer(
    {example_x: models.x2y.ancestral_sampling(example_x, num_samples=100)['output']}, 
    sample_size=30
)

In [15]:
buffered_sampler(example_x)

['It has never been easy to have a rational discussion about the value of gold.',
 'It has never been easy to enter into a rational discussion about the value of gold.',
 "It had never been easy to conduct a rational conversation about gold's value.",
 'It was never easy to negotiate rationally about the value of gold.',
 "It's never been easy to make rational talks about the value of gold.",
 'It has never been easy to hold a rational conversation about the value of gold.',
 'It has never before been easy to discuss the value of gold rationally.',
 'It has never been easy to ask about the value of gold, as a rational conversation is simply not in itself acknowledging it.',
 'It has never been easy to command a rational conversation about the value of gold.',
 'It has never been easy to negotiate on the value of gold rationally.',
 "It has never been easy to listen rationally to gold's value.",
 'It has never been easy to discuss the value of gold in a rational way.',
 "However, it was

Next, we estimate expected utilities and rank our candidates.

In [16]:
mbr = MBR(chrf, buffered_sampler)  # This version of MBR is what Eikema and Aziz (2021) call MBR N-by-S

In [17]:
y_mbr = mbr.decode(example_x, hyp_space)
y_mbr

'It has never been easy to make a rational conversation about the value of gold.'

In [18]:
mbr.mu(example_x, y_mbr)

0.677204700537251

You can also use inspect the expected utility of all candidates (note that estimates of expected utility are random variables, thus some variability across runs is expected, esp for small sample sizes):

In [19]:
for h, mu in sorted(zip(hyp_space, mbr.mus(example_x, hyp_space)), key=lambda pair: pair[1], reverse=True):
    print(f"{mu:.4f}\t{h}")

0.7119	It has never been easy to carry out rational conversation about the value of gold.
0.7051	It has never been easy to hold a rational conversation about the value of gold.
0.6855	It has never been easy to enter a rational conversation about the value of gold.
0.6703	It has never been easy to make a rational conversation about the value of gold.
0.6673	It has never been easy to engage in rational conversations about the value of gold.
0.6662	Getting a rational conversation about the value of gold has never been easy.
0.6653	It has never been easy to have a rational conversation about the value of gold.
0.6558	It has never been easy to engage in a rational conversation about the value of gold.
0.6506	Raising a rational conversation about the value of gold has never been easy.
0.6490	It has never been easy to proceed with a rational conversation about the value of gold.
0.6456	It has never been easy to make a rational discussion about the value of gold.
0.6456	It has never been easy 

## COMET as Utility

The utility function is anything we trust for assessing the merits of a translation candidate. In this part of the demo, we will employ a modern NN-based metric: [COMET](https://github.com/Unbabel/COMET). Some of COMET's dependencies are in conflict with our trained NMT models (e.g., fairseq version), thus we have prepared a flask that abstracts away from all the mess behind maintaining different versions of python packages: git clone Probabll's [mteval-flask](https://github.com/probabll/mteval-flask) and follow the instructions there to start an automatic evaluation server. 

The reason we pick COMET is not only that it's modern, unlike most MT evaluation metrics, COMET makes use of the source sentence, which we think is the right way towards evaluating adequacy.

In [20]:
import requests

In [32]:
class COMET(Utility):    
    
    def __call__(self, src: str, hyp: str, ref: str) -> float:
        jobs = {'hyps': [hyp], 'refs': [ref], 'srcs': [src]}
        # Unlike ChrF, COMET actually makes use of the source sentence! :D
        results = requests.post("http://localhost:4000/score", json=jobs, headers={'Content-Type': 'application/json'}).json()
        return results['comet'][0]
    
    def batch(self, src: str, hyp: str, refs: list):
        jobs = {'hyps': [hyp] * len(refs), 'refs': refs, 'srcs': [src] * len(refs)}
        results = requests.post("http://localhost:4000/score", json=jobs, headers={'Content-Type': 'application/json'}).json()
        return results['comet']

In [33]:
comet = COMET()

In [36]:
mbr_comet = MBR(comet, buffered_sampler)  # This version of MBR is what Eikema and Aziz (2021) call MBR N-by-S

In [37]:
for h, mu in sorted(zip(hyp_space, mbr_comet.mus(example_x, hyp_space)), key=lambda pair: pair[1], reverse=True):
    print(f"{mu:.4f}\t{h}")

0.8370	It was never easy to conduct a rational conversation about the value of gold.
0.8291	It was never easy to have a rational discussion about the value of gold.
0.8069	You have never been easy to start a rational discussion about the value of gold.
0.7906	It has never been easy to carry out rational conversation about the value of gold.
0.7861	It has not been easy to hold a rational discussion about the value of gold.
0.7856	It has never been easy to have a rational conversation about the value of gold.
0.7838	It has never been easy to engage in rational conversations about the value of gold.
0.7836	It has never been easy to make a rational conversation about the value of gold.
0.7776	It has never been easy to enter a rational conversation about the value of gold.
0.7589	It has never been easy to make a rational discussion about the value of gold.
0.7583	It has never been easy to conduct a rational discussion about the value of gold.
0.7473	It has never been easy to hold a rational