In [4]:
from openprompt.prompts import MixedTemplate
from openprompt.pipeline_base import PromptForGeneration, PromptDataLoader
from openprompt.plms import load_plm
from openprompt import plms
from openprompt.plms import *
from transformers import GPTJConfig, GPTJModel, GPTJForCausalLM
plms._MODEL_CLASSES["gptj"]= ModelClass(**{"config": GPTJConfig, "tokenizer": GPT2Tokenizer, "model": GPTJForCausalLM,
"wrapper": LMTokenizerWrapper})
import pandas as pd
from openprompt.data_utils import InputExample
from tqdm import tqdm

In [None]:
use_cuda = True
# plm, tokenizer, model_config, wrapper = load_plm("bert", "bert-large-uncased")
# plm, tokenizer, model_config, wrapper = load_plm("roberta", "roberta-large")
# plm, tokenizer, model_config, wrapper = load_plm("gpt2", "gpt2-large")
# plm, tokenizer, model_config, wrapper = load_plm("t5", "google/flan-t5-large")
plm, tokenizer, model_config, wrapper = load_plm("gptj", "EleutherAI/gpt-j-6b")

Downloading:   0%|          | 0.00/930 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/22.5G [00:00<?, ?B/s]

In [None]:
data = pd.read_csv('./cleaned_annotations.csv')

data = data.drop(labels=['alternate_names'],axis=1)

data['raw'] = data.raw.str.lower()
data['bird'] = data.bird.str.lower()
data['scientific_name'] = data.scientific_name.str.lower()

sentences = []
def create_sentences(row):
    sentences.append(f"{row['raw']}.\nhere the common name of the bird is {row['bird']}.")
    sentences.append(f"{row['raw']}.\nhere the scientific name of {row['bird']} is {row['scientific_name']}.")

data[:2].apply(create_sentences,axis=1)

#chain of thought
cot_example = '\n'.join(sentences)

bird_template_text = cot_example + """\n{\"placeholder\":\"text_a\"}\nhere the common name of the bird is {\"mask\"}."""
bird_input_data = list(data.apply(lambda x: InputExample(text_a=x['raw'], guid=x.name, tgt_text=x['bird']), axis=1).values)
template = MixedTemplate(plm, tokenizer, text=bird_template_text)


In [None]:
# %%
generation_arguments = {
    "max_new_tokens": 10,
    "min_length": 1,
    "temperature": 1.0,
    "do_sample": False,
    "top_k": 0,
    "top_p": 0.9,
    "repetition_penalty": 2.0,
    "num_beams": 5,
}

dataloader = PromptDataLoader(dataset=bird_input_data[3:],
                              template=template,
                              tokenizer=tokenizer,
                              tokenizer_wrapper_class=wrapper,
                              max_seq_length=512,
                              decoder_max_length=4,
                              batch_size=1)
# %%
model = PromptForGeneration(plm, template, freeze_plm=True, plm_eval_mode=True, tokenizer=tokenizer)
if use_cuda:
    model = model.cuda()
# %%
preds = []
for i, inputs in tqdm(enumerate(dataloader),total=len(dataloader)):
    if use_cuda:
        inputs = inputs.cuda()
    _, output_sentence = model.generate(inputs, **generation_arguments)
    preds.append(output_sentence[0].split('.')[0].split(' or ')[0].strip())

# %%

In [None]:
pd.DataFrame(preds).to_csv('cot_gptj.csv',index=False)

In [None]:
scn_template_text = cot_example + """\n{\"placeholder\":\"text_a\"}\nhere scientific name of {\"placeholder\":\"text_b\"} is {\"mask\"}."""
scn_input_data = list(data.apply(lambda x: InputExample(text_a=x['raw'],text_b=preds[int(x.name)-3], guid=x.name, tgt_text=x['scientific_name']), axis=1).values)
scn_template = MixedTemplate(plm, tokenizer, text=scn_template_text)


In [None]:
# %%
generation_arguments = {
    "max_new_tokens": 10,
    "min_length": 1,
    "temperature": 1.0,
    "do_sample": False,
    "top_k": 0,
    "top_p": 0.9,
    "repetition_penalty": 2.0,
    "num_beams": 5,
}

dataloader = PromptDataLoader(dataset=scn_input_data[3:],
                              template=scn_template,
                              tokenizer=tokenizer,
                              tokenizer_wrapper_class=wrapper,
                              max_seq_length=512,
                              decoder_max_length=4,
                              batch_size=1)
# %%
model = PromptForGeneration(plm, template, freeze_plm=True, plm_eval_mode=True, tokenizer=tokenizer)
if use_cuda:
    model = model.cuda()
# %%
preds = []
for i, inputs in tqdm(enumerate(dataloader),total=len(dataloader)):
    if use_cuda:
        inputs = inputs.cuda()
    _, output_sentence = model.generate(inputs, **generation_arguments)
    preds.append(output_sentence[0].split('.')[0].split(' or ')[0].strip())

In [None]:
pd.DataFrame(preds).to_csv('cot_gptj_scn.csv',index=False)