In [37]:
from transformers import BertConfig, BertTokenizer, BertModel, BertForMaskedLM
from openprompt.data_utils import InputExample
from openprompt.prompts import ManualTemplate, ManualVerbalizer
from openprompt import PromptForClassification, PromptDataLoader
from openprompt.plms import load_plm
from collections import namedtuple
from mlm import MLMTokenizerWrapper
import torch


model_path = "runs/ta_pretraining/checkpoint-435"
ModelClass = namedtuple("ModelClass", ('config', 'tokenizer', 'model','wrapper'))

_MODEL_CLASSES = {
    'bert': ModelClass(**{
        'config': BertConfig,
        'tokenizer': BertTokenizer,
        'model':BertForMaskedLM,
        'wrapper': MLMTokenizerWrapper,
    })}

def get_model_class(plm_type: str):
    return _MODEL_CLASSES[plm_type]

model_class = get_model_class('bert')
model_config = BertConfig.from_pretrained(model_path)
plm = BertForMaskedLM.from_pretrained(model_path, config=model_config)
tokenizer = BertTokenizer.from_pretrained('UFNLP/gatortron-base')
WrapperClass = MLMTokenizerWrapper


# plm, tokenizer, model_config, WrapperClass = load_plm("bert", 'UFNLP/gatortron-base') # "runs/ta_pretraining/checkpoint-435"

You are using a model of type megatron-bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at runs/ta_pretraining/checkpoint-435 were not used when initializing BertForMaskedLM: ['bert.encoder.layer.11.attention.ln.weight', 'bert.encoder.layer.18.attention.ln.bias', 'bert.encoder.layer.16.ln.bias', 'bert.encoder.layer.4.attention.ln.weight', 'bert.encoder.layer.19.ln.bias', 'bert.encoder.layer.17.attention.ln.bias', 'bert.encoder.layer.20.ln.weight', 'bert.encoder.layer.0.attention.ln.bias', 'bert.encoder.layer.1.ln.weight', 'bert.encoder.layer.14.attention.ln.weight', 'bert.encoder.layer.8.attention.ln.bias', 'bert.encoder.layer.1.attention.ln.weight', 'bert.encoder.layer.18.ln.weight', 'bert.encoder.layer.12.ln.bias', 'bert.encoder.layer.1.ln.bias', 'bert.encoder.layer.13.ln.weight', 'bert.encoder.layer.15.attention.ln.weight', 'bert.encoder.layer.3.attention.ln.weight', 'bert.enco

In [38]:
classes = [ # There are two classes in Sentiment Analysis, one for negative and one for positive
    "neutral",
    "negative",
    "positive"
]
dataset = [ # For simplicity, there's only two examples
    # text_a is the input text of the data, some other datasets may have multiple input sentences in one example.
    InputExample(
        guid = 0, # neutral
        meta = "neutral",
        text_a = "She states that pt has been compliant with meds",
    ),
    InputExample(
        guid = 1, #negative
        meta = "negative",
        text_a = "Pt remains aggressive and very threatening upon arrival, tried to hit security officer with his head while still in handcuff",
    ),
    InputExample(
        guid = 2, # positive
        meta = "positive",
        text_a = "MSE: pleasant, cooperative, euthymic, speech wnl, affect full and appropriate to content",
    )
]

print(dataset)

[{
  "guid": 0,
  "label": null,
  "meta": "neutral",
  "text_a": "She states that pt has been compliant with meds",
  "text_b": "",
  "tgt_text": null
}
, {
  "guid": 1,
  "label": null,
  "meta": "negative",
  "text_a": "Pt remains aggressive and very threatening upon arrival, tried to hit security officer with his head while still in handcuff",
  "text_b": "",
  "tgt_text": null
}
, {
  "guid": 2,
  "label": null,
  "meta": "positive",
  "text_a": "MSE: pleasant, cooperative, euthymic, speech wnl, affect full and appropriate to content",
  "text_b": "",
  "tgt_text": null
}
]


In [41]:
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} It was {"mask"}',
    tokenizer = tokenizer,
)

promptVerbalizer = ManualVerbalizer(
    classes = classes,
    label_words = {
        "neutral": ["fair", "okay", "unbiased", "unknown"],
        "negative": ["bad", "awful", "terrible", "horrible"],
        "positive": ["good", "wonderful", "great", "effective"],
    },
    tokenizer = tokenizer,
)

promptModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
)

data_loader = PromptDataLoader(
    dataset = dataset,
    tokenizer = tokenizer,
    template = promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)

# making zero-shot inference using pretrained MLM with prompt
promptModel.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = promptModel(batch)
        preds = torch.argmax(logits, dim = -1)
        print(classes[preds])
# predictions would be 1, 0 for classes 'positive', 'negative'

tokenizing: 3it [00:00, 841.78it/s]


negative
negative
positive
