In [103]:
from transformers import BertForSequenceClassification
from scipy.special import softmax
import pandas as pd
from IPython.display import Image
from transformers import BertTokenizer
from scipy.special import softmax
from IPython.display import Image
from datasets import load_dataset
import torch
import numpy as np
from openprompt import PromptDataLoader, PromptForClassification
from openprompt.data_utils import InputExample
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate, ManualVerbalizer, ManualTemplate

In [104]:
def forward(model, encoding, n):
    outputs = model(**encoding)
    predictions = outputs.logits.detach().numpy()[0]
    predictions = [(idx, single_output) for idx, single_output in enumerate((softmax(predictions)*100))]
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions[:n]

def predict(model, inference_text, tokenizer,n):
    inference_input = InputExample(text_a = inference_text)
    inference_dataloader = PromptDataLoader(dataset=[inference_input], template=promptTemplate, tokenizer=tokenizer,
        tokenizer_wrapper_class=WrapperClass, max_seq_length=250, decoder_max_length=3,
        batch_size=1,shuffle=False, teacher_forcing=False, predict_eos_token=False,
        truncate_method="head")

    for index, inputs in enumerate(inference_dataloader):
        logits = model(inputs)
    predictions = [(idx, single_output) for idx, single_output in enumerate((softmax(logits.detach().numpy().tolist()[0])*100))]
    predictions.sort(key=lambda x: x[1], reverse=True)
    return predictions[:n]

def create_input_text_list(input_text):
    input_text_list = [[]]
    line_input_length = int(len(input_text.split())/3)
    max_length = 10
    line_max = max_length if line_input_length < max_length else line_input_length if len(input_text) > max_length * 3 else max_length
    print(line_max)
    if len(input_text) > line_max:
        word_list = input_text.split()
        word_count = 0
        word_idx = 0
        for word in word_list:
            if word_count < line_max:
                input_text_list[word_idx].append(word)
                word_count += 1
            else:
                word_count = 0
                word_idx += 1
                input_text_list.append([])
        input_text_list = [" ".join(text_list) for text_list in input_text_list]
    else:
        input_text_list.append(input_text)
    return input_text_list

def create_pretty_string(model_names, input_list, model_labels, top_n):
    output_string = ""
    border = "  |  "
    input_title = "Input"
    for input_text in input_list:
        input_text_list = create_input_text_list(input_text)
        output_string += "######\n"
        first_column_length = len(max(input_text_list + [input_title], key=len))
        output_string += input_title
        output_string += "".ljust(first_column_length - len(input_title), " ")
        for column_idx, name in enumerate(model_names):
            output_string += border
            output_string += name
            column_length = len(max(model_labels[column_idx] + [name], key=len)) - len(name)
            output_string += "".ljust(column_length, " ")
        output_string += border
        output_string += "\n"
        output_string += (u'\u2500'*(len(output_string))) + "\n"
        for row_idx in range(top_n):
            if row_idx < len(input_text_list):
                output_string += input_text_list[row_idx]
                row_length = first_column_length - len(input_text_list[row_idx])
                output_string += "".ljust(row_length, " ")
            else:
                output_string += "".ljust(first_column_length, " ")
            output_string += border
            for model_idx, model_output in enumerate(model_labels):
                max_length_column = len(max(model_output + [model_names[model_idx]], key=len))
                whitespace_length = max_length_column - len(model_output[row_idx])
                output_string += model_output[row_idx]
                output_string += "".ljust(whitespace_length, " ")
                output_string += border
            output_string += "\n"
        output_string += "######\n"
    return output_string

def pretty_inference(model_list, model_names, input_list, tokenizer_list, top_n):
    for input_text in input_list:
        model_labels = [[] for _ in range(len(model_list))]
        for idx, model in enumerate(model_list):
            tokenizer = tokenizer_list[idx]
            if "prompting" in model_names[idx]:
                predictions = predict(model, input_text, tokenizer, top_n)
            else:
                encoding = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
                predictions = forward(model, encoding, top_n)
            for prediction in predictions:
                pk_name = mappings.loc[mappings["index"]==prediction[0]]["name"].values[0]
                model_labels[idx].append(f"{pk_name}:{prediction[1]:.2f}%")
        create_pretty_string(model_names, input_list, model_labels, top_n)

In [105]:
# Dependencies
plm, prompt_tokenizer, model_config, WrapperClass = load_plm("gpt2","gpt2")
mappings = pd.read_csv('data/pokemon_mapping.csv')
name_to_label_dict = mappings[["name","index"]].set_index('index').to_dict()["name"]
pokemon_descriptions = load_dataset('data/dataset/', delimiter=';')
NUM_CLASSES = np.unique(pokemon_descriptions['train']['labels'])

Using pad_token, but it is not set yet.
Using custom data configuration dataset-294e9b13f49dafc6
Found cached dataset csv (C:/Users/fst/.cache/huggingface/datasets/csv/dataset-294e9b13f49dafc6/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/1 [00:00<?, ?it/s]

In [106]:
promptTemplate = ManualTemplate(
    text = '{"placeholder":"text_a"} the pokemon is {"mask"}',
    tokenizer = prompt_tokenizer,
)

promptVerbalizer = ManualVerbalizer(
    classes = NUM_CLASSES,
    label_words = name_to_label_dict,
    tokenizer = prompt_tokenizer,
)

promptLoadedModel = PromptForClassification(
    template = promptTemplate,
    plm = plm,
    verbalizer = promptVerbalizer,
    freeze_plm= True
)

promptLoadedModel.load_state_dict(state_dict=torch.load("prompting/checkp_copy/gpt2_trained_model.cp"))

<All keys matched successfully>

In [107]:
mappings = pd.read_csv('data/pokemon_mapping.csv')

model = [BertForSequenceClassification.from_pretrained("saved-model-base/"),BertForSequenceClassification.from_pretrained("saved-model/"),promptLoadedModel]
tokenizer = [BertTokenizer.from_pretrained("saved-model-base/"),BertTokenizer.from_pretrained("saved-model/"),prompt_tokenizer]

input_text = [
    "Walking stone monster with a huge body.",
    "Walking stone monster with a huge body. It hates water.",
    "Walking stone monster with a huge body. It hates water. Favorit attack is earthshake",
    "Insect with sharp claws only found in the safari zone",
    "only wakes up to eat",
    "A rock pokemon which looks like a stone snake",
    "A stone like snake",
    "The pokemon has a small Flower on the head and likes to sing. During the night it is sleeping.",
    "Many believe that all other Pokémon are descendants of this one",
    "It was the result of various experiments of team rocket",
    "A snake dragon like pokemon with a long tail. It is an higher evolution and is really strong. One of the top five is using this pokemon",
    "It is yellow and it's cheeks have red circles. It has long ears and likes thunder. Ash is his best friend",
    "A psychic pokemon with spoons",
    "Red legendary dragon with fire",
]

pretty_inference(model_list=model, tokenizer_list=tokenizer,model_names=["bert-base","bert-large", "gpt2-prompting"],
input_list=input_text, top_n=5)

tokenizing: 1it [00:00, 499.74it/s]


######
                                         |  bert-base        |  bert-large         |  gpt2-prompting     |  
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         |  Exeggutor:4.62%  |  Dragonite:26.77%   |  Wartortle:48.29%   |  
                                         |  Mr. Mime:2.59%   |  Aerodactyl:21.14%  |  Wigglytuff:48.29%  |  
Walking stone monster with a huge body.  |  Golem:2.42%      |  Haunter:4.37%      |  Grimer:1.51%       |  
                                         |  Marowak:2.27%    |  Arcanine:4.06%     |  Golem:0.70%        |  
                                         |  Machop:2.22%     |  Gyarados:3.11%     |  Graveler:0.24%     |  

######



tokenizing: 1it [00:00, 499.62it/s]


######
                                                         |  bert-base        |  bert-large      |  gpt2-prompting     |  
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                         |  Exeggutor:2.73%  |  Onix:18.61%     |  Wartortle:47.80%   |  
                                                         |  Haunter:2.25%    |  Geodude:18.31%  |  Wigglytuff:47.80%  |  
Walking stone monster with a huge body. It hates water.  |  Golem:2.12%      |  Golem:15.95%    |  Metapod:0.96%      |  
                                                         |  Dugtrio:1.86%    |  Rhyhorn:8.45%   |  Grimer:0.93%       |  
                                                         |  Meowth:1.76%     |  Rhydon:4.69%    |  Onix:0.92%         |  

######



tokenizing: 1it [00:00, 499.74it/s]


######
                                                                                      |  bert-base         |  bert-large        |  gpt2-prompting    |  
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                      |  Poliwrath:3.15%   |  Exeggutor:44.72%  |  Onix:73.81%       |  
                                                                                      |  Scyther:2.83%     |  Nidoqueen:26.21%  |  Golem:5.13%       |  
Walking stone monster with a huge body. It hates water. Favorit attack is earthshake  |  Vulpix:2.21%      |  Pinsir:6.37%      |  Horsea:4.84%      |  
                                                                                      |  Hitmonchan:2.13%  |  Ekans:3.98%       |  Wartortle:4.09%   |  
                                                                   

tokenizing: 1it [00:00, 1005.35it/s]


######
                                                       |  bert-base         |  bert-large         |  gpt2-prompting    |  
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                       |  Pidgey:3.15%      |  Butterfree:28.46%  |  Sandslash:12.84%  |  
                                                       |  Pinsir:2.79%      |  Caterpie:24.06%    |  Rattata:10.19%    |  
Insect with sharp claws only found in the safari zone  |  Spearow:2.47%     |  Beedrill:22.04%    |  Spearow:8.57%     |  
                                                       |  Aerodactyl:2.38%  |  Venomoth:7.47%     |  Sandshrew:7.74%   |  
                                                       |  Pidgeotto:2.31%   |  Weedle:5.30%       |  Bulbasaur:7.42%   |  

######



tokenizing: 1it [00:00, 1832.37it/s]


######
                      |  bert-base        |  bert-large      |  gpt2-prompting  |  
───────────────────────────────────────────────────────────────────────────────────────────
                      |  Abra:8.88%       |  Snorlax:98.87%  |  Oddish:46.59%   |  
                      |  Snorlax:7.38%    |  Drowzee:0.14%   |  Onix:12.87%     |  
only wakes up to eat  |  Lickitung:4.65%  |  Psyduck:0.08%   |  Drowzee:9.05%   |  
                      |  Clefairy:3.52%   |  Weezing:0.07%   |  Ditto:9.05%     |  
                      |  Eevee:3.29%      |  Slowpoke:0.05%  |  Abra:6.45%      |  

######



tokenizing: 1it [00:00, 1002.46it/s]


######
                                               |  bert-base        |  bert-large       |  gpt2-prompting   |  
──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                               |  Jynx:3.29%       |  Onix:97.44%      |  Geodude:32.81%   |  
                                               |  Graveler:2.76%   |  Graveler:0.26%   |  Onix:28.03%      |  
A rock pokemon which looks like a stone snake  |  Golem:2.59%      |  Sandslash:0.20%  |  Articuno:12.55%  |  
                                               |  Geodude:2.51%    |  Geodude:0.18%    |  Rhyhorn:5.44%    |  
                                               |  Sandslash:2.29%  |  Golem:0.10%      |  Rhydon:5.44%     |  

######



tokenizing: 1it [00:00, 1001.27it/s]


######
                    |  bert-base        |  bert-large       |  gpt2-prompting   |  
───────────────────────────────────────────────────────────────────────────────────────────
                    |  Golem:3.52%      |  Onix:62.83%      |  Onix:38.99%      |  
                    |  Jynx:2.97%       |  Pinsir:5.77%     |  Articuno:18.32%  |  
A stone like snake  |  Marowak:2.80%    |  Dugtrio:5.42%    |  Moltres:9.35%    |  
                    |  Gengar:2.67%     |  Ekans:4.38%      |  Gengar:3.91%     |  
                    |  Sandslash:2.53%  |  Exeggutor:2.81%  |  Haunter:3.01%    |  

######



tokenizing: 1it [00:00, 499.86it/s]


######
                                                                                                |  bert-base         |  bert-large         |  gpt2-prompting     |  
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                |  Venusaur:5.31%    |  Oddish:28.47%      |  Jigglypuff:27.01%  |  
                                                                                                |  Jigglypuff:4.49%  |  Jigglypuff:15.30%  |  Jynx:27.01%        |  
The pokemon has a small Flower on the head and likes to sing. During the night it is sleeping.  |  Tangela:4.34%     |  Vileplume:10.03%   |  Mewtwo:7.08%       |  
                                                                                                |  Vileplume:2.93%   |  Venonat:4.61%      |  Mew:7.08%         

tokenizing: 1it [00:00, 499.50it/s]


######
                                                                 |  bert-base        |  bert-large     |  gpt2-prompting  |  
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                 |  Ninetales:6.86%  |  Mew:98.76%     |  Mewtwo:49.53%   |  
                                                                 |  Raichu:3.39%     |  Golbat:0.14%   |  Mew:49.53%      |  
Many believe that all other Pokémon are descendants of this one  |  Eevee:2.69%      |  Mewtwo:0.10%   |  Chansey:0.58%   |  
                                                                 |  Cubone:2.63%     |  Lapras:0.09%   |  Lapras:0.13%    |  
                                                                 |  Pikachu:2.24%    |  Weezing:0.07%  |  Dratini:0.09%   |  

######



tokenizing: 1it [00:00, 499.68it/s]


######
                                                         |  bert-base        |  bert-large        |  gpt2-prompting  |  
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                         |  Magnemite:4.74%  |  Blastoise:29.68%  |  Mewtwo:49.90%   |  
                                                         |  Geodude:3.50%    |  Mewtwo:14.84%     |  Mew:49.90%      |  
It was the result of various experiments of team rocket  |  Magmar:3.34%     |  Charizard:13.18%  |  Alakazam:0.06%  |  
                                                         |  Magneton:3.08%   |  Machamp:5.65%     |  Abra:0.03%      |  
                                                         |  Porygon:2.75%    |  Kingler:4.95%     |  Paras:0.02%     |  

######



tokenizing: 1it [00:00, 499.14it/s]


######
                                                                                                                                         |  bert-base        |  bert-large        |  gpt2-prompting   |  
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                                                         |  Tauros:4.05%     |  Dragonite:81.83%  |  Gyarados:96.35%  |  
                                                                                                                                         |  Gyarados:3.76%   |  Electabuzz:2.15%  |  Magmar:0.74%     |  
A snake dragon like pokemon with a long tail. It is an higher evolution and is really strong. One of the top five is using this pokemon  |  Marowak:3.56%    |  Mewtwo:2.09%     

tokenizing: 1it [00:00, 333.33it/s]


######
                                                                                                           |  bert-base        |  bert-large        |  gpt2-prompting   |  
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                                                                                           |  Pikachu:3.56%    |  Zapdos:96.98%     |  Pikachu:95.66%   |  
                                                                                                           |  Growlithe:3.52%  |  Moltres:0.57%     |  Bulbasaur:1.12%  |  
It is yellow and it's cheeks have red circles. It has long ears and likes thunder. Ash is his best friend  |  Marowak:3.21%    |  Tauros:0.23%      |  Raichu:0.30%     |  
                                                                                                           |  Tauros:3.06%   

tokenizing: 1it [00:00, 500.75it/s]


######
                               |  bert-base        |  bert-large       |  gpt2-prompting  |  
─────────────────────────────────────────────────────────────────────────────────────────────────────
                               |  Alakazam:15.91%  |  Alakazam:98.03%  |  Kadabra:99.60%  |  
                               |  Mr. Mime:4.04%   |  Kadabra:0.98%    |  Alakazam:0.30%  |  
A psychic pokemon with spoons  |  Kadabra:4.00%    |  Gengar:0.14%     |  Gengar:0.06%    |  
                               |  Gengar:2.26%     |  Mewtwo:0.13%     |  Drowzee:0.01%   |  
                               |  Hypno:1.68%      |  Rattata:0.08%    |  Ditto:0.01%     |  

######



tokenizing: 1it [00:00, 1001.74it/s]


######
                                |  bert-base         |  bert-large       |  gpt2-prompting   |  
────────────────────────────────────────────────────────────────────────────────────────────────────────
                                |  Charizard:10.81%  |  Moltres:87.28%   |  Moltres:65.48%   |  
                                |  Moltres:8.26%     |  Arcanine:1.68%   |  Articuno:12.77%  |  
Red legendary dragon with fire  |  Zapdos:4.05%      |  Magmar:1.38%     |  Marowak:11.35%   |  
                                |  Dragonair:2.67%   |  Articuno:1.36%   |  Magmar:2.87%     |  
                                |  Articuno:2.33%    |  Dragonair:0.88%  |  Magikarp:2.87%   |  

######

