# Imports and setup

In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, pipeline
import evaluate
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
import numpy as np

from data_handling import normalize_sentence

device = "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_sentences(filename="data/train/sentence.gold"):
    with open(filename, "r", encoding="UTF-8") as file:
        sentences = file.readlines()
    for sentence in sentences:
        yield normalize_sentence(sentence)


def load_own_dataset(tokenizer, input_file="data/train/sentence.input"):
    with open(input_file, "r", encoding="UTF-8") as file:
        input_sentences = file.readlines()

    input_tokenized = [tokenizer.encode(normalize_sentence(sentence)) for sentence in input_sentences]

    return input_tokenized


In [3]:
model_name = "t5-base"
model_name = "Helsinki-NLP/opus-mt-en-cs"


retrain_tokenizer = False
tokenizer_name = "Helsinki-NLP/opus-mt-en-cs"
#tokenizer_name = "t5-base"


base_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, return_tensors="pt")
base_seq2seq = AutoModelForSeq2SeqLM.from_pretrained(model_name)


# Tokenizer Training

In [4]:
training_corpus = load_sentences()

if retrain_tokenizer:
    tokenizer = base_tokenizer.train_new_from_iterator(training_corpus, 52000)
else:
    tokenizer = base_tokenizer

base_seq2seq.config.eos_token_id = tokenizer.eos_token_id
base_seq2seq.config.pad_token_id = tokenizer.pad_token_id

In [5]:
sample_pangram = "Ó, náhlý déšť již zvířil prach a čilá laň teď běží s houfcem gazel Ualdewara k exkluzívním úkrytům!"

print("Pangram with base tokenizer: ")
encodings_base = base_tokenizer.encode(sample_pangram)
print(base_tokenizer.decode(encodings_base))
print(f"{len(encodings_base)} tokens")

print("\nFine-tuned tokenizer:")
encodings_used = tokenizer.encode(sample_pangram)
print(tokenizer.decode(encodings_used))
print(f"{len(encodings_used)} tokens")


Pangram with base tokenizer: 
Ó, náhlý déšť již zvířil prach a čilá laň teď běží s houfcem gazel Ualdewara k exkluzívním úkrytům!</s>
59 tokens

Fine-tuned tokenizer:
Ó, náhlý déšť již zvířil prach a čilá laň teď běží s houfcem gazel Ualdewara k exkluzívním úkrytům!</s>
59 tokens


# Seq2Seq Model Training

## GECCC Dataset 

## Own dataset

In [24]:
dataset_dict = load_from_disk("data/czech_news_errors")
dataset_dict

DatasetDict({
    train: Dataset({
        features: ['target', 'source'],
        num_rows: 1483008
    })
    test: Dataset({
        features: ['target', 'source'],
        num_rows: 317788
    })
    dev: Dataset({
        features: ['target', 'source'],
        num_rows: 317788
    })
})

In [7]:
max_length = 128

def preprocess_data(dataset):
    normalized_inputs = [normalize_sentence(sentence) for sentence in dataset["source"]]
    normalized_gold = [normalize_sentence(sentence) for sentence in dataset["target"]]

    model_inputs = tokenizer(normalized_inputs, text_target=normalized_gold, max_length=max_length, truncation=True, padding="max_length")
    return model_inputs

In [8]:
tokenized_datasets = dataset_dict.map(
    preprocess_data,
    batched=True
)

In [25]:
"""
mt_metrics = evaluate.combine(
    ["bleu", "chrf"], force_prefix=True
)


def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    outputs = mt_metrics.compute(predictions=predictions,
                             references=references)

    return outputs"""




metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

training_args = Seq2SeqTrainingArguments(
    output_dir='./model',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=1,
    save_steps=5,
    eval_steps=1,
    max_steps=120,
    eval_strategy="no", #"steps",
    save_strategy="no", #"steps",
    predict_with_generate=True,
    report_to=None,
    metric_for_best_model="chr_f_score",
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model=base_seq2seq,
    args=training_args,
    train_dataset=tokenized_datasets["train"].with_format("torch"),
    eval_dataset=tokenized_datasets["dev"].with_format("torch"),
    compute_metrics=compute_metrics
)


max_steps is given, it will override any value given in num_train_epochs


In [26]:
trainer.train()

  1%|          | 1/120 [00:12<24:18, 12.26s/it]

{'loss': 0.5234, 'grad_norm': 2.68664288520813, 'learning_rate': 4.958333333333334e-05, 'epoch': 0.0}


  2%|▏         | 2/120 [00:18<16:36,  8.45s/it]

{'loss': 0.5101, 'grad_norm': 2.5996463298797607, 'learning_rate': 4.9166666666666665e-05, 'epoch': 0.0}


  2%|▎         | 3/120 [00:24<14:43,  7.55s/it]

{'loss': 0.4378, 'grad_norm': 2.461297035217285, 'learning_rate': 4.875e-05, 'epoch': 0.0}


  3%|▎         | 4/120 [00:30<13:16,  6.86s/it]

{'loss': 0.5171, 'grad_norm': 13.707554817199707, 'learning_rate': 4.8333333333333334e-05, 'epoch': 0.0}


  4%|▍         | 5/120 [00:35<12:14,  6.38s/it]

{'loss': 0.4104, 'grad_norm': 2.714751720428467, 'learning_rate': 4.791666666666667e-05, 'epoch': 0.0}


  5%|▌         | 6/120 [00:41<11:31,  6.07s/it]

{'loss': 0.3444, 'grad_norm': 2.1749844551086426, 'learning_rate': 4.75e-05, 'epoch': 0.0}


  6%|▌         | 7/120 [00:46<11:09,  5.93s/it]

{'loss': 0.3462, 'grad_norm': 1.8854520320892334, 'learning_rate': 4.708333333333334e-05, 'epoch': 0.0}


  7%|▋         | 8/120 [00:52<10:58,  5.88s/it]

{'loss': 0.3969, 'grad_norm': 2.5468862056732178, 'learning_rate': 4.666666666666667e-05, 'epoch': 0.0}


  8%|▊         | 9/120 [01:00<11:46,  6.37s/it]

{'loss': 0.4097, 'grad_norm': 2.414731979370117, 'learning_rate': 4.6250000000000006e-05, 'epoch': 0.0}


  8%|▊         | 10/120 [01:06<11:47,  6.43s/it]

{'loss': 0.209, 'grad_norm': 1.6890816688537598, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.0}


  9%|▉         | 11/120 [01:13<11:37,  6.40s/it]

{'loss': 0.4245, 'grad_norm': 3.2410264015197754, 'learning_rate': 4.541666666666667e-05, 'epoch': 0.0}


 10%|█         | 12/120 [01:18<10:46,  5.99s/it]

{'loss': 0.4573, 'grad_norm': 3.015742540359497, 'learning_rate': 4.5e-05, 'epoch': 0.0}


 11%|█         | 13/120 [01:22<09:59,  5.60s/it]

{'loss': 0.334, 'grad_norm': 3.6554157733917236, 'learning_rate': 4.458333333333334e-05, 'epoch': 0.0}


 12%|█▏        | 14/120 [01:27<09:33,  5.41s/it]

{'loss': 0.2913, 'grad_norm': 1.8242318630218506, 'learning_rate': 4.4166666666666665e-05, 'epoch': 0.0}


 12%|█▎        | 15/120 [01:33<09:49,  5.62s/it]

{'loss': 0.3181, 'grad_norm': 2.2917802333831787, 'learning_rate': 4.375e-05, 'epoch': 0.0}


 13%|█▎        | 16/120 [01:39<09:28,  5.46s/it]

{'loss': 0.5187, 'grad_norm': 3.079606533050537, 'learning_rate': 4.3333333333333334e-05, 'epoch': 0.0}


 14%|█▍        | 17/120 [01:45<09:52,  5.75s/it]

{'loss': 0.3812, 'grad_norm': 2.359492778778076, 'learning_rate': 4.291666666666667e-05, 'epoch': 0.0}


 15%|█▌        | 18/120 [01:51<10:02,  5.91s/it]

{'loss': 0.4493, 'grad_norm': 2.3060858249664307, 'learning_rate': 4.25e-05, 'epoch': 0.0}


 16%|█▌        | 19/120 [02:01<11:50,  7.03s/it]

{'loss': 0.3876, 'grad_norm': 2.2197580337524414, 'learning_rate': 4.208333333333334e-05, 'epoch': 0.0}


 17%|█▋        | 20/120 [02:08<11:32,  6.92s/it]

{'loss': 0.2923, 'grad_norm': 2.1346163749694824, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.0}


 18%|█▊        | 21/120 [02:15<11:37,  7.04s/it]

{'loss': 0.487, 'grad_norm': 2.8677213191986084, 'learning_rate': 4.125e-05, 'epoch': 0.0}


 18%|█▊        | 22/120 [02:20<10:42,  6.55s/it]

{'loss': 1.1458, 'grad_norm': 27.756139755249023, 'learning_rate': 4.0833333333333334e-05, 'epoch': 0.0}


 19%|█▉        | 23/120 [02:29<11:36,  7.18s/it]

{'loss': 0.4939, 'grad_norm': 3.2472004890441895, 'learning_rate': 4.041666666666667e-05, 'epoch': 0.0}


 20%|██        | 24/120 [02:35<11:10,  6.99s/it]

{'loss': 0.4015, 'grad_norm': 2.518772602081299, 'learning_rate': 4e-05, 'epoch': 0.0}


 21%|██        | 25/120 [02:41<10:14,  6.47s/it]

{'loss': 0.3977, 'grad_norm': 2.4719486236572266, 'learning_rate': 3.958333333333333e-05, 'epoch': 0.0}


 22%|██▏       | 26/120 [02:46<09:37,  6.15s/it]

{'loss': 0.3264, 'grad_norm': 2.106884717941284, 'learning_rate': 3.9166666666666665e-05, 'epoch': 0.0}


 22%|██▎       | 27/120 [02:51<09:07,  5.89s/it]

{'loss': 0.3228, 'grad_norm': 2.0161643028259277, 'learning_rate': 3.875e-05, 'epoch': 0.0}


 23%|██▎       | 28/120 [02:57<08:57,  5.84s/it]

{'loss': 0.4964, 'grad_norm': 2.7875406742095947, 'learning_rate': 3.8333333333333334e-05, 'epoch': 0.0}


 24%|██▍       | 29/120 [03:02<08:36,  5.67s/it]

{'loss': 0.2628, 'grad_norm': 1.8256276845932007, 'learning_rate': 3.791666666666667e-05, 'epoch': 0.0}


 25%|██▌       | 30/120 [03:08<08:31,  5.69s/it]

{'loss': 0.5695, 'grad_norm': 3.2678616046905518, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.0}


 26%|██▌       | 31/120 [03:13<08:11,  5.52s/it]

{'loss': 0.4916, 'grad_norm': 2.9669196605682373, 'learning_rate': 3.708333333333334e-05, 'epoch': 0.0}


 27%|██▋       | 32/120 [03:18<07:51,  5.36s/it]

{'loss': 0.3715, 'grad_norm': 2.230949878692627, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.0}


 28%|██▊       | 33/120 [03:23<07:36,  5.25s/it]

{'loss': 0.3755, 'grad_norm': 3.175701379776001, 'learning_rate': 3.625e-05, 'epoch': 0.0}


 28%|██▊       | 34/120 [03:28<07:19,  5.12s/it]

{'loss': 0.4399, 'grad_norm': 2.5880911350250244, 'learning_rate': 3.5833333333333335e-05, 'epoch': 0.0}


 29%|██▉       | 35/120 [03:33<07:08,  5.04s/it]

{'loss': 0.3241, 'grad_norm': 2.5480823516845703, 'learning_rate': 3.541666666666667e-05, 'epoch': 0.0}


 30%|███       | 36/120 [03:38<07:05,  5.07s/it]

{'loss': 0.4928, 'grad_norm': 2.77695894241333, 'learning_rate': 3.5e-05, 'epoch': 0.0}


 31%|███       | 37/120 [03:43<06:58,  5.05s/it]

{'loss': 0.537, 'grad_norm': 3.074765920639038, 'learning_rate': 3.458333333333333e-05, 'epoch': 0.0}


 32%|███▏      | 38/120 [03:48<06:50,  5.01s/it]

{'loss': 0.6124, 'grad_norm': 2.820174217224121, 'learning_rate': 3.4166666666666666e-05, 'epoch': 0.0}


 32%|███▎      | 39/120 [03:53<06:46,  5.02s/it]

{'loss': 0.3673, 'grad_norm': 2.4933133125305176, 'learning_rate': 3.375000000000001e-05, 'epoch': 0.0}


 33%|███▎      | 40/120 [03:59<06:56,  5.21s/it]

{'loss': 0.4119, 'grad_norm': 2.5386929512023926, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.0}


 34%|███▍      | 41/120 [04:04<06:46,  5.15s/it]

{'loss': 0.7499, 'grad_norm': 3.9958255290985107, 'learning_rate': 3.291666666666667e-05, 'epoch': 0.0}


 35%|███▌      | 42/120 [04:09<06:49,  5.25s/it]

{'loss': 0.3467, 'grad_norm': 2.0401852130889893, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.0}


 36%|███▌      | 43/120 [04:15<06:51,  5.35s/it]

{'loss': 0.3609, 'grad_norm': 2.525691270828247, 'learning_rate': 3.208333333333334e-05, 'epoch': 0.0}


 37%|███▋      | 44/120 [04:20<06:44,  5.32s/it]

{'loss': 0.4882, 'grad_norm': 2.7924983501434326, 'learning_rate': 3.1666666666666666e-05, 'epoch': 0.0}


 38%|███▊      | 45/120 [04:25<06:28,  5.18s/it]

{'loss': 0.7291, 'grad_norm': 4.025518417358398, 'learning_rate': 3.125e-05, 'epoch': 0.0}


 38%|███▊      | 46/120 [04:30<06:15,  5.08s/it]

{'loss': 0.3893, 'grad_norm': 2.6238467693328857, 'learning_rate': 3.0833333333333335e-05, 'epoch': 0.0}


 39%|███▉      | 47/120 [04:34<06:04,  4.99s/it]

{'loss': 0.4203, 'grad_norm': 2.7471847534179688, 'learning_rate': 3.0416666666666666e-05, 'epoch': 0.0}


 40%|████      | 48/120 [04:39<05:58,  4.98s/it]

{'loss': 0.6291, 'grad_norm': 3.848888874053955, 'learning_rate': 3e-05, 'epoch': 0.0}


 41%|████      | 49/120 [04:44<05:51,  4.95s/it]

{'loss': 0.4546, 'grad_norm': 2.5541131496429443, 'learning_rate': 2.9583333333333335e-05, 'epoch': 0.0}


 42%|████▏     | 50/120 [04:49<05:50,  5.00s/it]

{'loss': 0.4022, 'grad_norm': 2.224252939224243, 'learning_rate': 2.916666666666667e-05, 'epoch': 0.0}


 42%|████▎     | 51/120 [04:54<05:46,  5.02s/it]

{'loss': 0.2436, 'grad_norm': 1.8495457172393799, 'learning_rate': 2.8749999999999997e-05, 'epoch': 0.0}


 43%|████▎     | 52/120 [04:59<05:41,  5.02s/it]

{'loss': 0.401, 'grad_norm': 3.6769204139709473, 'learning_rate': 2.8333333333333335e-05, 'epoch': 0.0}


 44%|████▍     | 53/120 [05:04<05:34,  4.99s/it]

{'loss': 0.4013, 'grad_norm': 2.388005018234253, 'learning_rate': 2.791666666666667e-05, 'epoch': 0.0}


 45%|████▌     | 54/120 [05:09<05:31,  5.02s/it]

{'loss': 0.4624, 'grad_norm': 2.710179090499878, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.0}


 46%|████▌     | 55/120 [05:14<05:24,  4.99s/it]

{'loss': 0.2402, 'grad_norm': 2.0686097145080566, 'learning_rate': 2.7083333333333332e-05, 'epoch': 0.0}


 47%|████▋     | 56/120 [05:19<05:18,  4.98s/it]

{'loss': 0.6434, 'grad_norm': 3.0266613960266113, 'learning_rate': 2.6666666666666667e-05, 'epoch': 0.0}


 48%|████▊     | 57/120 [05:24<05:11,  4.95s/it]

{'loss': 0.6094, 'grad_norm': 3.339745283126831, 'learning_rate': 2.625e-05, 'epoch': 0.0}


 48%|████▊     | 58/120 [05:29<05:03,  4.89s/it]

{'loss': 0.5317, 'grad_norm': 2.9965262413024902, 'learning_rate': 2.5833333333333336e-05, 'epoch': 0.0}


 49%|████▉     | 59/120 [05:34<04:56,  4.87s/it]

{'loss': 0.67, 'grad_norm': 3.5709047317504883, 'learning_rate': 2.5416666666666667e-05, 'epoch': 0.0}


 50%|█████     | 60/120 [05:39<04:50,  4.84s/it]

{'loss': 0.3908, 'grad_norm': 2.60565447807312, 'learning_rate': 2.5e-05, 'epoch': 0.0}


 51%|█████     | 61/120 [05:43<04:46,  4.85s/it]

{'loss': 0.375, 'grad_norm': 2.3894433975219727, 'learning_rate': 2.4583333333333332e-05, 'epoch': 0.0}


 52%|█████▏    | 62/120 [05:48<04:41,  4.86s/it]

{'loss': 0.6576, 'grad_norm': 4.287093639373779, 'learning_rate': 2.4166666666666667e-05, 'epoch': 0.0}


 52%|█████▎    | 63/120 [05:54<04:45,  5.01s/it]

{'loss': 0.5188, 'grad_norm': 2.713465690612793, 'learning_rate': 2.375e-05, 'epoch': 0.0}


 53%|█████▎    | 64/120 [05:59<04:40,  5.01s/it]

{'loss': 0.3214, 'grad_norm': 1.8584692478179932, 'learning_rate': 2.3333333333333336e-05, 'epoch': 0.0}


 54%|█████▍    | 65/120 [06:04<04:37,  5.05s/it]

{'loss': 0.5352, 'grad_norm': 5.522486686706543, 'learning_rate': 2.2916666666666667e-05, 'epoch': 0.0}


 55%|█████▌    | 66/120 [06:09<04:35,  5.11s/it]

{'loss': 0.2911, 'grad_norm': 8.739320755004883, 'learning_rate': 2.25e-05, 'epoch': 0.0}


 56%|█████▌    | 67/120 [06:14<04:28,  5.07s/it]

{'loss': 0.5884, 'grad_norm': 3.3358757495880127, 'learning_rate': 2.2083333333333333e-05, 'epoch': 0.0}


 57%|█████▋    | 68/120 [06:19<04:19,  4.99s/it]

{'loss': 0.5057, 'grad_norm': 2.9963653087615967, 'learning_rate': 2.1666666666666667e-05, 'epoch': 0.0}


 57%|█████▊    | 69/120 [06:24<04:11,  4.94s/it]

{'loss': 0.3816, 'grad_norm': 2.7741591930389404, 'learning_rate': 2.125e-05, 'epoch': 0.0}


 58%|█████▊    | 70/120 [06:28<04:04,  4.89s/it]

{'loss': 0.505, 'grad_norm': 3.182297468185425, 'learning_rate': 2.0833333333333336e-05, 'epoch': 0.0}


 59%|█████▉    | 71/120 [06:33<04:01,  4.92s/it]

{'loss': 0.4208, 'grad_norm': 2.348478317260742, 'learning_rate': 2.0416666666666667e-05, 'epoch': 0.0}


 60%|██████    | 72/120 [06:38<03:54,  4.89s/it]

{'loss': 0.4669, 'grad_norm': 3.64544677734375, 'learning_rate': 2e-05, 'epoch': 0.0}


 61%|██████    | 73/120 [06:43<03:49,  4.89s/it]

{'loss': 0.4032, 'grad_norm': 2.7006115913391113, 'learning_rate': 1.9583333333333333e-05, 'epoch': 0.0}


 62%|██████▏   | 74/120 [06:48<03:44,  4.89s/it]

{'loss': 0.4092, 'grad_norm': 3.375107526779175, 'learning_rate': 1.9166666666666667e-05, 'epoch': 0.0}


 62%|██████▎   | 75/120 [06:53<03:40,  4.90s/it]

{'loss': 0.4887, 'grad_norm': 2.5091476440429688, 'learning_rate': 1.8750000000000002e-05, 'epoch': 0.0}


 63%|██████▎   | 76/120 [06:58<03:34,  4.87s/it]

{'loss': 0.5746, 'grad_norm': 3.195418119430542, 'learning_rate': 1.8333333333333333e-05, 'epoch': 0.0}


 64%|██████▍   | 77/120 [07:03<03:28,  4.86s/it]

{'loss': 0.4088, 'grad_norm': 2.124907970428467, 'learning_rate': 1.7916666666666667e-05, 'epoch': 0.0}


 65%|██████▌   | 78/120 [07:07<03:23,  4.85s/it]

{'loss': 0.6545, 'grad_norm': 3.2009527683258057, 'learning_rate': 1.75e-05, 'epoch': 0.0}


 66%|██████▌   | 79/120 [07:13<03:24,  4.98s/it]

{'loss': 0.3581, 'grad_norm': 2.0931408405303955, 'learning_rate': 1.7083333333333333e-05, 'epoch': 0.0}


 67%|██████▋   | 80/120 [07:18<03:21,  5.05s/it]

{'loss': 0.3643, 'grad_norm': 2.2046923637390137, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.0}


 68%|██████▊   | 81/120 [07:23<03:17,  5.06s/it]

{'loss': 0.6342, 'grad_norm': 3.4459969997406006, 'learning_rate': 1.6250000000000002e-05, 'epoch': 0.0}


 68%|██████▊   | 82/120 [07:28<03:10,  5.01s/it]

{'loss': 0.2447, 'grad_norm': 1.6055723428726196, 'learning_rate': 1.5833333333333333e-05, 'epoch': 0.0}


 69%|██████▉   | 83/120 [07:33<03:02,  4.93s/it]

{'loss': 1.6285, 'grad_norm': 5.951422214508057, 'learning_rate': 1.5416666666666668e-05, 'epoch': 0.0}


 70%|███████   | 84/120 [07:37<02:55,  4.89s/it]

{'loss': 0.4048, 'grad_norm': 2.6030707359313965, 'learning_rate': 1.5e-05, 'epoch': 0.0}


 71%|███████   | 85/120 [07:42<02:52,  4.93s/it]

{'loss': 0.2762, 'grad_norm': 2.454455614089966, 'learning_rate': 1.4583333333333335e-05, 'epoch': 0.0}


 72%|███████▏  | 86/120 [07:47<02:45,  4.88s/it]

{'loss': 0.3769, 'grad_norm': 2.431744337081909, 'learning_rate': 1.4166666666666668e-05, 'epoch': 0.0}


 72%|███████▎  | 87/120 [07:54<03:04,  5.58s/it]

{'loss': 0.451, 'grad_norm': 3.335707187652588, 'learning_rate': 1.3750000000000002e-05, 'epoch': 0.0}


 73%|███████▎  | 88/120 [08:01<03:10,  5.96s/it]

{'loss': 0.4748, 'grad_norm': 2.520878791809082, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.0}


 74%|███████▍  | 89/120 [08:07<03:00,  5.81s/it]

{'loss': 0.4902, 'grad_norm': 2.4723095893859863, 'learning_rate': 1.2916666666666668e-05, 'epoch': 0.0}


 75%|███████▌  | 90/120 [08:14<03:04,  6.15s/it]

{'loss': 0.2615, 'grad_norm': 1.7988883256912231, 'learning_rate': 1.25e-05, 'epoch': 0.0}


 76%|███████▌  | 91/120 [08:25<03:42,  7.66s/it]

{'loss': 0.4095, 'grad_norm': 2.4683988094329834, 'learning_rate': 1.2083333333333333e-05, 'epoch': 0.0}


 77%|███████▋  | 92/120 [08:31<03:20,  7.18s/it]

{'loss': 0.3545, 'grad_norm': 3.665443181991577, 'learning_rate': 1.1666666666666668e-05, 'epoch': 0.0}


 78%|███████▊  | 93/120 [08:36<03:00,  6.68s/it]

{'loss': 0.4263, 'grad_norm': 2.7706830501556396, 'learning_rate': 1.125e-05, 'epoch': 0.0}


 78%|███████▊  | 94/120 [08:42<02:41,  6.22s/it]

{'loss': 0.3856, 'grad_norm': 2.175520181655884, 'learning_rate': 1.0833333333333334e-05, 'epoch': 0.0}


 79%|███████▉  | 95/120 [08:47<02:25,  5.83s/it]

{'loss': 0.5032, 'grad_norm': 3.2240471839904785, 'learning_rate': 1.0416666666666668e-05, 'epoch': 0.0}


 80%|████████  | 96/120 [08:51<02:11,  5.47s/it]

{'loss': 0.2947, 'grad_norm': 2.259200096130371, 'learning_rate': 1e-05, 'epoch': 0.0}


 81%|████████  | 97/120 [08:56<02:00,  5.24s/it]

{'loss': 0.4744, 'grad_norm': 2.618772029876709, 'learning_rate': 9.583333333333334e-06, 'epoch': 0.0}


 82%|████████▏ | 98/120 [09:01<01:52,  5.13s/it]

{'loss': 0.409, 'grad_norm': 2.3881988525390625, 'learning_rate': 9.166666666666666e-06, 'epoch': 0.0}


 82%|████████▎ | 99/120 [09:06<01:45,  5.02s/it]

{'loss': 0.3766, 'grad_norm': 2.545694351196289, 'learning_rate': 8.75e-06, 'epoch': 0.0}


 83%|████████▎ | 100/120 [09:10<01:37,  4.89s/it]

{'loss': 0.4245, 'grad_norm': 2.5265276432037354, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.0}


 84%|████████▍ | 101/120 [09:15<01:34,  5.00s/it]

{'loss': 0.5956, 'grad_norm': 3.335174560546875, 'learning_rate': 7.916666666666667e-06, 'epoch': 0.0}


 85%|████████▌ | 102/120 [09:20<01:29,  4.99s/it]

{'loss': 0.6078, 'grad_norm': 2.96943736076355, 'learning_rate': 7.5e-06, 'epoch': 0.0}


 86%|████████▌ | 103/120 [09:25<01:25,  5.00s/it]

{'loss': 0.4382, 'grad_norm': 2.7796213626861572, 'learning_rate': 7.083333333333334e-06, 'epoch': 0.0}


 87%|████████▋ | 104/120 [09:32<01:26,  5.39s/it]

{'loss': 0.3265, 'grad_norm': 2.2501399517059326, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.0}


 88%|████████▊ | 105/120 [09:37<01:18,  5.26s/it]

{'loss': 0.5157, 'grad_norm': 2.9366579055786133, 'learning_rate': 6.25e-06, 'epoch': 0.0}


 88%|████████▊ | 106/120 [09:42<01:12,  5.19s/it]

{'loss': 0.3711, 'grad_norm': 2.2623443603515625, 'learning_rate': 5.833333333333334e-06, 'epoch': 0.0}


 89%|████████▉ | 107/120 [09:47<01:06,  5.13s/it]

{'loss': 0.3637, 'grad_norm': 2.1348791122436523, 'learning_rate': 5.416666666666667e-06, 'epoch': 0.0}


 90%|█████████ | 108/120 [09:52<01:02,  5.18s/it]

{'loss': 0.3692, 'grad_norm': 2.7635340690612793, 'learning_rate': 5e-06, 'epoch': 0.0}


 91%|█████████ | 109/120 [09:58<00:58,  5.31s/it]

{'loss': 0.3131, 'grad_norm': 1.8908228874206543, 'learning_rate': 4.583333333333333e-06, 'epoch': 0.0}


 92%|█████████▏| 110/120 [10:03<00:54,  5.47s/it]

{'loss': 0.4782, 'grad_norm': 2.644212484359741, 'learning_rate': 4.166666666666667e-06, 'epoch': 0.0}


 92%|█████████▎| 111/120 [10:09<00:49,  5.46s/it]

{'loss': 0.3118, 'grad_norm': 2.2117741107940674, 'learning_rate': 3.75e-06, 'epoch': 0.0}


 93%|█████████▎| 112/120 [10:14<00:43,  5.43s/it]

{'loss': 1.0354, 'grad_norm': 3.9562907218933105, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.0}


 94%|█████████▍| 113/120 [10:19<00:37,  5.38s/it]

{'loss': 0.4871, 'grad_norm': 2.6594200134277344, 'learning_rate': 2.916666666666667e-06, 'epoch': 0.0}


 95%|█████████▌| 114/120 [10:25<00:32,  5.40s/it]

{'loss': 0.5301, 'grad_norm': 3.3167057037353516, 'learning_rate': 2.5e-06, 'epoch': 0.0}


 96%|█████████▌| 115/120 [10:30<00:26,  5.32s/it]

{'loss': 0.44, 'grad_norm': 2.607652187347412, 'learning_rate': 2.0833333333333334e-06, 'epoch': 0.0}


 97%|█████████▋| 116/120 [10:35<00:20,  5.25s/it]

{'loss': 0.3389, 'grad_norm': 2.2633886337280273, 'learning_rate': 1.6666666666666667e-06, 'epoch': 0.0}


 98%|█████████▊| 117/120 [10:41<00:16,  5.57s/it]

{'loss': 0.3823, 'grad_norm': 2.1887028217315674, 'learning_rate': 1.25e-06, 'epoch': 0.0}


 98%|█████████▊| 118/120 [10:48<00:11,  5.86s/it]

{'loss': 0.2909, 'grad_norm': 1.925498604774475, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.0}


 99%|█████████▉| 119/120 [10:54<00:05,  5.96s/it]

{'loss': 0.1961, 'grad_norm': 1.686539888381958, 'learning_rate': 4.1666666666666667e-07, 'epoch': 0.0}


100%|██████████| 120/120 [11:00<00:00,  5.79s/it]

{'loss': 0.4289, 'grad_norm': 2.4045698642730713, 'learning_rate': 0.0, 'epoch': 0.0}


100%|██████████| 120/120 [11:00<00:00,  5.50s/it]

{'train_runtime': 660.3332, 'train_samples_per_second': 0.727, 'train_steps_per_second': 0.182, 'train_loss': 0.4516807119051615, 'epoch': 0.0}





TrainOutput(global_step=120, training_loss=0.4516807119051615, metrics={'train_runtime': 660.3332, 'train_samples_per_second': 0.727, 'train_steps_per_second': 0.182, 'total_flos': 16271215165440.0, 'train_loss': 0.4516807119051615, 'epoch': 0.0003236664940445365})

In [11]:
from transformers import GenerationConfig
gen_config = GenerationConfig(**{'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[62508]], 'forced_eos_token_id': 0})
gen_config.save_pretrained("./model")

In [28]:
trainer.save_model("./model")
tokenizer.save_pretrained("./model")

Non-default generation parameters: {'max_length': 512, 'num_beams': 4, 'bad_words_ids': [[62508]], 'forced_eos_token_id': 0}


('./model\\tokenizer_config.json',
 './model\\special_tokens_map.json',
 './model\\vocab.json',
 './model\\source.spm',
 './model\\target.spm',
 './model\\added_tokens.json')

# Using the model

In [29]:
model = pipeline("translation", "./model")

In [30]:
sample_sentences = dataset_dict["test"].select(list(range(10)))

for datapoint in sample_sentences:
    sentence = datapoint["source"]    
    ground_truth = datapoint["target"]
    model_result = model(sentence)[0]["translation_text"]

    print(f"Input: [{sentence}]\nCorrect: [{ground_truth}]\nModel: [{model_result}]\n")

Input: [miloš zeman tenkrát sice odmítl takové "řeči o levicových inletektulech".]
Correct: [Miloš Zeman tenkrát sice odmítl takové "řeči o levicových intelektuálech".]
Model: [Miloš zeman tenkrát sice odmítl takové "řeči o levičových vstupukulech."]

Input: [A je jedno jestli se aůtoři snažili nebo nesnažili, zkrátka sejim to vbůec nepovedlo.]
Correct: [A je jedno jestli se autoři snažili nebo nesnažili, zkrátka se jim to vůbec nepovedlo.]
Model: [A je jedno jeli se autoři snažili nebo nesnažili, zkrátka sejim to všechno nepovedlo.]

Input: [v prběhu tohoto nde k sobě oba antagonisté najdou cestu, objeví v sobě to dobré a poznají teplo upřmíných lidských citů.]
Correct: [V průběhu tohoto dne k sobě oba antagonisté najdou cestu, objeví v sobě to dobré a poznají teplo upřímných lidských citů.]
Model: [V překladu toto nde k sobě Oba antagonisté najdou cestu, objeví v sobě to dobré a poznají teplo upjatých lidských citů.]

Input: [Mohou to být alkalická mýdla, prací  nebo čizticíí prostře