In [1]:
import torch
import numpy as np
import random
seed=1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# https://github.com/huggingface/notebooks/blob/master/examples/translation.ipynb

<torch._C.Generator at 0x7f125009e4f0>

In [2]:
from transformers import AutoTokenizer
import sacrebleu
import sys
import os
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, MBartForConditionalGeneration, MBart50TokenizerFast

In [3]:
from datasets import load_dataset, load_metric

metric = load_metric("sacrebleu")

In [4]:
tgt_lang_to_code = {
    "hi" : "hi_IN",
    "de" : "de_DE",
    "es" : "es_XX",
    "it" : "it_IT",
    "ru" : "ru_RU",
    "ja" : "ja_XX"
}

In [5]:
def read_file(fname):
    data = []
    with open(fname) as f:
        for line in f:
            data.append(line.strip())
    return data

def get_data(tgt_lang, domain, split):
    source = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.en")
    formal_translations = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.formal.{tgt_lang}")
    informal_translations = read_file(f"../internal_split/en-{tgt_lang}/{split}.{domain}.informal.{tgt_lang}")
    return source, formal_translations, informal_translations

In [6]:
MAX_LENGTH=64
model_name="facebook/mbart-large-50-one-to-many-mmt"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/fs/clip-scratch/sweagraw/CACHE")

In [7]:
import torch
from torch.utils.data import Dataset

def encode_split(tokenizer, source, formal_translations, informal_translations ):
    model_inputs = tokenizer(source, max_length=MAX_LENGTH, truncation=True)
    with tokenizer.as_target_tokenizer():
        model_outputs_formal = tokenizer(formal_translations, max_length=MAX_LENGTH, truncation=True)
        model_outputs_informal = tokenizer(informal_translations, max_length=MAX_LENGTH, truncation=True)
    return model_inputs, model_outputs_formal, model_outputs_informal

class FormalityData(Dataset):
    
    def __init__(self, model_inputs, model_outputs, formality_idx, tgt_lang):
        self.model_inputs = model_inputs
        self.model_outputs = model_outputs
        self.formality_idx = torch.tensor(formality_idx).unsqueeze(0).T
        self.tgt_lang = tgt_lang
    
    def __len__(self):
        return len(self.model_inputs["input_ids"])

    def __getitem__(self, idx):
        item = {k: v[idx] for k, v in self.model_inputs.items()}
        item["labels"] = self.model_outputs["input_ids"][idx]
        item["covariate_ids"] = self.formality_idx
        item["forced_bos_token_id"] = tokenizer.lang_code_to_id[tgt_lang_to_code[self.tgt_lang]]
        return item

In [8]:
domain="combined"
src_lang="en"
direction="informal"
formality_idx = 1
informality_idx = 2
train_datasets = []
tokenizer.src_lang = "en-XX"
for tgt_lang in ["hi", "ja", "de", "es"]:
    tokenizer.tgt_lang  = tgt_lang_to_code[tgt_lang]
    source, formal_translations, informal_translations = get_data(tgt_lang, domain, "train")
    model_inputs, model_outputs_formal, model_outputs_informal = encode_split(tokenizer, source, formal_translations, informal_translations )
    train_datasets.append(FormalityData(model_inputs, model_outputs_formal, formality_idx, tgt_lang))
    train_datasets.append(FormalityData(model_inputs, model_outputs_informal, informality_idx, tgt_lang))
    
train_dataset = torch.utils.data.ConcatDataset(train_datasets)

dev_datasets = []
for tgt_lang in ["hi", "ja", "de", "es"]:
    tokenizer.tgt_lang  = tgt_lang_to_code[tgt_lang]
    source, formal_translations, informal_translations = get_data(tgt_lang, domain, "dev")
    model_inputs, model_outputs_formal, model_outputs_informal = encode_split(tokenizer, source, formal_translations, informal_translations )
    dev_datasets.append(FormalityData(model_inputs, model_outputs_formal, formality_idx, tgt_lang))
    dev_datasets.append(FormalityData(model_inputs, model_outputs_informal, informality_idx, tgt_lang))
    
dev_dataset = torch.utils.data.ConcatDataset(dev_datasets)

In [9]:
from mbart_covariate import CMBartForConditionalGeneration2, MBartSeq2SeqTrainer

In [10]:
model_name="facebook/mbart-large-50-one-to-many-mmt"

In [11]:
# model = CMBartForConditionalGeneration2.from_pretrained("../models/facebook/mbart-large-50-one-to-many-mmt-finetuned-covariate-lm-en-to-xx", cache_dir="/fs/clip-scratch/sweagraw/CACHE",num_covariates=3)

model = CMBartForConditionalGeneration2.from_pretrained(model_name, cache_dir="/fs/clip-scratch/sweagraw/CACHE",num_covariates=3)

Some weights of CMBartForConditionalGeneration2 were not initialized from the model checkpoint at facebook/mbart-large-50-one-to-many-mmt and are newly initialized: ['model.covariate.weight', 'covariate.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
model.requires_grad_(False)
model.get_covariate().requires_grad_(True)
# model.get_decoder().requires_grad_(True)
model.get_decoder().layers[-8:].requires_grad_(True)
model.lm_head.requires_grad_(True)

Linear(in_features=1024, out_features=250054, bias=False)

In [13]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Training {count_params(model)} parameters...")

Training 390431744 parameters...


In [14]:
train_batch_size = 1
eval_batch_size=1
MASK_PROB = 0.0
output_dir=f"../models/{model_name}-finetuned-covariate-lm-l8-{src_lang}-to-xx"
style_mask=model.config.style_mask
args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    weight_decay=0.01,
    save_total_limit=10,
    num_train_epochs=10,
    save_steps=100,
    eval_steps=100,
    gradient_accumulation_steps=2,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
)

In [15]:
collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [16]:
import numpy as np

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [17]:

def collate_fn(batch, collator, mask_prob, mask_idx):
    batch = collator(batch)
    mask = torch.rand_like(batch["covariate_ids"], dtype=torch.float32) < mask_prob
    batch["covariate_ids"] = batch["covariate_ids"].masked_fill(mask, mask_idx).unsqueeze(0).T
    return batch

In [18]:
trainer = MBartSeq2SeqTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    data_collator=lambda batch: collate_fn(batch, collator, 0,  model.config.style_mask),
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp fp16 backend


In [19]:
import torch
torch.cuda.empty_cache()

In [20]:
trainer.train()

***** Running training *****
  Num examples = 3600
  Num Epochs = 10
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 2
  Total optimization steps = 9000


Epoch,Training Loss,Validation Loss


RuntimeError: CUDA out of memory. Tried to allocate 978.00 MiB (GPU 0; 11.93 GiB total capacity; 9.34 GiB already allocated; 885.38 MiB free; 10.32 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
trainer.save_model(output_dir)

In [None]:
trainer.evaluate()