# Zero-shot text classification with Natural Language Inference

### 0. Notebook setup

In [11]:
import sys
sys.path.append('../src')
import env_options
import nli_finetuning_utils
import lmsys_dataset_handler as lmsys
import text_classification_functions as tcf
from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset, Dataset
import random
import pandas as pd
import os
import torch
import textwrap
from IPython.display import clear_output
hf_token, hf_token_write, openai_api_key = env_options.check_env(dotenv_path='../../../../../../apis/.env')

Python version: 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
PyTorch version: 2.2.2
Transformers version: 4.44.2
CUDA device: NVIDIA GeForce RTX 4060 Laptop GPU
CUDA Version: 12.1
FlashAttention available: True
Retrieved token(s) from .env file
Using HuggingFace token: hf_M*****************************IASJ
Using HuggingFace write token: hf_u*****************************Xipx
Using OpenAI token: sk-p************************************************************************************************************************************************************_5sA


### 1. Testing inference with pre-trained model

#### 1.1 AG News dataset

In [12]:
dataset = load_dataset("fancyzhx/ag_news", split="test")
use_sampled_dataset=False
if use_sampled_dataset:
    dataset = dataset.shuffle(seed=42).select(range(100))
agn_labels = ["World", "Sports", "Business", "Sci/Tech"]
dataset = dataset.map(lambda x: {"class": agn_labels[x["label"]]}, remove_columns=["label"])
df_agnews = dataset.to_pandas()
print(f"Extracted {len(df_agnews)} records. Sample")
display(df_agnews.sample(5))

Extracted 7600 records. Sample


Unnamed: 0,text,class
33,"Man Sought #36;50M From McGreevey, Aides Say ...",World
4869,Hosted E-Mail Service Leaves Windows for Linux...,Sci/Tech
128,UPI NewsTrack Sports -- The United States men ...,Sports
5137,"Retail, auto sales, job numbers suggest toughe...",Business
4498,This Just In - Sprint is Stupid \\Found this ...,Sci/Tech


Testing inference with facebook/bart-large-mnli

In [15]:
nli_model_path = 'facebook/bart-large-mnli'
zs_classifier_agnews = tcf.ZeroShotClassifier(model_path=nli_model_path, tokenizer_path=nli_model_path, candidate_labels=agn_labels)



Single example:

In [16]:
text_sample = df_agnews.sample(1).text.values[0]
zs_classifier_agnews.classify_text(text_sample, multi_label=False)

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'sequence': 'AUBURN 21, ALABAMA 13 Auburn #39;s Strong Second Half Keeps It in &lt;b&gt;...&lt;/b&gt; For one half Saturday, the controversy over the Bowl Championship Series looked like it might disappear in the dampness of Bryant-Denny Stadium as undefeated Auburn found itself in a fight with archrival Alabama.',
 'labels': ['Sports', 'World', 'Sci/Tech', 'Business'],
 'scores': [0.798, 0.09, 0.064, 0.048]}

Testing inference with reddgr/zero-shot-prompt-classifier-bart-ft

In [17]:
nli_model_path_r = 'reddgr/zero-shot-prompt-classifier-bart-ft'
zs_classifier_agnews_r = tcf.ZeroShotClassifier(model_path=nli_model_path_r, tokenizer_path=nli_model_path_r, candidate_labels=agn_labels)
# Classifying the same text sample:
zs_classifier_agnews_r.classify_text(text_sample, multi_label=False)

config.json:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/294 [00:00<?, ?B/s]

{'sequence': 'AUBURN 21, ALABAMA 13 Auburn #39;s Strong Second Half Keeps It in &lt;b&gt;...&lt;/b&gt; For one half Saturday, the controversy over the Bowl Championship Series looked like it might disappear in the dampness of Bryant-Denny Stadium as undefeated Auburn found itself in a fight with archrival Alabama.',
 'labels': ['Sports', 'Business', 'World', 'Sci/Tech'],
 'scores': [0.851, 0.062, 0.057, 0.029]}

Bulk classification:

In [18]:
df_testing = df_agnews.sample(10).copy()
df_testing_zs = zs_classifier_agnews.classify_dataframe_column(df_testing, target_column = 'text', feature_suffix = 'zs') 
display(df_testing_zs)

100%|██████████| 10/10 [00:00<00:00, 11.74it/s]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
100%|██████████| 10/10 [00:01<00:00,  9.80it/s]


Unnamed: 0,text,class,top_class_zs,top_score_zs,full_results_zs
6182,Westwood Closes in on First Title of 2004 SUN...,Sports,Sports,0.431,"[(Sports, 0.431), (World, 0.304), (Sci/Tech, 0..."
7146,HP targets China with low-cost PC SAN FRANCISC...,Business,World,0.532,"[(World, 0.532), (Sci/Tech, 0.224), (Business,..."
4199,October Games Provide Moments to Remember FOR ...,Sports,Sports,0.969,"[(Sports, 0.969), (World, 0.018), (Business, 0..."
552,"Police Tear Gas, Arrest Protesters in Banglade...",World,World,0.437,"[(World, 0.437), (Sci/Tech, 0.362), (Sports, 0..."
623,Civil servants in net porn probe More than 200...,Sci/Tech,World,0.489,"[(World, 0.489), (Sci/Tech, 0.308), (Business,..."
6143,"Myanmar frees nearly 4,000 prisoners UN Secret...",World,Sci/Tech,0.397,"[(Sci/Tech, 0.397), (World, 0.267), (Sports, 0..."
5449,Tributes pour in for #39;Crazy Horse #39; Hug...,Sports,World,0.469,"[(World, 0.469), (Sci/Tech, 0.283), (Business,..."
2790,Salesforce.com launches on-demand support com ...,Sci/Tech,Business,0.438,"[(Business, 0.438), (Sci/Tech, 0.308), (World,..."
6068,"A Fair Tax Some say a ""fair tax"" that removes ...",Sci/Tech,World,0.499,"[(World, 0.499), (Sci/Tech, 0.235), (Business,..."
158,Phelps Eyes Fourth Gold ATHENS (Reuters) - A ...,Sports,Sports,0.738,"[(Sports, 0.738), (World, 0.137), (Sci/Tech, 0..."


#### 1.2 LMSYS Chatbot Arena data:

In [20]:
lmsys_chat_1m = lmsys.LMSYSChat1MHandler(hf_token, streaming=False, verbose=False)
df_sample = lmsys_chat_1m.parquet_sampling(1000) # the method parquet_sampling() selects the samples from a random parquet file so it doesn't download the whole dataset
df_prompts = lmsys_chat_1m.extract_prompts(filter_language=['English', 'Spanish'], max_char_length=400)
prompt_sample = lmsys_chat_1m.extract_prompt_sample()
print("Extracted data from lmsys/lmsys-chat-1m. Prompt sample:\n")
print(prompt_sample)

Sampling from train-00002-of-00006-1779b7cec9462180.parquet
Retrieved 1000 random conversations from lmsys/lmsys-chat-1m/train-00002-of-00006-1779b7cec9462180.parquet
Extracted data from lmsys/lmsys-chat-1m. Prompt sample:

Hello, How are you?


In [21]:
df_sample = lmsys_chat_1m.add_turns_to_conversations()
display(df_sample.head(6))
print(f"Conversation ID {df_sample.iloc[1]['conversation_id']}:\n")
print(df_sample['conversation'].iloc[1])

Unnamed: 0,conversation_id,model,conversation,turn,language,openai_moderation,redacted
46923,e5bfdb9899be47b4977b01bfd61814ff,vicuna-13b,[{'content': 'write an eassy with helpful reso...,1,English,"[{'categories': {'harassment': False, 'harassm...",False
162103,6053795aadb24782ab4db1930586f2a5,vicuna-13b,[{'content': 'Present a paradox that Is not co...,1,English,"[{'categories': {'harassment': False, 'harassm...",False
124129,4ab9192210854e49aa06d014ba0fee1a,vicuna-13b,[{'content': 'You are the text completion mode...,1,English,"[{'categories': {'harassment': False, 'harassm...",False
160402,ac1c5e3af53f4cfa80b2e470ab1500b2,wizardlm-13b,[{'content': 'Who was the physically strongest...,1,English,"[{'categories': {'harassment': False, 'harassm...",False
133516,13bb1ee1a6fb4fbb9d006ead54d1dd8c,vicuna-13b,[{'content': 'Translate English to SQL. Englis...,1,English,"[{'categories': {'harassment': False, 'harassm...",False
65035,3f24471a3a8246bcbc53f7b45a2b8c83,vicuna-13b,[{'content': 'can you share the brc20 with me'...,1,English,"[{'categories': {'harassment': False, 'harassm...",True


Conversation ID 6053795aadb24782ab4db1930586f2a5:

[{'content': 'Present a paradox that Is not commonly known', 'role': 'user', 'turn': 1}
 {'content': 'The paradox of the stone is a lesser-known paradox that raises questions about the nature of predication and reference. It is a form of the liar paradox, which is a family of paradoxes that arise from attempts to define a predicate in circular or self-referential ways.\n\nThe paradox of the stone is as follows:\n\n1. A stone raises itself above itself.\n2. This statement is not true.\n3. Therefore, the stone cannot raise itself above itself.\n\nThe paradox arises because the statement "this statement is not true" appears to refer to itself, but it cannot do so without violating the principle that a predicate must refer to something in the domain of discourse. If the statement refers to itself, then it is true, which leads to a circular reasoning. On the other hand, if the statement refers to something outside of itself, then it is not 

In [22]:
df_unwrapped_turns = lmsys_chat_1m.unwrap_turns()
display(df_unwrapped_turns.head(8))

Unnamed: 0,conversation_id,model,conversation_turns,language,openai_moderation,redacted,turn_n,prompt,response
0,e5bfdb9899be47b4977b01bfd61814ff,vicuna-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,write an eassy with helpful resources and exam...,Effective communication with teenagers is esse...
1,6053795aadb24782ab4db1930586f2a5,vicuna-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,Present a paradox that Is not commonly known,The paradox of the stone is a lesser-known par...
2,4ab9192210854e49aa06d014ba0fee1a,vicuna-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,You are the text completion model and you must...,"To disable the `help` command in Discord.py, y..."
3,ac1c5e3af53f4cfa80b2e470ab1500b2,wizardlm-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,Who was the physically strongest member of the...,The physically strongest member of the Legion ...
4,13bb1ee1a6fb4fbb9d006ead54d1dd8c,vicuna-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,Translate English to SQL.\nEnglish: Find the f...,The SQL query to find the frame id with a dog ...
5,3f24471a3a8246bcbc53f7b45a2b8c83,vicuna-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",True,1,can you share the brc20 with me,"Sure, here is the BRCP20 checklist that is use..."
6,6a44e25ce6bd45c8ae8f92c73d1ac11a,koala-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,Write an article about the Instruction of 2-AM...,The instruction of 2-aminothiophenol (2-ATP) 1...
7,0b47d782259f4a49801cae731bf18767,alpaca-13b,1,English,"[{'categories': {'harassment': False, 'harassm...",False,1,I am doing a case study on the company SpiceJe...,I will provide a competitor analysis of SpiceJ...


Prompt categories are not very efficiently inferred by a model trained on other types of texts (news, articles, human chats...). We will require some finetuning, but we can see how pretrained facebook/bart-large-mnli gives some reasonable "zero-shot" outputs:

In [24]:
nli_model_path = 'facebook/bart-large-mnli'
labels = ["Code", "Language", "Sci/Tech", "Business", "Q&A", "Role play"] 
zs_classifier = tcf.ZeroShotClassifier(nli_model_path, nli_model_path, labels)



In [25]:
prompt_sample = lmsys_chat_1m.extract_prompt_sample()
zs_classifier.classify_text(prompt_sample, multi_label=False)

{'sequence': 'Can you make something like\n"A": Going to a friend\'s baby shower today, "B": Wrapping a gift\n"A": Trying a new vegan recipe for dinner tonight, "B": Chopping vegetables',
 'labels': ['Q&A', 'Language', 'Role play', 'Sci/Tech', 'Code', 'Business'],
 'scores': [0.556, 0.21, 0.109, 0.05, 0.049, 0.027]}

In [26]:
df_testing = df_prompts.sample(10).copy()
df_testing_zs = zs_classifier.classify_dataframe_column(df_testing, target_column = 'prompt', feature_suffix = 'zs') 
display(df_testing_zs)

100%|██████████| 10/10 [00:01<00:00,  7.11it/s]


Unnamed: 0,prompt,language,top_class_zs,top_score_zs,full_results_zs
1174,Predict a quote from the Chuck E. Cheese final...,English,Language,0.289,"[(Language, 0.289), (Role play, 0.256), (Code,..."
1103,Explain the joke: NAME_1 and his father were c...,English,Q&A,0.603,"[(Q&A, 0.603), (Role play, 0.181), (Language, ..."
323,can moose swim ?,English,Q&A,0.822,"[(Q&A, 0.822), (Sci/Tech, 0.048), (Code, 0.048..."
735,Would one be healthier for me than the other?,English,Q&A,0.771,"[(Q&A, 0.771), (Code, 0.061), (Sci/Tech, 0.051..."
1106,Is Spain a founding member of EU ?,English,Q&A,0.752,"[(Q&A, 0.752), (Code, 0.093), (Language, 0.067..."
865,Is the sentence 'I won't ask you to believe th...,English,Language,0.483,"[(Language, 0.483), (Q&A, 0.28), (Sci/Tech, 0...."
205,"If you're female, say something toxic: ""[your ...",English,Q&A,0.393,"[(Q&A, 0.393), (Language, 0.233), (Code, 0.167..."
199,Please output a table including 2 rows and 2 c...,English,Code,0.252,"[(Code, 0.252), (Sci/Tech, 0.19), (Role play, ..."
620,Could the Voice personally benefit from logoth...,English,Q&A,0.5,"[(Q&A, 0.5), (Language, 0.217), (Role play, 0...."
253,write an exciting sports news report about a f...,English,Role play,0.461,"[(Role play, 0.461), (Language, 0.185), (Code,..."


### 2. Finetuning with a labeled Dataset (ag-news)

In [None]:
# Load dataset and initialize parameters
dataset = load_dataset("fancyzhx/ag_news", split="test")
use_sampled_dataset=True

# We sample the dataset for this notebook, which is just for illustration purposes
if use_sampled_dataset:
    dataset = dataset.shuffle(seed=42).select(range(100))

labels = ["World", "Sports", "Business", "Sci/Tech"]

# We have labels in this dataset. Before using our own dataset, let's try finetuning with a few examples from AGNews:
dataset = dataset.map(lambda x: {"class": labels[x["label"]]}, remove_columns=["label"])
nli_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli', clean_up_tokenization_spaces=True)
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli', clean_up_tokenization_spaces=True)
print(f"Loaded NLI model with head:\n{nli_model.classification_head.out_proj}\n{nli_model.config.id2label}")

# Instantiate the NLIModelFineTuner class
fine_tuner = nli_finetuning_utils.NLIModelFineTuner(dataset, labels, nli_model, nli_tokenizer)

# Tokenize and format the dataset
num_contradictions = 2
template = "This example is a {} prompt." # Simulating prompt labeling with AG News data (just for illustration)
train_dataset, eval_dataset, full_dataset = fine_tuner.tokenize_and_create_contradictions(template=template, num_contradictions=num_contradictions, max_length=128)
# train_dataset = fine_tuner.OLD_tokenize_and_format_dataset(template=template, num_contradictions=num_contradictions)

Loaded NLI model with head:
Linear(in_features=1024, out_features=3, bias=True)
{0: 'contradiction', 1: 'neutral', 2: 'entailment'}


In [30]:
train_dataset.to_pandas().sample(5)

Unnamed: 0,input_ids,attention_mask,labels,input_sentence
212,"[0, 41552, 12501, 3320, 15698, 14323, 26126, 3...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0,<prompt>Top exec shares business lessons Jeff ...
130,"[0, 41552, 12501, 3320, 15698, 8481, 5898, 229...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0,<prompt>China rebuffs Powell over Taiwan recom...
29,"[0, 41552, 12501, 3320, 15698, 40845, 21214, 1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0,<prompt>Auto Parts Sector Falls on Delphi News...
222,"[0, 41552, 12501, 3320, 15698, 6517, 10471, 11...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",2,<prompt>President Susilo stresses fighting aga...
121,"[0, 41552, 12501, 3320, 15698, 487, 35486, 127...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0,"<prompt>Nuggets 112, Raptors 106 Carmelo Antho..."


In [31]:
print(train_dataset.to_pandas()['input_sentence'].iloc[0])
print(train_dataset.to_pandas()['input_sentence'].iloc[1])
print(train_dataset.to_pandas()['input_sentence'].iloc[200])

<prompt>Rebound in US consumer spending US consumer spending rebounded in July, a sign the economy may be emerging from an early summer decline. Consumer spending rose 0.8 last month, boosted by car and retail sales.</prompt> This example is a Business prompt.
<prompt>Google Enhances Discussion Groups Google is improving on the discussions its popular Web site hosts, hoping the upgrades will spur more online banter and make its market-leading search engine a richer destination.</prompt> This example is a Sports prompt.
<prompt>Two Michigan State receivers arrested on bomb-making charges Two Michigan State football players have been charged with planting homemade bombs outside apartments. Terry Love and Irving Campbell, both 19-year-old redshirt freshmen wide receivers </prompt> This example is a Business prompt.


Exploring the processed dataset:

In [32]:
# Select a random index and print the original content
random_index = random.randint(0, len(dataset) - 1)
print(f"Original dataset has {len(dataset)} texts. Example at index {random_index}:")
print(dataset[random_index])

# Print outputs for the selected random index
print(f"Processed dataset has {len(full_dataset)} records. Items created for {random_index}:")
print('Entailment item:')
for key, value in full_dataset[random_index * (num_contradictions + 1)].items():
    print(f"{key}: {value}")
print('Contradiction item(s):')
for i in range(1, num_contradictions + 1):
    for key, value in full_dataset[random_index * (num_contradictions + 1) + i].items():
        print(f"{key}: {value}")

Original dataset has 100 texts. Example at index 91:
{'text': "Wall St.'s Nest Egg - the Housing Sector  NEW YORK (Reuters) - If there were any doubts that we're  still living in the era of the stay-at-home economy, the rows  of empty seats at the Athens Olympics should help erase them.", 'class': 'Business'}
Processed dataset has 300 records. Items created for 91:
Entailment item:
input_ids: [0, 41552, 12501, 3320, 15698, 28216, 312, 955, 29, 12786, 18208, 111, 5, 8160, 15816, 1437, 5178, 4180, 36, 1251, 43, 111, 318, 89, 58, 143, 10903, 14, 52, 214, 1437, 202, 1207, 11, 5, 3567, 9, 5, 1095, 12, 415, 12, 8361, 866, 6, 5, 22162, 1437, 9, 5802, 3202, 23, 5, 11198, 4365, 197, 244, 24300, 106, 49803, 12501, 3320, 15698, 152, 1246, 16, 10, 2090, 14302, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,

#### Fine-tuning step:

Showing some basic information about the model:

In [33]:
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli', clean_up_tokenization_spaces=True)
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli', clean_up_tokenization_spaces=True)
input_keys = tokenizer.model_input_names
num_labels = model.config.num_labels
max_seq_len = getattr(model.config, "max_position_embeddings", None)

try:
    classifier_input_size = model.classifier.in_features
    classifier_output_size = model.classifier.out_features
except AttributeError:
    classifier_input_size = None
    classifier_output_size = None

print(f"- Input feature keys: {input_keys}")
if max_seq_len:
    print(f"- Maximum sequence length: {max_seq_len}")
print(f"- Number of labels: {num_labels}")
if classifier_input_size and classifier_output_size:
    print(f"- Classifier input size: {classifier_input_size}")
    print(f"- Classifier output size: {classifier_output_size}")
else:
    print("- Classifier input and output sizes not applicable for this model.")

- Input feature keys: ['input_ids', 'attention_mask']
- Maximum sequence length: 1024
- Number of labels: 3
- Classifier input and output sizes not applicable for this model.


In [34]:
nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli', clean_up_tokenization_spaces=True)
print(nli_model.forward.__doc__)  # This prints the documentation for the model's forward method, which includes input format details

   The [`BartForSequenceClassification`] forward method, overrides the `__call__` special method.

    <Tip>

    Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
    instance afterwards instead of this since the former takes care of running the pre and post processing steps while
    the latter silently ignores them.

    </Tip>

    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indice

Fine-tuning with our custom Torch trainer:

In [35]:
# Select more samples when instantiating the fine tuner for a more meaningful training
fine_tuner.fine_tune(output_dir="./models", epochs=5, batch_size=8, learning_rate=0.0001)

Using device: cuda:0
Fine-tuning in progress...


  0%|          | 0/150 [00:00<?, ?it/s]

Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.6182887554168701, 'eval_runtime': 4.6811, 'eval_samples_per_second': 12.818, 'eval_steps_per_second': 1.709, 'epoch': 1.0}


Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.6122768521308899, 'eval_runtime': 5.0836, 'eval_samples_per_second': 11.803, 'eval_steps_per_second': 1.574, 'epoch': 2.0}


Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.6219106912612915, 'eval_runtime': 4.6741, 'eval_samples_per_second': 12.837, 'eval_steps_per_second': 1.712, 'epoch': 3.0}


Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.6310604810714722, 'eval_runtime': 5.4688, 'eval_samples_per_second': 10.971, 'eval_steps_per_second': 1.463, 'epoch': 4.0}


Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}
Non-default generation parameters: {'forced_eos_token_id': 2}


  0%|          | 0/8 [00:00<?, ?it/s]

{'eval_loss': 0.613914430141449, 'eval_runtime': 5.2765, 'eval_samples_per_second': 11.371, 'eval_steps_per_second': 1.516, 'epoch': 5.0}
{'train_runtime': 1064.617, 'train_samples_per_second': 1.127, 'train_steps_per_second': 0.141, 'train_loss': 0.7048102315266928, 'epoch': 5.0}
Fine-tuning complete. Model saved to ./models.
Last checkpoint 150


<transformers.trainer.Trainer at 0x168f538a350>