Inference for masked tokens with BERT

In [1]:
from copy import deepcopy
from pathlib import Path
from typing import List

import nltk
import pandas as pd
import torch
from tqdm.notebook import tqdm
from transformers import AutoModelForMaskedLM, AutoTokenizer

from data import Sentence, load_sentences, nlp

In [2]:
model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-german-cased")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-german-cased")

MASK_TOKEN_ID = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

Some weights of the model checkpoint at google-bert/bert-base-german-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
INPATH = Path(__name__).absolute().parent.parent / "data/train.csv"
OUTPATH = Path(__name__).absolute().parent.parent / "data/train_augmented.csv"

def replace_nouns(sentence_row: pd.Series) -> List[pd.Series]:

    sentence = Sentence.from_row(sentence_row)

    noun_indices = []
    for i, token in enumerate(sentence.tokens):

        try:
            a = token[0]
        except IndexError:
            continue

        if i > 0 and token[0].isupper():# and token[-1].isalpha():
            noun_indices.append(i)

    pseudo_rows = []

    for i in noun_indices:
        new_sentence_list = deepcopy(sentence.tokens)
        new_sentence_list[i] = "[MASK]"

        if not sentence.tokens[i][-1].isalpha():
            new_sentence_list.insert(i+1, sentence.tokens[i][-1])
            suffixed_by_noalpha = True
        else:
            suffixed_by_noalpha = False
        
        # Get predictions
        new_sentence = " ".join(new_sentence_list)
        inputs = tokenizer(new_sentence, return_tensors="pt")
        mask_token_index = torch.where(inputs["input_ids"] == MASK_TOKEN_ID)[1]
        token_logits = model(**inputs).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_tokens = torch.topk(mask_token_logits, 50, dim=1).indices[0].tolist()
        
        n = len(pseudo_rows)
        for token_id in top_tokens:
            if len(pseudo_rows) - n > 4:
                break

            token = tokenizer.decode([token_id])

            if token[0].isupper() and token != sentence.tokens[i]:
                
                new_sentence_filled = deepcopy(new_sentence_list)
                new_sentence_filled[i] = token

                if suffixed_by_noalpha:
                    new_sentence_filled.pop(i+1)
                    new_sentence_filled[i] = new_sentence_filled[i] + sentence.tokens[i][-1]

                new_sentence = " ".join(new_sentence_filled)

                pseudo_row = pd.Series({
                    "sent-id": replace_nouns.running_idx,
                    "topic": sentence.topic,
                    "phrase": new_sentence,
                    "phrase_number": f"AUG-{sentence_row['sent-id']}",
                    "genre": sentence_row["genre"],
                    "timestamp": sentence_row["timestamp"],
                    "user": "dataaugmenter",
                    "phrase_tokenized": " ".join([f"{i}:={token}" for i, token in enumerate(new_sentence_filled)]),
                    "statement_spans": sentence.statement_spans.__repr__(),
                    "num_statements": len(sentence.statement_spans),
                })

                pseudo_rows.append(pseudo_row)

                replace_nouns.running_idx += 1

    return pseudo_rows
replace_nouns.running_idx = 0

df = pd.read_csv(INPATH)
df_augmented = pd.DataFrame(columns=df.columns)

for i, row in tqdm(df.iterrows(), total=len(df)):
    new_rows = replace_nouns(row)
    # Add the new rows to the augmented dataframe
    for new_row in new_rows:
        df_augmented = pd.concat([df_augmented, pd.DataFrame([new_row])])

df_augmented.to_csv(OUTPATH, index=False)

  0%|          | 0/2944 [00:00<?, ?it/s]

# POS-Tag constrained replacement

In [4]:
train_sentences = load_sentences("train")

for sentence in tqdm(train_sentences):
    sentence.spacy_tokens

  0%|          | 0/2944 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
def replace_content_words(sentence: Sentence) -> List[Sentence]:
    spacy_tokens = sentence.spacy_tokens
    given_tokens = sentence.tokens
    
    n_indices = [] # Nouns
    v_indices = [] # Verbs
    a_indices = [] # Adjectives and adverbs

    for i, token_group in enumerate(spacy_tokens):

        try:
            a = token_group[0]
        except IndexError:
            continue

        token = token_group[0]
        if token.pos_ in ["NOUN", "PROPN"]:
            n_indices.append(i)
        if token.pos_ in ["VERB", "AUX"]:
            v_indices.append(i)
        if token.pos_ in ["ADJ", "ADV"]:
            a_indices.append(i)

    def predict_single_masks(sentence, mask_idx, allowed_pos = None) -> List[str]:
        masked_tokens = deepcopy(sentence.tokens)
        masked_tokens[mask_idx] = "[MASK]"
        original_token = given_tokens[mask_idx]

        # If the original token was suffixed by a punctuation mark, add it as a separate
        # token for mask inference
        if not sentence.tokens[mask_idx][-1].isalpha():
            masked_tokens.insert(mask_idx+1, sentence.tokens[mask_idx][-1])
            original_token = original_token[:-1]
            suffixed_by_noalpha = True
        else:
            suffixed_by_noalpha = False

        # Get predictions
        new_sentence = " ".join(masked_tokens)
        inputs = tokenizer(new_sentence, return_tensors="pt")
        mask_token_index = torch.where(inputs["input_ids"] == MASK_TOKEN_ID)[1]
        token_logits = model(**inputs).logits
        mask_token_logits = token_logits[0, mask_token_index, :]
        top_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

        valid_replacements = []
        for predicted_token_id in top_tokens:
            predicted_token = tokenizer.decode([predicted_token_id])

            # Check for non-inital subword
            if predicted_token.startswith("##"):
                continue

            # Check whether the predicted token is the same as the original token
            if predicted_token == original_token:
                continue
            
            # Check for POS constraints

            doc = nlp(predicted_token)
            if allowed_pos and not doc[0].pos_ in allowed_pos:
                continue

            # masked_tokens[mask_idx] = predicted_token + (sentence.tokens[mask_idx][-1] if suffixed_by_noalpha else "")
            # pos_tags = nltk.pos_tag(masked_tokens, tagset="universal", lang="ger")
            # if allowed_pos and not pos_tags[mask_idx][1] in allowed_pos:
            #     continue
            
            # Add the suffix back to the predicted token if it was removed
            if suffixed_by_noalpha:
                predicted_token = predicted_token + sentence.tokens[mask_idx][-1]
            
            valid_replacements.append(predicted_token)
        
        return valid_replacements


    replacements = {}

    for idx in n_indices:
        replacements[idx] = predict_single_masks(sentence, idx, allowed_pos=["NOUN", "PROPN"])

    for idx in v_indices:
        replacements[idx] = predict_single_masks(sentence, idx, allowed_pos=["VERB", "AUX"])
    
    for idx in a_indices:
        replacements[idx] = predict_single_masks(sentence, idx, allowed_pos=["ADJ", "ADV"])
    
    new_sentences = []
    for idx, r_tokens in replacements.items():
        for r_token in r_tokens:
            new_sentence_tokens = deepcopy(given_tokens)
            new_sentence_tokens[idx] = r_token
            new_sentence = " ".join(new_sentence_tokens)

            pseudo_row = pd.Series({
                "sent-id": replace_content_words.running_idx,
                "topic": sentence.topic,
                "phrase": new_sentence,
                "phrase_number": f"AUG-{replace_content_words.running_idx}",
                "genre": "AUG",
                "timestamp": "It's augmentation time!",
                "user": "dataaugmenter",
                "phrase_tokenized": " ".join([f"{i}:={token}" for i, token in enumerate(new_sentence_tokens)]),
                "statement_spans": sentence.statement_spans.__repr__(),
                "num_statements": len(sentence.statement_spans),
                "replaced_postag": spacy_tokens[idx][0].pos_,
            })

            new_sentences.append(pseudo_row)
        
            replace_content_words.running_idx += 1

            break

    return new_sentences

replace_content_words.running_idx = 0


df_augmented = pd.DataFrame()

for sentence in tqdm(train_sentences):
    new_rows = replace_content_words(sentence)
    # Add the new rows to the augmented dataframe
    for new_row in new_rows:
        df_augmented = pd.concat([df_augmented, pd.DataFrame([new_row])])

OUTPATH = Path(__name__).absolute().parent.parent / "data/train_augmented_v2.csv"
df_augmented.to_csv(OUTPATH, index=False)

  0%|          | 0/2944 [00:00<?, ?it/s]