# MetricX 

This notebook is result of me looking inside MT5 and MetricX modeling to learn these models for reimplementation.


1. MetricX23 and MetricX24 are both based on MT5 models
2. Both have identical model alchitecture. Both concatente inputs into one long sequence. While both work in ref-less and ref-based modes.
3. But the way inputs are prepared (or joined) are different. See following:

|        | QE | REF-based |
| ------ | --- | -------- |
| MetricX23 | `candiate: $TEXT source: $TEXT`  | `candidate: $TEXT reference: $TEXT` |
| MetricX24 | `source: $TEXT candidate: $TEXT` | `source: $TEXT candidate: $TEXT reference: $TEXT`  |


In [1]:
import sys
from pathlib import Path
import os
import torch
import transformers

# print widely, dont wrap lines
torch.set_printoptions(linewidth=240, precision=4)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
print(DEVICE)

# Add the current directory to the path so we can import the model
MYDIR = Path.cwd().resolve()
print(MYDIR)
if str(MYDIR) not in sys.path:
    sys.path.append(str(MYDIR))

from metricx_model import MT5ForRegression

  from .autonotebook import tqdm as notebook_tqdm


cuda
/mnt/home/tg/work/repos/tahoma/scripts/debug


In [2]:

tokenizer_id = "google/mt5-base"
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_id)
print("tokenizer loaded")

#model_id = 'google/metricx-23-large-v2p0'
# model_id = 'google/metricx-24-hybrid-large-v2p6'
# model = MT5ForRegression.from_pretrained(model_id)
# print("model loaded")

data = """Good morning	good morning.	Good Morning
morning	good morning.	Good Morning!
Evening	good evening.	Good evening!
Good night	good night.	Good night!"""
data = [x.split('\t') for x in data.split('\n')]


# id, is_qe
model_ids = [
    ('google/metricx-24-hybrid-large-v2p6', 0),
    ('google/metricx-24-hybrid-large-v2p6', 1),
    ('google/metricx-23-large-v2p0', 0),
    ('google/metricx-23-qe-large-v2p0', 1),
]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


tokenizer loaded


In [3]:
def make_input(model_id, source=None, candidate=None, reference=None, is_qe=True):
    candidate = candidate and candidate.strip()
    assert candidate, "candidate is required"
    source = source and source.strip()
    reference = reference and reference.strip()

    if 'metricx-24' in model_id:
        assert source, f"source is required for {model_id}"
        text = f"source: {source} candidate: {candidate}"
        if not is_qe:
            assert reference, f"reference is required for {model_id} ref-based"
            text += f" reference: {reference}"
        return text

    elif 'metricx-23' in model_id:
        text = f"candidate: {candidate}"
        if is_qe:
            assert source, f"source is required for {model_id}  QE"
            text += f" source: {source}"
        else:
            assert reference, "reference is required for year 23 without QE"
            text += f" reference: {reference}"
    else:
        raise ValueError(f"Invalid model_id: {model_id}")
    return text

def tokenize_input(text, device=DEVICE):
    res = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
    # remove eos
    res['input_ids'] = res['input_ids'][:, :-1]
    res['attention_mask'] = res['attention_mask'][:, :-1]
    res = {k: v.to(device) for k, v in res.items()}
    return res


------
# Intercept


Intercepting all forward calls for the sake of implementing and comparing the activations.


Just picking one model and one input for simplicity

In [7]:
model_id, is_qe = model_ids[0]
s, t, r = data[0]
try:
    # if model is already loaded, delete it to free up memory
    del model
except:
    pass
model = MT5ForRegression.from_pretrained(model_id)
model = model.to(DEVICE)

text = make_input(model_id, source=s, candidate=t, reference=r, is_qe=is_qe)
inp = tokenize_input(text)
out = model(**inp).predictions[0].item()
print(out)


log_file = Path('intercept.log')
with log_file.open('w') as out:
    out.write(f"Intercepting {model_id} {is_qe and 'QE' or 'Ref-based'}\n")
    out.write(f"text:{text}\n")
    for k,v in inp.items():
        out.write(f"{k}: {v}\n")
    out.write(f"{'='*80}\n")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


0.3308029770851135


In [8]:
#type(model.encoder)
import functools
#from transformers.models.mt5.modeling_mt5 import MT5Stack, MT5Block

def intercepted_forward(self, *args, **kwargs):
    non_nulls = {k for k, v in kwargs.items() if v is not None}
    _name_ = getattr(self, '_name_', 'unknown')
    tag = f'{type(self).__qualname__} {_name_}'
    print(f"Intercept {tag}\t posargs: {len(args)} kwargs: {non_nulls}")
    tname = None
    tensor = None
    # guess the most important tensor to log: first posarg or kwarg named 'input_ids'
    if 'input_ids' in kwargs:
        tensor = kwargs['input_ids']
        tname = 'input_ids'
    elif 'input' in kwargs:
        tensor = kwargs['input']
        tname = 'input'
    elif args:
        tensor = args[0]
        tname = 'arg0'
    else:
        print("Could not detect")
    if type(self).__name__ in ('Linear', 'Dropout'):  # skip some basic modules to avoid spam
        tensor = None
    if tensor is not None:
        tag += f' {tname} Shape:[{tensor.shape}] AbsSum: {tensor.abs().sum().item()} max: {tensor.max().item()} min: {tensor.min().item()}'
        log_msg = f'\n\n{tag}\n    ' + str(tensor)
        with log_file.open('a') as f:
            f.write(log_msg)

    return self._forward_orig(*args, **kwargs)

def intercept_all_forwards(model):
    for name, module in model.named_modules():
        if hasattr(module, 'forward'):
            module._name_ = name
            if not hasattr(module, '_forward_orig'):
                # avoid recursive interception
                module._forward_orig = module.forward
            module.forward = functools.partial(intercepted_forward, module)

intercept_all_forwards(model.encoder)

if True:  # see below
    out = model(**inp)
    out = out.predictions[0].item()
    print(out)
    with log_file.open('a') as f:
        f.write(f"Final output: {out}\n")

Intercept MT5Stack 	 posargs: 0 kwargs: {'return_dict', 'attention_mask', 'input_ids'}
Intercept Embedding embed_tokens	 posargs: 1 kwargs: set()
Intercept Dropout dropout	 posargs: 1 kwargs: set()
Intercept MT5Block block.0	 posargs: 1 kwargs: {'cache_position', 'output_attentions', 'use_cache', 'return_dict', 'attention_mask'}
Intercept MT5LayerSelfAttention block.0.layer.0	 posargs: 1 kwargs: {'output_attentions', 'use_cache', 'attention_mask', 'cache_position'}
Intercept MT5LayerNorm block.0.layer.0.layer_norm	 posargs: 1 kwargs: set()
Intercept MT5Attention block.0.layer.0.SelfAttention	 posargs: 1 kwargs: {'use_cache', 'mask', 'output_attentions', 'cache_position'}
Intercept Linear block.0.layer.0.SelfAttention.q	 posargs: 1 kwargs: set()
Intercept Linear block.0.layer.0.SelfAttention.k	 posargs: 1 kwargs: set()
Intercept Linear block.0.layer.0.SelfAttention.v	 posargs: 1 kwargs: set()
Intercept Embedding block.0.layer.0.SelfAttention.relative_attention_bias	 posargs: 1 kwargs: s

### Intercept MT5Attention

I didnt get this right. Going to dig deeper via monkey patching to see where my C++ code divereged from this Python code.

In [None]:
import torch
import torch.nn as nn


log_file = Path('intercept.attn.log')

def forward_debug(
    self,
    hidden_states,
    mask=None,
    key_value_states=None,
    position_bias=None,
    past_key_value=None,
    layer_head_mask=None,
    query_length=None,
    use_cache=False,
    output_attentions=False,
    cache_position=None,
):
    """
    Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
    """
    # Input is (batch_size, seq_length, dim)
    # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
    batch_size, seq_length = hidden_states.shape[:2]

    # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
    is_cross_attention = key_value_states is not None

    query_states = self.q(hidden_states)
    query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

    if past_key_value is not None:
        is_updated = past_key_value.is_updated.get(self.layer_idx)
        if is_cross_attention:
            # after the first generated id, we can subsequently re-use all key/value_states from cache
            curr_past_key_value = past_key_value.cross_attention_cache
        else:
            curr_past_key_value = past_key_value.self_attention_cache

    current_states = key_value_states if is_cross_attention else hidden_states
    if is_cross_attention and past_key_value is not None and is_updated:
        # reuse k,v, cross_attentions
        key_states = curr_past_key_value.key_cache[self.layer_idx]
        value_states = curr_past_key_value.value_cache[self.layer_idx]
    else:
        key_states = self.k(current_states)
        value_states = self.v(current_states)
        key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        if past_key_value is not None:
            # save all key/value_states to cache to be re-used for fast auto-regressive generation
            cache_position = cache_position if not is_cross_attention else None
            key_states, value_states = curr_past_key_value.update(
                key_states, value_states, self.layer_idx, {"cache_position": cache_position}
            )
            # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
            if is_cross_attention:
                past_key_value.is_updated[self.layer_idx] = True

    # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
    scores = torch.matmul(query_states, key_states.transpose(3, 2))

    if position_bias is None:
        key_length = key_states.shape[-2]
        # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
        real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
        if not self.has_relative_attention_bias:
            position_bias = torch.zeros(
                (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
            )
            if self.gradient_checkpointing and self.training:
                position_bias.requires_grad = True
        else:
            position_bias = self.compute_bias(
                real_seq_length, key_length, device=scores.device, cache_position=cache_position
            )
            position_bias = position_bias[:, :, -seq_length:, :]

        if mask is not None:
            causal_mask = mask[:, :, :, : key_states.shape[-2]]
            print(">> Causal mask\n", causal_mask)
            position_bias = position_bias + causal_mask

    if self.pruned_heads:
        mask = torch.ones(position_bias.shape[1])
        mask[list(self.pruned_heads)] = 0
        position_bias_masked = position_bias[:, mask.bool()]
    else:
        position_bias_masked = position_bias

    with log_file.open('a') as f:
        f.write(f"\nMT5Attention Position Bias\n" + str(position_bias_masked))

    scores += position_bias_masked

    # (batch_size, n_heads, seq_length, key_length)
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
    attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

    # Mask heads if we want to
    if layer_head_mask is not None:
        attn_weights = attn_weights * layer_head_mask
    with log_file.open('a') as f:
        f.write(f"\nMT5Attention attn_weights\n" + str(attn_weights))

    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.view(batch_size, -1, self.inner_dim)
    with log_file.open('a') as f:
        f.write(f"\nMT5Attention attn_output before proj\n" + str(attn_output))
    attn_output = self.o(attn_output)
    with log_file.open('a') as f:
        f.write(f"\nMT5Attention attn_output after proj\n" + str(attn_output))

    outputs = (attn_output, past_key_value, position_bias)

    if output_attentions:
        outputs = outputs + (attn_weights,)
    return outputs


# monkey patch the forward method of the attention layer
from transformers.models.mt5.modeling_mt5 import MT5Attention
MT5Attention.forward = forward_debug
if False:
    out = model(**inp)
    out = out.predictions[0].item()
    print(out)
    log_file.open('a').write(f"\n\nFinal output: {out}\n")


 ------

## Scores From All Models

In [4]:
for model_id, is_qe in model_ids:
    model = MT5ForRegression.from_pretrained(model_id)
    model = model.to(DEVICE)
    texts = []
    print(f"==== {model_id} {is_qe and 'QE' or 'Ref-based'}====")
    for s, m, r in data:
        text = make_input(model_id, source=s, candidate=m, reference=r, is_qe=is_qe)
        texts.append(text)

        inp = tokenize_input(text)
        #print(inp)
        out = model(**inp)
        score = out.predictions[0].item()
        print(f'{score:.4f}\t{text}')
    del model



Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


==== google/metricx-24-hybrid-large-v2p6 Ref-based====


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


0.3308	source: Good morning candidate: good morning. reference: Good Morning
0.0000	source: morning candidate: good morning. reference: Good Morning!
0.0674	source: Evening candidate: good evening. reference: Good evening!
0.2105	source: Good night candidate: good night. reference: Good night!
==== google/metricx-24-hybrid-large-v2p6 QE====
0.4040	source: Good morning candidate: good morning.
0.2129	source: morning candidate: good morning.
0.6364	source: Evening candidate: good evening.
0.3974	source: Good night candidate: good night.
==== google/metricx-23-large-v2p0 Ref-based====
0.1915	candidate: good morning. reference: Good Morning
0.1159	candidate: good morning. reference: Good Morning!
0.1934	candidate: good evening. reference: Good evening!
0.1730	candidate: good night. reference: Good night!
==== google/metricx-23-qe-large-v2p0 QE====
0.7389	candidate: good morning. source: Good morning
4.9070	candidate: good morning. source: morning
0.6479	candidate: good evening. source: Eve

-----

### MT5

This didnt work as expected or my assumptions of how the model works is wrong

In [12]:
model_id = "google/mt5-base"
tokenizer = transformers.T5Tokenizer.from_pretrained(model_id)
pipe = transformers.pipeline("translation", model=model_id, tokenizer=tokenizer)
model = pipe.model

Device set to use cuda:0


In [13]:
inputs = [
    "translate English to German: That is good.",
    "cola sentence: The course is jumping well.",
    "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
    "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi"
]
outputs = pipe(inputs, max_length=64)
for out in outputs:
    print(out)


{'translation_text': '<extra_id_0> good.'}
{'translation_text': '<extra_id_0> -'}
{'translation_text': '<extra_id_0> ssb'}
{'translation_text': '<extra_id_0>, a few weeks'}
