Skip to content

Commit

Permalink
Add vocabulary checks
Browse files Browse the repository at this point in the history
  • Loading branch information
salvacarrion committed Dec 3, 2023
1 parent aee1c88 commit 5cd1f57
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 103 deletions.
6 changes: 5 additions & 1 deletion autonmt/preprocessing/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def _create_reduced_versions(self, force_overwrite):
if fname.split('.')[0] == ds.train_name: # train.xx
with open(ref_filename, 'rb') as fin:
lines = list(islice(fin, ds.dataset_lines)) # Copy n lines efficiently
if len(lines) == ds.dataset_lines:
if len(lines) == ds.dataset_lines or ds.dataset_lines is None: # None == All lines
with open(new_filename, 'wb') as fout:
fout.writelines(lines)
else:
Expand Down Expand Up @@ -489,6 +489,10 @@ def _train_tokenizer(self, force_overwrite):
character_coverage=self.character_coverage, split_digits=self.split_digits)
assert os.path.exists(f"{output_file}.model")

# Check vocabs
print(f"=> Checking existing vocabularies...")
ds.check_vocab_folder_consistency()

def _encode_datasets(self, force_overwrite):
print(f"=> Building datasets...")
for ds in self: # Dataset
Expand Down
44 changes: 44 additions & 0 deletions autonmt/preprocessing/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,47 @@ def has_split_files(self):
split_files = [self.get_split_path(f) for f in self.get_split_fnames()]
return all([os.path.exists(p) for p in split_files]), split_files
return False, []

def check_vocab_folder_consistency(self, check_extra=False, custom_vocabs=False):
default_vocab_extensions = ["model", "vocab"]

if check_extra:
default_vocab_extensions.append("vocabf")

# Ignore datasets with no vocabs
if self.subword_model in {None, "none", "bytes"}:
return True

# Check if it has a vocab folder
vocab_path = self.get_vocab_path()
if not os.path.exists(vocab_path):
raise ValueError(f"=> [ERROR CAPTURED]: Vocab path does not exist: {vocab_path}")

# Custom vocabs only need to check if all files exists
if custom_vocabs: # Any language
num_expected_files = len(default_vocab_extensions) if self.merge_vocabs else 2*len(default_vocab_extensions)
else:
# Get expected vocab files
lang_files = [f"{self.src_lang}-{self.trg_lang}"] if self.merge_vocabs else [self.src_lang, self.trg_lang]
expected_files = [f"{self.get_vocab_file(lang=lang)}.{ext}" for lang in lang_files for ext in
default_vocab_extensions]

# Check if all files exist
missing_files = [os.path.split(f)[1] for f in expected_files if not os.path.exists(f)]
if missing_files:
raise ValueError(f"=> [ERROR CAPTURED]: Missing vocab files for dataset '{self.id(as_path=True)}': {missing_files}\n\t- Vocab path: {vocab_path}")

# Get number of expected files
num_expected_files = len(expected_files)

# Check if there are extra files
existing_files = [os.path.join(vocab_path, f) for f in os.listdir(vocab_path) if f.endswith(tuple(default_vocab_extensions))]
if len(existing_files) != num_expected_files:
msg = (f"Incorrect number of vocab files for dataset '{self.id(as_path=True)}'. Expected {num_expected_files}, found {len(existing_files)}."
f"\n\t- Reason: This can lead to potential vocabulary mismatches during training."
f"\n\t- Vocab path: {vocab_path}")
if custom_vocabs:
print(f"=> [WARNING]: {msg}")
else:
raise ValueError(f"=> [PROCESS ABORTED]: {msg}")
return True
52 changes: 22 additions & 30 deletions examples/dev/0_test_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def preprocess_predict(data, ds):
else: # Target
return data["lines"]


# Preprocess functions
normalize_fn = lambda x: normalize_lines(x, seq=[NFKC(), Strip(), Lowercase()])
preprocess_raw_fn = lambda data, ds: preprocess_pairs(data["src"]["lines"], data["trg"]["lines"], normalize_fn=normalize_fn, min_len=1, max_len=None, remove_duplicates=False, shuffle_lines=False)
Expand Down Expand Up @@ -71,12 +72,12 @@ def main():
# {"name": "europarl", "languages": ["en-es"], "sizes": [("100k-gen", 100000)]},

# Multilingual: Spanish-French-German-Czech
# {"name": "europarl", "languages": ["en-es"], "sizes": [("100k-multi-lc", 100000)]},
{"name": "europarl", "languages": ["en-xx"], "sizes": [("100k-multi-lc", 100000)]},
{"name": "europarl", "languages": ["en-es", "en-fr", "en-de", "en-cs"], "sizes": [("100k-multi-lc", 100000)]},
# {"name": "europarl", "languages": ["en-es", "en-fr", "en-de", "en-cs"], "sizes": [("100k-multi-lc-base", 100000)]},
# {"name": "europarl", "languages": ["en-xx"], "sizes": [("original", None), ("100k-multi-lc-base", 100000)]},
],

# Set of subword models and vocab sizes to try
# encoding=None,
encoding=[
{"subword_models": ["bpe+bytes"], "vocab_sizes": [8000]},
],
Expand All @@ -89,15 +90,6 @@ def main():
merge_vocabs=False,
).build(make_plots=False, force_overwrite=False)

# # Merge datasets
# builder.merge_datasets(name="europarl", language_pair="en-xx", dataset_size_name="original",
# shuffle_lines=True, use_preprocessed_splits=False, force_overwrite=True,
# preprocess_fn=add_language_prefix)

# Create preprocessing for training and testing
tr_datasets = builder.get_train_ds()
ts_datasets = builder.get_test_ds()

builder_ts = DatasetBuilder(
# Root folder for datasets
base_path=BASE_PATH,
Expand All @@ -120,7 +112,7 @@ def main():
# {"name": "europarl", "languages": ["en-es"], "sizes": [("original", None)]},

# Multilingual: Spanish-French-German-Czech
{"name": "europarl", "languages": ["en-xx"], "sizes": [("original", None)]},
{"name": "europarl", "languages": ["en-es", "en-fr", "en-de", "en-cs"], "sizes": [("original", None)]},
],
)
# Create preprocessing for training and testing
Expand All @@ -130,24 +122,24 @@ def main():
# Train & Score a model for each dataset
scores = []
for i, train_ds in enumerate(tr_datasets, 1):
for iters in [200]:
for iters in [100]:
# Instantiate vocabs and model
src_vocab = Vocabulary(max_tokens=350).build_from_ds(ds=train_ds, lang=train_ds.src_lang)
trg_vocab = Vocabulary(max_tokens=350).build_from_ds(ds=train_ds, lang=train_ds.trg_lang)
model = Transformer(src_vocab_size=len(src_vocab), trg_vocab_size=len(trg_vocab), padding_idx=src_vocab.pad_id)

# Load checkpoint
# path = os.path.join(BASE_PATH, "health-bio-euro-legal/en-es/100k/models/autonmt/runs/health-bio-euro-legal_en-es_bpe+bytes_8000/checkpoints")
# checkpoint_path = os.path.join(path, "epoch=033-val_loss=2.086__best.pt")
# if checkpoint_path:
# print(f"\t- Loading previous checkpoint: {checkpoint_path}")
# model_state_dict = torch.load(checkpoint_path)
# model_state_dict = model_state_dict.get("state_dict", model_state_dict)
# model.load_state_dict(model_state_dict)
path = os.path.join(BASE_PATH, "europarl/en-xx/100k-multi-lc-base/models/autonmt/runs/europarl_en-xx_100k-multi-lc-base_bpe+bytes_8000/checkpoints")
checkpoint_path = os.path.join(path, "epoch=059-val_loss=2.289__best.pt")
if checkpoint_path:
print(f"\t- Loading previous checkpoint: {checkpoint_path}")
model_state_dict = torch.load(checkpoint_path)
model_state_dict = model_state_dict.get("state_dict", model_state_dict)
model.load_state_dict(model_state_dict)

# Define trainer
runs_dir = train_ds.get_runs_path(toolkit="autonmt")
run_prefix = "mnmt__" + '_'.join(train_ds.id()[:2]).replace('/', '-')
run_prefix = f"ft_en-xx->{train_ds.trg_lang}__" + '_'.join(train_ds.id()).replace('/', '-')
run_name = train_ds.get_run_name(run_prefix=run_prefix) #+ f"__{int(time.time())}"
trainer = AutonmtTranslator(model=model, src_vocab=src_vocab, trg_vocab=trg_vocab,
runs_dir=runs_dir, run_name=run_name)
Expand All @@ -159,25 +151,25 @@ def main():
print(f"\t- MODEL PREFIX: {run_prefix}")

# Train model
wandb_params = dict(project="continual-learning-new", entity="salvacarrion", reinit=True)
# trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=256, seed=None,
# patience=15, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
# wandb_params=wandb_params)
wandb_params = dict(project="continual-learning-multi", entity="salvacarrion", reinit=True)
trainer.fit(train_ds, max_epochs=iters, learning_rate=0.001, optimizer="adam", batch_size=256, seed=None,
patience=10, num_workers=0, accelerator="auto", strategy="auto", save_best=True, save_last=True, print_samples=1,
wandb_params=wandb_params)

# Test model
m_scores = trainer.predict(ts_datasets, metrics={"bleu"}, beams=[1], load_checkpoint="best",
preprocess_fn=preprocess_predict_fn, eval_mode="all", force_overwrite=True)
m_scores = trainer.predict(ts_datasets, metrics={"bleu", "chrf", "ter"}, beams=[1], load_checkpoint="best",
preprocess_fn=preprocess_predict_fn, eval_mode="compatible", force_overwrite=False)

# Add extra metrics
for ms in m_scores:
ms['train_dataset'] = train_ds.dataset_name
ms['vocab__merged'] = train_ds.merge_vocabs
ms['max_iters'] = str(iters)
ms['max_iters'] = iters
ms['train_dataset'] = str(train_ds)
scores.append(m_scores)

# Make report
output_path = os.path.join(BASE_PATH, f".outputs/autonmt/multilingual__BASE-ALL2")
output_path = os.path.join(BASE_PATH, f".outputs/autonmt/multilingual__FT-XX->YY-v3")
df_report, df_summary = generate_report(scores=scores, output_path=output_path)

# Print summary
Expand Down
Loading

0 comments on commit 5cd1f57

Please sign in to comment.