In [1]:
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets.dataset_dict import DatasetDict
from datasets import Dataset

import utils.dataset_processors as dataset_processors
import numpy as np
import pandas as pd
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Initialize the model name for tokenizer and also the saved model name
model_name = 'distilbert-base-uncased'
saved_model = 'distillbert-finetuned-segmented'

In [3]:
# Load the dataset
datafile = "data/essays/essays.csv"
dataset = dataset_processors.load_essays_df(datafile)

# Split the dataset (6:2:2)
train_data, temp_data = train_test_split(dataset, train_size=0.6, random_state=42)
validation_data, test_data = train_test_split(temp_data, train_size=0.5, random_state=42)

In [4]:
# Get the Big 5 labels
column_names = list(train_data.columns)
labels = [label for label in column_names if label not in ['user','text','token_len']]

# Forward and backward mapping
id2label = {idx:label for idx,label in enumerate(labels)}
label2id = {label:idx for idx,label in enumerate(labels)}

labels

['EXT', 'NEU', 'AGR', 'CON', 'OPN']

In [5]:
# Import the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    problem_type = "multi_label_classification",
    num_labels = len(labels),
    id2label = id2label,
    label2id = label2id
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
# Convert from essay to sentences
def split_text_with_labels(row):
    
    # Split sentences
    sentences = re.split(r"(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s", row['text'])

    return [{
        'text': sentence,
        'EXT': row['EXT'],
        'NEU': row['NEU'],
        'AGR': row['AGR'],
        'CON': row['CON'],
        'OPN': row['OPN']
    }
        for sentence in sentences       
    ]

def transform_dataframe(old_dataframe):

    # Begin the split
    split_data = []
    
    for _, row in old_dataframe.iterrows():
        split_data.extend(split_text_with_labels(row))
        
    return pd.DataFrame(split_data)

train_data = transform_dataframe(train_data)
test_data = transform_dataframe(test_data)
validation_data = transform_dataframe(validation_data)

In [7]:
print(train_data.head(20))

                                                 text  EXT  NEU  AGR  CON  OPN
0                                     I am tired now.    1    0    1    0    1
1              I don't know what I should talk about.    1    0    1    0    1
2                             I like this assignment.    1    0    1    0    1
3                               Wonder when it's due?    1    0    1    0    1
4    Kristi Urey is the most beautiful person I know.    1    0    1    0    1
5                                         I love you.    1    0    1    0    1
6   I mean I love her with all of my heart, mind, ...    1    0    1    0    1
7                   I like psychology in high school.    1    0    1    0    1
8               It was very interesting and personal.    1    0    1    0    1
9                 That draws people in or so I think.    1    0    1    0    1
10                 Psychology is the food of the sea.    1    0    1    0    1
11                            My roommate's a weirdo

In [8]:
# Convert to DatasetDict
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)
valid_dataset = Dataset.from_dict(validation_data)

full_dataset_dict = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
    "valid": valid_dataset
})

print(full_dataset_dict)

DatasetDict({
    train: Dataset({
        features: ['text', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 73537
    })
    test: Dataset({
        features: ['text', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 25629
    })
    valid: Dataset({
        features: ['text', 'EXT', 'NEU', 'AGR', 'CON', 'OPN'],
        num_rows: 23811
    })
})


In [9]:
def preprocess_text(row):

    # Extract the text
    essays = row['text']

    # Clean up
    essays = [dataset_processors.preprocess_text(essay) for essay in essays]

    # Encode them using the tokenizer
    encoded_essay = tokenizer(essays, truncation=True)
    
    # Add the labels
    labels_batch = {key: row[key] for key in row.keys() if key in labels}

    # Create numpy array of batch and labels
    labels_matrix = np.zeros((len(essays), len(labels)))

    # Fill the array
    for idx, label in enumerate(labels):
        labels_matrix[:, idx] = labels_batch[label]

    # Return the encoding
    encoded_essay["labels"] = labels_matrix.tolist()
    
    return encoded_essay

# Perform the preprocessing
full_dataset_dict = full_dataset_dict.map(
    preprocess_text, batched = True, 
    remove_columns = full_dataset_dict['train'].column_names
)

Map: 100%|██████████| 73537/73537 [00:10<00:00, 7095.17 examples/s]
Map: 100%|██████████| 25629/25629 [00:03<00:00, 6979.97 examples/s]
Map: 100%|██████████| 23811/23811 [00:03<00:00, 6663.40 examples/s]


In [10]:
#print(full_dataset_dict['train'][0].keys())
print(full_dataset_dict['train'][0])
#print(full_dataset_dict['train'][0]['labels'])
#tokenizer.decode(full_dataset_dict['train'][5]['input_ids'])
#[id2label[idx] for idx, label in enumerate(full_dataset_dict['train'][5]['labels']) if label == 1.0]

{'input_ids': [101, 1045, 2572, 5458, 2085, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1], 'labels': [1.0, 0.0, 1.0, 0.0, 1.0]}


In [None]:
full_dataset_dict.set_format("torch", device="cuda:0")
model = model.to("cuda:0")

In [12]:
from transformers import TrainingArguments, Trainer

batch_size = 16
learning_rate = 2e-5
epochs = 10
metric_name = "accuracy"

print(learning_rate)

2e-05


In [13]:
args = TrainingArguments(
    saved_model,
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate = learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs = epochs,
    weight_decay = 0.01,
    metric_for_best_model = metric_name
)



In [14]:
print(labels)

['EXT', 'NEU', 'AGR', 'CON', 'OPN']


In [15]:
from sklearn.metrics import accuracy_score
import torch

def multi_label_metrics(pred_logits, gold_labels):

    # Our threshold
    threshold = 0.5

    # Apply sigmoid to logits
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(pred_logits))

    # Convert predictions to integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1

    # Compute metrics
    y_true = gold_labels

    metrics = {
        f"{id2label[i]} - accuracy": accuracy_score(y_true[:, i], y_pred[:, i]) 
        for i in range(len(labels))
    }

    overall_accuracy = accuracy_score(y_true, y_pred)

    # Store and return as dictionary
    metrics['accuracy'] = overall_accuracy
    
    return metrics

def compute_metrics(p):

    # Get the type of predictions
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions

    # Compute the results
    results = multi_label_metrics(preds, p.label_ids)

    return results

In [16]:
# forward pass (testing)
outputs = model(
    input_ids=full_dataset_dict['train']['input_ids'][0].unsqueeze(0), 
    attention_mask=full_dataset_dict['train']['attention_mask'][0].unsqueeze(0),
    labels=full_dataset_dict['train'][0]['labels'].unsqueeze(0)
)

outputs

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [None]:
# Initialize the trainer before training
trainer = Trainer(
    model,
    args,
    train_dataset=full_dataset_dict["train"],
    eval_dataset=full_dataset_dict["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

  1%|          | 502/45970 [00:47<1:03:40, 11.90it/s]

{'loss': 0.691, 'grad_norm': 0.6273068189620972, 'learning_rate': 1.9782466826190996e-05, 'epoch': 0.11}


  2%|▏         | 1001/45970 [01:37<1:04:23, 11.64it/s]

{'loss': 0.6874, 'grad_norm': 0.5319086909294128, 'learning_rate': 1.956493365238199e-05, 'epoch': 0.22}


  3%|▎         | 1502/45970 [02:26<58:59, 12.56it/s]  

{'loss': 0.688, 'grad_norm': 0.527271568775177, 'learning_rate': 1.9347400478572985e-05, 'epoch': 0.33}


  4%|▍         | 2001/45970 [03:23<59:23, 12.34it/s]  

{'loss': 0.6865, 'grad_norm': 0.46152377128601074, 'learning_rate': 1.9129867304763977e-05, 'epoch': 0.44}


  5%|▌         | 2501/45970 [04:07<1:05:03, 11.14it/s]

{'loss': 0.6864, 'grad_norm': 0.44222092628479004, 'learning_rate': 1.891233413095497e-05, 'epoch': 0.54}


  7%|▋         | 3002/45970 [04:55<1:06:00, 10.85it/s]

{'loss': 0.6853, 'grad_norm': 0.5859605669975281, 'learning_rate': 1.8694800957145966e-05, 'epoch': 0.65}


  8%|▊         | 3502/45970 [05:44<1:00:01, 11.79it/s]

{'loss': 0.6854, 'grad_norm': 0.5201916694641113, 'learning_rate': 1.847726778333696e-05, 'epoch': 0.76}


  9%|▊         | 4001/45970 [06:31<1:03:39, 10.99it/s]

{'loss': 0.6842, 'grad_norm': 0.4794982969760895, 'learning_rate': 1.8259734609527955e-05, 'epoch': 0.87}


 10%|▉         | 4502/45970 [07:20<54:39, 12.64it/s]  

{'loss': 0.6848, 'grad_norm': 0.5099848508834839, 'learning_rate': 1.804220143571895e-05, 'epoch': 0.98}


 10%|█         | 4597/45970 [07:44<58:14, 11.84it/s]  
 10%|█         | 4598/45970 [08:13<75:04:26,  6.53s/it]

{'eval_loss': 0.6918413043022156, 'eval_EXT - accuracy': 0.5262007881696515, 'eval_NEU - accuracy': 0.525849623473409, 'eval_AGR - accuracy': 0.5238987084942838, 'eval_CON - accuracy': 0.5392719185297905, 'eval_OPN - accuracy': 0.5479339810371064, 'eval_accuracy': 0.04081314136329939, 'eval_runtime': 43.0002, 'eval_samples_per_second': 596.02, 'eval_steps_per_second': 37.256, 'epoch': 1.0}


 11%|█         | 5001/45970 [08:55<1:15:44,  9.02it/s] 

{'loss': 0.6752, 'grad_norm': 0.8602599501609802, 'learning_rate': 1.7824668261909944e-05, 'epoch': 1.09}


 12%|█▏        | 5501/45970 [09:46<1:01:41, 10.93it/s]

{'loss': 0.6715, 'grad_norm': 1.0399482250213623, 'learning_rate': 1.760713508810094e-05, 'epoch': 1.2}


 13%|█▎        | 6002/45970 [10:44<59:12, 11.25it/s]  

{'loss': 0.6705, 'grad_norm': 0.8623043298721313, 'learning_rate': 1.738960191429193e-05, 'epoch': 1.31}


 14%|█▍        | 6500/45970 [11:31<55:19, 11.89it/s]  

{'loss': 0.6701, 'grad_norm': 0.9025463461875916, 'learning_rate': 1.7172068740482925e-05, 'epoch': 1.41}


 15%|█▌        | 7002/45970 [12:16<54:22, 11.95it/s]  

{'loss': 0.669, 'grad_norm': 0.8689930438995361, 'learning_rate': 1.695453556667392e-05, 'epoch': 1.52}


 16%|█▋        | 7502/45970 [13:02<51:21, 12.48it/s]  

{'loss': 0.6689, 'grad_norm': 1.208105444908142, 'learning_rate': 1.6737002392864914e-05, 'epoch': 1.63}


 17%|█▋        | 8002/45970 [13:49<56:31, 11.19it/s]  

{'loss': 0.6684, 'grad_norm': 1.156531572341919, 'learning_rate': 1.651946921905591e-05, 'epoch': 1.74}


 18%|█▊        | 8501/45970 [14:36<54:28, 11.46it/s]  

{'loss': 0.6682, 'grad_norm': 1.1387372016906738, 'learning_rate': 1.6301936045246903e-05, 'epoch': 1.85}


 20%|█▉        | 9001/45970 [15:23<56:02, 10.99it/s]  

{'loss': 0.6665, 'grad_norm': 1.193719506263733, 'learning_rate': 1.6084402871437898e-05, 'epoch': 1.96}


 20%|██        | 9194/45970 [15:54<50:05, 12.24it/s]  
 20%|██        | 9195/45970 [16:25<66:11:46,  6.48s/it]

{'eval_loss': 0.7019850015640259, 'eval_EXT - accuracy': 0.5236255803972063, 'eval_NEU - accuracy': 0.5257325685746616, 'eval_AGR - accuracy': 0.5290881423387569, 'eval_CON - accuracy': 0.5339264114869874, 'eval_OPN - accuracy': 0.5407936322135082, 'eval_accuracy': 0.04795349018689766, 'eval_runtime': 42.634, 'eval_samples_per_second': 601.14, 'eval_steps_per_second': 37.576, 'epoch': 2.0}


 21%|██        | 9502/45970 [16:52<56:32, 10.75it/s]   

{'loss': 0.6486, 'grad_norm': 2.006427049636841, 'learning_rate': 1.586686969762889e-05, 'epoch': 2.07}


 22%|██▏       | 10002/45970 [17:51<1:31:05,  6.58it/s]

{'loss': 0.6324, 'grad_norm': 1.8180471658706665, 'learning_rate': 1.5649336523819884e-05, 'epoch': 2.18}


 23%|██▎       | 10503/45970 [18:44<48:56, 12.08it/s]  

{'loss': 0.6307, 'grad_norm': 2.0798797607421875, 'learning_rate': 1.543180335001088e-05, 'epoch': 2.28}


 24%|██▍       | 11002/45970 [19:29<49:10, 11.85it/s]  

{'loss': 0.6303, 'grad_norm': 3.033079147338867, 'learning_rate': 1.5214270176201873e-05, 'epoch': 2.39}


 25%|██▌       | 11502/45970 [20:11<46:12, 12.43it/s]  

{'loss': 0.6304, 'grad_norm': 2.4744064807891846, 'learning_rate': 1.4996737002392868e-05, 'epoch': 2.5}


 26%|██▌       | 12001/45970 [20:59<51:00, 11.10it/s]  

{'loss': 0.6306, 'grad_norm': 2.2164857387542725, 'learning_rate': 1.477920382858386e-05, 'epoch': 2.61}


 27%|██▋       | 12501/45970 [21:47<50:34, 11.03it/s]  

{'loss': 0.6328, 'grad_norm': 2.3960671424865723, 'learning_rate': 1.4561670654774855e-05, 'epoch': 2.72}


 28%|██▊       | 13001/45970 [22:39<42:56, 12.80it/s]  

{'loss': 0.6309, 'grad_norm': 2.0830626487731934, 'learning_rate': 1.434413748096585e-05, 'epoch': 2.83}


 29%|██▉       | 13502/45970 [23:25<44:01, 12.29it/s]  

{'loss': 0.6305, 'grad_norm': 2.428934335708618, 'learning_rate': 1.4126604307156841e-05, 'epoch': 2.94}


 30%|███       | 13791/45970 [24:05<47:53, 11.20it/s]  
 30%|███       | 13792/45970 [24:33<56:50:51,  6.36s/it]

{'eval_loss': 0.7336689829826355, 'eval_EXT - accuracy': 0.5228452144055562, 'eval_NEU - accuracy': 0.5217136837176636, 'eval_AGR - accuracy': 0.5195286589410434, 'eval_CON - accuracy': 0.525303367279254, 'eval_OPN - accuracy': 0.5463732490538062, 'eval_accuracy': 0.04697803269733505, 'eval_runtime': 41.8309, 'eval_samples_per_second': 612.68, 'eval_steps_per_second': 38.297, 'epoch': 3.0}


 30%|███       | 14001/45970 [24:55<45:40, 11.66it/s]   

{'loss': 0.6093, 'grad_norm': 2.4835166931152344, 'learning_rate': 1.3909071133347836e-05, 'epoch': 3.05}


 32%|███▏      | 14500/45970 [25:47<54:38,  9.60it/s]  

{'loss': 0.5762, 'grad_norm': 2.3108561038970947, 'learning_rate': 1.369153795953883e-05, 'epoch': 3.15}


 33%|███▎      | 15001/45970 [26:37<42:29, 12.15it/s]  

{'loss': 0.5837, 'grad_norm': 2.8878633975982666, 'learning_rate': 1.3474004785729825e-05, 'epoch': 3.26}


 34%|███▎      | 15501/45970 [27:25<40:10, 12.64it/s]  

{'loss': 0.5797, 'grad_norm': 4.083866596221924, 'learning_rate': 1.3256471611920818e-05, 'epoch': 3.37}


 35%|███▍      | 16002/45970 [28:13<41:37, 12.00it/s]  

{'loss': 0.5797, 'grad_norm': 3.1972618103027344, 'learning_rate': 1.3038938438111812e-05, 'epoch': 3.48}


 36%|███▌      | 16502/45970 [28:58<46:29, 10.57it/s]  

{'loss': 0.585, 'grad_norm': 3.7651357650756836, 'learning_rate': 1.2821405264302807e-05, 'epoch': 3.59}


 37%|███▋      | 17003/45970 [29:47<40:01, 12.06it/s]  

{'loss': 0.5829, 'grad_norm': 4.923808574676514, 'learning_rate': 1.26038720904938e-05, 'epoch': 3.7}


 38%|███▊      | 17501/45970 [30:30<38:58, 12.17it/s]  

{'loss': 0.5836, 'grad_norm': 3.0999486446380615, 'learning_rate': 1.2386338916684794e-05, 'epoch': 3.81}


 39%|███▉      | 18001/45970 [31:16<2:25:09,  3.21it/s]

{'loss': 0.5858, 'grad_norm': 3.529186725616455, 'learning_rate': 1.2168805742875789e-05, 'epoch': 3.92}


 40%|████      | 18388/45970 [32:15<46:51,  9.81it/s]  
 40%|████      | 18389/45970 [32:39<48:55:32,  6.39s/it]

{'eval_loss': 0.782143235206604, 'eval_EXT - accuracy': 0.5203090249326935, 'eval_NEU - accuracy': 0.5275664286550392, 'eval_AGR - accuracy': 0.5170314877677631, 'eval_CON - accuracy': 0.5219868118147412, 'eval_OPN - accuracy': 0.5438760778805259, 'eval_accuracy': 0.04615864840610246, 'eval_runtime': 41.9444, 'eval_samples_per_second': 611.023, 'eval_steps_per_second': 38.193, 'epoch': 4.0}


 40%|████      | 18501/45970 [32:49<40:22, 11.34it/s]   

{'loss': 0.5696, 'grad_norm': 3.552867889404297, 'learning_rate': 1.1951272569066784e-05, 'epoch': 4.02}


 41%|████▏     | 19001/45970 [33:39<41:31, 10.83it/s]  

{'loss': 0.5243, 'grad_norm': 4.614861488342285, 'learning_rate': 1.1733739395257777e-05, 'epoch': 4.13}


 42%|████▏     | 19502/45970 [34:21<39:07, 11.28it/s]  

{'loss': 0.5305, 'grad_norm': 4.2878241539001465, 'learning_rate': 1.1516206221448771e-05, 'epoch': 4.24}


 44%|████▎     | 20003/45970 [35:13<31:57, 13.54it/s]  

{'loss': 0.5355, 'grad_norm': 4.105475902557373, 'learning_rate': 1.1298673047639766e-05, 'epoch': 4.35}


 45%|████▍     | 20502/45970 [36:04<37:43, 11.25it/s]  

{'loss': 0.5251, 'grad_norm': 3.9235007762908936, 'learning_rate': 1.108113987383076e-05, 'epoch': 4.46}


 46%|████▌     | 21001/45970 [36:51<37:12, 11.18it/s]  

{'loss': 0.5319, 'grad_norm': 4.30959939956665, 'learning_rate': 1.0863606700021753e-05, 'epoch': 4.57}


 47%|████▋     | 21502/45970 [37:40<32:17, 12.63it/s]  

{'loss': 0.5338, 'grad_norm': 5.0950846672058105, 'learning_rate': 1.0646073526212748e-05, 'epoch': 4.68}


 48%|████▊     | 22002/45970 [38:27<33:17, 12.00it/s]  

{'loss': 0.5356, 'grad_norm': 4.265661239624023, 'learning_rate': 1.0428540352403743e-05, 'epoch': 4.79}


 49%|████▉     | 22502/45970 [39:14<33:49, 11.56it/s]  

{'loss': 0.5324, 'grad_norm': 4.905956268310547, 'learning_rate': 1.0211007178594735e-05, 'epoch': 4.89}


 50%|█████     | 22985/45970 [40:16<52:54,  7.24it/s]  
 50%|█████     | 22986/45970 [40:44<49:23:06,  7.74s/it]

{'eval_loss': 0.8464130163192749, 'eval_EXT - accuracy': 0.5162121034765305, 'eval_NEU - accuracy': 0.515938975379453, 'eval_AGR - accuracy': 0.5143782433961528, 'eval_CON - accuracy': 0.520738226228101, 'eval_OPN - accuracy': 0.5349408872761325, 'eval_accuracy': 0.048616801279800226, 'eval_runtime': 41.8004, 'eval_samples_per_second': 613.128, 'eval_steps_per_second': 38.325, 'epoch': 5.0}


 50%|█████     | 23002/45970 [40:45<2:49:57,  2.25it/s] 

{'loss': 0.5291, 'grad_norm': 4.816239356994629, 'learning_rate': 9.99347400478573e-06, 'epoch': 5.0}


 51%|█████     | 23500/45970 [41:32<57:53,  6.47it/s]  

{'loss': 0.4789, 'grad_norm': 4.138972282409668, 'learning_rate': 9.775940830976725e-06, 'epoch': 5.11}


 52%|█████▏    | 24002/45970 [42:23<31:34, 11.59it/s]  

{'loss': 0.4852, 'grad_norm': 5.230184078216553, 'learning_rate': 9.55840765716772e-06, 'epoch': 5.22}


 53%|█████▎    | 24502/45970 [43:16<28:44, 12.45it/s]  

{'loss': 0.4803, 'grad_norm': 4.789518356323242, 'learning_rate': 9.340874483358712e-06, 'epoch': 5.33}


 54%|█████▍    | 25002/45970 [44:10<30:26, 11.48it/s]  

{'loss': 0.4777, 'grad_norm': 3.9592678546905518, 'learning_rate': 9.123341309549707e-06, 'epoch': 5.44}


 55%|█████▌    | 25502/45970 [44:58<33:37, 10.14it/s]  

{'loss': 0.4831, 'grad_norm': 7.437893867492676, 'learning_rate': 8.905808135740701e-06, 'epoch': 5.55}


 57%|█████▋    | 26002/45970 [45:46<26:59, 12.33it/s]  

{'loss': 0.4853, 'grad_norm': 4.1771979331970215, 'learning_rate': 8.688274961931696e-06, 'epoch': 5.66}


 58%|█████▊    | 26501/45970 [46:31<30:02, 10.80it/s]  

{'loss': 0.4817, 'grad_norm': 5.542566299438477, 'learning_rate': 8.470741788122689e-06, 'epoch': 5.76}


 59%|█████▊    | 27001/45970 [47:19<28:38, 11.04it/s]  

{'loss': 0.4841, 'grad_norm': 4.553526401519775, 'learning_rate': 8.253208614313684e-06, 'epoch': 5.87}


 60%|█████▉    | 27501/45970 [48:00<25:57, 11.86it/s]

{'loss': 0.4818, 'grad_norm': 3.8263072967529297, 'learning_rate': 8.035675440504678e-06, 'epoch': 5.98}


 60%|██████    | 27582/45970 [48:26<26:38, 11.50it/s]
 60%|██████    | 27583/45970 [48:50<39:45:39,  7.78s/it]

{'eval_loss': 0.9326357245445251, 'eval_EXT - accuracy': 0.516407194974443, 'eval_NEU - accuracy': 0.5164462132740255, 'eval_AGR - accuracy': 0.5155097740840454, 'eval_CON - accuracy': 0.5217917203168286, 'eval_OPN - accuracy': 0.5352530336727925, 'eval_accuracy': 0.04471497132154981, 'eval_runtime': 42.0198, 'eval_samples_per_second': 609.927, 'eval_steps_per_second': 38.125, 'epoch': 6.0}


 61%|██████    | 28001/45970 [49:32<25:40, 11.67it/s]   

{'loss': 0.443, 'grad_norm': 5.769805431365967, 'learning_rate': 7.818142266695671e-06, 'epoch': 6.09}


 62%|██████▏   | 28501/45970 [50:19<26:44, 10.89it/s]  

{'loss': 0.4391, 'grad_norm': 5.021852016448975, 'learning_rate': 7.600609092886666e-06, 'epoch': 6.2}


 63%|██████▎   | 29002/45970 [51:04<23:07, 12.23it/s]  

{'loss': 0.438, 'grad_norm': 4.7704620361328125, 'learning_rate': 7.38307591907766e-06, 'epoch': 6.31}


 64%|██████▍   | 29501/45970 [51:53<24:32, 11.18it/s]  

{'loss': 0.4401, 'grad_norm': 6.7290778160095215, 'learning_rate': 7.165542745268654e-06, 'epoch': 6.42}


 65%|██████▌   | 30003/45970 [52:41<1:21:22,  3.27it/s]

{'loss': 0.4385, 'grad_norm': 5.03167724609375, 'learning_rate': 6.948009571459649e-06, 'epoch': 6.53}


 66%|██████▋   | 30501/45970 [53:34<20:00, 12.89it/s]  

{'loss': 0.442, 'grad_norm': 6.766926288604736, 'learning_rate': 6.7304763976506424e-06, 'epoch': 6.63}


 67%|██████▋   | 31002/45970 [54:25<22:07, 11.27it/s]  

{'loss': 0.4404, 'grad_norm': 6.841715335845947, 'learning_rate': 6.512943223841637e-06, 'epoch': 6.74}


 69%|██████▊   | 31501/45970 [55:11<20:00, 12.05it/s]  

{'loss': 0.4357, 'grad_norm': 6.853383541107178, 'learning_rate': 6.295410050032631e-06, 'epoch': 6.85}


 70%|██████▉   | 32001/45970 [55:55<19:15, 12.09it/s]  

{'loss': 0.4414, 'grad_norm': 6.311779499053955, 'learning_rate': 6.0778768762236254e-06, 'epoch': 6.96}


 70%|███████   | 32179/45970 [56:26<56:35,  4.06it/s]  
 70%|███████   | 32180/45970 [56:56<26:11:07,  6.84s/it]

{'eval_loss': 1.0016638040542603, 'eval_EXT - accuracy': 0.5140270786999103, 'eval_NEU - accuracy': 0.5143392250965703, 'eval_AGR - accuracy': 0.5144562799953178, 'eval_CON - accuracy': 0.5196457138397909, 'eval_OPN - accuracy': 0.5375941316477428, 'eval_accuracy': 0.04424675172655976, 'eval_runtime': 41.9603, 'eval_samples_per_second': 610.792, 'eval_steps_per_second': 38.179, 'epoch': 7.0}


 71%|███████   | 32502/45970 [57:24<20:17, 11.06it/s]   

{'loss': 0.4131, 'grad_norm': 6.8449320793151855, 'learning_rate': 5.860343702414619e-06, 'epoch': 7.07}


 72%|███████▏  | 33002/45970 [58:11<19:03, 11.34it/s]  

{'loss': 0.4043, 'grad_norm': 4.878905296325684, 'learning_rate': 5.642810528605612e-06, 'epoch': 7.18}


 73%|███████▎  | 33503/45970 [58:57<16:49, 12.35it/s]  

{'loss': 0.4008, 'grad_norm': 5.9229736328125, 'learning_rate': 5.425277354796607e-06, 'epoch': 7.29}


 74%|███████▍  | 34002/45970 [59:47<15:46, 12.65it/s]  

{'loss': 0.4047, 'grad_norm': 6.516345500946045, 'learning_rate': 5.2077441809876005e-06, 'epoch': 7.4}


 75%|███████▌  | 34502/45970 [1:00:34<14:34, 13.12it/s]  

{'loss': 0.4053, 'grad_norm': 5.588491916656494, 'learning_rate': 4.990211007178595e-06, 'epoch': 7.5}


 76%|███████▌  | 35001/45970 [1:01:19<18:20,  9.97it/s]  

{'loss': 0.4038, 'grad_norm': 8.990118980407715, 'learning_rate': 4.77267783336959e-06, 'epoch': 7.61}


 77%|███████▋  | 35503/45970 [1:02:10<13:14, 13.18it/s]  

{'loss': 0.4032, 'grad_norm': 5.471193790435791, 'learning_rate': 4.5551446595605835e-06, 'epoch': 7.72}


 78%|███████▊  | 36001/45970 [1:03:04<54:17,  3.06it/s]  

{'loss': 0.4009, 'grad_norm': 3.918320655822754, 'learning_rate': 4.337611485751577e-06, 'epoch': 7.83}


 79%|███████▉  | 36501/45970 [1:03:51<13:13, 11.93it/s]  

{'loss': 0.4025, 'grad_norm': 6.463129043579102, 'learning_rate': 4.120078311942571e-06, 'epoch': 7.94}


 80%|████████  | 36776/45970 [1:04:37<17:00,  9.01it/s]  
 80%|████████  | 36777/45970 [1:05:03<16:29:49,  6.46s/it]

{'eval_loss': 1.0700851678848267, 'eval_EXT - accuracy': 0.5145733348940653, 'eval_NEU - accuracy': 0.5151976276873854, 'eval_AGR - accuracy': 0.5131686761090951, 'eval_CON - accuracy': 0.5183581099535682, 'eval_OPN - accuracy': 0.5377111865464903, 'eval_accuracy': 0.04537828241445238, 'eval_runtime': 42.2722, 'eval_samples_per_second': 606.285, 'eval_steps_per_second': 37.897, 'epoch': 8.0}


 80%|████████  | 37002/45970 [1:05:22<13:34, 11.01it/s]   

{'loss': 0.3899, 'grad_norm': 7.184511184692383, 'learning_rate': 3.902545138133566e-06, 'epoch': 8.05}


 82%|████████▏ | 37502/45970 [1:06:11<11:53, 11.87it/s]

{'loss': 0.3775, 'grad_norm': 8.467717170715332, 'learning_rate': 3.68501196432456e-06, 'epoch': 8.16}


 83%|████████▎ | 38001/45970 [1:07:02<34:00,  3.90it/s]

{'loss': 0.3771, 'grad_norm': 9.418878555297852, 'learning_rate': 3.467478790515554e-06, 'epoch': 8.27}


 84%|████████▍ | 38502/45970 [1:07:52<10:02, 12.40it/s]

{'loss': 0.3689, 'grad_norm': 7.6100921630859375, 'learning_rate': 3.2499456167065478e-06, 'epoch': 8.38}


 85%|████████▍ | 39002/45970 [1:08:41<12:51,  9.04it/s]

{'loss': 0.3735, 'grad_norm': 7.055738925933838, 'learning_rate': 3.032412442897542e-06, 'epoch': 8.48}


 86%|████████▌ | 39502/45970 [1:09:32<09:10, 11.75it/s]

{'loss': 0.3768, 'grad_norm': 8.992749214172363, 'learning_rate': 2.814879269088536e-06, 'epoch': 8.59}


 87%|████████▋ | 40002/45970 [1:10:23<08:06, 12.26it/s]

{'loss': 0.3742, 'grad_norm': 6.056766510009766, 'learning_rate': 2.5973460952795303e-06, 'epoch': 8.7}


 88%|████████▊ | 40502/45970 [1:11:13<07:07, 12.79it/s]

{'loss': 0.3778, 'grad_norm': 6.803421497344971, 'learning_rate': 2.3798129214705245e-06, 'epoch': 8.81}


 89%|████████▉ | 41002/45970 [1:11:56<06:56, 11.93it/s]

{'loss': 0.3732, 'grad_norm': 4.9665846824646, 'learning_rate': 2.1622797476615183e-06, 'epoch': 8.92}


 90%|█████████ | 41373/45970 [1:12:39<05:29, 13.94it/s]
 90%|█████████ | 41374/45970 [1:13:08<9:20:32,  7.32s/it]

{'eval_loss': 1.1256890296936035, 'eval_EXT - accuracy': 0.5146513714932304, 'eval_NEU - accuracy': 0.5135588591049202, 'eval_AGR - accuracy': 0.511764017324125, 'eval_CON - accuracy': 0.5192555308439658, 'eval_OPN - accuracy': 0.5365016192594326, 'eval_accuracy': 0.04483202622029732, 'eval_runtime': 41.0144, 'eval_samples_per_second': 624.878, 'eval_steps_per_second': 39.059, 'epoch': 9.0}


 90%|█████████ | 41502/45970 [1:13:19<05:54, 12.59it/s]  

{'loss': 0.3737, 'grad_norm': 10.302719116210938, 'learning_rate': 1.9447465738525125e-06, 'epoch': 9.03}


 91%|█████████▏| 42002/45970 [1:14:12<13:02,  5.07it/s]

{'loss': 0.3496, 'grad_norm': 6.177398681640625, 'learning_rate': 1.7272134000435066e-06, 'epoch': 9.14}


 92%|█████████▏| 42501/45970 [1:14:56<06:06,  9.48it/s]

{'loss': 0.3508, 'grad_norm': 7.236640930175781, 'learning_rate': 1.5096802262345008e-06, 'epoch': 9.25}


 94%|█████████▎| 43002/45970 [1:15:43<04:39, 10.63it/s]

{'loss': 0.3552, 'grad_norm': 5.569890975952148, 'learning_rate': 1.292147052425495e-06, 'epoch': 9.35}


 95%|█████████▍| 43503/45970 [1:16:27<03:07, 13.18it/s]

{'loss': 0.3531, 'grad_norm': 7.129018306732178, 'learning_rate': 1.074613878616489e-06, 'epoch': 9.46}


 96%|█████████▌| 44002/45970 [1:17:12<02:50, 11.52it/s]

{'loss': 0.3531, 'grad_norm': 9.848424911499023, 'learning_rate': 8.570807048074832e-07, 'epoch': 9.57}


 97%|█████████▋| 44502/45970 [1:18:02<02:20, 10.47it/s]

{'loss': 0.3585, 'grad_norm': 7.727417469024658, 'learning_rate': 6.395475309984774e-07, 'epoch': 9.68}


 98%|█████████▊| 45002/45970 [1:18:50<01:18, 12.40it/s]

{'loss': 0.3512, 'grad_norm': 6.670873641967773, 'learning_rate': 4.2201435718947146e-07, 'epoch': 9.79}


 99%|█████████▉| 45503/45970 [1:19:38<00:36, 12.95it/s]

{'loss': 0.3586, 'grad_norm': 6.320346832275391, 'learning_rate': 2.0448118338046557e-07, 'epoch': 9.9}


100%|██████████| 45970/45970 [1:20:37<00:00, 11.85it/s]
100%|██████████| 45970/45970 [1:21:05<00:00,  9.45it/s]

{'eval_loss': 1.169968605041504, 'eval_EXT - accuracy': 0.5155878106832105, 'eval_NEU - accuracy': 0.5146123531936478, 'eval_AGR - accuracy': 0.5128175114128526, 'eval_CON - accuracy': 0.5185532014514808, 'eval_OPN - accuracy': 0.5367357290569277, 'eval_accuracy': 0.04448086152405478, 'eval_runtime': 40.0307, 'eval_samples_per_second': 640.234, 'eval_steps_per_second': 40.019, 'epoch': 10.0}
{'train_runtime': 4865.7421, 'train_samples_per_second': 151.132, 'train_steps_per_second': 9.448, 'train_loss': 0.5153251616415937, 'epoch': 10.0}





TrainOutput(global_step=45970, training_loss=0.5153251616415937, metrics={'train_runtime': 4865.7421, 'train_samples_per_second': 151.132, 'train_steps_per_second': 9.448, 'total_flos': 9631412851951230.0, 'train_loss': 0.5153251616415937, 'epoch': 10.0})

In [None]:
trainer.save_model('fine-tuned-roberta-personality-normal')