Skip to content

Commit

Permalink
Add a learned weighting between bert layers
Browse files Browse the repository at this point in the history
Make the number of hidden layers an option and start from zeros
Generalize the num_layers for Phobert and XLNet.  Keep old models alive

Includes an option to use the layers from the older versions of conparse
  • Loading branch information
AngledLuffa committed Sep 26, 2022
1 parent 7517061 commit 2d0c69e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 19 deletions.
45 changes: 27 additions & 18 deletions stanza/models/common/bert_embedding.py
Expand Up @@ -87,7 +87,24 @@ def filter_data(model_name, data, tokenizer = None):
return filtered_data


def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints):
def cloned_feature(feature, num_layers):
"""
Clone & detach the feature, keeping the last N layers (or averaging -2,-3,-4 if not specified)
averaging 3 of the last 4 layers worked well for non-VI languages
"""
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
# feature = feature[2]
feature = feature.hidden_states
if num_layers is None:
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
else:
feature = torch.stack(feature[-num_layers:], axis=3)
return feature.clone().detach()


def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers):
"""
Extract transformer embeddings using a method specifically for phobert
Expand Down Expand Up @@ -129,10 +146,7 @@ def extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_
for i in range(int(math.ceil(size/128))):
with torch.no_grad():
feature = model(tokenized_sents_padded[128*i:128*i+128].clone().detach().to(device), output_hidden_states=True)
# averaging the last four layers worked well for non-VI languages
feature = feature[2]
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
features += feature.clone().detach()
features += cloned_feature(feature, num_layers)

assert len(features)==size
assert len(features)==len(processed)
Expand Down Expand Up @@ -173,7 +187,7 @@ def fix_german_tokens(tokenizer, data):
new_data.append(new_sentence)
return new_data

def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints):
def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers):
# using attention masks makes contextual embeddings much more useful for downstream tasks
tokenized = tokenizer(data, is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=False)
#tokenized = tokenizer(data, padding="longest", is_split_into_words=True, return_offsets_mapping=False, return_attention_mask=True)
Expand Down Expand Up @@ -222,9 +236,7 @@ def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_en
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
# feature = feature[2]
feature = feature.hidden_states
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
features += feature.clone().detach()
features += cloned_feature(feature, num_layers)

processed = []
#process the output
Expand All @@ -238,19 +250,21 @@ def extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_en
return processed


def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints):
def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers=None):
"""
Extract transformer embeddings using a generic roberta extraction
data: list of list of string (the text tokens)
num_layers: how many to return. If None, the average of -2, -3, -4 is returned
"""
if model_name.startswith("vinai/phobert"):
return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints)
return extract_phobert_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers)

if isinstance(data, tuple):
data = list(data)

if model_name.startswith("xlnet"):
return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints)
return extract_xlnet_embeddings(model_name, tokenizer, model, data, device, keep_endpoints, num_layers)

if model_name in BAD_TOKENIZERS:
data = fix_german_tokens(tokenizer, data)
Expand Down Expand Up @@ -282,12 +296,7 @@ def extract_bert_embeddings(model_name, tokenizer, model, data, device, keep_end
attention_mask = torch.tensor(tokenized['attention_mask'][128*i:128*i+128], device=device)
id_tensor = torch.tensor(tokenized['input_ids'][128*i:128*i+128], device=device)
feature = model(id_tensor, attention_mask=attention_mask, output_hidden_states=True)
# feature[2] is the same for bert, but it didn't work for
# older versions of transformers for xlnet
# feature = feature[2]
feature = feature.hidden_states
feature = torch.stack(feature[-4:-1], axis=3).sum(axis=3) / 4
features += feature.clone().detach()
features += cloned_feature(feature, num_layers)

processed = []
#process the output
Expand Down
18 changes: 17 additions & 1 deletion stanza/models/constituency/lstm_model.py
Expand Up @@ -311,6 +311,15 @@ def __init__(self, pretrain, forward_charlm, backward_charlm, bert_model, bert_t
if bert_tokenizer is None:
raise ValueError("Cannot have a bert model without a tokenizer")
self.bert_dim = self.bert_model.config.hidden_size
if args['bert_hidden_layers']:
# The average will be offset by 1/N so that the default zeros
# repressents an average of the N layers
self.bert_layer_mix = nn.Linear(args['bert_hidden_layers'], 1, bias=False)
nn.init.zeros_(self.bert_layer_mix.weight)
else:
# an average of layers 2, 3, 4 will be used
# (for historic reasons)
self.bert_layer_mix = None
self.word_input_size = self.word_input_size + self.bert_dim

self.partitioned_transformer_module = None
Expand Down Expand Up @@ -616,7 +625,14 @@ def map_word(word):
# result will be len+2 for each sentence
# we will take 1:-1 if we don't care about the endpoints
bert_embeddings = extract_bert_embeddings(self.args['bert_model'], self.bert_tokenizer, self.bert_model, all_word_labels, device,
keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE)
keep_endpoints=self.sentence_boundary_vectors is not SentenceBoundary.NONE,
num_layers=self.bert_layer_mix.in_features if self.bert_layer_mix is not None else None)
if self.bert_layer_mix is not None:
# add the average so that the default behavior is to
# take an average of the N layers, and anything else
# other than that needs to be learned
bert_embeddings = [self.bert_layer_mix(feature).squeeze(2) + feature.sum(axis=2) / self.bert_layer_mix.in_features for feature in bert_embeddings]

all_word_inputs = [torch.cat((x, y), axis=1) for x, y in zip(all_word_inputs, bert_embeddings)]

# Extract partitioned representation
Expand Down
4 changes: 4 additions & 0 deletions stanza/models/constituency/trainer.py
Expand Up @@ -103,6 +103,10 @@ def load(filename, args=None, load_optimizer=False, foundation_cache=None):
saved_args['constituent_stack'] = StackHistory.LSTM
if 'num_tree_lstm_layers' not in saved_args:
saved_args['num_tree_lstm_layers'] = 1
if 'bert_hidden_layers' not in checkpoint['args']:
# TODO: no need to do this once the models have bert_hidden_layers in them
saved_args['bert_hidden_layers'] = None

params = checkpoint['params']

# TODO: can remove when all models have been rearranged to use the refactored lstm_stacks
Expand Down
2 changes: 2 additions & 0 deletions stanza/models/constituency_parser.py
Expand Up @@ -161,6 +161,8 @@ def parse_args(args=None):
# for VI, for example, use vinai/phobert-base
parser.add_argument('--bert_model', type=str, default=None, help="Use an external bert model (requires the transformers package)")
parser.add_argument('--no_bert_model', dest='bert_model', action="store_const", const=None, help="Don't use bert")
parser.add_argument('--bert_hidden_layers', type=int, default=4, help="How many layers of hidden state to use from the transformer")
parser.add_argument('--bert_hidden_layers_original', action='store_const', const=None, dest='bert_hidden_layers', help='Use layers 2,3,4 of the Bert embedding')

parser.add_argument('--tag_embedding_dim', type=int, default=20, help="Embedding size for a tag. 0 turns off the feature")
# Smaller values also seem to work
Expand Down

0 comments on commit 2d0c69e

Please sign in to comment.