In [2]:
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import confusion_matrix
from datasets import Dataset
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Read the CSV file using pandas
df = pd.read_csv('training_dataset.csv', encoding='utf-8')
dataset = Dataset.from_pandas(df)

# Define training and testing split
dataset = dataset.train_test_split(test_size=0.2)

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Prepare the dataset for PyTorch
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

Map: 100%|██████████| 333447/333447 [01:53<00:00, 2945.58 examples/s]
Map: 100%|██████████| 83362/83362 [00:27<00:00, 3009.99 examples/s]


In [6]:
# Intilize the model and move it to the GPU
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=6)
model.to(torch.device('cuda'))

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    optim="adamw_torch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
)

# Initialize the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

# Train the model
trainer.train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  1%|          | 500/62523 [04:08<8:37:26,  2.00it/s]

{'loss': 0.7709, 'grad_norm': 11.126877784729004, 'learning_rate': 1.9841018505190092e-05, 'epoch': 0.02}


  2%|▏         | 1000/62523 [08:27<9:24:49,  1.82it/s]

{'loss': 0.2293, 'grad_norm': 5.901440620422363, 'learning_rate': 1.9681077363530224e-05, 'epoch': 0.05}


  2%|▏         | 1500/62523 [12:40<8:12:17,  2.07it/s] 

{'loss': 0.1913, 'grad_norm': 83.10171508789062, 'learning_rate': 1.9521136221870354e-05, 'epoch': 0.07}


  3%|▎         | 2000/62523 [16:45<8:09:29,  2.06it/s] 

{'loss': 0.1579, 'grad_norm': 14.162734031677246, 'learning_rate': 1.9361195080210483e-05, 'epoch': 0.1}


  4%|▍         | 2500/62523 [20:58<8:22:18,  1.99it/s] 

{'loss': 0.1513, 'grad_norm': 1.4201369285583496, 'learning_rate': 1.9201573820833936e-05, 'epoch': 0.12}


  5%|▍         | 3000/62523 [25:11<8:17:26,  1.99it/s] 

{'loss': 0.1262, 'grad_norm': 1.0418074131011963, 'learning_rate': 1.9041632679174065e-05, 'epoch': 0.14}


  6%|▌         | 3500/62523 [29:24<7:53:57,  2.08it/s] 

{'loss': 0.1234, 'grad_norm': 13.620994567871094, 'learning_rate': 1.8882011419797518e-05, 'epoch': 0.17}


  6%|▋         | 4000/62523 [33:33<8:07:51,  2.00it/s] 

{'loss': 0.1241, 'grad_norm': 1.55431067943573, 'learning_rate': 1.8722070278137647e-05, 'epoch': 0.19}


  7%|▋         | 4500/62523 [37:39<7:47:55,  2.07it/s] 

{'loss': 0.1189, 'grad_norm': 2.5551462173461914, 'learning_rate': 1.856212913647778e-05, 'epoch': 0.22}


  8%|▊         | 5000/62523 [41:38<7:36:54,  2.10it/s] 

{'loss': 0.1204, 'grad_norm': 1.0490080118179321, 'learning_rate': 1.840218799481791e-05, 'epoch': 0.24}


  9%|▉         | 5500/62523 [45:37<7:32:14,  2.10it/s] 

{'loss': 0.1168, 'grad_norm': 1.1592341661453247, 'learning_rate': 1.8242246853158038e-05, 'epoch': 0.26}


 10%|▉         | 6000/62523 [49:37<7:28:49,  2.10it/s] 

{'loss': 0.1108, 'grad_norm': 1.482081651687622, 'learning_rate': 1.808230571149817e-05, 'epoch': 0.29}


 10%|█         | 6500/62523 [53:36<7:24:35,  2.10it/s] 

{'loss': 0.1205, 'grad_norm': 1.0092833042144775, 'learning_rate': 1.7922364569838303e-05, 'epoch': 0.31}


 11%|█         | 7000/62523 [57:35<7:20:40,  2.10it/s] 

{'loss': 0.1199, 'grad_norm': 3.386826992034912, 'learning_rate': 1.7762743310461752e-05, 'epoch': 0.34}


 12%|█▏        | 7500/62523 [1:01:34<7:16:36,  2.10it/s]

{'loss': 0.1092, 'grad_norm': 0.7086572051048279, 'learning_rate': 1.760280216880188e-05, 'epoch': 0.36}


 13%|█▎        | 8000/62523 [1:05:33<7:12:41,  2.10it/s] 

{'loss': 0.109, 'grad_norm': 0.9333259463310242, 'learning_rate': 1.7442861027142014e-05, 'epoch': 0.38}


 14%|█▎        | 8500/62523 [1:09:32<7:08:10,  2.10it/s] 

{'loss': 0.109, 'grad_norm': 49.56344985961914, 'learning_rate': 1.7282919885482143e-05, 'epoch': 0.41}


 14%|█▍        | 9000/62523 [1:13:32<7:05:02,  2.10it/s] 

{'loss': 0.1004, 'grad_norm': 3.629791498184204, 'learning_rate': 1.7122978743822272e-05, 'epoch': 0.43}


 15%|█▌        | 9500/62523 [1:17:31<7:01:08,  2.10it/s] 

{'loss': 0.104, 'grad_norm': 4.650318145751953, 'learning_rate': 1.6963037602162408e-05, 'epoch': 0.46}


 16%|█▌        | 10000/62523 [1:21:36<7:13:12,  2.02it/s]

{'loss': 0.1026, 'grad_norm': 5.914760112762451, 'learning_rate': 1.6803096460502537e-05, 'epoch': 0.48}


 17%|█▋        | 10500/62523 [1:25:45<7:09:33,  2.02it/s] 

{'loss': 0.1082, 'grad_norm': 1.4643604755401611, 'learning_rate': 1.6643155318842667e-05, 'epoch': 0.5}


 18%|█▊        | 11000/62523 [1:29:54<7:05:01,  2.02it/s] 

{'loss': 0.0962, 'grad_norm': 5.307196617126465, 'learning_rate': 1.64832141771828e-05, 'epoch': 0.53}


 18%|█▊        | 11500/62523 [1:34:03<7:00:27,  2.02it/s] 

{'loss': 0.1001, 'grad_norm': 0.8609771132469177, 'learning_rate': 1.632359291780625e-05, 'epoch': 0.55}


 19%|█▉        | 12000/62523 [1:38:11<6:57:00,  2.02it/s] 

{'loss': 0.1001, 'grad_norm': 1.5149198770523071, 'learning_rate': 1.6163651776146378e-05, 'epoch': 0.58}


 20%|█▉        | 12500/62523 [1:42:20<6:53:04,  2.02it/s] 

{'loss': 0.1039, 'grad_norm': 1.8400063514709473, 'learning_rate': 1.600403051676983e-05, 'epoch': 0.6}


 21%|██        | 13000/62523 [1:46:29<6:48:30,  2.02it/s] 

{'loss': 0.1084, 'grad_norm': 1.151090383529663, 'learning_rate': 1.584408937510996e-05, 'epoch': 0.62}


 22%|██▏       | 13500/62523 [1:50:37<6:44:23,  2.02it/s] 

{'loss': 0.0961, 'grad_norm': 1.559792399406433, 'learning_rate': 1.5684148233450092e-05, 'epoch': 0.65}


 22%|██▏       | 14000/62523 [1:54:46<6:40:37,  2.02it/s] 

{'loss': 0.098, 'grad_norm': 2.36411190032959, 'learning_rate': 1.552420709179022e-05, 'epoch': 0.67}


 23%|██▎       | 14500/62523 [1:58:55<6:36:37,  2.02it/s] 

{'loss': 0.0975, 'grad_norm': 0.006750921718776226, 'learning_rate': 1.5364265950130354e-05, 'epoch': 0.7}


 24%|██▍       | 15000/62523 [2:03:04<6:32:38,  2.02it/s] 

{'loss': 0.0999, 'grad_norm': 1.807051420211792, 'learning_rate': 1.5204324808470483e-05, 'epoch': 0.72}


 25%|██▍       | 15500/62523 [2:07:12<6:27:57,  2.02it/s] 

{'loss': 0.095, 'grad_norm': 0.9082040786743164, 'learning_rate': 1.5044383666810616e-05, 'epoch': 0.74}


 26%|██▌       | 16000/62523 [2:11:21<6:23:46,  2.02it/s] 

{'loss': 0.1009, 'grad_norm': 3.3242833614349365, 'learning_rate': 1.4884442525150747e-05, 'epoch': 0.77}


 26%|██▋       | 16500/62523 [2:15:30<6:19:53,  2.02it/s] 

{'loss': 0.0913, 'grad_norm': 1.0911779403686523, 'learning_rate': 1.4725141148057517e-05, 'epoch': 0.79}


 27%|██▋       | 17000/62523 [2:19:39<6:15:58,  2.02it/s] 

{'loss': 0.0989, 'grad_norm': 1.0604248046875, 'learning_rate': 1.4565200006397646e-05, 'epoch': 0.82}


 28%|██▊       | 17500/62523 [2:23:48<6:11:46,  2.02it/s] 

{'loss': 0.1034, 'grad_norm': 1.3030966520309448, 'learning_rate': 1.4405258864737777e-05, 'epoch': 0.84}


 29%|██▉       | 18000/62523 [2:27:56<6:07:21,  2.02it/s] 

{'loss': 0.0935, 'grad_norm': 0.9969227910041809, 'learning_rate': 1.4245317723077909e-05, 'epoch': 0.86}


 30%|██▉       | 18500/62523 [2:32:05<6:03:21,  2.02it/s] 

{'loss': 0.0952, 'grad_norm': 1.6677194833755493, 'learning_rate': 1.408537658141804e-05, 'epoch': 0.89}


 30%|███       | 19000/62523 [2:36:14<5:59:14,  2.02it/s] 

{'loss': 0.0934, 'grad_norm': 1.0115875005722046, 'learning_rate': 1.392575532204149e-05, 'epoch': 0.91}


 31%|███       | 19500/62523 [2:40:22<5:54:46,  2.02it/s]

{'loss': 0.0957, 'grad_norm': 51.354248046875, 'learning_rate': 1.376581418038162e-05, 'epoch': 0.94}


 32%|███▏      | 20000/62523 [2:44:31<5:51:08,  2.02it/s]

{'loss': 0.0925, 'grad_norm': 1.616455316543579, 'learning_rate': 1.3605873038721751e-05, 'epoch': 0.96}


 33%|███▎      | 20500/62523 [2:48:40<5:46:53,  2.02it/s]

{'loss': 0.0925, 'grad_norm': 0.9186195731163025, 'learning_rate': 1.3445931897061882e-05, 'epoch': 0.98}


                                                         
 33%|███▎      | 20841/62523 [3:04:40<4:56:16,  2.34it/s]

{'eval_loss': 0.0937456339597702, 'eval_runtime': 789.8746, 'eval_samples_per_second': 105.538, 'eval_steps_per_second': 6.597, 'epoch': 1.0}


 34%|███▎      | 21000/62523 [3:05:59<5:43:27,  2.01it/s]    

{'loss': 0.0896, 'grad_norm': 0.0033360968809574842, 'learning_rate': 1.3285990755402015e-05, 'epoch': 1.01}


 34%|███▍      | 21500/62523 [3:10:07<5:38:32,  2.02it/s]

{'loss': 0.0896, 'grad_norm': 2.0100739002227783, 'learning_rate': 1.3126049613742144e-05, 'epoch': 1.03}


 35%|███▌      | 22000/62523 [3:14:16<5:34:32,  2.02it/s]

{'loss': 0.087, 'grad_norm': 1.4837398529052734, 'learning_rate': 1.2966108472082274e-05, 'epoch': 1.06}


 36%|███▌      | 22500/62523 [3:18:25<5:30:13,  2.02it/s]

{'loss': 0.0914, 'grad_norm': 0.0070676179602742195, 'learning_rate': 1.2806167330422405e-05, 'epoch': 1.08}


 37%|███▋      | 23000/62523 [3:22:33<5:25:56,  2.02it/s]

{'loss': 0.0934, 'grad_norm': 17.78093910217285, 'learning_rate': 1.2646226188762536e-05, 'epoch': 1.1}


 38%|███▊      | 23500/62523 [3:26:42<5:21:43,  2.02it/s]

{'loss': 0.086, 'grad_norm': 1.394652009010315, 'learning_rate': 1.2486604929385986e-05, 'epoch': 1.13}


 38%|███▊      | 24000/62523 [3:30:51<5:18:00,  2.02it/s]

{'loss': 0.1001, 'grad_norm': 1.7807763814926147, 'learning_rate': 1.2326663787726118e-05, 'epoch': 1.15}


 39%|███▉      | 24500/62523 [3:35:00<5:13:46,  2.02it/s]

{'loss': 0.0873, 'grad_norm': 1.122314214706421, 'learning_rate': 1.2166722646066249e-05, 'epoch': 1.18}


 40%|███▉      | 25000/62523 [3:39:08<5:09:46,  2.02it/s]

{'loss': 0.0867, 'grad_norm': 0.6305559873580933, 'learning_rate': 1.200678150440638e-05, 'epoch': 1.2}


 41%|████      | 25500/62523 [3:43:17<5:05:35,  2.02it/s]

{'loss': 0.0862, 'grad_norm': 0.7416844964027405, 'learning_rate': 1.184684036274651e-05, 'epoch': 1.22}


 42%|████▏     | 26000/62523 [3:47:26<5:01:13,  2.02it/s]

{'loss': 0.0925, 'grad_norm': 1.4886800050735474, 'learning_rate': 1.1687538985653279e-05, 'epoch': 1.25}


 42%|████▏     | 26500/62523 [3:51:34<4:57:29,  2.02it/s]

{'loss': 0.0884, 'grad_norm': 3.125500440597534, 'learning_rate': 1.1527597843993412e-05, 'epoch': 1.27}


 43%|████▎     | 27000/62523 [3:55:43<4:53:15,  2.02it/s]

{'loss': 0.1033, 'grad_norm': 1.7894127368927002, 'learning_rate': 1.1367656702333542e-05, 'epoch': 1.3}


 44%|████▍     | 27500/62523 [3:59:52<4:48:50,  2.02it/s]

{'loss': 0.0928, 'grad_norm': 0.002468700520694256, 'learning_rate': 1.1207715560673673e-05, 'epoch': 1.32}


 45%|████▍     | 28000/62523 [4:04:00<4:44:47,  2.02it/s]

{'loss': 0.0888, 'grad_norm': 1.5479793548583984, 'learning_rate': 1.1047774419013804e-05, 'epoch': 1.34}


 46%|████▌     | 28500/62523 [4:08:09<4:40:48,  2.02it/s]

{'loss': 0.0901, 'grad_norm': 1.6317914724349976, 'learning_rate': 1.0887833277353935e-05, 'epoch': 1.37}


 46%|████▋     | 29000/62523 [4:12:18<4:36:32,  2.02it/s]

{'loss': 0.0928, 'grad_norm': 0.002985725412145257, 'learning_rate': 1.0727892135694064e-05, 'epoch': 1.39}


 47%|████▋     | 29500/62523 [4:16:27<4:32:40,  2.02it/s]

{'loss': 0.0977, 'grad_norm': 0.015489925630390644, 'learning_rate': 1.0568270876317517e-05, 'epoch': 1.42}


 48%|████▊     | 30000/62523 [4:20:35<4:28:31,  2.02it/s]

{'loss': 0.0883, 'grad_norm': 0.9838432669639587, 'learning_rate': 1.0408329734657648e-05, 'epoch': 1.44}


 49%|████▉     | 30500/62523 [4:24:44<4:24:15,  2.02it/s]

{'loss': 0.0847, 'grad_norm': 1.8102596998214722, 'learning_rate': 1.0248388592997777e-05, 'epoch': 1.46}


 50%|████▉     | 31000/62523 [4:28:53<4:20:02,  2.02it/s]

{'loss': 0.088, 'grad_norm': 1.7164726257324219, 'learning_rate': 1.0088447451337908e-05, 'epoch': 1.49}


 50%|█████     | 31500/62523 [4:33:01<4:16:02,  2.02it/s]

{'loss': 0.083, 'grad_norm': 1.2576600313186646, 'learning_rate': 9.928506309678039e-06, 'epoch': 1.51}


 51%|█████     | 32000/62523 [4:37:10<4:11:52,  2.02it/s]

{'loss': 0.0934, 'grad_norm': 2.1019577980041504, 'learning_rate': 9.76856516801817e-06, 'epoch': 1.54}


 52%|█████▏    | 32500/62523 [4:41:19<4:07:59,  2.02it/s]

{'loss': 0.0867, 'grad_norm': 2.1173839569091797, 'learning_rate': 9.6086240263583e-06, 'epoch': 1.56}


 53%|█████▎    | 33000/62523 [4:45:27<4:03:42,  2.02it/s]

{'loss': 0.0922, 'grad_norm': 0.9361430406570435, 'learning_rate': 9.448682884698431e-06, 'epoch': 1.58}


 54%|█████▎    | 33500/62523 [4:49:36<3:59:42,  2.02it/s]

{'loss': 0.0846, 'grad_norm': 2.2528891563415527, 'learning_rate': 9.288741743038564e-06, 'epoch': 1.61}


 54%|█████▍    | 34000/62523 [4:53:45<3:55:42,  2.02it/s]

{'loss': 0.0859, 'grad_norm': 15.638832092285156, 'learning_rate': 9.128800601378693e-06, 'epoch': 1.63}


 55%|█████▌    | 34500/62523 [4:57:54<3:51:04,  2.02it/s]

{'loss': 0.0859, 'grad_norm': 1.295509696006775, 'learning_rate': 8.968859459718824e-06, 'epoch': 1.66}


 56%|█████▌    | 35000/62523 [5:02:02<3:47:13,  2.02it/s]

{'loss': 0.0925, 'grad_norm': 2.164778470993042, 'learning_rate': 8.809238200342275e-06, 'epoch': 1.68}


 57%|█████▋    | 35500/62523 [5:06:11<3:43:06,  2.02it/s]

{'loss': 0.0865, 'grad_norm': 0.005021695513278246, 'learning_rate': 8.649297058682406e-06, 'epoch': 1.7}


 58%|█████▊    | 36000/62523 [5:10:20<3:38:44,  2.02it/s]

{'loss': 0.0882, 'grad_norm': 1.053728699684143, 'learning_rate': 8.489675799305857e-06, 'epoch': 1.73}


 58%|█████▊    | 36500/62523 [5:14:28<3:34:41,  2.02it/s]

{'loss': 0.087, 'grad_norm': 1.8314355611801147, 'learning_rate': 8.329734657645986e-06, 'epoch': 1.75}


 59%|█████▉    | 37000/62523 [5:18:37<3:30:39,  2.02it/s]

{'loss': 0.0868, 'grad_norm': 1.0927671194076538, 'learning_rate': 8.169793515986117e-06, 'epoch': 1.78}


 60%|█████▉    | 37500/62523 [5:22:46<3:26:39,  2.02it/s]

{'loss': 0.0914, 'grad_norm': 1.7072558403015137, 'learning_rate': 8.00985237432625e-06, 'epoch': 1.8}


 61%|██████    | 38000/62523 [5:26:55<3:22:27,  2.02it/s]

{'loss': 0.0826, 'grad_norm': 1.4776606559753418, 'learning_rate': 7.84991123266638e-06, 'epoch': 1.82}


 62%|██████▏   | 38500/62523 [5:31:03<3:18:19,  2.02it/s]

{'loss': 0.087, 'grad_norm': 1.2283473014831543, 'learning_rate': 7.68997009100651e-06, 'epoch': 1.85}


 62%|██████▏   | 39000/62523 [5:35:14<3:18:17,  1.98it/s]

{'loss': 0.0813, 'grad_norm': 0.6628310680389404, 'learning_rate': 7.53002894934664e-06, 'epoch': 1.87}


 63%|██████▎   | 39500/62523 [5:39:24<3:10:10,  2.02it/s]

{'loss': 0.0928, 'grad_norm': 5.11175537109375, 'learning_rate': 7.370087807686772e-06, 'epoch': 1.9}


 64%|██████▍   | 40000/62523 [5:43:33<3:06:00,  2.02it/s]

{'loss': 0.0874, 'grad_norm': 1.233674168586731, 'learning_rate': 7.210146666026903e-06, 'epoch': 1.92}


 65%|██████▍   | 40500/62523 [5:47:42<3:01:56,  2.02it/s]

{'loss': 0.0872, 'grad_norm': 1.6561059951782227, 'learning_rate': 7.050205524367033e-06, 'epoch': 1.94}


 66%|██████▌   | 41000/62523 [5:51:51<2:58:01,  2.02it/s]

{'loss': 0.0787, 'grad_norm': 1.236829400062561, 'learning_rate': 6.890264382707165e-06, 'epoch': 1.97}


 66%|██████▋   | 41500/62523 [5:55:59<2:53:34,  2.02it/s]

{'loss': 0.0831, 'grad_norm': 2.6982531547546387, 'learning_rate': 6.7303232410472954e-06, 'epoch': 1.99}


                                                         
 67%|██████▋   | 41682/62523 [6:10:41<2:26:51,  2.37it/s]

{'eval_loss': 0.0912330374121666, 'eval_runtime': 790.8501, 'eval_samples_per_second': 105.408, 'eval_steps_per_second': 6.589, 'epoch': 2.0}


 67%|██████▋   | 42000/62523 [6:13:19<2:49:48,  2.01it/s]    

{'loss': 0.0816, 'grad_norm': 2.090859889984131, 'learning_rate': 6.5703820993874254e-06, 'epoch': 2.02}


 68%|██████▊   | 42500/62523 [6:17:28<2:45:11,  2.02it/s]

{'loss': 0.0841, 'grad_norm': 1.4562145471572876, 'learning_rate': 6.410760840010877e-06, 'epoch': 2.04}


 69%|██████▉   | 43000/62523 [6:21:37<2:41:03,  2.02it/s]

{'loss': 0.086, 'grad_norm': 1.963714599609375, 'learning_rate': 6.2508196983510075e-06, 'epoch': 2.06}


 70%|██████▉   | 43500/62523 [6:25:46<2:37:01,  2.02it/s]

{'loss': 0.0821, 'grad_norm': 2.3427228927612305, 'learning_rate': 6.0908785566911375e-06, 'epoch': 2.09}


 70%|███████   | 44000/62523 [6:29:54<2:33:01,  2.02it/s]

{'loss': 0.0818, 'grad_norm': 2.4014880657196045, 'learning_rate': 5.930937415031269e-06, 'epoch': 2.11}


 71%|███████   | 44500/62523 [6:34:03<2:28:48,  2.02it/s]

{'loss': 0.0791, 'grad_norm': 2.3125569820404053, 'learning_rate': 5.7709962733714e-06, 'epoch': 2.14}


 72%|███████▏  | 45000/62523 [6:38:14<2:31:24,  1.93it/s]

{'loss': 0.0817, 'grad_norm': 0.818370521068573, 'learning_rate': 5.61105513171153e-06, 'epoch': 2.16}


 73%|███████▎  | 45500/62523 [6:42:26<2:20:38,  2.02it/s]

{'loss': 0.0844, 'grad_norm': 1.6671435832977295, 'learning_rate': 5.451433872334981e-06, 'epoch': 2.18}


 74%|███████▎  | 46000/62523 [6:46:38<2:17:50,  2.00it/s]

{'loss': 0.0802, 'grad_norm': 2.332866668701172, 'learning_rate': 5.291492730675112e-06, 'epoch': 2.21}


 74%|███████▍  | 46500/62523 [6:50:50<2:16:42,  1.95it/s]

{'loss': 0.084, 'grad_norm': 1.883328914642334, 'learning_rate': 5.131551589015242e-06, 'epoch': 2.23}


 75%|███████▌  | 47000/62523 [6:55:01<2:08:26,  2.01it/s]

{'loss': 0.084, 'grad_norm': 2.4423251152038574, 'learning_rate': 4.971610447355374e-06, 'epoch': 2.26}


 76%|███████▌  | 47500/62523 [6:59:10<2:04:00,  2.02it/s]

{'loss': 0.0878, 'grad_norm': 1.5336520671844482, 'learning_rate': 4.811989187978824e-06, 'epoch': 2.28}


 77%|███████▋  | 48000/62523 [7:03:21<2:00:39,  2.01it/s]

{'loss': 0.0815, 'grad_norm': 0.0016026643570512533, 'learning_rate': 4.652048046318955e-06, 'epoch': 2.3}


 78%|███████▊  | 48500/62523 [7:07:31<1:57:16,  1.99it/s]

{'loss': 0.0816, 'grad_norm': 1.8553959131240845, 'learning_rate': 4.492106904659086e-06, 'epoch': 2.33}


 78%|███████▊  | 49000/62523 [7:11:41<1:51:37,  2.02it/s]

{'loss': 0.0775, 'grad_norm': 1.4142200946807861, 'learning_rate': 4.332165762999217e-06, 'epoch': 2.35}


 79%|███████▉  | 49500/62523 [7:15:52<1:47:48,  2.01it/s]

{'loss': 0.0758, 'grad_norm': 1.3815594911575317, 'learning_rate': 4.1722246213393476e-06, 'epoch': 2.38}


 80%|███████▉  | 50000/62523 [7:20:06<1:44:21,  2.00it/s]

{'loss': 0.0755, 'grad_norm': 1.17353093624115, 'learning_rate': 4.012603361962798e-06, 'epoch': 2.4}


 81%|████████  | 50500/62523 [7:24:16<1:39:17,  2.02it/s]

{'loss': 0.0779, 'grad_norm': 3.0190272331237793, 'learning_rate': 3.852662220302929e-06, 'epoch': 2.42}


 82%|████████▏ | 51000/62523 [7:28:28<1:35:11,  2.02it/s]

{'loss': 0.0782, 'grad_norm': 0.9775117635726929, 'learning_rate': 3.6927210786430596e-06, 'epoch': 2.45}


 82%|████████▏ | 51500/62523 [7:32:40<1:31:13,  2.01it/s]

{'loss': 0.0766, 'grad_norm': 1.2408939599990845, 'learning_rate': 3.53309981926651e-06, 'epoch': 2.47}


 83%|████████▎ | 52000/62523 [7:36:54<1:27:31,  2.00it/s]

{'loss': 0.0796, 'grad_norm': 1.1758334636688232, 'learning_rate': 3.3731586776066413e-06, 'epoch': 2.5}


 84%|████████▍ | 52500/62523 [7:41:07<1:23:52,  1.99it/s]

{'loss': 0.0804, 'grad_norm': 1.7407315969467163, 'learning_rate': 3.213217535946772e-06, 'epoch': 2.52}


 85%|████████▍ | 53000/62523 [7:45:28<1:21:15,  1.95it/s]

{'loss': 0.0868, 'grad_norm': 0.001510003232397139, 'learning_rate': 3.0535962765702225e-06, 'epoch': 2.54}


 86%|████████▌ | 53500/62523 [7:49:51<1:17:07,  1.95it/s]

{'loss': 0.0792, 'grad_norm': 1.764146089553833, 'learning_rate': 2.893655134910353e-06, 'epoch': 2.57}


 86%|████████▋ | 54000/62523 [7:54:12<1:13:35,  1.93it/s]

{'loss': 0.0804, 'grad_norm': 0.0012682265369221568, 'learning_rate': 2.733713993250484e-06, 'epoch': 2.59}


 87%|████████▋ | 54500/62523 [7:58:32<1:07:48,  1.97it/s]

{'loss': 0.0815, 'grad_norm': 2.0612003803253174, 'learning_rate': 2.5737728515906146e-06, 'epoch': 2.62}


 88%|████████▊ | 55000/62523 [8:02:48<1:09:35,  1.80it/s]

{'loss': 0.0799, 'grad_norm': 1.5895118713378906, 'learning_rate': 2.413831709930746e-06, 'epoch': 2.64}


 89%|████████▉ | 55500/62523 [8:07:11<58:50,  1.99it/s]  

{'loss': 0.0822, 'grad_norm': 0.89723801612854, 'learning_rate': 2.2538905682708763e-06, 'epoch': 2.66}


 90%|████████▉ | 56000/62523 [8:11:24<54:56,  1.98it/s]  

{'loss': 0.0852, 'grad_norm': 2.0537662506103516, 'learning_rate': 2.093949426611007e-06, 'epoch': 2.69}


 90%|█████████ | 56500/62523 [8:15:44<50:12,  2.00it/s]  

{'loss': 0.0847, 'grad_norm': 1.396697759628296, 'learning_rate': 1.934328167234458e-06, 'epoch': 2.71}


 91%|█████████ | 57000/62523 [8:19:58<46:24,  1.98it/s]  

{'loss': 0.08, 'grad_norm': 16.216459274291992, 'learning_rate': 1.7743870255745888e-06, 'epoch': 2.73}


 92%|█████████▏| 57500/62523 [8:24:11<42:02,  1.99it/s]  

{'loss': 0.0773, 'grad_norm': 0.001132825156673789, 'learning_rate': 1.6144458839147196e-06, 'epoch': 2.76}


 93%|█████████▎| 58000/62523 [8:28:21<37:09,  2.03it/s]  

{'loss': 0.0788, 'grad_norm': 2.414628028869629, 'learning_rate': 1.4545047422548502e-06, 'epoch': 2.78}


 94%|█████████▎| 58500/62523 [8:32:31<34:39,  1.93it/s]  

{'loss': 0.0833, 'grad_norm': 0.0012236343463882804, 'learning_rate': 1.294563600594981e-06, 'epoch': 2.81}


 94%|█████████▍| 59000/62523 [8:36:51<30:41,  1.91it/s]

{'loss': 0.0805, 'grad_norm': 1.0917831659317017, 'learning_rate': 1.134622458935112e-06, 'epoch': 2.83}


 95%|█████████▌| 59500/62523 [8:41:15<26:02,  1.93it/s]

{'loss': 0.085, 'grad_norm': 1.7412432432174683, 'learning_rate': 9.746813172752428e-07, 'epoch': 2.85}


 96%|█████████▌| 60000/62523 [8:45:37<21:18,  1.97it/s]

{'loss': 0.0783, 'grad_norm': 1.0135109424591064, 'learning_rate': 8.147401756153736e-07, 'epoch': 2.88}


 97%|█████████▋| 60500/62523 [8:49:59<18:13,  1.85it/s]

{'loss': 0.0803, 'grad_norm': 1.38777756690979, 'learning_rate': 6.547990339555044e-07, 'epoch': 2.9}


 98%|█████████▊| 61000/62523 [8:54:18<13:11,  1.92it/s]

{'loss': 0.0815, 'grad_norm': 1.5566856861114502, 'learning_rate': 4.948578922956353e-07, 'epoch': 2.93}


 98%|█████████▊| 61500/62523 [8:58:37<08:48,  1.93it/s]

{'loss': 0.0796, 'grad_norm': 2.177053451538086, 'learning_rate': 3.3491675063576604e-07, 'epoch': 2.95}


 99%|█████████▉| 62000/62523 [9:02:51<04:19,  2.02it/s]

{'loss': 0.0864, 'grad_norm': 0.09250642359256744, 'learning_rate': 1.7497560897589686e-07, 'epoch': 2.97}


100%|█████████▉| 62500/62523 [9:07:06<00:11,  1.97it/s]

{'loss': 0.0817, 'grad_norm': 1.2306817770004272, 'learning_rate': 1.5034467316027702e-08, 'epoch': 3.0}


                                                       
100%|██████████| 62523/62523 [9:21:01<00:00,  1.86it/s]

{'eval_loss': 0.09125309437513351, 'eval_runtime': 820.9237, 'eval_samples_per_second': 101.547, 'eval_steps_per_second': 6.348, 'epoch': 3.0}
{'train_runtime': 33661.2577, 'train_samples_per_second': 29.718, 'train_steps_per_second': 1.857, 'train_loss': 0.09934960999019109, 'epoch': 3.0}





TrainOutput(global_step=62523, training_loss=0.09934960999019109, metrics={'train_runtime': 33661.2577, 'train_samples_per_second': 29.718, 'train_steps_per_second': 1.857, 'total_flos': 2.632102289241477e+17, 'train_loss': 0.09934960999019109, 'epoch': 3.0})

In [9]:
predicted = trainer.predict(tokenized_datasets['test']).predictions.argmax(-1)
expected = tokenized_datasets['test']['labels']
cm = confusion_matrix(predicted, expected)
print(cm)
print(f'Accuracy: {cm.trace() / cm.sum()}')

100%|██████████| 5211/5211 [13:35<00:00,  6.39it/s]

[[23677    16     1   470   122     0]
 [   12 28074  1948    16     6   259]
 [    0    69  4964     0     0     0]
 [  180    17     0 10865   520     0]
 [  328    14     0    38  8275     1]
 [   23     4     1     0   787  2675]]
Accuracy: 0.9420359396367649





In [10]:
# Save the model
model.save_pretrained('sentiment_analysis/model/saved_model')
tokenizer.save_pretrained('sentiment_analysis/model/saved_tokenizer')

('sentiment_analysis/model/saved_tokenizer\\tokenizer_config.json',
 'sentiment_analysis/model/saved_tokenizer\\special_tokens_map.json',
 'sentiment_analysis/model/saved_tokenizer\\vocab.txt',
 'sentiment_analysis/model/saved_tokenizer\\added_tokens.json')