Skip to content

Commit

Permalink
Add an option to train the depparse w/o upos... can test just a trans…
Browse files Browse the repository at this point in the history
…former, for example
  • Loading branch information
AngledLuffa committed Apr 8, 2024
1 parent 909159c commit 15b136b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
17 changes: 12 additions & 5 deletions stanza/models/depparse/model.py
Expand Up @@ -41,7 +41,8 @@ def add_unsaved_module(name, module):
input_size += self.args['word_emb_dim'] * 2

if self.args['tag_emb_dim'] > 0:
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
if self.args.get('use_upos', True):
self.upos_emb = nn.Embedding(len(vocab['upos']), self.args['tag_emb_dim'], padding_idx=0)
if self.args.get('use_xpos', True):
if not isinstance(vocab['xpos'], CompositeVocab):
self.xpos_emb = nn.Embedding(len(vocab['xpos']), self.args['tag_emb_dim'], padding_idx=0)
Expand All @@ -50,7 +51,8 @@ def add_unsaved_module(name, module):

for l in vocab['xpos'].lens():
self.xpos_emb.append(nn.Embedding(l, self.args['tag_emb_dim'], padding_idx=0))
input_size += self.args['tag_emb_dim']
if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
input_size += self.args['tag_emb_dim']

if self.args.get('use_ufeats', True):
self.ufeats_emb = nn.ModuleList()
Expand Down Expand Up @@ -159,16 +161,21 @@ def pack(x):
inputs += [word_emb, lemma_emb]

if self.args['tag_emb_dim'] > 0:
pos_emb = self.upos_emb(upos)
if self.args.get('use_upos', True):
pos_emb = self.upos_emb(upos)
else:
pos_emb = 0

if self.args.get('use_xpos', True):
if isinstance(self.vocab['xpos'], CompositeVocab):
for i in range(len(self.vocab['xpos'])):
pos_emb += self.xpos_emb[i](xpos[:, :, i])
else:
pos_emb += self.xpos_emb(xpos)
pos_emb = pack(pos_emb)
inputs += [pos_emb]

if self.args.get('use_upos', True) or self.args.get('use_xpos', True):
pos_emb = pack(pos_emb)
inputs += [pos_emb]

if self.args.get('use_ufeats', True):
feats_emb = 0
Expand Down
1 change: 1 addition & 0 deletions stanza/models/parser.py
Expand Up @@ -56,6 +56,7 @@ def build_argparse():
parser.add_argument('--word_emb_dim', type=int, default=75)
parser.add_argument('--char_emb_dim', type=int, default=100)
parser.add_argument('--tag_emb_dim', type=int, default=50)
parser.add_argument('--no_upos', dest='use_upos', action='store_false', default=True, help="Don't use upos tags as part of the tag embedding")
parser.add_argument('--no_xpos', dest='use_xpos', action='store_false', default=True, help="Don't use xpos tags as part of the tag embedding")
parser.add_argument('--no_ufeats', dest='use_ufeats', action='store_false', default=True, help="Don't use ufeats as part of the tag embedding")
parser.add_argument('--transformed_dim', type=int, default=125)
Expand Down

0 comments on commit 15b136b

Please sign in to comment.