Using models from HF to do a sentiment classification

In [1]:
from transformers import pipeline
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
classifier = pipeline("sentiment-analysis")

In [None]:
pred = classifier("Join the growing community on the Hub, forum, or Discord today")
pred

In [None]:
pred = classifier(["Join the growing community on the Hub, forum, or Discord today", 
"At Microsoft, AI software development is guided by a set of six principles, designed to ensure that AI applications provide amazing solutions to difficult problems without any unintended negative consequences."])
pred

specify model name

In [None]:
model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
classifier = pipeline("sentiment-analysis", model=model_name)

In [None]:
pred = classifier(["Join the growing community on the Hub, forum, or Discord today", 
"At Microsoft, AI software development is guided by a set of six principles, designed to ensure that AI applications provide amazing solutions to difficult problems without any unintended negative consequences."])
pred

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
model_name = 'distilbert-base-uncased-finetuned-sst-2-english'
classifier = pipeline("sentiment-analysis", model=model_name)

In [None]:
pred = classifier(["Join the growing community on the Hub, forum, or Discord today", 
"At Microsoft, AI software development is guided by a set of six principles, designed to ensure that AI applications provide amazing solutions to difficult problems without any unintended negative consequences."])
pred

using tokenizer

In [None]:
tokens = tokenizer.tokenize("Join the growing community on the Hub, forum, or Discord today")
input_ids = tokenizer("Join the growing community on the Hub, forum, or Discord today")
input_ids

In [None]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)

In [None]:
x_train = ["Join the growing community on the Hub, forum, or Discord today", 
"At Microsoft, AI software development is guided by a set of six principles, designed to ensure that AI applications provide amazing solutions to \
difficult problems without any unintended negative consequences."]

batch = tokenizer(x_train, padding=True, truncation=True, max_length=512, return_tensors='pt')  # pt for pytorch

In [None]:
print(batch)

In [None]:
with torch.no_grad():
    outputs = model(**batch)

print(outputs)
predictions = F.softmax(outputs.logits, dim=1)
labels = torch.argmax(predictions, dim=1)
print(labels)
labels_text = [model.config.id2label[label_id] for label_id in labels.tolist()]
labels_text

want a loss

In [None]:
with torch.no_grad():
    outputs = model(**batch, labels=torch.tensor([1, 0]))

print(outputs)
predictions = F.softmax(outputs.logits, dim=1)
labels = torch.argmax(predictions, dim=1)
print(labels)
labels_text = [model.config.id2label[label_id] for label_id in labels.tolist()]
labels_text

save and load

In [None]:
save_directory = "saved"

In [None]:
tokenizer.save_pretrained(save_directory)

In [None]:
model.save_pretrained(save_directory)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(save_directory)
model = AutoModelForSequenceClassification.from_pretrained(save_directory)

using a German sentiment classification model

In [None]:
model_name = "oliverguhr/german-sentiment-bert"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [None]:
texts = ["Mit keinem guten Ergebniss","Das ist gar nicht mal so gut",
    "Total awesome!","nicht so schlecht wie erwartet",
    "Der Test verlief positiv.","Sie fährt ein grünes Auto."]

batch = tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')  # pt for pytorch
print(batch)

with torch.no_grad():
    outputs = model(**batch)
    # predictions = F.softmax(outputs.logits, dim=1)
    labels = torch.argmax(outputs.logits, dim=1)
    print(labels)
    labels_text = [model.config.id2label[label_id] for label_id in labels.tolist()]
    print(labels_text)

fine-tuning

1. prepare dataset

2. load pretrained tokenizer, call it with dataset -> encoding

3. build pytorch dataset w encoding

4. load pretrained model

5.  a) load trainer and train it

    b) or use native pytorch training pipeline

In [2]:
from torch.utils.data import Dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import Trainer, TrainingArguments
from pathlib import Path
from sklearn.model_selection import train_test_split

In [3]:
model_name = "distilbert-base-uncased"

In [4]:
def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ['pos', 'neg']:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())
            labels.append(0 if label_dir == 'neg' else 1)

    return texts, labels

In [5]:
train_texts, train_labels = read_imdb_split('/home/tian/Projects/d2l/data/aclImdb/train')
test_texts, test_labels = read_imdb_split('/home/tian/Projects/d2l/data/aclImdb/test')

In [6]:
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=2)

In [7]:
class IMDbDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, index) :
        item = {key: torch.tensor(val[index])for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[index])
        return item
    
    def __len__(self):
        return len(self.labels)

In [8]:
tokenizer = DistilBertTokenizerFast.from_pretrained(model_name) # a generate tokenizer

In [9]:
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

In [10]:
train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

In [11]:
training_arguments = TrainingArguments(
    output_dir='./results',               # output directory
    num_train_epochs=2,                   # max traning epochs
    per_device_train_batch_size=16,       # batch size per device during training
    per_device_eval_batch_size=64,        # batch size per device during evaluation
    warmup_steps=500,                     # number of warmup steps for learning rate scheduler
    learning_rate=5e-5,                   # lr
    weight_decay=0.01,
    logging_dir='./logs',                 # directory for storing logs
    logging_steps=10    
)

In [14]:
model = DistilBertForSequenceClassification.from_pretrained(model_name)
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.bias', 'pre_classi

In [15]:
trainer.train()

***** Running training *****
  Num examples = 24998
  Num Epochs = 2
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 3126
  Number of trainable parameters = 66955010
  0%|          | 10/3126 [00:03<16:05,  3.23it/s]

{'loss': 0.6983, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.01}


  1%|          | 20/3126 [00:06<15:41,  3.30it/s]

{'loss': 0.6979, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.01}


  1%|          | 30/3126 [00:09<15:39,  3.29it/s]

{'loss': 0.6941, 'learning_rate': 3e-06, 'epoch': 0.02}


  1%|▏         | 40/3126 [00:12<15:38,  3.29it/s]

{'loss': 0.6934, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.03}


  2%|▏         | 50/3126 [00:15<15:35,  3.29it/s]

{'loss': 0.6823, 'learning_rate': 5e-06, 'epoch': 0.03}


  2%|▏         | 60/3126 [00:18<15:33,  3.29it/s]

{'loss': 0.6888, 'learning_rate': 6e-06, 'epoch': 0.04}


  2%|▏         | 70/3126 [00:21<15:29,  3.29it/s]

{'loss': 0.668, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.04}


  3%|▎         | 80/3126 [00:24<15:27,  3.28it/s]

{'loss': 0.6694, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.05}


  3%|▎         | 90/3126 [00:28<15:24,  3.29it/s]

{'loss': 0.6383, 'learning_rate': 9e-06, 'epoch': 0.06}


  3%|▎         | 100/3126 [00:31<15:22,  3.28it/s]

{'loss': 0.5891, 'learning_rate': 1e-05, 'epoch': 0.06}


  4%|▎         | 110/3126 [00:34<15:18,  3.28it/s]

{'loss': 0.5137, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.07}


  4%|▍         | 120/3126 [00:37<15:14,  3.29it/s]

{'loss': 0.4725, 'learning_rate': 1.2e-05, 'epoch': 0.08}


  4%|▍         | 130/3126 [00:40<15:11,  3.29it/s]

{'loss': 0.3921, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.08}


  4%|▍         | 140/3126 [00:43<15:09,  3.28it/s]

{'loss': 0.313, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.09}


  5%|▍         | 150/3126 [00:46<15:05,  3.29it/s]

{'loss': 0.3548, 'learning_rate': 1.5e-05, 'epoch': 0.1}


  5%|▌         | 160/3126 [00:49<15:03,  3.28it/s]

{'loss': 0.3815, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.1}


  5%|▌         | 170/3126 [00:52<15:02,  3.27it/s]

{'loss': 0.3284, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.11}


  6%|▌         | 180/3126 [00:55<14:58,  3.28it/s]

{'loss': 0.3788, 'learning_rate': 1.8e-05, 'epoch': 0.12}


  6%|▌         | 190/3126 [00:58<14:56,  3.27it/s]

{'loss': 0.3152, 'learning_rate': 1.9e-05, 'epoch': 0.12}


  6%|▋         | 200/3126 [01:01<14:53,  3.27it/s]

{'loss': 0.3104, 'learning_rate': 2e-05, 'epoch': 0.13}


  7%|▋         | 210/3126 [01:04<14:50,  3.27it/s]

{'loss': 0.3118, 'learning_rate': 2.1e-05, 'epoch': 0.13}


  7%|▋         | 220/3126 [01:07<14:49,  3.27it/s]

{'loss': 0.3187, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.14}


  7%|▋         | 230/3126 [01:10<14:44,  3.27it/s]

{'loss': 0.2821, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.15}


  8%|▊         | 240/3126 [01:13<14:40,  3.28it/s]

{'loss': 0.3184, 'learning_rate': 2.4e-05, 'epoch': 0.15}


  8%|▊         | 250/3126 [01:16<14:39,  3.27it/s]

{'loss': 0.3353, 'learning_rate': 2.5e-05, 'epoch': 0.16}


  8%|▊         | 260/3126 [01:19<14:36,  3.27it/s]

{'loss': 0.395, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.17}


  9%|▊         | 270/3126 [01:23<14:32,  3.27it/s]

{'loss': 0.2205, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.17}


  9%|▉         | 280/3126 [01:26<14:30,  3.27it/s]

{'loss': 0.2477, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.18}


  9%|▉         | 290/3126 [01:29<14:27,  3.27it/s]

{'loss': 0.3279, 'learning_rate': 2.9e-05, 'epoch': 0.19}


 10%|▉         | 300/3126 [01:32<14:24,  3.27it/s]

{'loss': 0.2936, 'learning_rate': 3e-05, 'epoch': 0.19}


 10%|▉         | 310/3126 [01:35<14:21,  3.27it/s]

{'loss': 0.2578, 'learning_rate': 3.1e-05, 'epoch': 0.2}


 10%|█         | 320/3126 [01:38<14:18,  3.27it/s]

{'loss': 0.3455, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.2}


 11%|█         | 330/3126 [01:41<14:16,  3.26it/s]

{'loss': 0.3837, 'learning_rate': 3.3e-05, 'epoch': 0.21}


 11%|█         | 340/3126 [01:44<14:12,  3.27it/s]

{'loss': 0.3261, 'learning_rate': 3.4000000000000007e-05, 'epoch': 0.22}


 11%|█         | 350/3126 [01:47<14:09,  3.27it/s]

{'loss': 0.3239, 'learning_rate': 3.5e-05, 'epoch': 0.22}


 12%|█▏        | 360/3126 [01:50<14:06,  3.27it/s]

{'loss': 0.2263, 'learning_rate': 3.6e-05, 'epoch': 0.23}


 12%|█▏        | 370/3126 [01:53<14:05,  3.26it/s]

{'loss': 0.3639, 'learning_rate': 3.7e-05, 'epoch': 0.24}


 12%|█▏        | 380/3126 [01:56<13:59,  3.27it/s]

{'loss': 0.2658, 'learning_rate': 3.8e-05, 'epoch': 0.24}


 12%|█▏        | 390/3126 [01:59<13:57,  3.27it/s]

{'loss': 0.3844, 'learning_rate': 3.9000000000000006e-05, 'epoch': 0.25}


 13%|█▎        | 400/3126 [02:02<13:55,  3.26it/s]

{'loss': 0.4273, 'learning_rate': 4e-05, 'epoch': 0.26}


 13%|█▎        | 410/3126 [02:05<13:51,  3.27it/s]

{'loss': 0.4582, 'learning_rate': 4.1e-05, 'epoch': 0.26}


 13%|█▎        | 420/3126 [02:08<13:48,  3.26it/s]

{'loss': 0.2962, 'learning_rate': 4.2e-05, 'epoch': 0.27}


 14%|█▍        | 430/3126 [02:12<13:46,  3.26it/s]

{'loss': 0.2722, 'learning_rate': 4.3e-05, 'epoch': 0.28}


 14%|█▍        | 440/3126 [02:15<13:43,  3.26it/s]

{'loss': 0.3764, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.28}


 14%|█▍        | 450/3126 [02:18<13:38,  3.27it/s]

{'loss': 0.288, 'learning_rate': 4.5e-05, 'epoch': 0.29}


 15%|█▍        | 460/3126 [02:21<13:37,  3.26it/s]

{'loss': 0.2401, 'learning_rate': 4.600000000000001e-05, 'epoch': 0.29}


 15%|█▌        | 470/3126 [02:24<13:34,  3.26it/s]

{'loss': 0.3321, 'learning_rate': 4.7e-05, 'epoch': 0.3}


 15%|█▌        | 480/3126 [02:27<13:31,  3.26it/s]

{'loss': 0.2302, 'learning_rate': 4.8e-05, 'epoch': 0.31}


 16%|█▌        | 490/3126 [02:30<13:28,  3.26it/s]

{'loss': 0.2645, 'learning_rate': 4.9e-05, 'epoch': 0.31}


 16%|█▌        | 500/3126 [02:33<13:24,  3.26it/s]Saving model checkpoint to ./results/checkpoint-500
Configuration saved in ./results/checkpoint-500/config.json


{'loss': 0.3394, 'learning_rate': 5e-05, 'epoch': 0.32}


Model weights saved in ./results/checkpoint-500/pytorch_model.bin
 16%|█▋        | 510/3126 [02:37<13:52,  3.14it/s]

{'loss': 0.3658, 'learning_rate': 4.980959634424981e-05, 'epoch': 0.33}


 17%|█▋        | 520/3126 [02:40<13:20,  3.26it/s]

{'loss': 0.3867, 'learning_rate': 4.9619192688499624e-05, 'epoch': 0.33}


 17%|█▋        | 530/3126 [02:43<13:14,  3.27it/s]

{'loss': 0.3205, 'learning_rate': 4.942878903274943e-05, 'epoch': 0.34}


 17%|█▋        | 540/3126 [02:46<13:11,  3.27it/s]

{'loss': 0.3452, 'learning_rate': 4.923838537699924e-05, 'epoch': 0.35}


 18%|█▊        | 550/3126 [02:49<13:08,  3.27it/s]

{'loss': 0.2404, 'learning_rate': 4.9047981721249046e-05, 'epoch': 0.35}


 18%|█▊        | 560/3126 [02:52<13:04,  3.27it/s]

{'loss': 0.3395, 'learning_rate': 4.885757806549886e-05, 'epoch': 0.36}


 18%|█▊        | 570/3126 [02:55<13:03,  3.26it/s]

{'loss': 0.2695, 'learning_rate': 4.866717440974867e-05, 'epoch': 0.36}


 19%|█▊        | 580/3126 [02:59<13:00,  3.26it/s]

{'loss': 0.2735, 'learning_rate': 4.847677075399848e-05, 'epoch': 0.37}


 19%|█▉        | 590/3126 [03:02<12:56,  3.27it/s]

{'loss': 0.2677, 'learning_rate': 4.828636709824829e-05, 'epoch': 0.38}


 19%|█▉        | 600/3126 [03:05<12:54,  3.26it/s]

{'loss': 0.1957, 'learning_rate': 4.8095963442498096e-05, 'epoch': 0.38}


 20%|█▉        | 610/3126 [03:08<12:50,  3.26it/s]

{'loss': 0.3031, 'learning_rate': 4.7905559786747903e-05, 'epoch': 0.39}


 20%|█▉        | 620/3126 [03:11<12:49,  3.26it/s]

{'loss': 0.3485, 'learning_rate': 4.771515613099772e-05, 'epoch': 0.4}


 20%|██        | 630/3126 [03:14<12:45,  3.26it/s]

{'loss': 0.2493, 'learning_rate': 4.7524752475247525e-05, 'epoch': 0.4}


 20%|██        | 640/3126 [03:17<12:41,  3.26it/s]

{'loss': 0.2957, 'learning_rate': 4.733434881949734e-05, 'epoch': 0.41}


 21%|██        | 650/3126 [03:20<12:38,  3.26it/s]

{'loss': 0.3773, 'learning_rate': 4.7143945163747146e-05, 'epoch': 0.42}


 21%|██        | 660/3126 [03:23<12:36,  3.26it/s]

{'loss': 0.2887, 'learning_rate': 4.6953541507996954e-05, 'epoch': 0.42}


 21%|██▏       | 670/3126 [03:26<12:33,  3.26it/s]

{'loss': 0.2713, 'learning_rate': 4.676313785224676e-05, 'epoch': 0.43}


 22%|██▏       | 680/3126 [03:29<12:30,  3.26it/s]

{'loss': 0.2507, 'learning_rate': 4.6572734196496575e-05, 'epoch': 0.44}


 22%|██▏       | 690/3126 [03:32<12:26,  3.26it/s]

{'loss': 0.1485, 'learning_rate': 4.638233054074638e-05, 'epoch': 0.44}


 22%|██▏       | 700/3126 [03:35<12:24,  3.26it/s]

{'loss': 0.3612, 'learning_rate': 4.61919268849962e-05, 'epoch': 0.45}


 23%|██▎       | 710/3126 [03:38<12:21,  3.26it/s]

{'loss': 0.2119, 'learning_rate': 4.6001523229246004e-05, 'epoch': 0.45}


 23%|██▎       | 720/3126 [03:41<12:17,  3.26it/s]

{'loss': 0.2267, 'learning_rate': 4.581111957349581e-05, 'epoch': 0.46}


 23%|██▎       | 730/3126 [03:45<12:14,  3.26it/s]

{'loss': 0.3402, 'learning_rate': 4.562071591774562e-05, 'epoch': 0.47}


 24%|██▎       | 740/3126 [03:48<12:10,  3.26it/s]

{'loss': 0.3316, 'learning_rate': 4.543031226199543e-05, 'epoch': 0.47}


 24%|██▍       | 750/3126 [03:51<12:08,  3.26it/s]

{'loss': 0.2752, 'learning_rate': 4.523990860624524e-05, 'epoch': 0.48}


 24%|██▍       | 760/3126 [03:54<12:05,  3.26it/s]

{'loss': 0.2563, 'learning_rate': 4.5049504950495054e-05, 'epoch': 0.49}


 25%|██▍       | 770/3126 [03:57<12:02,  3.26it/s]

{'loss': 0.2566, 'learning_rate': 4.485910129474486e-05, 'epoch': 0.49}


 25%|██▍       | 780/3126 [04:00<11:59,  3.26it/s]

{'loss': 0.333, 'learning_rate': 4.466869763899467e-05, 'epoch': 0.5}


 25%|██▌       | 790/3126 [04:03<11:56,  3.26it/s]

{'loss': 0.3259, 'learning_rate': 4.4478293983244476e-05, 'epoch': 0.51}


 26%|██▌       | 800/3126 [04:06<11:53,  3.26it/s]

{'loss': 0.246, 'learning_rate': 4.428789032749429e-05, 'epoch': 0.51}


 26%|██▌       | 810/3126 [04:09<11:50,  3.26it/s]

{'loss': 0.2049, 'learning_rate': 4.40974866717441e-05, 'epoch': 0.52}


 26%|██▌       | 820/3126 [04:12<11:46,  3.26it/s]

{'loss': 0.1862, 'learning_rate': 4.390708301599391e-05, 'epoch': 0.52}


 27%|██▋       | 830/3126 [04:15<11:43,  3.26it/s]

{'loss': 0.3135, 'learning_rate': 4.371667936024372e-05, 'epoch': 0.53}


 27%|██▋       | 840/3126 [04:18<11:39,  3.27it/s]

{'loss': 0.3318, 'learning_rate': 4.3526275704493527e-05, 'epoch': 0.54}


 27%|██▋       | 850/3126 [04:21<11:37,  3.26it/s]

{'loss': 0.3613, 'learning_rate': 4.3335872048743334e-05, 'epoch': 0.54}


 28%|██▊       | 860/3126 [04:24<11:35,  3.26it/s]

{'loss': 0.273, 'learning_rate': 4.314546839299315e-05, 'epoch': 0.55}


 28%|██▊       | 870/3126 [04:28<11:32,  3.26it/s]

{'loss': 0.2181, 'learning_rate': 4.2955064737242955e-05, 'epoch': 0.56}


 28%|██▊       | 880/3126 [04:31<11:28,  3.26it/s]

{'loss': 0.2836, 'learning_rate': 4.276466108149277e-05, 'epoch': 0.56}


 28%|██▊       | 890/3126 [04:34<11:25,  3.26it/s]

{'loss': 0.262, 'learning_rate': 4.257425742574258e-05, 'epoch': 0.57}


 29%|██▉       | 900/3126 [04:37<11:23,  3.26it/s]

{'loss': 0.2113, 'learning_rate': 4.238385376999239e-05, 'epoch': 0.58}


 29%|██▉       | 910/3126 [04:40<11:20,  3.26it/s]

{'loss': 0.3333, 'learning_rate': 4.219345011424219e-05, 'epoch': 0.58}


 29%|██▉       | 920/3126 [04:43<11:16,  3.26it/s]

{'loss': 0.2957, 'learning_rate': 4.2003046458492006e-05, 'epoch': 0.59}


 30%|██▉       | 930/3126 [04:46<11:12,  3.26it/s]

{'loss': 0.2796, 'learning_rate': 4.181264280274181e-05, 'epoch': 0.6}


 30%|███       | 940/3126 [04:49<11:11,  3.26it/s]

{'loss': 0.3853, 'learning_rate': 4.162223914699163e-05, 'epoch': 0.6}


 30%|███       | 950/3126 [04:52<11:07,  3.26it/s]

{'loss': 0.2207, 'learning_rate': 4.1431835491241434e-05, 'epoch': 0.61}


 31%|███       | 960/3126 [04:55<11:03,  3.26it/s]

{'loss': 0.3143, 'learning_rate': 4.124143183549125e-05, 'epoch': 0.61}


 31%|███       | 970/3126 [04:58<11:01,  3.26it/s]

{'loss': 0.2497, 'learning_rate': 4.105102817974105e-05, 'epoch': 0.62}


 31%|███▏      | 980/3126 [05:01<10:58,  3.26it/s]

{'loss': 0.1529, 'learning_rate': 4.086062452399086e-05, 'epoch': 0.63}


 32%|███▏      | 990/3126 [05:04<10:56,  3.25it/s]

{'loss': 0.2832, 'learning_rate': 4.067022086824067e-05, 'epoch': 0.63}


 32%|███▏      | 1000/3126 [05:07<10:52,  3.26it/s]Saving model checkpoint to ./results/checkpoint-1000
Configuration saved in ./results/checkpoint-1000/config.json


{'loss': 0.3255, 'learning_rate': 4.0479817212490485e-05, 'epoch': 0.64}


Model weights saved in ./results/checkpoint-1000/pytorch_model.bin
 32%|███▏      | 1010/3126 [05:12<11:14,  3.14it/s]

{'loss': 0.2108, 'learning_rate': 4.028941355674029e-05, 'epoch': 0.65}


 33%|███▎      | 1020/3126 [05:15<10:46,  3.26it/s]

{'loss': 0.1843, 'learning_rate': 4.0099009900990106e-05, 'epoch': 0.65}


 33%|███▎      | 1030/3126 [05:18<10:42,  3.26it/s]

{'loss': 0.1657, 'learning_rate': 3.990860624523991e-05, 'epoch': 0.66}


 33%|███▎      | 1040/3126 [05:21<10:39,  3.26it/s]

{'loss': 0.2603, 'learning_rate': 3.971820258948972e-05, 'epoch': 0.67}


 34%|███▎      | 1050/3126 [05:24<10:36,  3.26it/s]

{'loss': 0.4957, 'learning_rate': 3.952779893373953e-05, 'epoch': 0.67}


 34%|███▍      | 1060/3126 [05:27<10:33,  3.26it/s]

{'loss': 0.3104, 'learning_rate': 3.933739527798934e-05, 'epoch': 0.68}


 34%|███▍      | 1070/3126 [05:30<10:30,  3.26it/s]

{'loss': 0.2189, 'learning_rate': 3.914699162223915e-05, 'epoch': 0.68}


 35%|███▍      | 1080/3126 [05:33<10:28,  3.26it/s]

{'loss': 0.2236, 'learning_rate': 3.8956587966488964e-05, 'epoch': 0.69}


 35%|███▍      | 1090/3126 [05:36<10:24,  3.26it/s]

{'loss': 0.2637, 'learning_rate': 3.876618431073877e-05, 'epoch': 0.7}


 35%|███▌      | 1100/3126 [05:39<10:21,  3.26it/s]

{'loss': 0.1892, 'learning_rate': 3.857578065498858e-05, 'epoch': 0.7}


 36%|███▌      | 1110/3126 [05:42<10:18,  3.26it/s]

{'loss': 0.2489, 'learning_rate': 3.8385376999238386e-05, 'epoch': 0.71}


 36%|███▌      | 1120/3126 [05:45<10:16,  3.26it/s]

{'loss': 0.2006, 'learning_rate': 3.81949733434882e-05, 'epoch': 0.72}


 36%|███▌      | 1130/3126 [05:48<10:12,  3.26it/s]

{'loss': 0.2377, 'learning_rate': 3.800456968773801e-05, 'epoch': 0.72}


 36%|███▋      | 1140/3126 [05:51<10:09,  3.26it/s]

{'loss': 0.3651, 'learning_rate': 3.7814166031987815e-05, 'epoch': 0.73}


 37%|███▋      | 1150/3126 [05:55<10:06,  3.26it/s]

{'loss': 0.2798, 'learning_rate': 3.762376237623763e-05, 'epoch': 0.74}


 37%|███▋      | 1160/3126 [05:58<10:03,  3.26it/s]

{'loss': 0.2247, 'learning_rate': 3.743335872048743e-05, 'epoch': 0.74}


 37%|███▋      | 1170/3126 [06:01<09:59,  3.26it/s]

{'loss': 0.2788, 'learning_rate': 3.7242955064737243e-05, 'epoch': 0.75}


 38%|███▊      | 1180/3126 [06:04<09:56,  3.26it/s]

{'loss': 0.2089, 'learning_rate': 3.705255140898705e-05, 'epoch': 0.75}


 38%|███▊      | 1190/3126 [06:07<09:55,  3.25it/s]

{'loss': 0.2626, 'learning_rate': 3.6862147753236865e-05, 'epoch': 0.76}


 38%|███▊      | 1200/3126 [06:10<09:51,  3.26it/s]

{'loss': 0.2097, 'learning_rate': 3.667174409748667e-05, 'epoch': 0.77}


 39%|███▊      | 1210/3126 [06:13<09:48,  3.26it/s]

{'loss': 0.1905, 'learning_rate': 3.6481340441736486e-05, 'epoch': 0.77}


 39%|███▉      | 1220/3126 [06:16<09:44,  3.26it/s]

{'loss': 0.3015, 'learning_rate': 3.6290936785986294e-05, 'epoch': 0.78}


 39%|███▉      | 1230/3126 [06:19<09:41,  3.26it/s]

{'loss': 0.1751, 'learning_rate': 3.61005331302361e-05, 'epoch': 0.79}


 40%|███▉      | 1240/3126 [06:22<09:38,  3.26it/s]

{'loss': 0.1525, 'learning_rate': 3.591012947448591e-05, 'epoch': 0.79}


 40%|███▉      | 1250/3126 [06:25<09:34,  3.26it/s]

{'loss': 0.2094, 'learning_rate': 3.571972581873572e-05, 'epoch': 0.8}


 40%|████      | 1260/3126 [06:28<09:33,  3.26it/s]

{'loss': 0.2629, 'learning_rate': 3.552932216298553e-05, 'epoch': 0.81}


 41%|████      | 1270/3126 [06:31<09:29,  3.26it/s]

{'loss': 0.2747, 'learning_rate': 3.5338918507235344e-05, 'epoch': 0.81}


 41%|████      | 1280/3126 [06:34<09:26,  3.26it/s]

{'loss': 0.1383, 'learning_rate': 3.514851485148515e-05, 'epoch': 0.82}


 41%|████▏     | 1290/3126 [06:38<09:23,  3.26it/s]

{'loss': 0.2039, 'learning_rate': 3.495811119573496e-05, 'epoch': 0.83}


 42%|████▏     | 1300/3126 [06:41<09:20,  3.26it/s]

{'loss': 0.165, 'learning_rate': 3.4767707539984766e-05, 'epoch': 0.83}


 42%|████▏     | 1310/3126 [06:44<09:18,  3.25it/s]

{'loss': 0.299, 'learning_rate': 3.457730388423458e-05, 'epoch': 0.84}


 42%|████▏     | 1320/3126 [06:47<09:14,  3.26it/s]

{'loss': 0.2257, 'learning_rate': 3.438690022848439e-05, 'epoch': 0.84}


 43%|████▎     | 1330/3126 [06:50<09:10,  3.26it/s]

{'loss': 0.2864, 'learning_rate': 3.41964965727342e-05, 'epoch': 0.85}


 43%|████▎     | 1340/3126 [06:53<09:07,  3.26it/s]

{'loss': 0.2339, 'learning_rate': 3.400609291698401e-05, 'epoch': 0.86}


 43%|████▎     | 1350/3126 [06:56<09:05,  3.26it/s]

{'loss': 0.2445, 'learning_rate': 3.3815689261233816e-05, 'epoch': 0.86}


 44%|████▎     | 1360/3126 [06:59<09:02,  3.26it/s]

{'loss': 0.1643, 'learning_rate': 3.3625285605483624e-05, 'epoch': 0.87}


 44%|████▍     | 1370/3126 [07:02<08:57,  3.26it/s]

{'loss': 0.2444, 'learning_rate': 3.343488194973344e-05, 'epoch': 0.88}


 44%|████▍     | 1380/3126 [07:05<08:56,  3.26it/s]

{'loss': 0.3121, 'learning_rate': 3.3244478293983245e-05, 'epoch': 0.88}


 44%|████▍     | 1390/3126 [07:08<08:53,  3.25it/s]

{'loss': 0.2163, 'learning_rate': 3.305407463823306e-05, 'epoch': 0.89}


 45%|████▍     | 1400/3126 [07:11<08:50,  3.25it/s]

{'loss': 0.2041, 'learning_rate': 3.2863670982482866e-05, 'epoch': 0.9}


 45%|████▌     | 1410/3126 [07:14<08:46,  3.26it/s]

{'loss': 0.2135, 'learning_rate': 3.2673267326732674e-05, 'epoch': 0.9}


 45%|████▌     | 1420/3126 [07:17<08:43,  3.26it/s]

{'loss': 0.2395, 'learning_rate': 3.248286367098248e-05, 'epoch': 0.91}


 46%|████▌     | 1430/3126 [07:21<08:40,  3.26it/s]

{'loss': 0.2319, 'learning_rate': 3.2292460015232295e-05, 'epoch': 0.91}


 46%|████▌     | 1440/3126 [07:24<08:37,  3.26it/s]

{'loss': 0.2426, 'learning_rate': 3.21020563594821e-05, 'epoch': 0.92}


 46%|████▋     | 1450/3126 [07:27<08:34,  3.26it/s]

{'loss': 0.2025, 'learning_rate': 3.191165270373192e-05, 'epoch': 0.93}


 47%|████▋     | 1460/3126 [07:30<08:31,  3.26it/s]

{'loss': 0.193, 'learning_rate': 3.1721249047981724e-05, 'epoch': 0.93}


 47%|████▋     | 1470/3126 [07:33<08:27,  3.26it/s]

{'loss': 0.2168, 'learning_rate': 3.153084539223153e-05, 'epoch': 0.94}


 47%|████▋     | 1480/3126 [07:36<08:25,  3.26it/s]

{'loss': 0.3184, 'learning_rate': 3.134044173648134e-05, 'epoch': 0.95}


 48%|████▊     | 1490/3126 [07:39<08:22,  3.26it/s]

{'loss': 0.2528, 'learning_rate': 3.115003808073115e-05, 'epoch': 0.95}


 48%|████▊     | 1500/3126 [07:42<08:18,  3.26it/s]Saving model checkpoint to ./results/checkpoint-1500
Configuration saved in ./results/checkpoint-1500/config.json


{'loss': 0.2765, 'learning_rate': 3.095963442498096e-05, 'epoch': 0.96}


Model weights saved in ./results/checkpoint-1500/pytorch_model.bin
 48%|████▊     | 1510/3126 [07:46<08:34,  3.14it/s]

{'loss': 0.2091, 'learning_rate': 3.0769230769230774e-05, 'epoch': 0.97}


 49%|████▊     | 1520/3126 [07:49<08:12,  3.26it/s]

{'loss': 0.2214, 'learning_rate': 3.057882711348058e-05, 'epoch': 0.97}


 49%|████▉     | 1530/3126 [07:52<08:09,  3.26it/s]

{'loss': 0.2275, 'learning_rate': 3.0388423457730392e-05, 'epoch': 0.98}


 49%|████▉     | 1540/3126 [07:55<08:06,  3.26it/s]

{'loss': 0.2202, 'learning_rate': 3.01980198019802e-05, 'epoch': 0.99}


 50%|████▉     | 1550/3126 [07:58<08:03,  3.26it/s]

{'loss': 0.2264, 'learning_rate': 3.000761614623001e-05, 'epoch': 0.99}


 50%|████▉     | 1560/3126 [08:01<08:00,  3.26it/s]

{'loss': 0.1967, 'learning_rate': 2.9817212490479818e-05, 'epoch': 1.0}


 50%|█████     | 1570/3126 [08:04<07:50,  3.31it/s]

{'loss': 0.3014, 'learning_rate': 2.962680883472963e-05, 'epoch': 1.0}


 51%|█████     | 1580/3126 [08:07<07:53,  3.26it/s]

{'loss': 0.2189, 'learning_rate': 2.9436405178979436e-05, 'epoch': 1.01}


 51%|█████     | 1590/3126 [08:11<07:50,  3.26it/s]

{'loss': 0.1983, 'learning_rate': 2.924600152322925e-05, 'epoch': 1.02}


 51%|█████     | 1600/3126 [08:14<07:48,  3.26it/s]

{'loss': 0.1073, 'learning_rate': 2.9055597867479057e-05, 'epoch': 1.02}


 52%|█████▏    | 1610/3126 [08:17<07:45,  3.26it/s]

{'loss': 0.1002, 'learning_rate': 2.8865194211728868e-05, 'epoch': 1.03}


 52%|█████▏    | 1620/3126 [08:20<07:42,  3.26it/s]

{'loss': 0.0793, 'learning_rate': 2.8674790555978675e-05, 'epoch': 1.04}


 52%|█████▏    | 1630/3126 [08:23<07:39,  3.26it/s]

{'loss': 0.1426, 'learning_rate': 2.848438690022849e-05, 'epoch': 1.04}


 52%|█████▏    | 1640/3126 [08:26<07:36,  3.26it/s]

{'loss': 0.213, 'learning_rate': 2.8293983244478294e-05, 'epoch': 1.05}


 53%|█████▎    | 1650/3126 [08:29<07:32,  3.26it/s]

{'loss': 0.1835, 'learning_rate': 2.8103579588728108e-05, 'epoch': 1.06}


 53%|█████▎    | 1660/3126 [08:32<07:29,  3.26it/s]

{'loss': 0.0723, 'learning_rate': 2.7913175932977915e-05, 'epoch': 1.06}


 53%|█████▎    | 1670/3126 [08:35<07:27,  3.26it/s]

{'loss': 0.2043, 'learning_rate': 2.7722772277227726e-05, 'epoch': 1.07}


 54%|█████▎    | 1680/3126 [08:38<07:23,  3.26it/s]

{'loss': 0.1955, 'learning_rate': 2.7532368621477533e-05, 'epoch': 1.07}


 54%|█████▍    | 1690/3126 [08:41<07:20,  3.26it/s]

{'loss': 0.1698, 'learning_rate': 2.7341964965727347e-05, 'epoch': 1.08}


 54%|█████▍    | 1700/3126 [08:44<07:16,  3.26it/s]

{'loss': 0.0931, 'learning_rate': 2.715156130997715e-05, 'epoch': 1.09}


 55%|█████▍    | 1710/3126 [08:47<07:14,  3.26it/s]

{'loss': 0.1111, 'learning_rate': 2.6961157654226965e-05, 'epoch': 1.09}


 55%|█████▌    | 1720/3126 [08:50<07:11,  3.26it/s]

{'loss': 0.1622, 'learning_rate': 2.6770753998476773e-05, 'epoch': 1.1}


 55%|█████▌    | 1730/3126 [08:54<07:08,  3.26it/s]

{'loss': 0.0762, 'learning_rate': 2.6580350342726583e-05, 'epoch': 1.11}


 56%|█████▌    | 1740/3126 [08:57<07:05,  3.26it/s]

{'loss': 0.1558, 'learning_rate': 2.638994668697639e-05, 'epoch': 1.11}


 56%|█████▌    | 1750/3126 [09:00<07:02,  3.26it/s]

{'loss': 0.2216, 'learning_rate': 2.6199543031226205e-05, 'epoch': 1.12}


 56%|█████▋    | 1760/3126 [09:03<06:58,  3.26it/s]

{'loss': 0.0744, 'learning_rate': 2.6009139375476012e-05, 'epoch': 1.13}


 57%|█████▋    | 1770/3126 [09:06<06:56,  3.25it/s]

{'loss': 0.1775, 'learning_rate': 2.5818735719725816e-05, 'epoch': 1.13}


 57%|█████▋    | 1780/3126 [09:09<06:53,  3.26it/s]

{'loss': 0.0808, 'learning_rate': 2.562833206397563e-05, 'epoch': 1.14}


 57%|█████▋    | 1790/3126 [09:12<06:50,  3.26it/s]

{'loss': 0.1597, 'learning_rate': 2.5437928408225438e-05, 'epoch': 1.15}


 58%|█████▊    | 1800/3126 [09:15<06:46,  3.26it/s]

{'loss': 0.1425, 'learning_rate': 2.5247524752475248e-05, 'epoch': 1.15}


 58%|█████▊    | 1810/3126 [09:18<06:43,  3.26it/s]

{'loss': 0.1394, 'learning_rate': 2.5057121096725056e-05, 'epoch': 1.16}


 58%|█████▊    | 1820/3126 [09:21<06:40,  3.26it/s]

{'loss': 0.1675, 'learning_rate': 2.486671744097487e-05, 'epoch': 1.16}


 59%|█████▊    | 1830/3126 [09:24<06:37,  3.26it/s]

{'loss': 0.1826, 'learning_rate': 2.4676313785224677e-05, 'epoch': 1.17}


 59%|█████▉    | 1840/3126 [09:27<06:34,  3.26it/s]

{'loss': 0.2063, 'learning_rate': 2.4485910129474488e-05, 'epoch': 1.18}


 59%|█████▉    | 1850/3126 [09:30<06:32,  3.25it/s]

{'loss': 0.1538, 'learning_rate': 2.42955064737243e-05, 'epoch': 1.18}


 60%|█████▉    | 1860/3126 [09:33<06:28,  3.26it/s]

{'loss': 0.105, 'learning_rate': 2.4105102817974106e-05, 'epoch': 1.19}


 60%|█████▉    | 1870/3126 [09:37<06:25,  3.26it/s]

{'loss': 0.1025, 'learning_rate': 2.3914699162223917e-05, 'epoch': 1.2}


 60%|██████    | 1880/3126 [09:40<06:22,  3.25it/s]

{'loss': 0.1742, 'learning_rate': 2.3724295506473727e-05, 'epoch': 1.2}


 60%|██████    | 1890/3126 [09:43<06:19,  3.26it/s]

{'loss': 0.1818, 'learning_rate': 2.3533891850723535e-05, 'epoch': 1.21}


 61%|██████    | 1900/3126 [09:46<06:17,  3.25it/s]

{'loss': 0.2033, 'learning_rate': 2.3343488194973345e-05, 'epoch': 1.22}


 61%|██████    | 1910/3126 [09:49<06:13,  3.26it/s]

{'loss': 0.1279, 'learning_rate': 2.3153084539223156e-05, 'epoch': 1.22}


 61%|██████▏   | 1920/3126 [09:52<06:10,  3.26it/s]

{'loss': 0.09, 'learning_rate': 2.2962680883472963e-05, 'epoch': 1.23}


 62%|██████▏   | 1930/3126 [09:55<06:06,  3.26it/s]

{'loss': 0.1332, 'learning_rate': 2.2772277227722774e-05, 'epoch': 1.23}


 62%|██████▏   | 1940/3126 [09:58<06:04,  3.26it/s]

{'loss': 0.0933, 'learning_rate': 2.2581873571972585e-05, 'epoch': 1.24}


 62%|██████▏   | 1950/3126 [10:01<06:00,  3.26it/s]

{'loss': 0.151, 'learning_rate': 2.2391469916222392e-05, 'epoch': 1.25}


 63%|██████▎   | 1960/3126 [10:04<05:58,  3.26it/s]

{'loss': 0.1239, 'learning_rate': 2.2201066260472203e-05, 'epoch': 1.25}


 63%|██████▎   | 1970/3126 [10:07<05:55,  3.25it/s]

{'loss': 0.1561, 'learning_rate': 2.2010662604722014e-05, 'epoch': 1.26}


 63%|██████▎   | 1980/3126 [10:10<05:51,  3.26it/s]

{'loss': 0.175, 'learning_rate': 2.182025894897182e-05, 'epoch': 1.27}


 64%|██████▎   | 1990/3126 [10:13<05:48,  3.26it/s]

{'loss': 0.1142, 'learning_rate': 2.1629855293221632e-05, 'epoch': 1.27}


 64%|██████▍   | 2000/3126 [10:16<05:45,  3.26it/s]Saving model checkpoint to ./results/checkpoint-2000
Configuration saved in ./results/checkpoint-2000/config.json


{'loss': 0.1697, 'learning_rate': 2.1439451637471443e-05, 'epoch': 1.28}


Model weights saved in ./results/checkpoint-2000/pytorch_model.bin
 64%|██████▍   | 2010/3126 [10:21<05:55,  3.14it/s]

{'loss': 0.1321, 'learning_rate': 2.124904798172125e-05, 'epoch': 1.29}


 65%|██████▍   | 2020/3126 [10:24<05:39,  3.26it/s]

{'loss': 0.2173, 'learning_rate': 2.105864432597106e-05, 'epoch': 1.29}


 65%|██████▍   | 2030/3126 [10:27<05:35,  3.26it/s]

{'loss': 0.0936, 'learning_rate': 2.086824067022087e-05, 'epoch': 1.3}


 65%|██████▌   | 2040/3126 [10:30<05:33,  3.26it/s]

{'loss': 0.089, 'learning_rate': 2.067783701447068e-05, 'epoch': 1.31}


 66%|██████▌   | 2050/3126 [10:33<05:30,  3.26it/s]

{'loss': 0.1509, 'learning_rate': 2.048743335872049e-05, 'epoch': 1.31}


 66%|██████▌   | 2060/3126 [10:36<05:26,  3.26it/s]

{'loss': 0.1578, 'learning_rate': 2.02970297029703e-05, 'epoch': 1.32}


 66%|██████▌   | 2070/3126 [10:39<05:24,  3.26it/s]

{'loss': 0.1553, 'learning_rate': 2.0106626047220107e-05, 'epoch': 1.32}


 67%|██████▋   | 2080/3126 [10:42<05:20,  3.26it/s]

{'loss': 0.2236, 'learning_rate': 1.9916222391469915e-05, 'epoch': 1.33}


 67%|██████▋   | 2090/3126 [10:45<05:17,  3.26it/s]

{'loss': 0.1236, 'learning_rate': 1.9725818735719726e-05, 'epoch': 1.34}


 67%|██████▋   | 2100/3126 [10:48<05:15,  3.26it/s]

{'loss': 0.1182, 'learning_rate': 1.9535415079969536e-05, 'epoch': 1.34}


 67%|██████▋   | 2110/3126 [10:51<05:11,  3.26it/s]

{'loss': 0.0927, 'learning_rate': 1.9345011424219344e-05, 'epoch': 1.35}


 68%|██████▊   | 2120/3126 [10:54<05:09,  3.25it/s]

{'loss': 0.1597, 'learning_rate': 1.9154607768469154e-05, 'epoch': 1.36}


 68%|██████▊   | 2130/3126 [10:57<05:05,  3.26it/s]

{'loss': 0.1707, 'learning_rate': 1.8964204112718965e-05, 'epoch': 1.36}


 68%|██████▊   | 2140/3126 [11:00<05:02,  3.26it/s]

{'loss': 0.1619, 'learning_rate': 1.8773800456968772e-05, 'epoch': 1.37}


 69%|██████▉   | 2150/3126 [11:04<04:59,  3.25it/s]

{'loss': 0.2254, 'learning_rate': 1.8583396801218583e-05, 'epoch': 1.38}


 69%|██████▉   | 2160/3126 [11:07<04:56,  3.26it/s]

{'loss': 0.1715, 'learning_rate': 1.8392993145468394e-05, 'epoch': 1.38}


 69%|██████▉   | 2170/3126 [11:10<04:53,  3.26it/s]

{'loss': 0.1334, 'learning_rate': 1.82025894897182e-05, 'epoch': 1.39}


 70%|██████▉   | 2180/3126 [11:13<04:50,  3.26it/s]

{'loss': 0.1176, 'learning_rate': 1.8012185833968012e-05, 'epoch': 1.39}


 70%|███████   | 2190/3126 [11:16<04:47,  3.25it/s]

{'loss': 0.1299, 'learning_rate': 1.7821782178217823e-05, 'epoch': 1.4}


 70%|███████   | 2200/3126 [11:19<04:44,  3.26it/s]

{'loss': 0.1149, 'learning_rate': 1.763137852246763e-05, 'epoch': 1.41}


 71%|███████   | 2210/3126 [11:22<04:41,  3.26it/s]

{'loss': 0.1291, 'learning_rate': 1.744097486671744e-05, 'epoch': 1.41}


 71%|███████   | 2220/3126 [11:25<04:38,  3.26it/s]

{'loss': 0.133, 'learning_rate': 1.725057121096725e-05, 'epoch': 1.42}


 71%|███████▏  | 2230/3126 [11:28<04:34,  3.26it/s]

{'loss': 0.2044, 'learning_rate': 1.706016755521706e-05, 'epoch': 1.43}


 72%|███████▏  | 2240/3126 [11:31<04:31,  3.26it/s]

{'loss': 0.234, 'learning_rate': 1.686976389946687e-05, 'epoch': 1.43}


 72%|███████▏  | 2250/3126 [11:34<04:28,  3.26it/s]

{'loss': 0.1086, 'learning_rate': 1.667936024371668e-05, 'epoch': 1.44}


 72%|███████▏  | 2260/3126 [11:37<04:25,  3.26it/s]

{'loss': 0.0895, 'learning_rate': 1.6488956587966488e-05, 'epoch': 1.45}


 73%|███████▎  | 2270/3126 [11:40<04:23,  3.25it/s]

{'loss': 0.2111, 'learning_rate': 1.62985529322163e-05, 'epoch': 1.45}


 73%|███████▎  | 2280/3126 [11:43<04:19,  3.25it/s]

{'loss': 0.2954, 'learning_rate': 1.610814927646611e-05, 'epoch': 1.46}


 73%|███████▎  | 2290/3126 [11:47<04:16,  3.26it/s]

{'loss': 0.1366, 'learning_rate': 1.5917745620715916e-05, 'epoch': 1.47}


 74%|███████▎  | 2300/3126 [11:50<04:14,  3.25it/s]

{'loss': 0.1171, 'learning_rate': 1.5727341964965727e-05, 'epoch': 1.47}


 74%|███████▍  | 2310/3126 [11:53<04:10,  3.26it/s]

{'loss': 0.1238, 'learning_rate': 1.5536938309215538e-05, 'epoch': 1.48}


 74%|███████▍  | 2320/3126 [11:56<04:07,  3.26it/s]

{'loss': 0.121, 'learning_rate': 1.534653465346535e-05, 'epoch': 1.48}


 75%|███████▍  | 2330/3126 [11:59<04:04,  3.26it/s]

{'loss': 0.0964, 'learning_rate': 1.5156130997715156e-05, 'epoch': 1.49}


 75%|███████▍  | 2340/3126 [12:02<04:01,  3.26it/s]

{'loss': 0.1461, 'learning_rate': 1.4965727341964967e-05, 'epoch': 1.5}


 75%|███████▌  | 2350/3126 [12:05<03:58,  3.25it/s]

{'loss': 0.1453, 'learning_rate': 1.4775323686214776e-05, 'epoch': 1.5}


 75%|███████▌  | 2360/3126 [12:08<03:55,  3.26it/s]

{'loss': 0.0848, 'learning_rate': 1.4584920030464585e-05, 'epoch': 1.51}


 76%|███████▌  | 2370/3126 [12:11<03:51,  3.26it/s]

{'loss': 0.1047, 'learning_rate': 1.4394516374714396e-05, 'epoch': 1.52}


 76%|███████▌  | 2380/3126 [12:14<03:49,  3.26it/s]

{'loss': 0.1756, 'learning_rate': 1.4204112718964205e-05, 'epoch': 1.52}


 76%|███████▋  | 2390/3126 [12:17<03:45,  3.26it/s]

{'loss': 0.1928, 'learning_rate': 1.4013709063214015e-05, 'epoch': 1.53}


 77%|███████▋  | 2400/3126 [12:20<03:42,  3.26it/s]

{'loss': 0.1537, 'learning_rate': 1.3823305407463824e-05, 'epoch': 1.54}


 77%|███████▋  | 2410/3126 [12:23<03:39,  3.26it/s]

{'loss': 0.1315, 'learning_rate': 1.3632901751713633e-05, 'epoch': 1.54}


 77%|███████▋  | 2420/3126 [12:27<03:36,  3.26it/s]

{'loss': 0.1552, 'learning_rate': 1.3442498095963444e-05, 'epoch': 1.55}


 78%|███████▊  | 2430/3126 [12:30<03:33,  3.26it/s]

{'loss': 0.0765, 'learning_rate': 1.3252094440213253e-05, 'epoch': 1.55}


 78%|███████▊  | 2440/3126 [12:33<03:30,  3.26it/s]

{'loss': 0.1119, 'learning_rate': 1.3061690784463062e-05, 'epoch': 1.56}


 78%|███████▊  | 2450/3126 [12:36<03:27,  3.26it/s]

{'loss': 0.2282, 'learning_rate': 1.2871287128712873e-05, 'epoch': 1.57}


 79%|███████▊  | 2460/3126 [12:39<03:24,  3.26it/s]

{'loss': 0.041, 'learning_rate': 1.2680883472962682e-05, 'epoch': 1.57}


 79%|███████▉  | 2470/3126 [12:42<03:21,  3.26it/s]

{'loss': 0.1131, 'learning_rate': 1.2490479817212491e-05, 'epoch': 1.58}


 79%|███████▉  | 2480/3126 [12:45<03:18,  3.25it/s]

{'loss': 0.1202, 'learning_rate': 1.2300076161462302e-05, 'epoch': 1.59}


 80%|███████▉  | 2490/3126 [12:48<03:15,  3.26it/s]

{'loss': 0.1923, 'learning_rate': 1.210967250571211e-05, 'epoch': 1.59}


 80%|███████▉  | 2500/3126 [12:51<03:12,  3.26it/s]Saving model checkpoint to ./results/checkpoint-2500
Configuration saved in ./results/checkpoint-2500/config.json


{'loss': 0.1952, 'learning_rate': 1.191926884996192e-05, 'epoch': 1.6}


Model weights saved in ./results/checkpoint-2500/pytorch_model.bin
 80%|████████  | 2510/3126 [12:55<03:16,  3.14it/s]

{'loss': 0.0847, 'learning_rate': 1.172886519421173e-05, 'epoch': 1.61}


 81%|████████  | 2520/3126 [12:58<03:06,  3.25it/s]

{'loss': 0.1137, 'learning_rate': 1.153846153846154e-05, 'epoch': 1.61}


 81%|████████  | 2530/3126 [13:01<03:02,  3.26it/s]

{'loss': 0.1459, 'learning_rate': 1.1348057882711349e-05, 'epoch': 1.62}


 81%|████████▏ | 2540/3126 [13:04<02:59,  3.26it/s]

{'loss': 0.0816, 'learning_rate': 1.1157654226961158e-05, 'epoch': 1.63}


 82%|████████▏ | 2550/3126 [13:07<02:56,  3.26it/s]

{'loss': 0.1225, 'learning_rate': 1.0967250571210967e-05, 'epoch': 1.63}


 82%|████████▏ | 2560/3126 [13:10<02:53,  3.26it/s]

{'loss': 0.1163, 'learning_rate': 1.0776846915460777e-05, 'epoch': 1.64}


 82%|████████▏ | 2570/3126 [13:14<02:50,  3.26it/s]

{'loss': 0.0808, 'learning_rate': 1.0586443259710586e-05, 'epoch': 1.64}


 83%|████████▎ | 2580/3126 [13:17<02:47,  3.26it/s]

{'loss': 0.1388, 'learning_rate': 1.0396039603960395e-05, 'epoch': 1.65}


 83%|████████▎ | 2590/3126 [13:20<02:44,  3.26it/s]

{'loss': 0.1852, 'learning_rate': 1.0205635948210206e-05, 'epoch': 1.66}


 83%|████████▎ | 2600/3126 [13:23<02:41,  3.26it/s]

{'loss': 0.1741, 'learning_rate': 1.0015232292460015e-05, 'epoch': 1.66}


 83%|████████▎ | 2610/3126 [13:26<02:38,  3.26it/s]

{'loss': 0.196, 'learning_rate': 9.824828636709824e-06, 'epoch': 1.67}


 84%|████████▍ | 2620/3126 [13:29<02:35,  3.26it/s]

{'loss': 0.144, 'learning_rate': 9.634424980959635e-06, 'epoch': 1.68}


 84%|████████▍ | 2630/3126 [13:32<02:32,  3.26it/s]

{'loss': 0.1462, 'learning_rate': 9.444021325209444e-06, 'epoch': 1.68}


 84%|████████▍ | 2640/3126 [13:35<02:29,  3.26it/s]

{'loss': 0.113, 'learning_rate': 9.253617669459253e-06, 'epoch': 1.69}


 85%|████████▍ | 2650/3126 [13:38<02:26,  3.26it/s]

{'loss': 0.1404, 'learning_rate': 9.063214013709064e-06, 'epoch': 1.7}


 85%|████████▌ | 2660/3126 [13:41<02:23,  3.26it/s]

{'loss': 0.1485, 'learning_rate': 8.872810357958873e-06, 'epoch': 1.7}


 85%|████████▌ | 2670/3126 [13:44<02:19,  3.26it/s]

{'loss': 0.0849, 'learning_rate': 8.682406702208684e-06, 'epoch': 1.71}


 86%|████████▌ | 2680/3126 [13:47<02:16,  3.26it/s]

{'loss': 0.0772, 'learning_rate': 8.492003046458493e-06, 'epoch': 1.71}


 86%|████████▌ | 2690/3126 [13:50<02:13,  3.26it/s]

{'loss': 0.1158, 'learning_rate': 8.301599390708302e-06, 'epoch': 1.72}


 86%|████████▋ | 2700/3126 [13:53<02:10,  3.26it/s]

{'loss': 0.133, 'learning_rate': 8.111195734958112e-06, 'epoch': 1.73}


 87%|████████▋ | 2710/3126 [13:57<02:07,  3.26it/s]

{'loss': 0.1324, 'learning_rate': 7.920792079207921e-06, 'epoch': 1.73}


 87%|████████▋ | 2720/3126 [14:00<02:04,  3.26it/s]

{'loss': 0.1802, 'learning_rate': 7.73038842345773e-06, 'epoch': 1.74}


 87%|████████▋ | 2730/3126 [14:03<02:01,  3.26it/s]

{'loss': 0.1779, 'learning_rate': 7.53998476770754e-06, 'epoch': 1.75}


 88%|████████▊ | 2740/3126 [14:06<01:58,  3.26it/s]

{'loss': 0.1599, 'learning_rate': 7.34958111195735e-06, 'epoch': 1.75}


 88%|████████▊ | 2750/3126 [14:09<01:55,  3.26it/s]

{'loss': 0.2061, 'learning_rate': 7.15917745620716e-06, 'epoch': 1.76}


 88%|████████▊ | 2760/3126 [14:12<01:52,  3.26it/s]

{'loss': 0.1181, 'learning_rate': 6.968773800456969e-06, 'epoch': 1.77}


 89%|████████▊ | 2770/3126 [14:15<01:49,  3.25it/s]

{'loss': 0.0897, 'learning_rate': 6.778370144706779e-06, 'epoch': 1.77}


 89%|████████▉ | 2780/3126 [14:18<01:46,  3.25it/s]

{'loss': 0.1473, 'learning_rate': 6.587966488956589e-06, 'epoch': 1.78}


 89%|████████▉ | 2790/3126 [14:21<01:43,  3.26it/s]

{'loss': 0.1434, 'learning_rate': 6.397562833206398e-06, 'epoch': 1.79}


 90%|████████▉ | 2800/3126 [14:24<01:40,  3.26it/s]

{'loss': 0.1359, 'learning_rate': 6.207159177456207e-06, 'epoch': 1.79}


 90%|████████▉ | 2810/3126 [14:27<01:37,  3.26it/s]

{'loss': 0.2329, 'learning_rate': 6.016755521706017e-06, 'epoch': 1.8}


 90%|█████████ | 2820/3126 [14:30<01:33,  3.26it/s]

{'loss': 0.0629, 'learning_rate': 5.826351865955827e-06, 'epoch': 1.8}


 91%|█████████ | 2830/3126 [14:33<01:30,  3.26it/s]

{'loss': 0.0968, 'learning_rate': 5.635948210205636e-06, 'epoch': 1.81}


 91%|█████████ | 2840/3126 [14:37<01:27,  3.25it/s]

{'loss': 0.1019, 'learning_rate': 5.445544554455446e-06, 'epoch': 1.82}


 91%|█████████ | 2850/3126 [14:40<01:24,  3.26it/s]

{'loss': 0.1639, 'learning_rate': 5.2551408987052555e-06, 'epoch': 1.82}


 91%|█████████▏| 2860/3126 [14:43<01:21,  3.26it/s]

{'loss': 0.0787, 'learning_rate': 5.064737242955065e-06, 'epoch': 1.83}


 92%|█████████▏| 2870/3126 [14:46<01:18,  3.26it/s]

{'loss': 0.1646, 'learning_rate': 4.8743335872048744e-06, 'epoch': 1.84}


 92%|█████████▏| 2880/3126 [14:49<01:15,  3.25it/s]

{'loss': 0.0751, 'learning_rate': 4.683929931454684e-06, 'epoch': 1.84}


 92%|█████████▏| 2890/3126 [14:52<01:12,  3.26it/s]

{'loss': 0.1081, 'learning_rate': 4.493526275704494e-06, 'epoch': 1.85}


 93%|█████████▎| 2900/3126 [14:55<01:09,  3.26it/s]

{'loss': 0.1141, 'learning_rate': 4.303122619954303e-06, 'epoch': 1.86}


 93%|█████████▎| 2910/3126 [14:58<01:06,  3.26it/s]

{'loss': 0.1244, 'learning_rate': 4.112718964204113e-06, 'epoch': 1.86}


 93%|█████████▎| 2920/3126 [15:01<01:03,  3.26it/s]

{'loss': 0.1338, 'learning_rate': 3.922315308453923e-06, 'epoch': 1.87}


 94%|█████████▎| 2930/3126 [15:04<01:00,  3.26it/s]

{'loss': 0.1779, 'learning_rate': 3.7319116527037316e-06, 'epoch': 1.87}


 94%|█████████▍| 2940/3126 [15:07<00:57,  3.25it/s]

{'loss': 0.2046, 'learning_rate': 3.5415079969535415e-06, 'epoch': 1.88}


 94%|█████████▍| 2950/3126 [15:10<00:54,  3.26it/s]

{'loss': 0.1159, 'learning_rate': 3.351104341203351e-06, 'epoch': 1.89}


 95%|█████████▍| 2960/3126 [15:13<00:50,  3.26it/s]

{'loss': 0.1779, 'learning_rate': 3.160700685453161e-06, 'epoch': 1.89}


 95%|█████████▌| 2970/3126 [15:16<00:47,  3.26it/s]

{'loss': 0.1294, 'learning_rate': 2.9702970297029703e-06, 'epoch': 1.9}


 95%|█████████▌| 2980/3126 [15:20<00:44,  3.25it/s]

{'loss': 0.1063, 'learning_rate': 2.7798933739527798e-06, 'epoch': 1.91}


 96%|█████████▌| 2990/3126 [15:23<00:41,  3.26it/s]

{'loss': 0.1442, 'learning_rate': 2.5894897182025897e-06, 'epoch': 1.91}


 96%|█████████▌| 3000/3126 [15:26<00:38,  3.26it/s]Saving model checkpoint to ./results/checkpoint-3000
Configuration saved in ./results/checkpoint-3000/config.json


{'loss': 0.0894, 'learning_rate': 2.399086062452399e-06, 'epoch': 1.92}


Model weights saved in ./results/checkpoint-3000/pytorch_model.bin
 96%|█████████▋| 3010/3126 [15:30<00:36,  3.14it/s]

{'loss': 0.1636, 'learning_rate': 2.208682406702209e-06, 'epoch': 1.93}


 97%|█████████▋| 3020/3126 [15:33<00:32,  3.26it/s]

{'loss': 0.1174, 'learning_rate': 2.0182787509520185e-06, 'epoch': 1.93}


 97%|█████████▋| 3030/3126 [15:36<00:29,  3.25it/s]

{'loss': 0.1229, 'learning_rate': 1.8278750952018281e-06, 'epoch': 1.94}


 97%|█████████▋| 3040/3126 [15:39<00:26,  3.26it/s]

{'loss': 0.249, 'learning_rate': 1.6374714394516378e-06, 'epoch': 1.94}


 98%|█████████▊| 3050/3126 [15:42<00:23,  3.26it/s]

{'loss': 0.0913, 'learning_rate': 1.4470677837014473e-06, 'epoch': 1.95}


 98%|█████████▊| 3060/3126 [15:45<00:20,  3.26it/s]

{'loss': 0.2097, 'learning_rate': 1.2566641279512567e-06, 'epoch': 1.96}


 98%|█████████▊| 3070/3126 [15:48<00:17,  3.26it/s]

{'loss': 0.0855, 'learning_rate': 1.0662604722010662e-06, 'epoch': 1.96}


 99%|█████████▊| 3080/3126 [15:51<00:14,  3.26it/s]

{'loss': 0.0588, 'learning_rate': 8.75856816450876e-07, 'epoch': 1.97}


 99%|█████████▉| 3090/3126 [15:54<00:11,  3.26it/s]

{'loss': 0.1023, 'learning_rate': 6.854531607006854e-07, 'epoch': 1.98}


 99%|█████████▉| 3100/3126 [15:57<00:07,  3.26it/s]

{'loss': 0.2192, 'learning_rate': 4.950495049504951e-07, 'epoch': 1.98}


 99%|█████████▉| 3110/3126 [16:00<00:04,  3.26it/s]

{'loss': 0.0987, 'learning_rate': 3.0464584920030465e-07, 'epoch': 1.99}


100%|█████████▉| 3120/3126 [16:03<00:01,  3.26it/s]

{'loss': 0.0889, 'learning_rate': 1.1424219345011426e-07, 'epoch': 2.0}


100%|██████████| 3126/3126 [16:05<00:00,  3.92it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 3126/3126 [16:05<00:00,  3.24it/s]

{'train_runtime': 965.6626, 'train_samples_per_second': 51.774, 'train_steps_per_second': 3.237, 'train_loss': 0.222401846269347, 'epoch': 2.0}





TrainOutput(global_step=3126, training_loss=0.222401846269347, metrics={'train_runtime': 965.6626, 'train_samples_per_second': 51.774, 'train_steps_per_second': 3.237, 'train_loss': 0.222401846269347, 'epoch': 2.0})

fine-tuning w classic pytorch training loop

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW

In [None]:
device = torch.device('cuda')

In [None]:
model2 = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model2.to(device)

In [None]:
model2.train()

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_train_epochs = 2

In [None]:
for epoch in range(num_train_epochs):
    for batch in train_dataloader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model2(input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs[0]
        loss.backward()
        optimizer.step()

Inference

In [None]:
model.eval()