Skip to content

Commit

Permalink
restore mask in init
Browse files Browse the repository at this point in the history
  • Loading branch information
rbroc committed Mar 10, 2020
1 parent 0f2918f commit dbf4d20
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions pliers/extractors/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ def __init__(self,

model = model_class if self.framework == 'pt' else 'TF' + model_class
self.model = getattr(transformers, model).from_pretrained(
pretrained_model, **model_kwargs)
pretrained_model, **self.model_kwargs)
self.tokenizer = transformers.BertTokenizer.from_pretrained(
tokenizer, **tokenizer_kwargs)
tokenizer, **self.tokenizer_kwargs)
super(BertExtractor, self).__init__()

def _mask(self, wds, mask):
Expand All @@ -474,9 +474,9 @@ def _preprocess(self, stims, mask):
idx = self.tokenizer.encode(tok, return_tensors=self.framework)
return wds, ons, dur, tok, idx

def _extract(self, stims, **kwargs):
mask = kwargs['mask'] if 'mask' in kwargs else None
wds, ons, dur, tok, idx = self._preprocess(stims, mask)
def _extract(self, stims):
mask = self.mask or None
wds, ons, dur, tok, idx = self._preprocess(stims, mask=mask)
preds = self.model(idx)
preds = [p.detach() if self.framework == 'pt' else p for p in preds]
data, feat, ons, dur = self._postprocess(preds, tok, wds, ons, dur)
Expand Down Expand Up @@ -605,6 +605,11 @@ class BertLMExtractor(BertExtractor):
unknown tokens.
framework (str): name deep learning framework to use. Must be 'pt'
(PyTorch) or 'tf' (tensorflow). Defaults to 'pt'.
mask (int or str): Words to be masked (string) or indices of
words in the sequence to be masked (indexing starts at 0). Can
be either a single word/index or a list of words/indices.
If str is passed and more than one word in the input matches
the string, only the first one is masked.
top_n (int): Specifies how many of the highest-probability tokens are
to be returned. Mutually exclusive with target and threshold.
target (str or list): Vocabulary token(s) for which probability is to
Expand All @@ -623,8 +628,7 @@ class BertLMExtractor(BertExtractor):
See https://huggingface.co/transformers/main_classes/tokenizer.html.
'''

_log_attributes = ('pretrained_model', 'framework', 'top_n', 'mask_pos',
'mask_token', 'target', 'tokenizer_type', 'return_softmax')
_log_attributes = ('pretrained_model', 'framework', 'top_n', 'target', 'tokenizer_type', 'return_softmax')

def __init__(self,
pretrained_model='bert-base-uncased',
Expand Down Expand Up @@ -662,6 +666,7 @@ def __init__(self,
f'(\'{tokenizer}\').vocab.keys() to see available tokens')
self.return_softmax = return_softmax
self.return_true = return_true
self.mask = mask

def _mask(self, wds, mask):
if not type(mask) in [int, str]:
Expand All @@ -672,17 +677,6 @@ def _mask(self, wds, mask):
mwds[self.mask_pos] = '[MASK]'
return mwds

def _extract(self, stims, mask):
'''
Args:
mask (int or str): Words to be masked (string) or indices of
words in the sequence to be masked (indexing starts at 0). Can
be either a single word/index or a list of words/indices.
If str is passed and more than one word in the input matches
the string, only the first one is masked.
'''
return super()._extract(stims=stims, mask=mask)

def _postprocess(self, preds, tok, wds, ons, dur):
preds = preds[0].numpy()[:,1:-1,:]
if self.return_softmax:
Expand All @@ -706,22 +700,28 @@ def _return_true_token(self, preds, feat, data):
if self.mask_token in self.tokenizer.vocab:
true_vocab_idx = self.tokenizer.vocab[self.mask_token]
true_score = preds[0, self.mask_pos, true_vocab_idx]
feat += ['true_word', 'true_word_score']
data += [self.mask_token, true_score]
else:
logging.warning('True token not in vocabulary, cannot return')
true_score = np.nan
logging.warning('True token not in vocabulary. Returning NaN')
feat += ['true_word', 'true_word_score']
data += [self.mask_token, true_score]
return feat, data

def _get_model_attributes(self):
return ['pretrained_model', 'framework', 'top_n', 'mask_pos',
'target', 'threshold', 'mask_token', 'tokenizer_type']
return ['pretrained_model', 'framework', 'top_n', 'mask',
'target', 'threshold', 'tokenizer_type']

# To discuss:
# What to do with SEP token? Does it need to be there?
# Return other layers and/or attentions?
# Couple of mixins (sequence coherence, probability)
# Look into the sentiment extractor
# Discuss probability mixin with Tal
# Metadata as features / Add other field to store additional info?

# To dos:
# Metadata as features / Add other field to store additional info (?)
# Log input sequence in LM extractor
# NB: a bit suboptimal to set mask in init, but handier

class WordCounterExtractor(ComplexTextExtractor):

Expand Down

0 comments on commit dbf4d20

Please sign in to comment.