#mBART model with frozen encoder & embeddings

In [1]:
#instal libraries
!pip install transformers
!pip install sacrebleu
!pip install sacremoses
!pip install sentencepiece

Collecting sacrebleu
  Downloading sacrebleu-2.4.0-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.3/106.3 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)
Collecting colorama (from sacrebleu)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Installing collected packages: portalocker, colorama, sacrebleu
Successfully installed colorama-0.4.6 portalocker-2.8.2 sacrebleu-2.4.0
Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sacremoses
Successfully installed sacremoses-0.1.1
Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[

In [2]:
!pip install datasets==2.13.1
!pip install evaluate
!pip install accelerate

Collecting datasets==2.13.1
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets==2.13.1)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets==2.13.1)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[

In [9]:
#delete ouput directory if re-running code
!rm -rf models

In [3]:
from torch import nn
from typing import Callable, Dict, Iterable, List, Tuple, Union
# 1 copy this functions into the run
# 2 we need to https://pytorch.org/docs/stable/generated/torch.Tensor.requires_grad_.html
def freeze_params(model: nn.Module):
    """Set req grad False for each of model.parameters()"""
    for par in model.parameters():
        par.requires_grad = False

def freeze_embeds(model):
    """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
    model_type = model.config.model_type

    if model_type == "t5":
        freeze_params(model.shared)
        #funct model shared params dir(model)
        for d in [model.encoder, model.decoder]:
          freeze_params(d.embed_positions)
          freeze_params(d.embed_tokens)
            #funct embed tokens freeze_params(?)
    elif model_type == "fsmt":
        freeze_params(model.shared)
        for d in [model.model.encoder, model.model.decoder]:
          freeze_params(d.embed_positions)
          freeze_params(d.embed_tokens)
            #funct positions embeddings x.embed_positions? dir(d)?
            #funct token embeddings  embed_tokens
    else:
        freeze_params(model.model.shared)
        #funct model shared params dir(model)
        for d in [model.model.encoder, model.model.decoder]:
            freeze_params(d.embed_positions)
            freeze_params(d.embed_tokens)

def grad_status(model: nn.Module) -> Iterable:
    return (par.requires_grad for par in model.parameters())


def any_requires_grad(model: nn.Module) -> bool:
    return any(grad_status(model))

def lmap(f: Callable, x: Iterable) -> List:
    """list(map(f, x))"""
    return list(map(f, x))

def assert_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    n_require_grad = sum(lmap(int, model_grads))
    npars = len(model_grads)
    assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad"


def assert_not_all_frozen(model):
    model_grads: List[bool] = list(grad_status(model))
    npars = len(model_grads)
    assert any(model_grads), f"none of {npars} weights require grad"

IndentationError: ignored

#Model training - 3 epochs

In [10]:
#upload data
#facebook/mbart-large-50-one-to-many-mmt
!python run_translationfreeze_no_trainer.py \
    --model_name_or_path facebook/mbart-large-50-one-to-many-mmt \
    --source_lang en_XX \
    --target_lang de_DE \
    --forced_bos_token de_DE \
    --learning_rate 3e-05 \
    --num_warmup_steps 2500 \
    --num_train_epochs 3 \
    --num_beams 5 \
    --max_source_length 250 \
    --max_target_length 250 \
    --per_device_eval_batch_size 2 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --train_file en-de.emea.1k.train.json \
    --validation_file en-de.emea.500.valid.json \
    --freeze_embeds \
    --freeze_encoder \
    --output_dir models/mbart50_freeze


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  "num_beams": 5,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "max_length": 200,
  "num_beams": 5,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "max_length": 200,
  "num_beams": 5,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "max_length": 200,
  "num_beams": 5,
  "pad_token_id": 1
}

Generate config GenerationConfig {
  "bos_token_id": 0,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "max_length": 200,
  "num_beams": 5,
  "pad_token_id": 1
}

#Translating medline file

In [12]:
#translating medline using frozen embeddings & encoder

from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
src_file = 'medline_en-de.en.txt'
output_file = "translation_frozen_embeddings_encoder_mBART.txt"
checkpoint = 'models//mbart50_freeze'
mnmt = 'de_DE'
src_lines = []
device = 'cuda'

model = MBartForConditionalGeneration.from_pretrained(checkpoint)
model = model.to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(checkpoint)


tokenizer.src_lang = "en_XX"

with open(src_file, 'r', encoding='utf-8') as src, open(output_file, 'w', encoding='utf-8') as output:
    for line in src:
        line = line.strip()
        encoded = tokenizer(line, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded,
            forced_bos_token_id=tokenizer.lang_code_to_id[mnmt]
        )
        translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        translated_text = translation[0] + '\n'
        output.write(translated_text)

#Evaluation using bleu score, chrF, comet

In [15]:
!pip install sacrebleu
!pip install unbabel-comet
import comet



In [17]:
#compute BLEU and chrF scores for fine-tuned model
!sacrebleu medline_en-de.de.txt -l en-de -i translation_frozen_embeddings_encoder_mBART.txt  -m bleu chrf

[
{
 "name": "BLEU",
 "score": 23.9,
 "signature": "nrefs:1|case:mixed|eff:no|tok:13a|smooth:exp|version:2.4.0",
 "verbose_score": "48.2/26.6/18.7/13.7 (BP = 1.000 ratio = 1.016 hyp_len = 1248 ref_len = 1228)",
 "nrefs": "1",
 "case": "mixed",
 "eff": "no",
 "tok": "13a",
 "smooth": "exp",
 "version": "2.4.0"
},
{
 "name": "chrF2",
 "score": 47.2,
 "signature": "nrefs:1|case:mixed|eff:yes|nc:6|nw:0|space:no|version:2.4.0",
 "nrefs": "1",
 "case": "mixed",
 "eff": "yes",
 "nc": "6",
 "nw": "0",
 "space": "no",
 "version": "2.4.0"
}
]
[0m

In [19]:
#comet-score
!comet-score -s medline_en-de.en.txt -t translation_frozen_embeddings_encoder_mBART.txt -r medline_en-de_reference_de.txt

Seed set to 1
Fetching 5 files:   0% 0/5 [00:00<?, ?it/s]Fetching 5 files: 100% 5/5 [00:00<00:00, 15947.92it/s]
Lightning automatically upgraded your loaded checkpoint from v1.8.3.post1 to v2.1.3. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../root/.cache/huggingface/hub/models--Unbabel--wmt22-comet-da/snapshots/371e9839ca4e213dde891b066cf3080f75ec7e72/checkpoints/model.ckpt`
Encoder model frozen.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/saving.py:177: Found keys that are not in the model state dict but in the checkpoint: ['encoder.model.embeddings.position_ids']
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100% 4/4 [00:03<00:00,  1.33it/s]
translation_frozen_embeddings_encoder_mBART.txt	Segment 0	score: 0.7211
translation_fro

#Results

bleu: 23.9

chrF: 47.2

comet: 0.716