# Text classification

Text classification is a common NLP task that assigns a label or class to text. Some of the largest companies run text classification in production for a wide range of practical applications. One of the most popular forms of text classification is sentiment analysis, which assigns a label like 🙂 positive, 🙁 negative, or 😐 neutral to a sequence of text.

This guide will show you how to:

1. Finetune [DistilBERT](https://huggingface.co/distilbert-base-uncased) on the [IMDb](https://huggingface.co/datasets/imdb) dataset to determine whether a movie review is positive or negative.
2. Use your finetuned model for inference.

<Tip>
The task illustrated in this tutorial is supported by the following model architectures:

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/albert), [BART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bart), [BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert), [BigBird](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bigbird_pegasus), [BLOOM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/camembert), [CANINE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/canine), [ConvBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/convbert), [CTRL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ctrl), [Data2VecText](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/data2vec-text), [DeBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta), [DeBERTa-v2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta-v2), [DistilBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/distilbert), [ELECTRA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie), [ErnieM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie_m), [ESM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/esm), [FlauBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/flaubert), [FNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/fnet), [Funnel Transformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/funnel), [GPT-Sw3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt-sw3), [OpenAI GPT-2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt2), [GPTBigCode](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_bigcode), [GPT Neo](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox), [GPT-J](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gptj), [I-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ibert), [LayoutLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlm), [LayoutLMv2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv2), [LayoutLMv3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv3), [LED](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/led), [LiLT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/lilt), [LLaMA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/llama), [Longformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/longformer), [LUKE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/luke), [MarkupLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/markuplm), [mBART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/megatron-bert), [MobileBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mobilebert), [MPNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mpnet), [MVP](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mvp), [Nezha](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nezha), [Nyströmformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nystromformer), [OpenAI GPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/openai-gpt), [OPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/opt), [Perceiver](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/perceiver), [PLBart](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/plbart), [QDQBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/reformer), [RemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roformer), [SqueezeBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/squeezebert), [TAPAS](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/tapas), [Transformer-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/transfo-xl), [XLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm), [XLM-RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xmod), [YOSO](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/yoso)


<!--End of the generated tip-->

</Tip>


## load dataset

In [None]:
from datasets import load_dataset
#imdb = load_dataset("imdb")

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset imdb (/home/jovyan/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
  0%|          | 0/3 [00:00<?, ?it/s]

## Preprocess
- load DistilBERT Tokenizer to process the text field
Can we load T5-small?
- Create a preprocessinging function to tokenize text and trucnate sequences to be no longer than DistilBERT's or T5's max input length


In [4]:
T5_SMALL = "t5-small"
GPT = "gpt2"  # 117M parameters as per https://huggingface.co/transformers/v3.3.1/pretrained_models.html # "openai-gpt"
DISTILBERT = "distilbert-base-uncased"

In [5]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained(GPT)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

In [6]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)

# Use `DataCollatorWithPadding` as it is more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding to max length
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-de26af1d4e21f3fd.arrow
Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-969fb4e6a6458b6c.arrow
Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-973e8afd19ecaa3f.arrow


## Evaluate

In [9]:
import evaluate
import numpy as np
accuracy = evaluate.load("accuracy")

# create fn that pass predictions and labels to compute accuracy
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

## openprompt

In [None]:
# optional imports
from collections import defaultdict

# mandatory imports
from datasets import load_dataset


load dataset

In [18]:
raw_dataset = load_dataset('super_glue', 'cb', cache_dir="./datasets/.cache/huggingface_datasets")

Found cached dataset super_glue (/home/jovyan/llm_peft_exploration/datasets/.cache/huggingface_datasets/super_glue/cb/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)
100%|██████████| 3/3 [00:00<00:00, 462.42it/s]


In [19]:
print(raw_dataset['train'][1])
{k: len(raw_dataset[k]) for k in raw_dataset}

{'premise': 'It is part of their religion, a religion I do not scoff at as it holds many elements which match our own even though it lacks the truth of ours. At one of their great festivals they have the ritual of driving out the devils from their bodies. First the drummers come on - I may say that no women are allowed to take part in this ritual and the ladies here will perhaps agree with me that they are fortunate in that omission.', 'hypothesis': 'no women are allowed to take part in this ritual', 'idx': 1, 'label': 0}


{'train': 250, 'validation': 56, 'test': 250}

In [20]:
def check_label_distr(data_split):
    labels = defaultdict(list)
    for idx, item in enumerate(raw_dataset[data_split]):
        label = item['label']
        labels[item['label']].append(idx)
    label_distr = {label: len(labels[label]) for label in labels.keys()}
    print(f"label distribution in {data_split} is:\n{label_distr}")

In [21]:
DATA_SPLIT = 'train'
check_label_distr(data_split=DATA_SPLIT)

label distribution in train is:
{0: 115, 1: 119, 2: 16}


In [22]:
DATA_SPLIT = 'test'
check_label_distr(data_split=DATA_SPLIT)

label distribution in test is:
{-1: 250}


In [5]:
from openprompt.data_utils import InputExample

dataset = {}
for split in ['train', 'validation', 'test']:
    dataset[split] = []
    for data in raw_dataset[split]:
        input_example = InputExample(text_a = data['premise'], text_b = data['hypothesis'], label=int(data['label']), guid=data['idx'])
        dataset[split].append(input_example)
print(dataset['train'][0])




{
  "guid": 0,
  "label": 0,
  "meta": {},
  "text_a": "It was a complex language. Not written down but handed down. One might say it was peeled down.",
  "text_b": "the language was peeled down",
  "tgt_text": null
}



In [7]:
for example in labels[1][:3]:
    print(raw_dataset['train'][example])

{'premise': "He's weird enough to have undressed me without thinking, according to some mad notion of the ``proper'' thing to do. Perhaps he thought I couldn't lie in bed with my clothes on.", 'hypothesis': "she couldn't lie in bed with her clothes on", 'idx': 23, 'label': 1}
{'premise': "I should dearly have liked to know whether they were Europeans or Americans, but I couldn't hear the accents. They appeared to be arguing. I hoped the white men weren't telling him to eliminate all witnesses because I don't believe it would have needed much persuasion.", 'hypothesis': 'eliminating all witnesses would have needed much persuasion', 'idx': 26, 'label': 1}
{'premise': "But the damage was done as far as my faith was concerned, which is probably why I went mad. So anyway, that Christmas Eve night confirmed my worst fears, it was like a kind of ``royal flush'' for the infant Jimbo. All three kings - Pa Santa and the King of Kings - all down the pan together... And to be honest I don't believ

In [8]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm('t5', 't5-base')

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [9]:
# Constructing Template
# A template can be constructed from the yaml config, but it can also be constructed by directly passing arguments.
from openprompt.prompts import ManualTemplate
template_text = '{"placeholder":"text_a"} Question: {"placeholder":"text_b"}? Is it correct? {"mask"}.'
mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)


In [10]:
# To better understand how does the template wrap the example, we visualize one instance.
wrapped_example = mytemplate.wrap_one_example(dataset['train'][10])
print(wrapped_example)

[[{'text': "``Do you mind if I use your phone?'' Ronni could see that Guido's brain was whirring.", 'loss_ids': 0, 'shortenable_ids': 1}, {'text': ' Question:', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': " Guido's brain was whirring", 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '? Is it correct?', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}, {'text': '.', 'loss_ids': 0, 'shortenable_ids': 0}], {'guid': 10, 'label': 0}]


* Now the wrapped example is ready to be passed to the tokenizer, hence producing the input for the language model
* You can use the tokenizer to tokenize the input yourelf but we recommend using the wrapped tokenizer, which is wrapped around the tokenizer tailored for the InputExample
* The wrapper has been given if you use our `load_plm` function, otherwise you shoudl choose the suitable wrapper based on the configs in `openprompt.plms.__init__.py`
* Note that when t5 is used for classification, we only need to pass `<pad><extra_id_0><eos>` to decoder.
* the loss is calculated at the `<extra_id_0>`. Thus passing decder_max_length=3 saves the space
(`<extra_id_0>` is the `<mask>` token)

In [11]:
wrapped_t5tokenizer = WrapperClass(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
# or
#from openprompt.plms import T5TokenizerWrapper
#wrapped_t5tokenizer= T5TokenizerWrapper(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")


In [12]:
# You can see what a tokenized example looks like by
tokenized_example = wrapped_t5tokenizer.tokenize_one_example(wrapped_example, teacher_forcing=False)
print(tokenized_example)
print(tokenizer.convert_ids_to_tokens(tokenized_example['input_ids']))
print(tokenizer.convert_ids_to_tokens(tokenized_example['decoder_input_ids']))


{'input_ids': [3, 2, 4135, 25, 809, 3, 99, 27, 169, 39, 951, 58, 31, 31, 10297, 29, 23, 228, 217, 24, 15703, 26, 32, 31, 7, 2241, 47, 3, 12124, 52, 1007, 5, 11860, 10, 15703, 26, 32, 31, 7, 2241, 47, 3, 12124, 52, 1007, 3, 58, 27, 7, 34, 2024, 58, 32099, 3, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'decoder_input_ids': [0, 32099, 0], 'loss_ids': [0, 1, 0]}
['▁', '<unk>', 'Do', '▁you', '▁mind', '▁', 'if', '▁I', '

In [13]:
# Now it's time to convert the whole dataset into the input format!
# Simply loop over the dataset to achieve it!

model_inputs = {}
for split in ['train', 'validation', 'test']:
    model_inputs[split] = []
    for sample in dataset[split]:
        tokenized_example = wrapped_t5tokenizer.tokenize_one_example(mytemplate.wrap_one_example(sample), teacher_forcing=False)
        model_inputs[split].append(tokenized_example)



Token indices sequence length is longer than the specified maximum sequence length for this model (519 > 512). Running this sequence through the model will result in indexing errors


### why is the max_seq_len below different than the warpped tokenizer above?

In [14]:
# We provide a `PromptDataLoader` class to help you do all the above matters and wrap them into an `torch.DataLoader` style iterator.
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
# next(iter(train_dataloader))



tokenizing: 250it [00:00, 432.44it/s]


In [15]:
# Define the verbalizer
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:

from openprompt.prompts import ManualVerbalizer
import torch

# for example the verbalizer contains multiple label words in each class
myverbalizer = ManualVerbalizer(tokenizer, num_classes=3,
                        label_words=[["yes"], ["no"], ["maybe"]])

print(myverbalizer.label_words_ids)

Parameter containing:
tensor([[[4273]],

        [[ 150]],

        [[2087]]])


In [16]:
#len(tokenizer) returns the vocab_size!

In [17]:
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm, and
print(myverbalizer.process_logits(logits)) # see what the verbalizer does


tensor([[-1.5613, -0.3632, -2.3569],
        [-1.4797, -2.1906, -0.4148]])


Although you can manually combine the plm, template, verbalizer together, we provide a pipeline
model which takes the batched data from the PromptDataLoader and produces class wise logits


In [None]:
from openprompt import PromptForClassification

use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()

In [19]:
# Now the training is standard
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']
# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [20]:
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)



In [21]:
from tqdm import tqdm 
for epoch in range(10):
    tot_loss = 0
    for step, inputs in enumerate(tqdm(train_dataloader)):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if step %100 ==1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss/(step+1)), flush=True)


  2%|▏         | 1/63 [00:10<10:33, 10.23s/it]

Epoch 0, average loss: 1.363224446773529


100%|██████████| 63/63 [08:26<00:00,  8.04s/it]
  2%|▏         | 1/63 [00:07<07:25,  7.19s/it]

Epoch 1, average loss: 0.0500437468290329


100%|██████████| 63/63 [08:21<00:00,  7.96s/it]
  2%|▏         | 1/63 [00:07<07:54,  7.65s/it]

Epoch 2, average loss: 0.03132111858576536


100%|██████████| 63/63 [08:18<00:00,  7.91s/it]
  2%|▏         | 1/63 [00:07<07:59,  7.73s/it]

Epoch 3, average loss: 0.0017303311615251005


100%|██████████| 63/63 [08:30<00:00,  8.11s/it]
  2%|▏         | 1/63 [00:07<07:49,  7.57s/it]

Epoch 4, average loss: 0.000288311031908961


100%|██████████| 63/63 [08:22<00:00,  7.97s/it]
  2%|▏         | 1/63 [00:09<09:42,  9.40s/it]

Epoch 5, average loss: 0.00013476229287334718


100%|██████████| 63/63 [08:19<00:00,  7.92s/it]
  2%|▏         | 1/63 [00:07<07:56,  7.69s/it]

Epoch 6, average loss: 0.0006233839376363903


100%|██████████| 63/63 [08:24<00:00,  8.01s/it]
  2%|▏         | 1/63 [00:07<07:43,  7.48s/it]

Epoch 7, average loss: 0.0002590718795545399


100%|██████████| 63/63 [08:14<00:00,  7.85s/it]
  2%|▏         | 1/63 [00:08<08:18,  8.05s/it]

Epoch 8, average loss: 0.0001658460678299889


100%|██████████| 63/63 [09:02<00:00,  8.61s/it]
  2%|▏         | 1/63 [00:07<07:57,  7.70s/it]

Epoch 9, average loss: 4.914044984616339e-05


100%|██████████| 63/63 [08:32<00:00,  8.14s/it]


In [23]:
# Evaluate
validation_dataloader = PromptDataLoader(dataset=dataset["validation"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")


tokenizing: 56it [00:00, 375.16it/s]


In [25]:
allpreds = []
alllabels = []
for batch, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
acc    


0.9107142857142857

In [None]:
allpreds

In [None]:
from sklearn.metrics import accuracy_score
accuracy_score(y_true=alllabels, y_pred=allpreds)

In [None]:
dataset