# Train Vec4Gloss - Denoise

In [1]:
## reference: https://huggingface.co/course/chapter7/4

In [2]:
%env WANDB_PROJECT=vec4gloss

env: WANDB_PROJECT=vec4gloss


In [3]:
%load_ext autoreload
%autoreload 2
import sys
if "../src" not in sys.path:
    sys.path.append("../src")

In [4]:
from datetime import datetime
from pathlib import Path
from datasets import Dataset
import numpy as np
from tqdm.auto import tqdm
from vec4gloss import check_hashes

## Data dependencies

```
../data/denoising_dataset_cwn/train/dataset.arrow ad87fa
```

In [5]:
_ = check_hashes([
    "../data/denoising_dataset_cwn/train/dataset.arrow",    
])

../data/denoising_dataset_cwn/train/dataset.arrow ad87fa


## Prepare dataset

In [6]:
import numpy as np
from transformers import MT5ForConditionalGeneration, MT5TokenizerFast
from transformers import DataCollatorForSeq2Seq
import datasets
from datasets import load_metric

In [7]:
ds_denoise = datasets.load_from_disk("../data/denoising_dataset_cwn/")

In [8]:
print({k: len(v) for k, v in ds_denoise.items()})

{'train': 26118, 'test': 2903}


In [9]:
tokenizer = MT5TokenizerFast.from_pretrained("google/mt5-base")

KeyboardInterrupt: 

### Eye-balling

In [None]:
ds_denoise["train"][10:12]

In [None]:
tokenizer.convert_ids_to_tokens(tokenizer(ds_denoise["train"][10]["tgt"])["input_ids"])

## Preprocess

In [None]:
max_length = 256

def preprocess_fn(batch):    
    src_batch = tokenizer(batch["src"], 
                          max_length=max_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        tgt_batch = tokenizer(batch["tgt"],
                              max_length=max_length, truncation=True)    
    return {
        **src_batch, "labels": tgt_batch["input_ids"]
    }

In [None]:
drop_columns = ["cwnid", "src", "tgt"]
ds_denoise = ds_denoise.map(preprocess_fn, batched=True, remove_columns=drop_columns)

### Eye-balling

In [None]:
print(" ".join(tokenizer.convert_ids_to_tokens(ds_denoise["train"][10]["input_ids"])))
print(" ".join(tokenizer.convert_ids_to_tokens(ds_denoise["train"][10]["labels"])))

## Model definition

In [None]:
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base").to("cuda")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding="longest")

## Trainer

In [None]:
# import wandb
# wandb.login()

In [None]:
out_dir = Path("/mnt/md0/seantyh/vec4gloss")
if not out_dir.exists():
    out_dir = "vec4gloss"
print(out_dir)

In [None]:
from transformers import Seq2SeqTrainingArguments
timestamp = datetime.now().strftime("%y%m%d-%H%M")

args = Seq2SeqTrainingArguments(
    out_dir,
    evaluation_strategy="steps",
    save_strategy="epoch",    
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    logging_steps=100, # 10 for debug, else 500
    eval_steps=3000,
    save_total_limit=3,
    num_train_epochs=3,
    # report_to="wandb",
    run_name=f"denoising-{timestamp}",
    predict_with_generate=False,    
)

In [None]:
train_ds = ds_denoise["train"]
test_ds = ds_denoise["test"]
# train_ds = train_ds.select(range(100))
# test_ds = test_ds.select(range(200))

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=data_collator,
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(f"../data/models/vec4gloss-denoise-{timestamp}")