Skip to content
4 changes: 2 additions & 2 deletions stanza/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Single source of truth for version number """

__version__ = "1.10.1"
__resources_version__ = '1.10.0'
__version__ = "1.11.0"
__resources_version__ = '1.11.0'
5 changes: 3 additions & 2 deletions stanza/models/common/char_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ def __init__(self, charlms):
super().__init__()
self.charlms = charlms

def forward(self, words):
words = [CHARLM_START + x + CHARLM_END for x in words]
def forward(self, words, wrap=True):
if wrap:
words = [CHARLM_START + x + CHARLM_END for x in words]
padded_reps = []
for charlm in self.charlms:
rep = charlm.per_char_representation(words)
Expand Down
17 changes: 10 additions & 7 deletions stanza/models/common/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,16 @@ def set_mwt_expansions(self, expansions,
word.sent = sentence
word.parent = token
sentence.words.append(word)
if token.start_char is not None and token.end_char is not None and "".join(word.text for word in token.words) == token.text:
start_char = token.start_char
for word in token.words:
end_char = start_char + len(word.text)
word.start_char = start_char
word.end_char = end_char
start_char = end_char
if len(token.words) == 1:
word.start_char = token.start_char
word.end_char = token.end_char
elif token.start_char is not None and token.end_char is not None:
search_string = "^%s$" % ("\\s*".join("(%s)" % re.escape(word.text) for word in token.words))
match = re.compile(search_string).match(token.text)
if match:
for word_idx, word in enumerate(token.words):
word.start_char = match.start(word_idx+1) + token.start_char
word.end_char = match.end(word_idx+1) + token.start_char

if fake_dependencies:
sentence.build_fake_dependencies()
Expand Down
83 changes: 82 additions & 1 deletion stanza/models/tokenization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,35 @@ def build_move_punct_set(data, move_back_prob):
continue
return move_punct

def build_known_mwt(data, mwt_expansions):
known_mwts = set()
for chunk in data:
for idx, unit in enumerate(chunk):
if unit[1] != 3:
continue
# found an MWT
prev_idx = idx - 1
while prev_idx >= 0 and chunk[prev_idx][1] == 0:
prev_idx -= 1
prev_idx += 1
while chunk[prev_idx][0].isspace():
prev_idx += 1
if prev_idx == idx:
continue
mwt = "".join(x[0] for x in chunk[prev_idx:idx+1])
if mwt not in mwt_expansions:
continue
if len(mwt_expansions[mwt]) > 2:
# TODO: could split 3 word tokens as well
continue
known_mwts.add(mwt)
return known_mwts

class DataLoader(TokenizationDataset):
"""
This is the training version of the dataset.
"""
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None):
def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=None, vocab=None, evaluation=False, dictionary=None, mwt_expansions=None):
super().__init__(args, input_files, input_text, vocab, evaluation, dictionary)

self.vocab = vocab if vocab is not None else self.init_vocab()
Expand All @@ -262,6 +286,15 @@ def __init__(self, args, input_files={'txt': None, 'label': None}, input_text=No
else:
logger.debug('Based on the training data, no punct are eligible to be rearranged with extra whitespace')

split_mwt_prob = args.get('split_mwt_prob', 0.0)
if split_mwt_prob > 0.0 and not evaluation:
self.mwt_expansions = mwt_expansions
self.known_mwt = build_known_mwt(self.data, mwt_expansions)
if len(self.known_mwt) > 0:
logger.debug('Based on the training data, there are %d MWT which might be split at training time', len(self.known_mwt))
else:
logger.debug('Based on the training data, there are NO MWT to split at training time')

def __len__(self):
return len(self.sentence_ids)

Expand Down Expand Up @@ -300,6 +333,45 @@ def move_last_char(self, sentence):
return encoded
return None

def split_mwt(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None

# if we find a token in the sentence which ends with label 3,
# eg it is an MWT,
# with some probability we split it into two tokens
# and treat the split tokens as both label 1 instead of 3
# in this manner, we teach the tokenizer not to treat the
# entire sequence of characters with added spaces as an MWT,
# which weirdly can happen in some corner cases

mwt_ends = [idx for idx, label in enumerate(sentence[1]) if label == 3]
if len(mwt_ends) == 0:
return None
random_end = random.randint(0, len(mwt_ends)-1)
mwt_end = mwt_ends[random_end]
mwt_start = mwt_end - 1
while mwt_start >= 0 and sentence[1][mwt_start] == 0:
mwt_start -= 1
mwt_start += 1
while sentence[3][mwt_start].isspace():
mwt_start += 1
if mwt_start == mwt_end:
return None
mwt = "".join(x for x in sentence[3][mwt_start:mwt_end+1])
if mwt not in self.mwt_expansions:
return None

all_units = [(x, int(y)) for x, y in zip(sentence[3], sentence[1])]
w0_units = [(x, 0) for x in self.mwt_expansions[mwt][0]]
w0_units[-1] = (w0_units[-1][0], 1)
w1_units = [(x, 0) for x in self.mwt_expansions[mwt][1]]
w1_units[-1] = (w1_units[-1][0], 1)
split_units = w0_units + [(' ', 0)] + w1_units
new_units = all_units[:mwt_start] + split_units + all_units[mwt_end+1:]
encoded = self.para_to_sentences(new_units)
return encoded

def move_punct_back(self, sentence):
if len(sentence[3]) <= 1 or len(sentence[3]) >= self.args['max_seqlen']:
return None
Expand Down Expand Up @@ -342,6 +414,7 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
drop_last_char = False if self.eval or (self.args.get('last_char_drop_prob', 0) == 0) else (random.random() < self.args.get('last_char_drop_prob', 0))
move_last_char_prob = 0.0 if self.eval else self.args.get('last_char_move_prob', 0.0)
move_punct_back_prob = 0.0 if self.eval else self.args.get('punct_move_back_prob', 0.0)
split_mwt_prob = 0.0 if self.eval else self.args.get('split_mwt_prob', 0.0)

pid, sid = id_pair if self.eval else random.choice(self.sentence_ids)
sentences = [copy([x[offset:] for x in self.sentences[pid][sid]])]
Expand Down Expand Up @@ -386,6 +459,14 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]

if split_mwt_prob > 0.0:
for sentence_idx, sentence in enumerate(sentences):
if random.random() < split_mwt_prob:
new_sentence = self.split_mwt(sentence)
if new_sentence is not None:
total_len = total_len + len(new_sentence[0][3]) - len(sentences[sentence_idx][3])
sentences[sentence_idx] = new_sentence[0]

if drop_sents and len(sentences) > 1:
if total_len > self.args['max_seqlen']:
sentences = sentences[:-1]
Expand Down
32 changes: 28 additions & 4 deletions stanza/models/tokenization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,38 @@
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence

from stanza.models.common.char_model import CharacterLanguageModelWordAdapter
from stanza.models.common.foundation_cache import load_charlm

class Tokenizer(nn.Module):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, foundation_cache=None):
super().__init__()

self.unsaved_modules = []

self.args = args
feat_dim = args['feat_dim']

self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)

self.rnn = nn.LSTM(emb_dim + feat_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)
self.input_dim = emb_dim + feat_dim

charmodel = None
if args is not None and args.get('charlm_forward_file', None):
charmodel_forward = load_charlm(args['charlm_forward_file'], foundation_cache=foundation_cache)
charmodels = nn.ModuleList([charmodel_forward])
charmodel = CharacterLanguageModelWordAdapter(charmodels)
self.input_dim += charmodel.hidden_dim()
self.add_unsaved_module("charmodel", charmodel)

self.rnn = nn.LSTM(self.input_dim, hidden_dim, num_layers=self.args['rnn_layers'], bidirectional=True, batch_first=True, dropout=dropout if self.args['rnn_layers'] > 1 else 0)

if self.args['conv_res'] is not None:
self.conv_res = nn.ModuleList()
self.conv_sizes = [int(x) for x in self.args['conv_res'].split(',')]

for si, size in enumerate(self.conv_sizes):
l = nn.Conv1d(emb_dim + feat_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
l = nn.Conv1d(self.input_dim, hidden_dim * 2, size, padding=size//2, bias=self.args.get('hier_conv_res', False) or (si == 0))
self.conv_res.append(l)

if self.args.get('hier_conv_res', False):
Expand All @@ -42,8 +57,17 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):

self.toknoise = nn.Dropout(self.args['tok_noise'])

def forward(self, x, feats, lengths):
def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)

def forward(self, x, feats, lengths, raw=None):
emb = self.embeddings(x)

if self.charmodel is not None and raw is not None:
char_emb = self.charmodel(raw, wrap=False)
emb = torch.cat([emb, char_emb], axis=2)

emb = self.dropout(emb)
feats = self.dropout_feat(feats)

Expand Down
33 changes: 25 additions & 8 deletions stanza/models/tokenization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
logger = logging.getLogger('stanza')

class Trainer(BaseTrainer):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, foundation_cache=None):
# TODO: make a test of the training w/ and w/o charlm
if model_file is not None:
# load everything from file
self.load(model_file)
self.load(model_file, args, foundation_cache)
else:
# build model from scratch
self.args = args
Expand All @@ -41,7 +42,7 @@ def update(self, inputs):
labels = labels.to(device)
features = features.to(device)

pred = self.model(units, features, lengths)
pred = self.model(units, features, lengths, text)

self.optimizer.zero_grad()
classes = pred.size(2)
Expand All @@ -62,13 +63,22 @@ def predict(self, inputs):
units = units.to(device)
features = features.to(device)

pred = self.model(units, features, lengths)
pred = self.model(units, features, lengths, text)

return pred.data.cpu().numpy()

def save(self, filename):
def save(self, filename, skip_modules=True):
model_state = None
if self.model is not None:
model_state = self.model.state_dict()
# skip saving modules like the pretrained charlm
if skip_modules:
skipped = [k for k in model_state.keys() if k.split('.')[0] in self.model.unsaved_modules]
for k in skipped:
del model_state[k]

params = {
'model': self.model.state_dict() if self.model is not None else None,
'model': model_state,
'vocab': self.vocab.state_dict(),
# save and load lexicon as list instead of set so
# we can use weights_only=True
Expand All @@ -81,19 +91,26 @@ def save(self, filename):
except BaseException:
logger.warning("Saving failed... continuing anyway.")

def load(self, filename):
def load(self, filename, args, foundation_cache):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
logger.error("Cannot load model from {}".format(filename))
raise
self.args = checkpoint['config']
if args is not None and args.get('charlm_forward_file', None) is not None:
if checkpoint['config'].get('charlm_forward_file') is None:
# if the saved model didn't use a charlm, we skip the charlm here
# otherwise the loaded model weights won't fit in the newly created model
self.args['charlm_forward_file'] = None
else:
self.args['charlm_forward_file'] = args['charlm_forward_file']
if self.args.get('use_mwt', None) is None:
# Default to True as many currently saved models
# were built with mwt layers
self.args['use_mwt'] = True
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model.load_state_dict(checkpoint['model'])
self.model.load_state_dict(checkpoint['model'], strict=False)
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
self.lexicon = checkpoint['lexicon']

Expand Down
27 changes: 16 additions & 11 deletions stanza/models/tokenization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,25 @@ def load_lexicon(args):


def load_mwt_dict(filename):
if filename is not None:
with open(filename, 'r') as f:
mwt_dict0 = json.load(f)
"""
Returns a dict from an MWT to its most common expansion and count.

mwt_dict = dict()
for item in mwt_dict0:
(key, expansion), count = item
Other less common expansions are discarded.
"""
if filename is None:
return None

if key not in mwt_dict or mwt_dict[key][1] < count:
mwt_dict[key] = (expansion, count)
with open(filename, 'r') as f:
mwt_dict0 = json.load(f)

return mwt_dict
else:
return
mwt_dict = dict()
for item in mwt_dict0:
(key, expansion), count = item

if key not in mwt_dict or mwt_dict[key][1] < count:
mwt_dict[key] = (expansion, count)

return mwt_dict

def process_sentence(sentence, mwt_dict=None):
sent = []
Expand Down
Loading