# Text classification

Text classification is a common NLP task that assigns a label or class to text. Some of the largest companies run text classification in production for a wide range of practical applications. One of the most popular forms of text classification is sentiment analysis, which assigns a label like üôÇ positive, üôÅ negative, or üòê neutral to a sequence of text.

This guide will show you how to:

1. Finetune [DistilBERT](https://huggingface.co/distilbert-base-uncased) on the [IMDb](https://huggingface.co/datasets/imdb) dataset to determine whether a movie review is positive or negative.
2. Use your finetuned model for inference.

<Tip>
The task illustrated in this tutorial is supported by the following model architectures:

<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->

[ALBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/albert), [BART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bart), [BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert), [BigBird](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bigbird_pegasus), [BLOOM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/camembert), [CANINE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/canine), [ConvBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/convbert), [CTRL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ctrl), [Data2VecText](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/data2vec-text), [DeBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta), [DeBERTa-v2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/deberta-v2), [DistilBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/distilbert), [ELECTRA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie), [ErnieM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie_m), [ESM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/esm), [FlauBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/flaubert), [FNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/fnet), [Funnel Transformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/funnel), [GPT-Sw3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt-sw3), [OpenAI GPT-2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt2), [GPTBigCode](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_bigcode), [GPT Neo](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox), [GPT-J](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gptj), [I-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ibert), [LayoutLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlm), [LayoutLMv2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv2), [LayoutLMv3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/layoutlmv3), [LED](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/led), [LiLT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/lilt), [LLaMA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/llama), [Longformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/longformer), [LUKE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/luke), [MarkupLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/markuplm), [mBART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/megatron-bert), [MobileBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mobilebert), [MPNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mpnet), [MVP](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mvp), [Nezha](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nezha), [Nystr√∂mformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/nystromformer), [OpenAI GPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/openai-gpt), [OPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/opt), [Perceiver](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/perceiver), [PLBart](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/plbart), [QDQBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/reformer), [RemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roformer), [SqueezeBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/squeezebert), [TAPAS](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/tapas), [Transformer-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/transfo-xl), [XLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm), [XLM-RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xmod), [YOSO](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/yoso)


<!--End of the generated tip-->

</Tip>


## load dataset

## Preprocess
- load DistilBERT Tokenizer to process the text field
Can we load T5-small?
- Create a preprocessinging function to tokenize text and trucnate sequences to be no longer than DistilBERT's or T5's max input length


In [1]:
T5_SMALL = "t5-small"
GPT = "gpt2"  # 117M parameters as per https://huggingface.co/transformers/v3.3.1/pretrained_models.html # "openai-gpt"
DISTILBERT = "distilbert-base-uncased"

In [2]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
tokenizer = AutoTokenizer.from_pretrained(GPT)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)

# Use `DataCollatorWithPadding` as it is more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding to max length
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


NameError: name 'imdb' is not defined

## openprompt

In [None]:
# opt
from collections import defaultdict

# mandatory imports
from datasets import load_dataset


load dataset

In [2]:
#raw_dataset = load_dataset('super_glue', 'cb', cache_dir="./datasets/.cache/huggingface_datasets")
raw_dataset = load_dataset('ag_news', cache_dir="./datasets/.cache/huggingface_datasets")

Found cached dataset ag_news (/home/jovyan/llm_peft_exploration/datasets/.cache/huggingface_datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  4.84it/s]


In [3]:
{k: len(raw_dataset[k]) for k in raw_dataset}

{'train': 120000, 'test': 7600}

In [4]:
def check_label_distr(data_split):
    label_distr = defaultdict(list)
    for idx, item in enumerate(raw_dataset[data_split]):
        label = item['label']
        label_distr[item['label']].append(idx)
    label_distr_lens = {label: len(label_distr[label]) for label in label_distr.keys()}
    print(f"label distribution in {data_split} is:\n{label_distr_lens}")
    return label_distr

In [5]:
DATA_SPLIT = 'train'
train_label_distr = check_label_distr(data_split=DATA_SPLIT)

label distribution in train is:
{2: 30000, 3: 30000, 1: 30000, 0: 30000}


In [6]:
DATA_SPLIT = 'test'
CLASS_LABEL = 2
test_label_distr = check_label_distr(data_split=DATA_SPLIT)
raw_dataset[DATA_SPLIT][test_label_distr[CLASS_LABEL][0]]

label distribution in test is:
{2: 1900, 3: 1900, 1: 1900, 0: 1900}


{'text': "Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
 'label': 2}

Cast examples into openprompts's InputExample structure

In [8]:
from openprompt.data_utils import InputExample
dataset = {}
for split in ['train', 'test']:
    dataset[split] = []
    for idx, data in enumerate(raw_dataset[split]):
        input_example = InputExample(text_a = data['text'], label=int(data['label']), guid=idx)
        dataset[split].append(input_example)
print(dataset['train'][0])


{
  "guid": 0,
  "label": 2,
  "meta": {},
  "text_a": "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.",
  "text_b": "",
  "tgt_text": null
}



In [9]:
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm('t5', 't5-base')

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [10]:
# Constructing Template
# A template can be constructed from the yaml config, but it can also be constructed by directly passing arguments.
from openprompt.prompts import ManualTemplate
template_text = '{"placeholder":"text_a"}. The topic of this news headline is {"mask"}.'
mytemplate = ManualTemplate(tokenizer=tokenizer, text=template_text)


In [11]:
# To better understand how does the template wrap the example, we visualize one instance.
wrapped_example = mytemplate.wrap_one_example(dataset['train'][10])
print(wrapped_example)

[[{'text': "Oil and Economy Cloud Stocks' Outlook  NEW YORK (Reuters) - Soaring crude prices plus worries  about the economy and the outlook for earnings are expected to  hang over the stock market next week during the depth of the  summer doldrums.", 'loss_ids': 0, 'shortenable_ids': 1}, {'text': '. The topic of this news headline is', 'loss_ids': 0, 'shortenable_ids': 0}, {'text': '<mask>', 'loss_ids': 1, 'shortenable_ids': 0}, {'text': '.', 'loss_ids': 0, 'shortenable_ids': 0}], {'guid': 10, 'label': 2}]


* Now the wrapped example is ready to be passed to the tokenizer, hence producing the input for the language model
* You can use the tokenizer to tokenize the input yourelf but we recommend using the wrapped tokenizer, which is wrapped around the tokenizer tailored for the InputExample
* The wrapper has been given if you use our `load_plm` function, otherwise you shoudl choose the suitable wrapper based on the configs in `openprompt.plms.__init__.py`
* Note that when t5 is used for classification, we only need to pass `<pad><extra_id_0><eos>` to decoder.
* the loss is calculated at the `<extra_id_0>`. Thus passing decder_max_length=3 saves the space
(`<extra_id_0>` is the `<mask>` token)

In [12]:
wrapped_t5tokenizer = WrapperClass(max_seq_length=256, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")
# or
#from openprompt.plms import T5TokenizerWrapper
#wrapped_t5tokenizer= T5TokenizerWrapper(max_seq_length=128, decoder_max_length=3, tokenizer=tokenizer,truncate_method="head")


In [13]:
# You can see what a tokenized example looks like by
tokenized_example = wrapped_t5tokenizer.tokenize_one_example(wrapped_example, teacher_forcing=False)
print(tokenized_example)
print(tokenizer.convert_ids_to_tokens(tokenized_example['input_ids']))
print(tokenizer.convert_ids_to_tokens(tokenized_example['decoder_input_ids']))


{'input_ids': [6067, 11, 22077, 5713, 6394, 7, 31, 19269, 8747, 3, 476, 2990, 439, 41, 18844, 61, 3, 18, 180, 26605, 19058, 1596, 303, 17168, 81, 8, 2717, 11, 8, 16395, 21, 8783, 33, 1644, 12, 5168, 147, 8, 1519, 512, 416, 471, 383, 8, 4963, 13, 8, 1248, 103, 40, 17870, 7, 5, 3, 5, 37, 2859, 13, 48, 1506, 12392, 19, 32099, 3, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [24]:
# Now it's time to convert the whole dataset into the input format!
# Simply loop over the dataset to achieve it!

model_inputs = {}
for split in ['test']:  # ['train', 'test']:
    model_inputs[split] = []
    for sample in dataset[split]:
        tokenized_example = wrapped_t5tokenizer.tokenize_one_example(mytemplate.wrap_one_example(sample), teacher_forcing=False)
        model_inputs[split].append(tokenized_example)



In [15]:
# We provide a `PromptDataLoader` class to help you do all the above matters and wrap them into an `torch.DataLoader` style iterator.
from openprompt import PromptDataLoader

train_dataloader = PromptDataLoader(dataset=dataset["train"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=16,shuffle=True, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")
# next(iter(train_dataloader))



tokenizing: 120000it [02:51, 698.22it/s]


In [16]:
classes = ['world', 'sports', 'business', 'sci/tech']
id2label_dict = {idx:k for idx, k in enumerate(classes)}
label2id_dict = {v:k for k,v in id2label_dict.items()}

In [17]:
# Define the verbalizer
# In classification, you need to define your verbalizer, which is a mapping from logits on the vocabulary to the final label probability. Let's have a look at the verbalizer details:

from openprompt.prompts import ManualVerbalizer
import torch

# for example the verbalizer contains multiple label words in each class
myverbalizer = ManualVerbalizer(
    tokenizer,
    num_classes=4,
    classes=classes,
    label_words={
        "world": ["world"],
        "sports": ["sports"],
        "business": ["business"],
        "sci/tech": ["science",  "technology", "sci tech"],
    }
)

print(myverbalizer.label_words_ids)

Parameter containing:
tensor([[[  296,     0],
         [    0,     0],
         [    0,     0]],

        [[ 2100,     0],
         [    0,     0],
         [    0,     0]],

        [[  268,     0],
         [    0,     0],
         [    0,     0]],

        [[ 2056,     0],
         [  748,     0],
         [17201,  5256]]])


In [18]:
#len(tokenizer) returns the vocab_size!

In [19]:
logits = torch.randn(2,len(tokenizer)) # creating a pseudo output from the plm, and
print(myverbalizer.process_logits(logits)) # see what the verbalizer does


tensor([[-2.1145, -2.4590, -0.9576, -2.2278],
        [-1.9564, -2.5508, -1.5859, -1.7017]])


Although you can manually combine the plm, template, verbalizer together, we provide a pipeline
model which takes the batched data from the PromptDataLoader and produces class wise logits


In [28]:
from openprompt import PromptForClassification

use_cuda = True
prompt_model = PromptForClassification(plm=plm,template=mytemplate, verbalizer=myverbalizer, freeze_plm=False)
if use_cuda:
    prompt_model=  prompt_model.cuda()

In [29]:
# Now the training is standard
from transformers import  AdamW, get_linear_schedule_with_warmup
loss_func = torch.nn.CrossEntropyLoss()
no_decay = ['bias', 'LayerNorm.weight']
# it's always good practice to set no decay to biase and LayerNorm parameters
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [30]:
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)



In [25]:
# Evaluate
test_dataloader = PromptDataLoader(dataset=dataset["test"], template=mytemplate, tokenizer=tokenizer,
    tokenizer_wrapper_class=WrapperClass, max_seq_length=256, decoder_max_length=3,
    batch_size=4,shuffle=False, teacher_forcing=False, predict_eos_token=False,
    truncate_method="head")


tokenizing: 7600it [00:10, 705.83it/s]


In [33]:
len(test_dataloader)

1900

In [50]:
from tqdm import tqdm
train_losses = []
test_losses = []
use_cuda = True
for epoch in range(20):
    tot_loss = 0
    for step, inputs in enumerate(tqdm(train_dataloader)):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        ep_train_loss = tot_loss/(step+1)
        if step %2000 ==1:
            train_losses.append(ep_train_loss)
            print("Epoch {}, average loss: {}".format(epoch, ep_train_loss), flush=True)
    allpreds = []
    alllabels = []
    for batch, inputs in enumerate(test_dataloader):
        if use_cuda:
            inputs = inputs.cuda()
        logits = prompt_model(inputs)
        labels = inputs['label']
        alllabels.extend(labels.cpu().tolist())
        allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
    acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
    print(f'Epoch {epoch}: test accuracy: {acc}')



  0%|          | 1/7500 [00:00<38:03,  3.28it/s]

Epoch 0, average loss: 0.07447695569135249


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:54<27:37,  3.32it/s]

Epoch 0, average loss: 0.09317532743272194


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:49<17:07,  3.41it/s]

Epoch 0, average loss: 0.09629840620888583


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:43<07:14,  3.45it/s]

Epoch 0, average loss: 0.09757871261904678


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:07<00:00,  3.37it/s]


Epoch 0: test accuracy: 0.9465789473684211


  0%|          | 1/7500 [00:00<37:45,  3.31it/s]

Epoch 1, average loss: 0.031372164376080036


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<27:28,  3.34it/s]

Epoch 1, average loss: 0.03638204309497708


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:48<17:10,  3.40it/s]

Epoch 1, average loss: 0.041061164024217466


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:40<07:23,  3.38it/s]

Epoch 1, average loss: 0.042727474881904176


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:05<00:00,  3.37it/s]


Epoch 1: test accuracy: 0.945


  0%|          | 1/7500 [00:00<37:40,  3.32it/s]

Epoch 2, average loss: 0.008110608905553818


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<26:58,  3.40it/s]

Epoch 2, average loss: 0.019265413387238885


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:46<17:05,  3.41it/s]

Epoch 2, average loss: 0.023106005005659007


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:39<07:35,  3.29it/s]

Epoch 2, average loss: 0.025498516863532664


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:04<00:00,  3.37it/s]


Epoch 2: test accuracy: 0.9465789473684211


  0%|          | 1/7500 [00:00<37:47,  3.31it/s]

Epoch 3, average loss: 0.0018224656232632697


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<27:27,  3.34it/s]

Epoch 3, average loss: 0.01433275319054617


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<17:25,  3.35it/s]

Epoch 3, average loss: 0.017226989501110814


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:41<07:29,  3.34it/s]

Epoch 3, average loss: 0.01908128715847471


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:06<00:00,  3.37it/s]


Epoch 3: test accuracy: 0.9464473684210526


  0%|          | 1/7500 [00:00<42:51,  2.92it/s]

Epoch 4, average loss: 0.0010925912647508085


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<26:43,  3.43it/s]

Epoch 4, average loss: 0.012275348992028971


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<17:08,  3.40it/s]

Epoch 4, average loss: 0.01310150157845624


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:39<07:19,  3.41it/s]

Epoch 4, average loss: 0.015475986897156149


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:04<00:00,  3.37it/s]


Epoch 4: test accuracy: 0.9475


  0%|          | 1/7500 [00:00<38:04,  3.28it/s]

Epoch 5, average loss: 0.00480890175094828


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:54<27:04,  3.38it/s]

Epoch 5, average loss: 0.011649178321212029


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<17:21,  3.36it/s]

Epoch 5, average loss: 0.01214325343143492


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:39<07:23,  3.38it/s]

Epoch 5, average loss: 0.013850588207033502


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:04<00:00,  3.37it/s]


Epoch 5: test accuracy: 0.9472368421052632


  0%|          | 1/7500 [00:00<37:30,  3.33it/s]

Epoch 6, average loss: 0.0001773704971128609


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<26:55,  3.40it/s]

Epoch 6, average loss: 0.010855463469311162


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<17:16,  3.38it/s]

Epoch 6, average loss: 0.012064546242813825


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:42<07:37,  3.27it/s]

Epoch 6, average loss: 0.01250786911575898


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:07<00:00,  3.37it/s]


Epoch 6: test accuracy: 0.9478947368421052


  0%|          | 1/7500 [00:00<37:23,  3.34it/s]

Epoch 7, average loss: 0.0009870353387668729


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:54<27:15,  3.36it/s]

Epoch 7, average loss: 0.007647645044765501


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:51<17:45,  3.28it/s]

Epoch 7, average loss: 0.009735594871383875


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:48<07:24,  3.37it/s]

Epoch 7, average loss: 0.01019961646487118


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:14<00:00,  3.36it/s]


Epoch 7: test accuracy: 0.9456578947368421


  0%|          | 1/7500 [00:00<38:35,  3.24it/s]

Epoch 8, average loss: 0.0006434436654672027


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:55<26:44,  3.43it/s]

Epoch 8, average loss: 0.010980361509342151


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:49<16:58,  3.43it/s]

Epoch 8, average loss: 0.010838185029570315


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:44<07:25,  3.36it/s]

Epoch 8, average loss: 0.010895906623017877


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:10<00:00,  3.36it/s]


Epoch 8: test accuracy: 0.9447368421052632


  0%|          | 1/7500 [00:00<37:47,  3.31it/s]

Epoch 9, average loss: 0.010528070370128262


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:54<27:06,  3.38it/s]

Epoch 9, average loss: 0.009318949781246576


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:48<17:20,  3.36it/s]

Epoch 9, average loss: 0.010102476102322058


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:43<07:30,  3.33it/s]

Epoch 9, average loss: 0.011255457109534


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:07<00:00,  3.37it/s]


Epoch 9: test accuracy: 0.9467105263157894


  0%|          | 1/7500 [00:00<37:16,  3.35it/s]

Epoch 10, average loss: 8.784360215940978e-05


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:54<26:58,  3.40it/s]

Epoch 10, average loss: 0.006546116735990147


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:48<18:20,  3.18it/s]

Epoch 10, average loss: 0.008889513546248137


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:40<07:17,  3.43it/s]

Epoch 10, average loss: 0.009827858608230659


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:04<00:00,  3.37it/s]


Epoch 10: test accuracy: 0.9469736842105263


  0%|          | 1/7500 [00:00<37:23,  3.34it/s]

Epoch 11, average loss: 0.0011608727218117565


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<26:50,  3.41it/s]

Epoch 11, average loss: 0.006660266698941149


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<16:56,  3.44it/s]

Epoch 11, average loss: 0.008020238391335437


 71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 5336/7500 [26:22<10:28,  3.44it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:48<17:10,  3.39it/s]

Epoch 12, average loss: 0.007884605233659463


 67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 5012/7500 [24:48<12:08,  3.42it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:55<26:57,  3.40it/s]

Epoch 13, average loss: 0.008033443022561813


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:50<17:30,  3.33it/s]

Epoch 13, average loss: 0.008259694514089342


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:45<07:27,  3.35it/s]

Epoch 13, average loss: 0.008924862471027702


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:11<00:00,  3.36it/s]


Epoch 13: test accuracy: 0.9478947368421052


  0%|          | 1/7500 [00:00<37:28,  3.34it/s]

Epoch 14, average loss: 0.00048300328489858657


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<27:00,  3.39it/s]

Epoch 14, average loss: 0.005861840679875735


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:49<17:17,  3.37it/s]

Epoch 14, average loss: 0.006774755869287761


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:41<07:39,  3.26it/s]

Epoch 14, average loss: 0.0071073428124988425


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:09<00:00,  3.36it/s]


Epoch 14: test accuracy: 0.9486842105263158


  0%|          | 1/7500 [00:00<38:16,  3.27it/s]

Epoch 15, average loss: 0.01741854052670533


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:52<26:55,  3.40it/s]

Epoch 15, average loss: 0.006390110117993448


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:46<16:59,  3.43it/s]

Epoch 15, average loss: 0.0067789925776595445


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:39<07:21,  3.39it/s]

Epoch 15, average loss: 0.007404725531324859


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:04<00:00,  3.37it/s]


Epoch 15: test accuracy: 0.9481578947368421


  0%|          | 1/7500 [00:00<37:41,  3.32it/s]

Epoch 16, average loss: 0.0008547779070795514


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:53<27:28,  3.34it/s]

Epoch 16, average loss: 0.005855757648435902


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:47<17:38,  3.30it/s]

Epoch 16, average loss: 0.007172257324654123


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:39<07:21,  3.39it/s]

Epoch 16, average loss: 0.007530230293921482


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [37:01<00:00,  3.38it/s]


Epoch 16: test accuracy: 0.9475


  0%|          | 1/7500 [00:00<41:02,  3.05it/s]

Epoch 17, average loss: 4.2787789425347e-05


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:51<27:10,  3.37it/s]

Epoch 17, average loss: 0.004507046333325917


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:43<17:23,  3.35it/s]

Epoch 17, average loss: 0.0058147416288691365


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:34<07:18,  3.42it/s]

Epoch 17, average loss: 0.006744115019564166


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [36:59<00:00,  3.38it/s]


Epoch 17: test accuracy: 0.9477631578947369


  0%|          | 1/7500 [00:00<37:34,  3.33it/s]

Epoch 18, average loss: 5.784257382401847e-05


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:52<26:44,  3.43it/s]

Epoch 18, average loss: 0.004421349419179631


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:43<17:12,  3.39it/s]

Epoch 18, average loss: 0.004974599751183542


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:34<07:44,  3.23it/s]

Epoch 18, average loss: 0.005956028616079698


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [36:58<00:00,  3.38it/s]


Epoch 18: test accuracy: 0.9457894736842105


  0%|          | 1/7500 [00:00<40:03,  3.12it/s]

Epoch 19, average loss: 0.002635062555782497


 27%|‚ñà‚ñà‚ñã       | 2001/7500 [09:51<26:46,  3.42it/s]

Epoch 19, average loss: 0.004935971168553347


 53%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé    | 4001/7500 [19:44<17:18,  3.37it/s]

Epoch 19, average loss: 0.004854086502257933


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6001/7500 [29:35<07:29,  3.34it/s]

Epoch 19, average loss: 0.006098550923006531


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [36:58<00:00,  3.38it/s]


Epoch 19: test accuracy: 0.9463157894736842


In [51]:
import torch

torch.save(prompt_model, 'data/models/tensor.pt')

test how the accuracy improves with batches of training data

In [52]:
import json
with open('test_losses.json', 'w+') as f:
    json.dump(test_losses, f)

In [None]:
import json
with open('train_losses.json', 'w+') as f:
    json.dump(test_losses, f)

In [25]:
allpreds = []
alllabels = []
for batch, inputs in enumerate(validation_dataloader):
    if use_cuda:
        inputs = inputs.cuda()
    logits = prompt_model(inputs)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())
acc = sum([int(i==j) for i,j in zip(allpreds, alllabels)])/len(allpreds)
acc    


0.9107142857142857

In [None]:
allpreds

In [None]:
from sklearn.metrics import accuracy_score
accuracy_score(y_true=alllabels, y_pred=allpreds)

In [None]:
dataset