In [1]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import sacrebleu
import sys
import os
import argparse
sys.path.append("/fs/clip-controllablemt/IWSLT2022/notebooks/")
from mbart_covariate import CMBartForConditionalGeneration
import torch

tgt_lang_to_code = {
    "hi" : "hi_IN",
    "de" : "de_DE",
    "es" : "es_XX",
    "it" : "it_IT",
    "ru" : "ru_RU",
    "ja" : "ja_XX"
}

direction_to_id = {
    "formal":1,
    "informal":2
}

In [2]:
def read_file(fname, n=None):
    data = []
    i = 0
    with open(fname) as f:
        for line in f:
            data.append(line.strip())
            i+=1
            if n is not None and i > n:
                break
    return data
    
def get_data( tgt_lang, domain, split, data_dir="internal_split"):
    source = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.en")
    formal_translations = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.formal.{tgt_lang}")
    informal_translations = read_file(f"{data_dir}/en-{tgt_lang}/{split}.{domain}.informal.{tgt_lang}")
    return source, formal_translations, informal_translations

In [3]:
def translate_text(text, tgt_lang, model, tokenizer,  covariate_index=None, strategy="greedy"):
    model_inputs = tokenizer(text, return_tensors="pt", padding=True)
    kwargs = {}
    if covariate_index is not None:
        kwargs["covariate_ids"] = torch.tensor([covariate_index]*len(text))
    if strategy == "greedy":
        generated_tokens = model.generate(
            **model_inputs,
            **kwargs
        )
    else:
        generated_tokens = model.generate(
            **model_inputs,
            forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_to_code[tgt_lang]],
            max_length=50, 
            num_beams=5, 
            num_return_sequences=5, 
            early_stopping=True,
            **kwargs
        )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

In [4]:
model_dir="../models/vastai/covariate"

In [5]:
eval_direction ="formal"

In [6]:
model = CMBartForConditionalGeneration.from_pretrained(model_dir, cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [7]:
model.eval()

CMBartForConditionalGeneration(
  (model): CMBartModel(
    (covariate): Embedding(3, 1024)
    (shared): Embedding(250054, 1024, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): Embedding(250054, 1024, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): MBartEncoderLayer(
          (self_attn): MBartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1

In [8]:
tgt_lang="de"

In [9]:
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", tgt_lang=tgt_lang_to_code[tgt_lang], cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [10]:
source = read_file("../internal_split/en-es/dev.combined.en")

In [13]:
translate_text([source[4]], tgt_lang, model, tokenizer, 1)[0] 

1 tensor([[[-4.7424, -5.0231, -0.5793,  ..., -1.6756, -1.0476, -4.7395]],

        [[-4.7424, -5.0231, -0.5793,  ..., -1.6756, -1.0476, -4.7395]],

        [[-4.7424, -5.0231, -0.5793,  ..., -1.6756, -1.0476, -4.7395]],

        [[-4.7424, -5.0231, -0.5793,  ..., -1.6756, -1.0476, -4.7395]],

        [[-4.7424, -5.0231, -0.5793,  ..., -1.6756, -1.0476, -4.7395]]])
2 tensor([[[ -5.5482,  -6.6353,  -4.2972,  ..., -10.4037, -13.9227,  -5.8926]],

        [[ -1.9137,  -2.6081,   0.8267,  ..., -16.4556,  -4.1660,  -2.2295]],

        [[ -2.2839,  -3.3162,   1.8673,  ..., -13.0749,  -7.8504,  -2.5264]],

        [[ -3.0832,  -3.4401,   1.8120,  ..., -10.6313,  -7.5578,  -3.1886]],

        [[ -4.3922,  -5.5456,  -4.3989,  ..., -11.8608,  -7.7589,  -4.8864]]])
3 tensor([[[ -6.0276,  -7.2264,  -8.0568,  ..., -17.3852, -14.8795,  -6.5836]],

        [[ -1.2559,  -1.7679,   2.6213,  ...,  -9.9650,  -3.1344,  -1.4084]],

        [[ -1.1017,  -1.7379,   3.3292,  ...,  -8.3625,  -5.9727,  -1.3238]]

23 tensor([[[ -1.6418,  -2.4584,   7.9762,  ...,  -6.0222,  -2.7805,  -2.2473]],

        [[ -0.0429,  -0.1661,  17.7253,  ...,  -0.0416,   0.1301,  -0.0636]],

        [[ -1.5832,  -2.4155,   7.6438,  ...,  -5.6599,  -2.6128,  -2.1941]],

        [[ -3.0845,  -3.7446,   4.0761,  ..., -15.1694,  -9.4091,  -3.1554]],

        [[ -2.3498,  -3.1403,   2.5075,  ..., -14.4866,  -9.7566,  -2.5350]]])
24 tensor([[[ -3.0105,  -3.6368,   4.1565,  ..., -15.3424, -10.0046,  -3.0647]],

        [[ -1.5074,  -1.8657,   7.7108,  ..., -10.8266,  -5.3680,  -1.5737]],

        [[ -2.9646,  -3.6254,   4.3295,  ..., -15.2370,  -9.8587,  -3.0174]],

        [[ -2.2658,  -3.0358,   2.8230,  ..., -14.6023, -10.1063,  -2.4420]],

        [[ -0.9081,  -1.2261,   8.2119,  ...,  -8.3735,  -4.1668,  -0.9815]]])
25 tensor([[[ -1.5382,  -1.8889,   8.0689,  ..., -10.8068,  -5.4506,  -1.5931]],

        [[ -0.0327,  -0.2110,  17.9628,  ...,  -1.2807,   1.4624,  -0.1546]],

        [[ -1.6910,  -2.0405,   7.4916,  ..

'Jo, to dává smysl, slyšel jste o tom bunkru za 10 milionů dolarů, co má?'

In [15]:
translate_text([source[4]], tgt_lang, model, tokenizer, 2)[0] 

1 tensor([[[-4.7464, -5.0151, -0.5664,  ..., -1.6514, -1.1374, -4.7402]],

        [[-4.7464, -5.0151, -0.5664,  ..., -1.6514, -1.1374, -4.7402]],

        [[-4.7464, -5.0151, -0.5664,  ..., -1.6514, -1.1374, -4.7402]],

        [[-4.7464, -5.0151, -0.5664,  ..., -1.6514, -1.1374, -4.7402]],

        [[-4.7464, -5.0151, -0.5664,  ..., -1.6514, -1.1374, -4.7402]]])
2 tensor([[[ -5.6712,  -6.7244,  -4.5681,  ..., -10.6964, -15.3725,  -5.9849]],

        [[ -1.7648,  -2.4561,   1.1433,  ..., -16.2529,  -4.8401,  -2.0909]],

        [[ -4.5754,  -5.7117,  -4.6680,  ..., -12.9725,  -9.6723,  -5.0037]],

        [[ -2.1776,  -3.3012,   2.3210,  ..., -13.7446,  -8.4125,  -2.4810]],

        [[ -3.1118,  -4.2364,   2.9570,  ...,  -6.4048, -10.9913,  -3.1200]]])
3 tensor([[[ -6.0094,  -7.2505,  -8.2941,  ..., -18.0351, -16.3461,  -6.5825]],

        [[ -1.0057,  -1.5278,   2.9936,  ...,  -9.5703,  -3.3526,  -1.1647]],

        [[ -0.9703,  -1.6195,   3.5106,  ...,  -8.4079,  -6.2536,  -1.2160]]

22 tensor([[[-1.5584e+00, -2.1660e+00,  6.3692e+00,  ..., -6.7219e+00,
          -2.9220e+00, -1.8416e+00]],

        [[-1.4917e+00, -2.2727e+00,  8.8907e+00,  ..., -3.5433e+00,
          -1.9351e+00, -2.0552e+00]],

        [[-1.4591e+00, -2.0585e+00,  6.1580e+00,  ..., -6.4600e+00,
          -2.7980e+00, -1.7495e+00]],

        [[ 4.5345e-02, -1.2175e-01,  1.8114e+01,  ...,  4.2065e-02,
           2.0281e-01, -7.5166e-03]],

        [[ 2.4338e-01, -2.6677e-01,  6.1952e+00,  ..., -1.8593e+00,
           2.3992e+00, -6.6405e-02]]])
23 tensor([[[-1.3857e+00, -2.2391e+00,  9.2652e+00,  ..., -5.5842e+00,
          -3.5953e+00, -2.0152e+00]],

        [[-2.8584e+00, -3.4720e+00,  4.1527e+00,  ..., -1.4321e+01,
          -1.0640e+01, -2.9010e+00]],

        [[-1.5109e+00, -2.0991e+00,  5.6305e+00,  ..., -6.5982e+00,
          -2.9278e+00, -1.7922e+00]],

        [[-1.3797e+00, -2.2439e+00,  8.9240e+00,  ..., -5.2480e+00,
          -3.6052e+00, -2.0049e+00]],

        [[ 1.2032e-02, -2.1490e

'Jo, to dává smysl, slyšel jsi o tom bunkru za 10 milionů dolarů, co má?'