From d490eccb49839f33873598024e2e3d5e95160b91 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 10 Sep 2025 13:00:11 -0700 Subject: [PATCH 1/8] Will need to bump version number to make the tokenizer upgrades --- stanza/_version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stanza/_version.py b/stanza/_version.py index bf284cc37..5a096525e 100644 --- a/stanza/_version.py +++ b/stanza/_version.py @@ -1,4 +1,4 @@ """ Single source of truth for version number """ -__version__ = "1.10.1" -__resources_version__ = '1.10.0' +__version__ = "1.11.0" +__resources_version__ = '1.11.0' From 9ade0db64a0060d8638ee9810109f01b1d96eaa2 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 11 Sep 2023 00:54:06 -0700 Subject: [PATCH 2/8] Add the forward charlm to the tokenizer. Seems to work decently in terms of improving sentence splitting scores for certain languages. Also helps with MWT scores for Hebrew, not significantly tested on other languages yet Save & load the tokenizers without putting a charlm (if relevant) into the model file Pass the charlm_forward_file from the pipeline to the model When training a tokenizer, the run_tokenizer script finds a charlm, if possible, and attaches it to the model Ignore extra charlm when passed in from the run_ script or as part of the Pipeline if the saved model didn't use charlm --- stanza/models/common/char_model.py | 5 ++-- stanza/models/tokenization/model.py | 32 +++++++++++++++++++++--- stanza/models/tokenization/trainer.py | 34 ++++++++++++++++++++------ stanza/models/tokenizer.py | 8 ++++-- stanza/pipeline/tokenize_processor.py | 3 ++- stanza/resources/default_packages.py | 2 ++ stanza/utils/training/common.py | 27 +++++++++++++++----- stanza/utils/training/run_tokenizer.py | 11 ++++++--- 8 files changed, 96 insertions(+), 26 deletions(-) diff --git a/stanza/models/common/char_model.py b/stanza/models/common/char_model.py index 4a03be7d6..ab0741ac3 100644 --- a/stanza/models/common/char_model.py +++ b/stanza/models/common/char_model.py @@ -283,8 +283,9 @@ def __init__(self, charlms): super().__init__() self.charlms = charlms - def forward(self, words): - words = [CHARLM_START + x + CHARLM_END for x in words] + def forward(self, words, wrap=True): + if wrap: + words = [CHARLM_START + x + CHARLM_END for x in words] padded_reps = [] for charlm in self.charlms: rep = charlm.per_char_representation(words) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index f6349af58..173f7c292 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -3,23 +3,38 @@ import torch.nn as nn from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence +from stanza.models.common.char_model import CharacterLanguageModelWordAdapter +from stanza.models.common.foundation_cache import load_charlm + class Tokenizer(nn.Module): - def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): + def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None): super().__init__() + self.unsaved_modules = [] + self.args = args feat_dim = args['feat_dim'] self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0) - self.rnn = nn.LSTM(emb_dim + feat_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0) + self.input_dim = emb_dim + feat_dim + + charmodel = None + if args is not None and args.get('charlm_forward_file', None): + charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache) + charmodels = nn.ModuleList([charmodel_forward]) + charmodel = CharacterLanguageModelWordAdapter(charmodels) + self.input_dim += charmodel.hidden_dim() + self.add_unsaved_module("charmodel", charmodel) + + self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0) if self.args['conv_res'] is not None: self.conv_res = nn.ModuleList() self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')] for si, size in enumerate(self.conv_sizes): - l = nn.Conv1d(emb_dim + feat_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0)) + l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0)) self.conv_res.append(l) if self.args.get('hier_conv_res', False): @@ -42,8 +57,17 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): self.toknoise = nn.Dropout(self.args['tok_noise']) - def forward(self, x, feats, lengths): + def add_unsaved_module(self, name, module): + self.unsaved_modules += [name] + setattr(self, name, module) + + def forward(self, x, feats, lengths, raw=None): emb = self.embeddings(x) + + if self.charmodel is not None and raw is not None: + char_emb = self.charmodel(raw, wrap=False) + emb = torch.cat([emb, char_emb], axis=2) + emb = self.dropout(emb) feats = self.dropout_feat(feats) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 2cd60ce37..81a34b660 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -14,10 +14,12 @@ logger = logging.getLogger('stanza') class Trainer(BaseTrainer): - def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None): + def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None): + # TODO: make a test of the training w/ and w/o charlm + # TODO: build the resources with the forward charlm if model_file is not None: # load everything from file - self.load(model_file) + self.load(model_file, args, foundation_cache) else: # build model from scratch self.args = args @@ -41,7 +43,7 @@ def update(self, inputs): labels = labels.to(device) features = features.to(device) - pred = self.model(units, features, lengths) + pred = self.model(units, features, lengths, text) self.optimizer.zero_grad() classes = pred.size(2) @@ -62,13 +64,22 @@ def predict(self, inputs): units = units.to(device) features = features.to(device) - pred = self.model(units, features, lengths) + pred = self.model(units, features, lengths, text) return pred.data.cpu().numpy() - def save(self, filename): + def save(self, filename, skip_modules=True): + model_state = None + if self.model is not None: + model_state = self.model.state_dict() + # skip saving modules like the pretrained charlm + if skip_modules: + skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules] + for k in skipped: + del model_state[k] + params = { - 'model': self.model.state_dict() if self.model is not None else None, + 'model': model_state, 'vocab': self.vocab.state_dict(), # save and load lexicon as list instead of set so # we can use weights_only=True @@ -81,19 +92,26 @@ def save(self, filename): except BaseException: logger.warning("Saving failed... continuing anyway.") - def load(self, filename): + def load(self, filename, args, foundation_cache): try: checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True) except BaseException: logger.error("Cannot load model from {}".format(filename)) raise self.args = checkpoint['config'] + if args is not None and args.get('charlm_forward_file', None) is not None: + if checkpoint['config'].get('charlm_forward_file') is None: + # if the saved model didn't use a charlm, we skip the charlm here + # otherwise the loaded model weights won't fit in the newly created model + self.args['charlm_forward_file'] = None + else: + self.args['charlm_forward_file'] = args['charlm_forward_file'] if self.args.get('use_mwt', None) is None: # Default to True as many currently saved models # were built with mwt layers self.args['use_mwt'] = True self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) - self.model.load_state_dict(checkpoint['model']) + self.model.load_state_dict(checkpoint['model'], strict=False) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 089813089..223bfd462 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -88,6 +88,10 @@ def build_argparse(): utils.add_device_args(parser) parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--charlm', action='store_true', help="Turn on contextualized char embedding using pretrained character-level language model.") + parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") + parser.add_argument('--charlm_forward_file', type=str, default=None, help="Exact path to use for forward charlm") + parser.add_argument('--use_mwt', dest='use_mwt', default=None, action='store_true', help='Whether or not to include mwt output layers. If set to None, this will be determined by examining the training data for MWTs') parser.add_argument('--no_use_mwt', dest='use_mwt', action='store_false', help='Whether or not to include mwt output layers') @@ -166,7 +170,7 @@ def train(args): args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) - trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device']) + trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], foundation_cache=None) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) @@ -236,7 +240,7 @@ def train(args): def evaluate(args): mwt_dict = load_mwt_dict(args['mwt_json_file']) - trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device']) + trainer = Trainer(args=args, model_file=args['load_name'] or args['save_name'], device=args['device'], foundation_cache=None) loaded_args, vocab = trainer.args, trainer.vocab for k in loaded_args: diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py index 6827ecb73..5e6341387 100644 --- a/stanza/pipeline/tokenize_processor.py +++ b/stanza/pipeline/tokenize_processor.py @@ -42,7 +42,8 @@ def _set_up_model(self, config, pipeline, device): if config.get('pretokenized'): self._trainer = None else: - self._trainer = Trainer(model_file=config['model_path'], device=device) + args = {'charlm_forward_file': config.get('forward_charlm_path', None)} + self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache) # get and typecheck the postprocessor postprocessor = config.get('postprocessor') diff --git a/stanza/resources/default_packages.py b/stanza/resources/default_packages.py index 284a6373d..21bbc453e 100644 --- a/stanza/resources/default_packages.py +++ b/stanza/resources/default_packages.py @@ -333,6 +333,8 @@ def build_default_pretrains(default_treebanks): lemma_charlms = copy.deepcopy(pos_charlms) +tokenizer_charlms = copy.deepcopy(pos_charlms) + ner_charlms = { "en": { "conll03": "1billion", diff --git a/stanza/utils/training/common.py b/stanza/utils/training/common.py index 0e33f00c0..81bb5b7dc 100644 --- a/stanza/utils/training/common.py +++ b/stanza/utils/training/common.py @@ -9,7 +9,7 @@ from enum import Enum -from stanza.resources.default_packages import default_charlms, lemma_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS +from stanza.resources.default_packages import default_charlms, lemma_charlms, tokenizer_charlms, pos_charlms, depparse_charlms, TRANSFORMERS, TRANSFORMER_LAYERS from stanza.models.common.constant import treebank_to_short_name from stanza.models.common.utils import ud_scores from stanza.resources.common import download, DEFAULT_MODEL_DIR, UnknownLanguageError @@ -303,14 +303,15 @@ def find_charlm_file(direction, language, charlm, model_dir=DEFAULT_MODEL_DIR): raise FileNotFoundError(f"Cannot find {direction} charlm in either {saved_path} or {resource_path} Attempted downloading {charlm} but that did not work") -def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): +def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR, use_backward_model=True): """ If specified, return forward and backward charlm args """ if charlm: try: forward = find_charlm_file('forward', language, charlm, model_dir=model_dir) - backward = find_charlm_file('backward', language, charlm, model_dir=model_dir) + if use_backward_model: + backward = find_charlm_file('backward', language, charlm, model_dir=model_dir) except FileNotFoundError as e: # if we couldn't find sd_isra when training an SD model, # for example, but isra exists, we try to download the @@ -319,15 +320,17 @@ def build_charlm_args(language, charlm, base_args=True, model_dir=DEFAULT_MODEL_ short_charlm = charlm[len(language)+1:] try: forward = find_charlm_file('forward', language, short_charlm, model_dir=model_dir) - backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir) + if use_backward_model: + backward = find_charlm_file('backward', language, short_charlm, model_dir=model_dir) except FileNotFoundError as e2: raise FileNotFoundError("Tried to find charlm %s, which doesn't exist. Also tried %s, but didn't find that either" % (charlm, short_charlm)) from e logger.warning("Was asked to find charlm %s, which does not exist. Did find %s though", charlm, short_charlm) else: raise - char_args = ['--charlm_forward_file', forward, - '--charlm_backward_file', backward] + char_args = ['--charlm_forward_file', forward] + if use_backward_model: + char_args += ['--charlm_backward_file', backward] if not base_args: return char_args return ['--charlm', @@ -377,6 +380,13 @@ def choose_lemma_charlm(short_language, dataset, charlm): """ return choose_charlm(short_language, dataset, charlm, default_charlms, lemma_charlms) +def choose_tokenizer_charlm(short_language, dataset, charlm): + """ + charlm == "default" means the default charlm for this dataset or language + charlm == None is no charlm + """ + return choose_charlm(short_language, dataset, charlm, default_charlms, tokenizer_charlms) + def choose_transformer(short_language, command_args, extra_args, warn=True, layers=False): """ Choose a transformer using the default options for this language @@ -406,3 +416,8 @@ def build_depparse_charlm_args(short_language, dataset, charlm, base_args=True, charlm = choose_depparse_charlm(short_language, dataset, charlm) charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir) return charlm_args + +def build_tokenizer_charlm_args(short_language, dataset, charlm, base_args=True, model_dir=DEFAULT_MODEL_DIR): + charlm = choose_tokenizer_charlm(short_language, dataset, charlm) + charlm_args = build_charlm_args(short_language, charlm, base_args, model_dir, use_backward_model=False) + return charlm_args diff --git a/stanza/utils/training/run_tokenizer.py b/stanza/utils/training/run_tokenizer.py index 7c4245f61..c48489f57 100644 --- a/stanza/utils/training/run_tokenizer.py +++ b/stanza/utils/training/run_tokenizer.py @@ -25,10 +25,13 @@ from stanza.models import tokenizer from stanza.utils.avg_sent_len import avg_sent_len from stanza.utils.training import common -from stanza.utils.training.common import Mode +from stanza.utils.training.common import Mode, add_charlm_args, build_tokenizer_charlm_args logger = logging.getLogger('stanza') +def add_tokenizer_args(parser): + add_charlm_args(parser) + def uses_dictionary(short_language): """ Some of the languages (as shown here) have external dictionaries @@ -44,7 +47,7 @@ def run_treebank(mode, paths, treebank, short_name, temp_output_file, command_args, extra_args): tokenize_dir = paths["TOKENIZE_DATA_DIR"] - short_language = short_name.split("_")[0] + short_language, dataset = short_name.split("_", 1) label_type = "--label_file" label_file = f"{tokenize_dir}/{short_name}-ud-train.toklabels" dev_type = "--txt_file" @@ -70,6 +73,8 @@ def run_treebank(mode, paths, treebank, short_name, dev_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.dev.pred.conllu" test_pred = temp_output_file if temp_output_file else f"{tokenize_dir}/{short_name}.test.pred.conllu" + charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm) + if mode == Mode.TRAIN: seqlen = str(math.ceil(avg_sent_len(label_file) * 3 / 100) * 100) train_args = ([label_type, label_file, train_type, train_file, "--lang", short_language, @@ -118,7 +123,7 @@ def run_treebank(mode, paths, treebank, short_name, def main(): - common.main(run_treebank, "tokenize", "tokenizer", sub_argparse=tokenizer.build_argparse()) + common.main(run_treebank, "tokenize", "tokenizer", add_tokenizer_args, sub_argparse=tokenizer.build_argparse()) if __name__ == "__main__": main() From 638757aff1093407e3b174ac1047b9e666ef8bad Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 11 Sep 2023 20:54:56 -0700 Subject: [PATCH 3/8] Fancy save name for the tokenizer --- stanza/models/tokenizer.py | 12 +++++----- stanza/utils/training/run_tokenizer.py | 31 +++++++++++++++++++++----- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 223bfd462..d1dec7008 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -82,7 +82,7 @@ def build_argparse(): parser.add_argument('--shuffle_steps', type=int, default=100, help="Step interval to shuffle each paragraph in the generator") parser.add_argument('--eval_steps', type=int, default=200, help="Step interval to evaluate the model on the dev set for early stopping") parser.add_argument('--max_steps_before_stop', type=int, default=5000, help='Early terminates after this many steps if the dev scores are not improving') - parser.add_argument('--save_name', type=str, default=None, help="File name to save the model") + parser.add_argument('--save_name', type=str, default="{shorthand}_{embedding}_tokenizer.pt", help="File name to save the model") parser.add_argument('--load_name', type=str, default=None, help="File name to load a saved model") parser.add_argument('--save_dir', type=str, default='saved_models/tokenize', help="Directory to save models in") utils.add_device_args(parser) @@ -110,11 +110,13 @@ def parse_args(args=None): return args def model_file_name(args): - if args['save_name'] is not None: - save_name = args['save_name'] - else: - save_name = args['shorthand'] + "_tokenizer.pt" + embedding = "nocharlm" + if args['charlm'] and args['charlm_forward_file']: + embedding = "charlm" + save_name = args['save_name'].format(shorthand=args['shorthand'], + embedding=embedding) + logger.info("Saving to: %s", save_name) if not os.path.exists(os.path.join(args['save_dir'], save_name)) and os.path.exists(save_name): return save_name return os.path.join(args['save_dir'], save_name) diff --git a/stanza/utils/training/run_tokenizer.py b/stanza/utils/training/run_tokenizer.py index c48489f57..3197dd5b2 100644 --- a/stanza/utils/training/run_tokenizer.py +++ b/stanza/utils/training/run_tokenizer.py @@ -32,6 +32,27 @@ def add_tokenizer_args(parser): add_charlm_args(parser) + +def build_model_filename(paths, short_name, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + # TODO: can avoid downloading the charlm at this point, since we + # might not even be training + charlm_args = build_tokenizer_charlm_args(short_language, dataset, command_args.charlm) + + train_args = ["--shorthand", short_name, + "--mode", "train"] + train_args = train_args + charlm_args + extra_args + if command_args.save_name is not None: + train_args.extend(["--save_name", command_args.save_name]) + if command_args.save_dir is not None: + train_args.extend(["--save_dir", command_args.save_dir]) + args = tokenizer.parse_args(train_args) + save_name = tokenizer.model_file_name(args) + return save_name + + + def uses_dictionary(short_language): """ Some of the languages (as shown here) have external dictionaries @@ -83,14 +104,14 @@ def run_treebank(mode, paths, treebank, short_name, ["--dev_conll_gold", dev_gold, "--conll_file", dev_pred, "--shorthand", short_name]) if uses_dictionary(short_language): train_args = train_args + ["--use_dictionary"] - train_args = train_args + extra_args + train_args = train_args + charlm_args + extra_args logger.info("Running train step with args: {}".format(train_args)) tokenizer.main(train_args) if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: dev_args = ["--mode", "predict", dev_type, dev_file, "--lang", short_language, "--conll_file", dev_pred, "--shorthand", short_name, "--mwt_json_file", dev_mwt] - dev_args = dev_args + extra_args + dev_args = dev_args + charlm_args + extra_args logger.info("Running dev step with args: {}".format(dev_args)) tokenizer.main(dev_args) @@ -103,7 +124,7 @@ def run_treebank(mode, paths, treebank, short_name, if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: test_args = ["--mode", "predict", test_type, test_file, "--lang", short_language, "--conll_file", test_pred, "--shorthand", short_name, "--mwt_json_file", test_mwt] - test_args = test_args + extra_args + test_args = test_args + charlm_args + extra_args logger.info("Running test step with args: {}".format(test_args)) tokenizer.main(test_args) @@ -123,7 +144,7 @@ def run_treebank(mode, paths, treebank, short_name, def main(): - common.main(run_treebank, "tokenize", "tokenizer", add_tokenizer_args, sub_argparse=tokenizer.build_argparse()) - + common.main(run_treebank, "tokenize", "tokenizer", add_tokenizer_args, sub_argparse=tokenizer.build_argparse(), build_model_filename=build_model_filename) + if __name__ == "__main__": main() From d4bbebfd48297009c73331dfead5193652f73230 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 11 Sep 2025 13:26:12 -0700 Subject: [PATCH 4/8] Connect the tokenizer for default_accurate to a charlm version when rebuilding resources.json, if such a thing exists --- stanza/models/tokenization/trainer.py | 1 - stanza/resources/prepare_resources.py | 55 +++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 81a34b660..a88c9ada4 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -16,7 +16,6 @@ class Trainer(BaseTrainer): def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None): # TODO: make a test of the training w/ and w/o charlm - # TODO: build the resources with the forward charlm if model_file is not None: # load everything from file self.load(model_file, args, foundation_cache) diff --git a/stanza/resources/prepare_resources.py b/stanza/resources/prepare_resources.py index a8225e025..7d396150f 100644 --- a/stanza/resources/prepare_resources.py +++ b/stanza/resources/prepare_resources.py @@ -102,7 +102,7 @@ def split_model_name(model): lang, package = model.split('_', 1) return lang, package, processor -def split_package(package): +def split_package(package, default_use_charlm=True): if package.endswith("_finetuned"): package = package[:-10] @@ -124,7 +124,7 @@ def split_package(package): # guess it was a model which wasn't built with the new naming convention of putting the pretrain type at the end # assume WV and charlm... if the language / package doesn't allow for one, that should be caught later - return package, True, True + return package, True, default_use_charlm def get_pretrain_package(lang, package, model_pretrains, default_pretrains): package, uses_pretrain, _ = split_package(package) @@ -138,8 +138,8 @@ def get_pretrain_package(lang, package, model_pretrains, default_pretrains): raise RuntimeError("pretrain not specified for lang %s package %s" % (lang, package)) -def get_charlm_package(lang, package, model_charlms, default_charlms): - package, _, uses_charlm = split_package(package) +def get_charlm_package(lang, package, model_charlms, default_charlms, default_use_charlm=True): + package, _, uses_charlm = split_package(package, default_use_charlm) if not uses_charlm: return None @@ -210,6 +210,16 @@ def get_lemma_dependencies(lang, package): return dependencies +def get_tokenizer_charlm_package(lang, package): + return get_charlm_package(lang, package, tokenizer_charlms, default_charlms, default_use_charlm=False) + +def get_tokenizer_dependencies(lang, package): + dependencies = [] + charlm_package = get_tokenizer_charlm_package(lang, package) + if charlm_package is not None: + dependencies.append({'model': 'forward_charlm', 'package': charlm_package}) + return dependencies + def get_depparse_charlm_package(lang, package): return get_charlm_package(lang, package, depparse_charlms, default_charlms) @@ -285,6 +295,8 @@ def get_dependencies(processor, lang, package): return get_sentiment_dependencies(lang, package) elif processor == 'constituency': return get_con_dependencies(lang, package) + elif processor == 'tokenize': + return get_tokenizer_dependencies(lang, package) return {} def process_dirs(args): @@ -411,9 +423,14 @@ def get_default_processors(resources, lang): if lang in default_tokenizer: default_processors['tokenize'] = default_tokenizer[lang] else: - default_processors['tokenize'] = default_package - - if 'mwt' in resources[lang] and default_processors['tokenize'] in resources[lang]['mwt']: + tokenize_package = default_package + if tokenize_package not in resources[lang]['tokenize']: + tokenize_package = tokenize_package + "_nocharlm" + if tokenize_package not in resources[lang]['tokenize']: + raise AssertionError("Can't find a tokenizer package for %s! Tried %s and %s" % (lang, default_package, tokenize_package)) + default_processors['tokenize'] = tokenize_package + + if 'mwt' in resources[lang] and default_package in resources[lang]['mwt']: # if this doesn't happen, we just skip MWT default_processors['mwt'] = default_package @@ -484,6 +501,15 @@ def get_default_accurate(resources, lang): """ default_processors = get_default_processors(resources, lang) + tokenizer_model = default_processors['tokenize'] + if tokenizer_model.endswith('_nocharlm'): + tokenizer_model = tokenizer_model.replace('_nocharlm', '_charlm') + elif 'charlm' not in tokenizer_model: + tokenizer_model = tokenizer_model + '_charlm' + if tokenizer_model.endswith('_charlm') and tokenizer_model in resources[lang]['tokenize']: + default_processors['tokenize'] = tokenizer_model + print("TOKENIZE found a charlm version %s for %s default_accurate" % (tokenizer_model, lang)) + if 'lemma' in default_processors and default_processors['lemma'] != 'identity': lemma_model = default_processors['lemma'] lemma_model = lemma_model.replace('_nocharlm', '_charlm') @@ -594,7 +620,20 @@ def process_packages(args): # We then create a package in the packages dict for each of those treebanks if 'tokenize' in resources[lang]: for package in resources[lang]['tokenize']: - processors = {"tokenize": package} + package, _, _ = split_package(package) + if package in resources[lang][PACKAGES]: + # can happen in the case of a _nocharlm and _charlm version of the tokenizer + continue + + processors = {} + # TODO: when we rebuild all the models, make all the tokenizers say _nocharlm + if package in resources[lang]['tokenize']: + processors["tokenize"] = package + elif package + "_nocharlm" in resources[lang]['tokenize']: + processors["tokenize"] = package + "_nocharlm" + else: + raise AssertionError("Should have found a tokenizer for lang %s package %s" % (lang, package)) + if "mwt" in resources[lang] and package in resources[lang]["mwt"]: processors["mwt"] = package From 0d4193cbf709de90084ebaf68c5761656dfaaf94 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 15 Sep 2025 17:31:14 -0700 Subject: [PATCH 5/8] When converting the HE coref treebank, skip any tokens which don't have the necessary endpoints --- .../datasets/coref/convert_hebrew_iahlt.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/stanza/utils/datasets/coref/convert_hebrew_iahlt.py b/stanza/utils/datasets/coref/convert_hebrew_iahlt.py index 4ab8db81e..f0a8777b5 100644 --- a/stanza/utils/datasets/coref/convert_hebrew_iahlt.py +++ b/stanza/utils/datasets/coref/convert_hebrew_iahlt.py @@ -37,11 +37,15 @@ # TODO: binary search for speed? def search_mention_start(doc, mention_start): for sent_idx, sentence in enumerate(doc.sentences): - if mention_start < doc.sentences[sent_idx].words[-1].end_char: + if mention_start < doc.sentences[sent_idx].tokens[-1].end_char: break else: raise ValueError for word_idx, word in enumerate(sentence.words): + if word.end_char is None: + print("Found weirdness on sentence:\n|%s|" % sentence.text) + print(word.parent) + return None, None if mention_start < word.end_char: break else: @@ -50,7 +54,7 @@ def search_mention_start(doc, mention_start): def search_mention_end(doc, mention_end): for sent_idx, sentence in enumerate(doc.sentences): - if sent_idx + 1 == len(doc.sentences) or mention_end < doc.sentences[sent_idx+1].words[0].start_char: + if sent_idx + 1 == len(doc.sentences) or mention_end < doc.sentences[sent_idx+1].tokens[0].start_char: break for word_idx, word in enumerate(sentence.words): if word_idx + 1 == len(sentence.words) or mention_end < sentence.words[word_idx+1].start_char: @@ -60,6 +64,7 @@ def search_mention_end(doc, mention_end): def extract_doc(tokenizer, lines): # 16, 1, 5 for the train, dev, test sets broken = 0 + tok_error = 0 singletons = 0 one_words = 0 processed_docs = [] @@ -75,30 +80,34 @@ def extract_doc(tokenizer, lines): mention_start = mention[0] mention_end = mention[1] start_sent, start_word = search_mention_start(doc, mention_start) + if start_sent is None or start_word is None: + tok_error += 1 + continue end_sent, end_word = search_mention_end(doc, mention_end) assert end_sent >= start_sent if start_sent != end_sent: broken += 1 - else: - assert end_word >= start_word - if end_word == start_word: - one_words += 1 - found_mentions.append((start_sent, start_word, end_word)) - - #if cluster_idx == 0 and line_idx == 0: - # expanded_start = max(0, mention_start - 10) - # expanded_end = min(len(text), mention_end + 10) - # print("EXTRACTING MENTION: %d %d" % (mention[0], mention[1])) - # print(" context: |%s|" % text[expanded_start:expanded_end]) - # print(" mention[0]:mention[1]: |%s|" % text[mention[0]:mention[1]]) - # print(" search text: |%s|" % text[mention_start:mention_end]) - # extracted_words = doc.sentences[start_sent].words[start_word:end_word+1] - # extracted_text = " ".join([x.text for x in extracted_words]) - # print(" extracted words: |%s|" % extracted_text) - # print(" endpoints: %d %d" % (mention_start, mention_end)) - # print(" number of extracted words: %d" % len(extracted_words)) - # print(" first word endpoints: %d %d" % (extracted_words[0].start_char, extracted_words[0].end_char)) - # print(" last word endpoints: %d %d" % (extracted_words[-1].start_char, extracted_words[-1].end_char)) + continue + + assert end_word >= start_word + if end_word == start_word: + one_words += 1 + found_mentions.append((start_sent, start_word, end_word)) + + #if cluster_idx == 0 and line_idx == 0: + # expanded_start = max(0, mention_start - 10) + # expanded_end = min(len(text), mention_end + 10) + # print("EXTRACTING MENTION: %d %d" % (mention[0], mention[1])) + # print(" context: |%s|" % text[expanded_start:expanded_end]) + # print(" mention[0]:mention[1]: |%s|" % text[mention[0]:mention[1]]) + # print(" search text: |%s|" % text[mention_start:mention_end]) + # extracted_words = doc.sentences[start_sent].words[start_word:end_word+1] + # extracted_text = " ".join([x.text for x in extracted_words]) + # print(" extracted words: |%s|" % extracted_text) + # print(" endpoints: %d %d" % (mention_start, mention_end)) + # print(" number of extracted words: %d" % len(extracted_words)) + # print(" first word endpoints: %d %d" % (extracted_words[0].start_char, extracted_words[0].end_char)) + # print(" last word endpoints: %d %d" % (extracted_words[-1].start_char, extracted_words[-1].end_char)) if len(found_mentions) == 0: continue elif len(found_mentions) == 1: @@ -118,7 +127,7 @@ def extract_doc(tokenizer, lines): for sent_idx, start_word, end_word in all_clusters[cluster_idx]: coref_spans[sent_idx].append((cluster_idx, start_word, end_word)) processed_docs.append(CorefDoc(doc_id, sentences, coref_spans)) - print("Found %d broken across two sentences, %d singleton mentions, %d one_word mentions" % (broken, singletons, one_words)) + print("Found %d broken across two sentences, %d tok errors, %d singleton mentions, %d one_word mentions" % (broken, tok_error, singletons, one_words)) return processed_docs def read_doc(tokenizer, filename): From 6ff29fa95fc74b676fc22c6020234f0cdb395c42 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 15 Sep 2025 20:06:42 -0700 Subject: [PATCH 6/8] Apparently spaces can sometimes happen in MWT, depending on the training of the tokenizer... need to account for that when breaking down MWT --- stanza/models/common/doc.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index 22c3fa7bf..9e0a0746a 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -400,13 +400,16 @@ def set_mwt_expansions(self, expansions, word.sent = sentence word.parent = token sentence.words.append(word) - if token.start_char is not None and token.end_char is not None and "".join(word.text for word in token.words) == token.text: - start_char = token.start_char - for word in token.words: - end_char = start_char + len(word.text) - word.start_char = start_char - word.end_char = end_char - start_char = end_char + if len(token.words) == 1: + word.start_char = token.start_char + word.end_char = token.end_char + elif token.start_char is not None and token.end_char is not None: + search_string = "^%s$" % ("\\s*".join("(%s)" % re.escape(word.text) for word in token.words)) + match = re.compile(search_string).match(token.text) + if match: + for word_idx, word in enumerate(token.words): + word.start_char = match.start(word_idx+1) + token.start_char + word.end_char = match.end(word_idx+1) + token.start_char if fake_dependencies: sentence.build_fake_dependencies() From 81ce2e11d93637efd24d8986e15354d4a133ed1b Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 17 Sep 2025 14:57:15 -0700 Subject: [PATCH 7/8] Add some doc to the load_mwt_dict function --- stanza/models/tokenization/utils.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py index 548dd6087..f3767c975 100644 --- a/stanza/models/tokenization/utils.py +++ b/stanza/models/tokenization/utils.py @@ -143,20 +143,25 @@ def load_lexicon(args): def load_mwt_dict(filename): - if filename is not None: - with open(filename, 'r') as f: - mwt_dict0 = json.load(f) + """ + Returns a dict from an MWT to its most common expansion and count. - mwt_dict = dict() - for item in mwt_dict0: - (key, expansion), count = item + Other less common expansions are discarded. + """ + if filename is None: + return None - if key not in mwt_dict or mwt_dict[key][1] < count: - mwt_dict[key] = (expansion, count) + with open(filename, 'r') as f: + mwt_dict0 = json.load(f) - return mwt_dict - else: - return + mwt_dict = dict() + for item in mwt_dict0: + (key, expansion), count = item + + if key not in mwt_dict or mwt_dict[key][1] < count: + mwt_dict[key] = (expansion, count) + + return mwt_dict def process_sentence(sentence, mwt_dict=None): sent = [] From 52cea783431c85af68227c0f00dc4022a36ea7f4 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 17 Sep 2025 21:22:27 -0700 Subject: [PATCH 8/8] Add an augmentation to the tokenizer which splits MWTs while training and treats them as separate words, teaching the tokenizer to not treat 'can not' as a single token with the space This wound up being a problem in Hebrew when using the new charlm-attached tokenizer, although on further reflection there's no reason it couldn't happen in any language or with non-pretrained-charlm tokenizers as well --- stanza/models/tokenization/data.py | 83 +++++++++++++++++++++++++++++- stanza/models/tokenizer.py | 10 ++-- 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/stanza/models/tokenization/data.py b/stanza/models/tokenization/data.py index 210890be0..2d2c8d596 100644 --- a/stanza/models/tokenization/data.py +++ b/stanza/models/tokenization/data.py @@ -236,11 +236,35 @@ def build_move_punct_set(data, move_back_prob): continue return move_punct +def build_known_mwt(data, mwt_expansions): + known_mwts = set() + for chunk in data: + for idx, unit in enumerate(chunk): + if unit[1] != 3: + continue + # found an MWT + prev_idx = idx - 1 + while prev_idx >= 0 and chunk[prev_idx][1] == 0: + prev_idx -= 1 + prev_idx += 1 + while chunk[prev_idx][0].isspace(): + prev_idx += 1 + if prev_idx == idx: + continue + mwt = "".join(x[0] for x in chunk[prev_idx:idx+1]) + if mwt not in mwt_expansions: + continue + if len(mwt_expansions[mwt]) > 2: + # TODO: could split 3 word tokens as well + continue + known_mwts.add(mwt) + return known_mwts + class DataLoader(TokenizationDataset): """ This is the training version of the dataset. """ - def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None): + def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None): super().__init__(args, input_files, input_text, vocab, evaluation, dictionary) self.vocab = vocab if vocab is not None else self.init_vocab() @@ -262,6 +286,15 @@ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=No else: logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace') + split_mwt_prob = args.get('split_mwt_prob', 0.0) + if split_mwt_prob > 0.0 and not evaluation: + self.mwt_expansions = mwt_expansions + self.known_mwt = build_known_mwt(self.data, mwt_expansions) + if len(self.known_mwt) > 0: + logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt)) + else: + logger.debug('Based on the training data, there are NO MWT to split at training time') + def __len__(self): return len(self.sentence_ids) @@ -300,6 +333,45 @@ def move_last_char(self, sentence): return encoded return None + def split_mwt(self, sentence): + if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: + return None + + # if we find a token in the sentence which ends with label 3, + # eg it is an MWT, + # with some probability we split it into two tokens + # and treat the split tokens as both label 1 instead of 3 + # in this manner, we teach the tokenizer not to treat the + # entire sequence of characters with added spaces as an MWT, + # which weirdly can happen in some corner cases + + mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3] + if len(mwt_ends) == 0: + return None + random_end = random.randint(0, len(mwt_ends)-1) + mwt_end = mwt_ends[random_end] + mwt_start = mwt_end - 1 + while mwt_start >= 0 and sentence[1][mwt_start] == 0: + mwt_start -= 1 + mwt_start += 1 + while sentence[3][mwt_start].isspace(): + mwt_start += 1 + if mwt_start == mwt_end: + return None + mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1]) + if mwt not in self.mwt_expansions: + return None + + all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])] + w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]] + w0_units[-1] = (w0_units[-1][0], 1) + w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]] + w1_units[-1] = (w1_units[-1][0], 1) + split_units = w0_units + [(' ', 0)] + w1_units + new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:] + encoded = self.para_to_sentences(new_units) + return encoded + def move_punct_back(self, sentence): if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']: return None @@ -342,6 +414,7 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0)) move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0) move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0) + split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0) pid, sid = id_pair if self.eval else random.choice(self.sentence_ids) sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])] @@ -386,6 +459,14 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) sentences[sentence_idx] = new_sentence[0] + if split_mwt_prob > 0.0: + for sentence_idx, sentence in enumerate(sentences): + if random.random() < split_mwt_prob: + new_sentence = self.split_mwt(sentence) + if new_sentence is not None: + total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3]) + sentences[sentence_idx] = new_sentence[0] + if drop_sents and len(sentences) > 1: if total_len > self.args['max_seqlen']: sentences = sentences[:-1] diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index d1dec7008..73fcdd6af 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -73,6 +73,7 @@ def build_argparse(): parser.add_argument('--last_char_drop_prob', type=float, default=0.2, help="Probability to drop the last char of a block of text during training, uniformly at random. Idea is to fake a document ending w/o sentence final punctuation, hopefully to avoid the tokenizer learning to always tokenize the last character as a period") parser.add_argument('--last_char_move_prob', type=float, default=0.02, help="Probability to move the sentence final punctuation of a sentence during training, uniformly at random. Idea is to teach the tokenizer that a space separated sentence final punct still ends the sentence") parser.add_argument('--punct_move_back_prob', type=float, default=0.02, help="Probability to move a comma in the sentence one over, removing the previous space, during training. Idea is to teach the tokenizer that commas can appear next to words even in languages where the dataset doesn't allow it, such as Vietnamese") + parser.add_argument('--split_mwt_prob', type=float, default=0.01, help="Probably to split an MWT into its component pieces and turn it into separate words") parser.add_argument('--weight_decay', type=float, default=0.0, help="Weight decay") parser.add_argument('--max_seqlen', type=int, default=100, help="Maximum sequence length to consider at a time") parser.add_argument('--batch_size', type=int, default=32, help="Batch size to use") @@ -152,12 +153,13 @@ def train(args): dictionary=None mwt_dict = load_mwt_dict(args['mwt_json_file']) + mwt_expansions = {x: y[0] for x, y in mwt_dict.items()} train_input_files = { - 'txt': args['txt_file'], - 'label': args['label_file'] - } - train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary) + 'txt': args['txt_file'], + 'label': args['label_file'] + } + train_batches = DataLoader(args, input_files=train_input_files, dictionary=dictionary, mwt_expansions=mwt_expansions) vocab = train_batches.vocab args['vocab_size'] = len(vocab)