In [1]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets.dataset_dict import DatasetDict
from datasets import Dataset

import utils.dataset_processors as dataset_processors
import numpy as np
import pandas as pd
import re

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def get_model_names(model_name):
    return {
        'distillbert': 'distilbert-base-uncased', # 66M
        'xlnet': 'xlnet-base-cased', # 110M
        'bert': 'bert-base-uncased', # 110M
        'roberta': 'roberta-base', # 125M
        'albert': 'albert-base-v2', # 11M
        'electra': 'google/electra-small-discriminator', # 14M
        'big-bird': 'google/bigbird-roberta-base', # 125M
        'longformer': 'allenai/longformer-base-4096' # 149M
    }[model_name]

In [2]:
# Initialize the model name for tokenizer and also the saved model name
model_name = 'distilbert-base-uncased'
saved_model = 'distillbert-finetuned-segmented'

# Add boolean for sentence segmentation:
segment_sentences = False

# Add filename that will be saved later:

In [3]:
# Load the dataset
datafile = "data/essays/essays.csv"
dataset = dataset_processors.load_essays_df(datafile)

# Split the dataset (6:2:2)
train_data, temp_data = train_test_split(dataset, train_size=0.6, random_state=42)
validation_data, test_data = train_test_split(temp_data, train_size=0.5, random_state=42)

In [None]:
# Get the Big 5 labels
column_names = list(train_data.columns)
labels = [label for label in column_names if label not in ['user','text','token_len']]

# Forward and backward mapping
id2label = {idx:label for idx,label in enumerate(labels)}
label2id = {label:idx for idx,label in enumerate(labels)}

labels

['EXT', 'NEU', 'AGR', 'CON', 'OPN']

In [5]:
# Import the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    problem_type = "multi_label_classification",
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Convert from essay to sentences
def split_text_with_labels(row):
    
    # Split sentences
    sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s", row['text'])

    return [{
        'text': sentence,
        'EXT': row['EXT'],
        'NEU': row['NEU'],
        'AGR': row['AGR'],
        'CON': row['CON'],
        'OPN': row['OPN']
    }
        for sentence in sentences       
    ]

def transform_dataframe(old_dataframe):

    # Begin the split
    split_data = []
    
    for _, row in old_dataframe.iterrows():
        split_data.extend(split_text_with_labels(row))
        
    return pd.DataFrame(split_data)

if segment_sentences:
    train_data = transform_dataframe(train_data)
    test_data = transform_dataframe(test_data)
    validation_data = transform_dataframe(validation_data)

In [7]:
print(train_data.head(20))

                 user                                               text  \
1244  2000_691500.txt  I am tired now. I don't know what I should tal...   
579   1998_894233.txt  I never thought that college would be this ove...   
1660  2002_651875.txt       Well, my first psychology writing assignm...   
1456  2000_940609.txt  I swear that office space is the funniest movi...   
892   1999_821141.txt  Why is'nt my roomate quieter doesn't he unders...   
2124     2003_349.txt  As I look at my clock in the lower right hand ...   
780   1999_633874.txt  We went to Barton Springs today as a sorority,...   
2052     2003_223.txt  I can overhear the sound of my roommate watchi...   
159   1997_936626.txt  It is now 12:32 and so I cannot wait until 12:...   
1047  2000_075257.txt  I'm so excited that my computer is doing what ...   
2127     2003_357.txt  Wow, that clock starts right off the bat. I al...   
1444  2000_933743.txt  Last night around 11pm or so a friend of mine ...   
2017     200

In [8]:
# Convert to DatasetDict
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)
valid_dataset = Dataset.from_dict(validation_data)

full_dataset_dict = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
    "valid": valid_dataset
})

print(full_dataset_dict)

DatasetDict({
    train: Dataset({
        features: ['user', 'text', 'token_len', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 1480
    })
    test: Dataset({
        features: ['user', 'text', 'token_len', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 494
    })
    valid: Dataset({
        features: ['user', 'text', 'token_len', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 493
    })
})


In [9]:
def preprocess_text(row):

    # Extract the text
    essays = row['text']

    # Clean up
    essays = [dataset_processors.preprocess_text(essay) for essay in essays]

    # Encode them using the tokenizer
    encoded_essay = tokenizer(essays, truncation=True)
    
    # Add the labels
    labels_batch = {key: row[key] for key in row.keys() if key in labels}

    # Create numpy array of batch and labels
    labels_matrix = np.zeros((len(essays), len(labels)))

    # Fill the array
    for idx, label in enumerate(labels):
        labels_matrix[:, idx] = labels_batch[label]

    # Return the encoding
    encoded_essay["labels"] = labels_matrix.tolist()
    
    return encoded_essay

# Perform the preprocessing
full_dataset_dict = full_dataset_dict.map(
    preprocess_text, batched = True, 
    remove_columns = full_dataset_dict['train'].column_names
)

Map:   0%|          | 0/1480 [00:00<?, ? examples/s]

Map: 100%|██████████| 1480/1480 [00:04<00:00, 327.92 examples/s]
Map: 100%|██████████| 494/494 [00:01<00:00, 338.18 examples/s]
Map: 100%|██████████| 493/493 [00:01<00:00, 332.83 examples/s]


In [10]:
#print(full_dataset_dict['train'][0].keys())
print(full_dataset_dict['train'][0])
#print(full_dataset_dict['train'][0]['labels'])
#tokenizer.decode(full_dataset_dict['train'][5]['input_ids'])
#[id2label[idx] for idx, label in enumerate(full_dataset_dict['train'][5]['labels']) if label == 1.0]

{'input_ids': [101, 1045, 2572, 5458, 2085, 1012, 1045, 2123, 1005, 1056, 2113, 2054, 1045, 2323, 2831, 2055, 1012, 1045, 2066, 2023, 8775, 1012, 4687, 2043, 2009, 1005, 1055, 2349, 1029, 19031, 3775, 24471, 3240, 2003, 1996, 2087, 3376, 2711, 1045, 2113, 1012, 1045, 2293, 2017, 1012, 1045, 2812, 1045, 2293, 2014, 2007, 2035, 1997, 2026, 2540, 1010, 2568, 1010, 2303, 1010, 1998, 3969, 1012, 1045, 2066, 6825, 1999, 2152, 2082, 1012, 2009, 2001, 2200, 5875, 1998, 3167, 1012, 2008, 9891, 2111, 1999, 2030, 2061, 1045, 2228, 1012, 6825, 2003, 1996, 2833, 1997, 1996, 2712, 1012, 2026, 18328, 1005, 1055, 1037, 6881, 2080, 1012, 2002, 11651, 1037, 6045, 2000, 3422, 16608, 1051, 1005, 9848, 1012, 2026, 2060, 18328, 2288, 2010, 3274, 2013, 12418, 2651, 1012, 1045, 2066, 19031, 3775, 1012, 1045, 2089, 2025, 2022, 1037, 6047, 2158, 1010, 2021, 1045, 2113, 2054, 2293, 2003, 1012, 2619, 1005, 1055, 2183, 2000, 2022, 3374, 1010, 2066, 2035, 2017, 12566, 1998, 2035, 2017, 11754, 2063, 1998, 2035, 2017

In [11]:
full_dataset_dict.set_format("torch", device="cuda:0")
model = model.to("cuda:0")

In [12]:
from transformers import TrainingArguments, Trainer

batch_size = 16
learning_rate = 2e-5
epochs = 10
metric_name = "accuracy"

print(learning_rate)

2e-05


In [13]:
args = TrainingArguments(
    saved_model,
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate = learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs = epochs,
    weight_decay = 0.01,
    metric_for_best_model = metric_name
)



In [14]:
print(labels)

['EXT', 'NEU', 'AGR', 'CON', 'OPN']


In [15]:
from sklearn.metrics import accuracy_score
import torch

def multi_label_metrics(pred_logits, gold_labels):

    # Our threshold
    threshold = 0.5

    # Apply sigmoid to logits
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(pred_logits))

    # Convert predictions to integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1

    # Compute metrics
    y_true = gold_labels

    metrics = {
        f"{id2label[i]} - accuracy": accuracy_score(y_true[:, i], y_pred[:, i]) 
        for i in range(len(labels))
    }

    overall_accuracy = accuracy_score(y_true, y_pred)

    # Store and return as dictionary
    metrics['accuracy'] = overall_accuracy
    
    return metrics

def compute_metrics(p):

    # Get the type of predictions
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions

    # Compute the results
    results = multi_label_metrics(preds, p.label_ids)

    return results

In [16]:
# forward pass (testing)
outputs = model(
    input_ids=full_dataset_dict['train']['input_ids'][0].unsqueeze(0), 
    attention_mask=full_dataset_dict['train']['attention_mask'][0].unsqueeze(0),
    labels=full_dataset_dict['train'][0]['labels'].unsqueeze(0)
)

outputs

SequenceClassifierOutput(loss=tensor(0.7369, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), logits=tensor([[-0.0444, -0.0161, -0.0475,  0.1298, -0.2149]], device='cuda:0',
       grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [17]:
# Initialize the trainer before training
trainer = Trainer(
    model,
    args,
    train_dataset=full_dataset_dict["train"],
    eval_dataset=full_dataset_dict["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [18]:
trainer.train()

  1%|          | 501/45970 [00:47<1:03:49, 11.87it/s]

{'loss': 0.691, 'grad_norm': 0.5171191096305847, 'learning_rate': 1.9782466826190996e-05, 'epoch': 0.11}


  2%|▏         | 1002/45970 [01:37<1:04:12, 11.67it/s]

{'loss': 0.6879, 'grad_norm': 0.6280295848846436, 'learning_rate': 1.956493365238199e-05, 'epoch': 0.22}


  3%|▎         | 1502/45970 [02:26<1:00:38, 12.22it/s]

{'loss': 0.6882, 'grad_norm': 0.5294129848480225, 'learning_rate': 1.9347400478572985e-05, 'epoch': 0.33}


  4%|▍         | 2001/45970 [03:21<57:35, 12.73it/s]  

{'loss': 0.6866, 'grad_norm': 0.4595237374305725, 'learning_rate': 1.9129867304763977e-05, 'epoch': 0.44}


  5%|▌         | 2501/45970 [04:03<1:01:11, 11.84it/s]

{'loss': 0.6862, 'grad_norm': 0.4855089485645294, 'learning_rate': 1.891233413095497e-05, 'epoch': 0.54}


  7%|▋         | 3002/45970 [04:50<1:03:28, 11.28it/s]

{'loss': 0.6854, 'grad_norm': 0.5492913126945496, 'learning_rate': 1.8694800957145966e-05, 'epoch': 0.65}


  8%|▊         | 3502/45970 [05:36<58:12, 12.16it/s]  

{'loss': 0.6852, 'grad_norm': 0.5762941241264343, 'learning_rate': 1.847726778333696e-05, 'epoch': 0.76}


  9%|▊         | 4002/45970 [06:22<1:00:13, 11.62it/s]

{'loss': 0.6844, 'grad_norm': 0.5434798002243042, 'learning_rate': 1.8259734609527955e-05, 'epoch': 0.87}


 10%|▉         | 4501/45970 [07:09<53:20, 12.96it/s]  

{'loss': 0.6844, 'grad_norm': 0.6064297556877136, 'learning_rate': 1.804220143571895e-05, 'epoch': 0.98}


                                                      
 10%|█         | 4598/45970 [08:01<73:25:33,  6.39s/it]

{'eval_loss': 0.6906459331512451, 'eval_EXT - accuracy': 0.5233914705997113, 'eval_NEU - accuracy': 0.526161769870069, 'eval_AGR - accuracy': 0.5310780756174646, 'eval_CON - accuracy': 0.5368137656560927, 'eval_OPN - accuracy': 0.5519138475945218, 'eval_accuracy': 0.042061726949939524, 'eval_runtime': 42.041, 'eval_samples_per_second': 609.619, 'eval_steps_per_second': 38.106, 'epoch': 1.0}


 11%|█         | 5001/45970 [08:42<1:13:34,  9.28it/s] 

{'loss': 0.676, 'grad_norm': 0.7779818773269653, 'learning_rate': 1.7824668261909944e-05, 'epoch': 1.09}


 12%|█▏        | 5501/45970 [09:31<1:00:29, 11.15it/s]

{'loss': 0.6711, 'grad_norm': 1.1152278184890747, 'learning_rate': 1.760713508810094e-05, 'epoch': 1.2}


 13%|█▎        | 6002/45970 [10:24<58:35, 11.37it/s]  

{'loss': 0.6707, 'grad_norm': 0.8384998440742493, 'learning_rate': 1.738960191429193e-05, 'epoch': 1.31}


 14%|█▍        | 6500/45970 [11:10<52:46, 12.47it/s]  

{'loss': 0.6708, 'grad_norm': 0.8987204432487488, 'learning_rate': 1.7172068740482925e-05, 'epoch': 1.41}


 15%|█▌        | 7002/45970 [11:53<54:45, 11.86it/s]  

{'loss': 0.6701, 'grad_norm': 0.9368668794631958, 'learning_rate': 1.695453556667392e-05, 'epoch': 1.52}


 16%|█▋        | 7502/45970 [12:38<52:11, 12.28it/s]  

{'loss': 0.6684, 'grad_norm': 1.134456992149353, 'learning_rate': 1.6737002392864914e-05, 'epoch': 1.63}


 17%|█▋        | 8001/45970 [13:25<58:21, 10.84it/s]  

{'loss': 0.6694, 'grad_norm': 1.135223388671875, 'learning_rate': 1.651946921905591e-05, 'epoch': 1.74}


 18%|█▊        | 8502/45970 [14:13<54:17, 11.50it/s]  

{'loss': 0.6685, 'grad_norm': 1.0298538208007812, 'learning_rate': 1.6301936045246903e-05, 'epoch': 1.85}


 20%|█▉        | 9001/45970 [15:00<58:40, 10.50it/s]  

{'loss': 0.6677, 'grad_norm': 1.0304205417633057, 'learning_rate': 1.6084402871437898e-05, 'epoch': 1.96}


                                                      
 20%|██        | 9195/45970 [16:01<65:06:16,  6.37s/it]

{'eval_loss': 0.7057145237922668, 'eval_EXT - accuracy': 0.5218697569159936, 'eval_NEU - accuracy': 0.5187092746498108, 'eval_AGR - accuracy': 0.5329119356978423, 'eval_CON - accuracy': 0.5318974599086972, 'eval_OPN - accuracy': 0.5414959616059932, 'eval_accuracy': 0.047095087596082566, 'eval_runtime': 41.9021, 'eval_samples_per_second': 611.64, 'eval_steps_per_second': 38.232, 'epoch': 2.0}


 21%|██        | 9502/45970 [16:30<1:00:34, 10.03it/s] 

{'loss': 0.6482, 'grad_norm': 1.8893413543701172, 'learning_rate': 1.586686969762889e-05, 'epoch': 2.07}


 22%|██▏       | 10001/45970 [17:26<1:41:28,  5.91it/s]

{'loss': 0.6333, 'grad_norm': 1.7941211462020874, 'learning_rate': 1.5649336523819884e-05, 'epoch': 2.18}


 23%|██▎       | 10501/45970 [18:17<49:32, 11.93it/s]  

{'loss': 0.6323, 'grad_norm': 2.015155076980591, 'learning_rate': 1.543180335001088e-05, 'epoch': 2.28}


 24%|██▍       | 11001/45970 [19:01<48:10, 12.10it/s]  

{'loss': 0.6322, 'grad_norm': 2.5702438354492188, 'learning_rate': 1.5214270176201873e-05, 'epoch': 2.39}


 25%|██▌       | 11502/45970 [19:43<45:58, 12.49it/s]  

{'loss': 0.6328, 'grad_norm': 2.213710308074951, 'learning_rate': 1.4996737002392868e-05, 'epoch': 2.5}


 26%|██▌       | 12002/45970 [20:29<48:57, 11.56it/s]  

{'loss': 0.6334, 'grad_norm': 2.140180826187134, 'learning_rate': 1.477920382858386e-05, 'epoch': 2.61}


 27%|██▋       | 12502/45970 [21:16<48:09, 11.58it/s]  

{'loss': 0.635, 'grad_norm': 2.2775559425354004, 'learning_rate': 1.4561670654774855e-05, 'epoch': 2.72}


 28%|██▊       | 13001/45970 [22:05<42:30, 12.93it/s]  

{'loss': 0.6317, 'grad_norm': 2.0462839603424072, 'learning_rate': 1.434413748096585e-05, 'epoch': 2.83}


 29%|██▉       | 13503/45970 [22:50<42:09, 12.84it/s]  

{'loss': 0.633, 'grad_norm': 2.5613489151000977, 'learning_rate': 1.4126604307156841e-05, 'epoch': 2.94}


                                                       
 30%|███       | 13792/45970 [23:56<64:01:45,  7.16s/it]

{'eval_loss': 0.7324612736701965, 'eval_EXT - accuracy': 0.5236255803972063, 'eval_NEU - accuracy': 0.5227671778063911, 'eval_AGR - accuracy': 0.5213235007218385, 'eval_CON - accuracy': 0.525303367279254, 'eval_OPN - accuracy': 0.5423153458972259, 'eval_accuracy': 0.04857778298021772, 'eval_runtime': 40.1007, 'eval_samples_per_second': 639.116, 'eval_steps_per_second': 39.949, 'epoch': 3.0}


 30%|███       | 14001/45970 [24:17<45:19, 11.76it/s]   

{'loss': 0.6109, 'grad_norm': 2.302736282348633, 'learning_rate': 1.3909071133347836e-05, 'epoch': 3.05}


 32%|███▏      | 14500/45970 [25:07<54:46,  9.57it/s]  

{'loss': 0.579, 'grad_norm': 2.9896769523620605, 'learning_rate': 1.369153795953883e-05, 'epoch': 3.15}


 33%|███▎      | 15001/45970 [25:54<42:02, 12.28it/s]  

{'loss': 0.5843, 'grad_norm': 2.8129184246063232, 'learning_rate': 1.3474004785729825e-05, 'epoch': 3.26}


 34%|███▎      | 15501/45970 [26:41<40:09, 12.64it/s]  

{'loss': 0.5833, 'grad_norm': 4.217611312866211, 'learning_rate': 1.3256471611920818e-05, 'epoch': 3.37}


 35%|███▍      | 16002/45970 [27:28<41:35, 12.01it/s]  

{'loss': 0.5818, 'grad_norm': 3.756267547607422, 'learning_rate': 1.3038938438111812e-05, 'epoch': 3.48}


 36%|███▌      | 16502/45970 [28:12<45:46, 10.73it/s]  

{'loss': 0.5892, 'grad_norm': 4.2466630935668945, 'learning_rate': 1.2821405264302807e-05, 'epoch': 3.59}


 37%|███▋      | 17003/45970 [28:59<39:30, 12.22it/s]  

{'loss': 0.5843, 'grad_norm': 3.655155897140503, 'learning_rate': 1.26038720904938e-05, 'epoch': 3.7}


 38%|███▊      | 17503/45970 [29:41<36:05, 13.14it/s]  

{'loss': 0.5847, 'grad_norm': 3.5651936531066895, 'learning_rate': 1.2386338916684794e-05, 'epoch': 3.81}


 39%|███▉      | 18001/45970 [30:23<1:54:59,  4.05it/s]

{'loss': 0.5878, 'grad_norm': 2.6874818801879883, 'learning_rate': 1.2168805742875789e-05, 'epoch': 3.92}


                                                       
 40%|████      | 18389/45970 [31:38<43:57:06,  5.74s/it]

{'eval_loss': 0.7814905047416687, 'eval_EXT - accuracy': 0.5194896406414609, 'eval_NEU - accuracy': 0.5275274103554567, 'eval_AGR - accuracy': 0.5180459635569082, 'eval_CON - accuracy': 0.5177728354598307, 'eval_OPN - accuracy': 0.539623083226033, 'eval_accuracy': 0.04627570330484997, 'eval_runtime': 37.6736, 'eval_samples_per_second': 680.29, 'eval_steps_per_second': 42.523, 'epoch': 4.0}


 40%|████      | 18501/45970 [31:47<38:46, 11.81it/s]   

{'loss': 0.5714, 'grad_norm': 3.0686984062194824, 'learning_rate': 1.1951272569066784e-05, 'epoch': 4.02}


 41%|████▏     | 19002/45970 [32:33<38:17, 11.74it/s]  

{'loss': 0.5274, 'grad_norm': 4.508354187011719, 'learning_rate': 1.1733739395257777e-05, 'epoch': 4.13}


 42%|████▏     | 19502/45970 [33:13<37:46, 11.68it/s]

{'loss': 0.5297, 'grad_norm': 5.109891414642334, 'learning_rate': 1.1516206221448771e-05, 'epoch': 4.24}


 44%|████▎     | 20003/45970 [34:00<30:22, 14.25it/s]  

{'loss': 0.5362, 'grad_norm': 4.097936153411865, 'learning_rate': 1.1298673047639766e-05, 'epoch': 4.35}


 45%|████▍     | 20501/45970 [34:46<36:58, 11.48it/s]  

{'loss': 0.5276, 'grad_norm': 4.229053974151611, 'learning_rate': 1.108113987383076e-05, 'epoch': 4.46}


 46%|████▌     | 21001/45970 [35:29<34:45, 11.97it/s]  

{'loss': 0.5312, 'grad_norm': 4.897073268890381, 'learning_rate': 1.0863606700021753e-05, 'epoch': 4.57}


 47%|████▋     | 21501/45970 [36:14<30:59, 13.16it/s]  

{'loss': 0.5329, 'grad_norm': 6.195981025695801, 'learning_rate': 1.0646073526212748e-05, 'epoch': 4.68}


 48%|████▊     | 22001/45970 [36:58<32:29, 12.29it/s]  

{'loss': 0.539, 'grad_norm': 5.244258403778076, 'learning_rate': 1.0428540352403743e-05, 'epoch': 4.79}


 49%|████▉     | 22503/45970 [37:41<30:15, 12.93it/s]  

{'loss': 0.5341, 'grad_norm': 4.684493064880371, 'learning_rate': 1.0211007178594735e-05, 'epoch': 4.89}


                                                       
 50%|█████     | 22986/45970 [39:03<43:21:57,  6.79s/it]

{'eval_loss': 0.8470182418823242, 'eval_EXT - accuracy': 0.5144172616957353, 'eval_NEU - accuracy': 0.5153537008857154, 'eval_AGR - accuracy': 0.5196457138397909, 'eval_CON - accuracy': 0.5176557805610832, 'eval_OPN - accuracy': 0.5327558624995122, 'eval_accuracy': 0.044871044519879826, 'eval_runtime': 37.7659, 'eval_samples_per_second': 678.628, 'eval_steps_per_second': 42.419, 'epoch': 5.0}


 50%|█████     | 23002/45970 [39:04<2:35:00,  2.47it/s] 

{'loss': 0.5335, 'grad_norm': 4.728116989135742, 'learning_rate': 9.99347400478573e-06, 'epoch': 5.0}


 51%|█████     | 23500/45970 [39:47<48:16,  7.76it/s]  

{'loss': 0.4829, 'grad_norm': 4.388766765594482, 'learning_rate': 9.775940830976725e-06, 'epoch': 5.11}


 52%|█████▏    | 24002/45970 [40:34<30:03, 12.18it/s]  

{'loss': 0.4829, 'grad_norm': 4.932460784912109, 'learning_rate': 9.55840765716772e-06, 'epoch': 5.22}


 53%|█████▎    | 24503/45970 [41:21<26:06, 13.70it/s]  

{'loss': 0.4803, 'grad_norm': 4.490340709686279, 'learning_rate': 9.340874483358712e-06, 'epoch': 5.33}


 54%|█████▍    | 25003/45970 [42:09<27:34, 12.67it/s]  

{'loss': 0.483, 'grad_norm': 4.659549713134766, 'learning_rate': 9.123341309549707e-06, 'epoch': 5.44}


 55%|█████▌    | 25502/45970 [42:53<30:58, 11.01it/s]  

{'loss': 0.4851, 'grad_norm': 7.018158912658691, 'learning_rate': 8.905808135740701e-06, 'epoch': 5.55}


 57%|█████▋    | 26002/45970 [43:37<24:54, 13.36it/s]  

{'loss': 0.4858, 'grad_norm': 5.047472953796387, 'learning_rate': 8.688274961931696e-06, 'epoch': 5.66}


 58%|█████▊    | 26501/45970 [44:19<28:27, 11.40it/s]  

{'loss': 0.4829, 'grad_norm': 5.007820129394531, 'learning_rate': 8.470741788122689e-06, 'epoch': 5.76}


 59%|█████▊    | 27001/45970 [45:04<27:18, 11.58it/s]  

{'loss': 0.4846, 'grad_norm': 5.993010520935059, 'learning_rate': 8.253208614313684e-06, 'epoch': 5.87}


 60%|█████▉    | 27501/45970 [45:43<24:27, 12.59it/s]

{'loss': 0.4835, 'grad_norm': 5.1200480461120605, 'learning_rate': 8.035675440504678e-06, 'epoch': 5.98}


                                                     
 60%|██████    | 27583/45970 [46:28<34:13:36,  6.70s/it]

{'eval_loss': 0.9181484580039978, 'eval_EXT - accuracy': 0.5147684263919778, 'eval_NEU - accuracy': 0.5185532014514808, 'eval_AGR - accuracy': 0.5131686761090951, 'eval_CON - accuracy': 0.5137539506028327, 'eval_OPN - accuracy': 0.5286979593429318, 'eval_accuracy': 0.0427640563424246, 'eval_runtime': 37.443, 'eval_samples_per_second': 684.48, 'eval_steps_per_second': 42.785, 'epoch': 6.0}


 61%|██████    | 28001/45970 [47:06<23:51, 12.55it/s]   

{'loss': 0.4435, 'grad_norm': 5.7683916091918945, 'learning_rate': 7.818142266695671e-06, 'epoch': 6.09}


 62%|██████▏   | 28501/45970 [47:49<25:03, 11.62it/s]  

{'loss': 0.4359, 'grad_norm': 5.255147457122803, 'learning_rate': 7.600609092886666e-06, 'epoch': 6.2}


 63%|██████▎   | 29003/45970 [48:31<21:27, 13.18it/s]  

{'loss': 0.4372, 'grad_norm': 6.070585250854492, 'learning_rate': 7.38307591907766e-06, 'epoch': 6.31}


 64%|██████▍   | 29501/45970 [49:16<23:10, 11.84it/s]  

{'loss': 0.4381, 'grad_norm': 8.117278099060059, 'learning_rate': 7.165542745268654e-06, 'epoch': 6.42}


 65%|██████▌   | 30003/45970 [50:00<1:04:05,  4.15it/s]

{'loss': 0.4384, 'grad_norm': 6.087441921234131, 'learning_rate': 6.948009571459649e-06, 'epoch': 6.53}


 66%|██████▋   | 30501/45970 [50:47<18:39, 13.81it/s]  

{'loss': 0.4426, 'grad_norm': 7.768171310424805, 'learning_rate': 6.7304763976506424e-06, 'epoch': 6.63}


 67%|██████▋   | 31003/45970 [51:34<20:16, 12.31it/s]  

{'loss': 0.439, 'grad_norm': 6.367941856384277, 'learning_rate': 6.512943223841637e-06, 'epoch': 6.74}


 69%|██████▊   | 31501/45970 [52:17<18:39, 12.93it/s]  

{'loss': 0.4389, 'grad_norm': 5.343508720397949, 'learning_rate': 6.295410050032631e-06, 'epoch': 6.85}


 70%|██████▉   | 32001/45970 [52:58<18:11, 12.80it/s]  

{'loss': 0.4414, 'grad_norm': 6.30684757232666, 'learning_rate': 6.0778768762236254e-06, 'epoch': 6.96}


                                                       
 70%|███████   | 32180/45970 [53:53<26:05:42,  6.81s/it]

{'eval_loss': 0.9996153116226196, 'eval_EXT - accuracy': 0.5148854812907253, 'eval_NEU - accuracy': 0.5233134340005463, 'eval_AGR - accuracy': 0.5139880604003277, 'eval_CON - accuracy': 0.51106168793164, 'eval_OPN - accuracy': 0.5364235826602677, 'eval_accuracy': 0.043700495532404696, 'eval_runtime': 37.6126, 'eval_samples_per_second': 681.394, 'eval_steps_per_second': 42.592, 'epoch': 7.0}


 71%|███████   | 32502/45970 [54:19<19:25, 11.55it/s]   

{'loss': 0.4115, 'grad_norm': 6.220955848693848, 'learning_rate': 5.860343702414619e-06, 'epoch': 7.07}


 72%|███████▏  | 33002/45970 [55:02<18:08, 11.92it/s]  

{'loss': 0.4001, 'grad_norm': 5.504371166229248, 'learning_rate': 5.642810528605612e-06, 'epoch': 7.18}


 73%|███████▎  | 33503/45970 [55:45<15:44, 13.20it/s]  

{'loss': 0.4011, 'grad_norm': 9.580866813659668, 'learning_rate': 5.425277354796607e-06, 'epoch': 7.29}


 74%|███████▍  | 34003/45970 [56:30<14:31, 13.74it/s]  

{'loss': 0.4037, 'grad_norm': 8.232292175292969, 'learning_rate': 5.2077441809876005e-06, 'epoch': 7.4}


 75%|███████▌  | 34502/45970 [57:13<13:56, 13.71it/s]  

{'loss': 0.4041, 'grad_norm': 5.037126541137695, 'learning_rate': 4.990211007178595e-06, 'epoch': 7.5}


 76%|███████▌  | 35002/45970 [57:56<16:55, 10.80it/s]

{'loss': 0.4007, 'grad_norm': 5.9925031661987305, 'learning_rate': 4.77267783336959e-06, 'epoch': 7.61}


 77%|███████▋  | 35503/45970 [58:42<12:39, 13.79it/s]

{'loss': 0.3991, 'grad_norm': 5.4212470054626465, 'learning_rate': 4.5551446595605835e-06, 'epoch': 7.72}


 78%|███████▊  | 36003/45970 [59:30<33:25,  4.97it/s]

{'loss': 0.4009, 'grad_norm': 6.063460350036621, 'learning_rate': 4.337611485751577e-06, 'epoch': 7.83}


 79%|███████▉  | 36503/45970 [1:00:14<12:11, 12.94it/s]  

{'loss': 0.4033, 'grad_norm': 6.339635372161865, 'learning_rate': 4.120078311942571e-06, 'epoch': 7.94}


                                                       
 80%|████████  | 36777/45970 [1:01:18<14:37:04,  5.72s/it]

{'eval_loss': 1.0621198415756226, 'eval_EXT - accuracy': 0.5151976276873854, 'eval_NEU - accuracy': 0.5179679269577432, 'eval_AGR - accuracy': 0.5131296578095127, 'eval_CON - accuracy': 0.5119981271216201, 'eval_OPN - accuracy': 0.5311561122166296, 'eval_accuracy': 0.04389558703031722, 'eval_runtime': 37.5375, 'eval_samples_per_second': 682.757, 'eval_steps_per_second': 42.677, 'epoch': 8.0}


 80%|████████  | 37002/45970 [1:01:35<12:41, 11.77it/s]   

{'loss': 0.3907, 'grad_norm': 8.881650924682617, 'learning_rate': 3.902545138133566e-06, 'epoch': 8.05}


 82%|████████▏ | 37502/45970 [1:02:21<11:05, 12.73it/s]

{'loss': 0.3731, 'grad_norm': 8.917471885681152, 'learning_rate': 3.68501196432456e-06, 'epoch': 8.16}


 83%|████████▎ | 38002/45970 [1:03:07<24:59,  5.31it/s]

{'loss': 0.3789, 'grad_norm': 7.838942050933838, 'learning_rate': 3.467478790515554e-06, 'epoch': 8.27}


 84%|████████▍ | 38502/45970 [1:03:53<09:26, 13.17it/s]

{'loss': 0.3691, 'grad_norm': 5.01944637298584, 'learning_rate': 3.2499456167065478e-06, 'epoch': 8.38}


 85%|████████▍ | 39002/45970 [1:04:37<12:14,  9.49it/s]

{'loss': 0.3715, 'grad_norm': 6.9027934074401855, 'learning_rate': 3.032412442897542e-06, 'epoch': 8.48}


 86%|████████▌ | 39502/45970 [1:05:24<08:40, 12.43it/s]

{'loss': 0.3727, 'grad_norm': 8.098807334899902, 'learning_rate': 2.814879269088536e-06, 'epoch': 8.59}


 87%|████████▋ | 40003/45970 [1:06:09<07:34, 13.12it/s]

{'loss': 0.3742, 'grad_norm': 7.864850044250488, 'learning_rate': 2.5973460952795303e-06, 'epoch': 8.7}


 88%|████████▊ | 40502/45970 [1:06:55<06:50, 13.33it/s]

{'loss': 0.372, 'grad_norm': 7.484764099121094, 'learning_rate': 2.3798129214705245e-06, 'epoch': 8.81}


 89%|████████▉ | 41003/45970 [1:07:36<06:10, 13.41it/s]

{'loss': 0.374, 'grad_norm': 7.836643695831299, 'learning_rate': 2.1622797476615183e-06, 'epoch': 8.92}


                                                       
 90%|█████████ | 41374/45970 [1:08:44<8:39:44,  6.79s/it]

{'eval_loss': 1.1250104904174805, 'eval_EXT - accuracy': 0.5124663467166101, 'eval_NEU - accuracy': 0.5146903897928128, 'eval_AGR - accuracy': 0.5123492918178626, 'eval_CON - accuracy': 0.5097350657458348, 'eval_OPN - accuracy': 0.5313902220141247, 'eval_accuracy': 0.04233485504701705, 'eval_runtime': 38.0288, 'eval_samples_per_second': 673.937, 'eval_steps_per_second': 42.126, 'epoch': 9.0}


 90%|█████████ | 41502/45970 [1:08:54<05:39, 13.15it/s]  

{'loss': 0.3704, 'grad_norm': 8.627129554748535, 'learning_rate': 1.9447465738525125e-06, 'epoch': 9.03}


 91%|█████████▏| 42001/45970 [1:09:42<12:54,  5.13it/s]

{'loss': 0.3464, 'grad_norm': 6.298225402832031, 'learning_rate': 1.7272134000435066e-06, 'epoch': 9.14}


 92%|█████████▏| 42503/45970 [1:10:24<05:19, 10.86it/s]

{'loss': 0.3505, 'grad_norm': 6.922345161437988, 'learning_rate': 1.5096802262345008e-06, 'epoch': 9.25}


 94%|█████████▎| 43001/45970 [1:11:08<04:32, 10.89it/s]

{'loss': 0.3545, 'grad_norm': 6.366048812866211, 'learning_rate': 1.292147052425495e-06, 'epoch': 9.35}


 95%|█████████▍| 43503/45970 [1:11:49<02:57, 13.86it/s]

{'loss': 0.3515, 'grad_norm': 7.91326379776001, 'learning_rate': 1.074613878616489e-06, 'epoch': 9.46}


 96%|█████████▌| 44002/45970 [1:12:32<02:43, 12.01it/s]

{'loss': 0.3551, 'grad_norm': 9.168742179870605, 'learning_rate': 8.570807048074832e-07, 'epoch': 9.57}


 97%|█████████▋| 44502/45970 [1:13:18<02:13, 10.99it/s]

{'loss': 0.3531, 'grad_norm': 5.689018249511719, 'learning_rate': 6.395475309984774e-07, 'epoch': 9.68}


 98%|█████████▊| 45002/45970 [1:14:02<01:14, 12.96it/s]

{'loss': 0.351, 'grad_norm': 7.872915744781494, 'learning_rate': 4.2201435718947146e-07, 'epoch': 9.79}


 99%|█████████▉| 45503/45970 [1:14:47<00:35, 13.32it/s]

{'loss': 0.3579, 'grad_norm': 5.9566731452941895, 'learning_rate': 2.0448118338046557e-07, 'epoch': 9.9}


                                                       
100%|██████████| 45970/45970 [1:16:09<00:00, 10.06it/s]

{'eval_loss': 1.163111686706543, 'eval_EXT - accuracy': 0.5138710055015803, 'eval_NEU - accuracy': 0.5157438838815405, 'eval_AGR - accuracy': 0.5143782433961528, 'eval_CON - accuracy': 0.5131296578095127, 'eval_OPN - accuracy': 0.5310000390182996, 'eval_accuracy': 0.04210074524952202, 'eval_runtime': 37.6444, 'eval_samples_per_second': 680.818, 'eval_steps_per_second': 42.556, 'epoch': 10.0}
{'train_runtime': 4569.2885, 'train_samples_per_second': 160.938, 'train_steps_per_second': 10.061, 'train_loss': 0.5155778935921822, 'epoch': 10.0}





TrainOutput(global_step=45970, training_loss=0.5155778935921822, metrics={'train_runtime': 4569.2885, 'train_samples_per_second': 160.938, 'train_steps_per_second': 10.061, 'total_flos': 9631412851951230.0, 'train_loss': 0.5155778935921822, 'epoch': 10.0})

In [20]:
trainer.save_model(saved_model)