### Logs
* 30.08: use 2048 for max_len and max_position_embeddings
* 03.09: use 512 instead

In [2]:
!pip install transformers -q
!pip install sentencepiece -q
!pip install datasets -q
!pip install accelerate -U -q



In [3]:
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")

from datasets import Dataset, disable_progress_bar
import pandas as pd

pdf = pd.read_csv("./input/prompts_train.csv")
sdf = pd.read_csv("./input/summaries_train.csv")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
from transformers import AutoModelForSequenceClassification
#model = AutoModelForSequenceClassification.from_pretrained("microsoft/deberta-v3-large")
## use the pretrained model
# from transformers import AutoConfig
# config = AutoConfig.from_pretrained('./input/pretrain/pretrained_model/')
# model = AutoModelForSequenceClassification.from_pretrained('./input/pretrain/pretrained_model', config = config)

In [9]:
df = pdf.merge(sdf, on="prompt_id")
df.head()

Unnamed: 0,prompt_id,prompt_question,prompt_title,prompt_text,student_id,text,content,wording
0,39c16e,Summarize at least 3 elements of an ideal trag...,On Tragedy,Chapter 13 \r\nAs the sequel to what has alrea...,00791789cc1f,1 element of an ideal tragedy is that it shoul...,-0.210614,-0.471415
1,39c16e,Summarize at least 3 elements of an ideal trag...,On Tragedy,Chapter 13 \r\nAs the sequel to what has alrea...,0086ef22de8f,The three elements of an ideal tragedy are: H...,-0.970237,-0.417058
2,39c16e,Summarize at least 3 elements of an ideal trag...,On Tragedy,Chapter 13 \r\nAs the sequel to what has alrea...,0094589c7a22,Aristotle states that an ideal tragedy should ...,-0.387791,-0.584181
3,39c16e,Summarize at least 3 elements of an ideal trag...,On Tragedy,Chapter 13 \r\nAs the sequel to what has alrea...,00cd5736026a,One element of an Ideal tragedy is having a co...,0.088882,-0.59471
4,39c16e,Summarize at least 3 elements of an ideal trag...,On Tragedy,Chapter 13 \r\nAs the sequel to what has alrea...,00d98b8ff756,The 3 ideal of tragedy is how complex you need...,-0.687288,-0.460886


## Train

In [14]:
%%writefile train.py

import os
import logging
import warnings
from dataclasses import dataclass, field
from typing import Optional

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    AutoConfig,
    set_seed,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    DataCollatorWithPadding,
)
from datasets import Dataset, disable_progress_bar
import pandas as pd
import numpy as np

warnings.simplefilter("ignore")
logging.disable(logging.ERROR)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['WANDB_PROJECT'] = 'kaggle-commonlit-eval-student-summaries-0309'

disable_progress_bar()

@dataclass
class Config:
    model_name_or_path: Optional[str] = field(
        default="microsoft/deberta-v3-base",
        metadata={"help": "Model name or path"},
    )

    data_dir: Optional[str] = field(
        default="/kaggle/input/commonlit-evaluate-student-summaries",
        metadata={"help": "Data directory"},
    )

    max_seq_length: Optional[int] = field(
        #default=1600,
        #default = 512,
        default = 512,
        metadata={"help": "Max sequence length"},
    )

    add_prompt_question: Optional[bool] = field(
        default=False,
        metadata={"help": "Add prompt question into input"},
    )

    add_prompt_text: Optional[bool] = field(
        default=False,
        metadata={"help": "Add prompt text into input"},
    )

    fold: Optional[int] = field(
        default=0,
        metadata={"help": "Fold"},
    )

    num_proc: Optional[int] = field(
        default=4,
        metadata={"help": "Number of processes"},
    )

    dropout: Optional[float] = field(
        default=0.,
        metadata={"help": "Amount of dropout to apply"},
    )
    max_position_embeddings: Optional[int] = field(
        #default=1600,
        #default=512,
        default = 512,
        metadata={"help": "Amount of dropout to apply"},
    )


# Spell auto correction
from spellchecker import SpellChecker

def correct_spelling(input_text):
    print('input_text: ', input_text)
    # Initialize the spell checker
    spell = SpellChecker()
    # Split the input text into words
    words = input_text.split()
    # Initialize an empty list to store the corrected words
    corrected_words = []
    for word in words:
        # Check if the word has any punctuation at the end
        if word[-1].isalpha():
            # Extract the punctuation
            punctuation_end = ""
        else:
            punctuation_end = word[-1]
            word = word[:-1]
        # check if the word has any punctuation at the start
        if word[0].isalpha():
           # Extract the punctuation
            punctuation_start = ""
        else:
            punctuation_start = word[0]
            word = word[1:]
        # Check the spelling of the word (case insensitive)
        corrected_word = spell.correction(word.lower())
        # Preserve the original capitalization
        if word[0].isupper():
            corrected_word = corrected_word.capitalize()
        # Combine the corrected word and punctuation (if any)
        corrected_word = punctuation_start+corrected_word+punctuation_end
        # Append the corrected word to the list
        corrected_words.append(corrected_word)
    # Join the corrected words back into a single string
    corrected_text = " ".join(corrected_words)
    return corrected_text

def tokenize(example, tokenizer, config):
    sep = tokenizer.sep_token

    # if config.add_prompt_question:
    #     text = sep.join(
    #         [example["prompt_question"], example["prompt_text"], example["text"]]
    #     )
    # elif config.add_prompt_text:
    #     text = sep.join([example["prompt_text"], example["text"]])
    # else:
    #     text = example["text"]
    prompt = sep.join([example["prompt_title"], example["prompt_text"], example["prompt_question"]])
    labels = [example["content"], example["wording"]]

    tokenized = tokenizer(
#         prompt,
#         example["text"],
        example['text'],
        prompt,
        padding=False,
        truncation=True, # changed by Peng, turn on the truncation
        max_length=config.max_seq_length,
    )

    return {
        **tokenized,
        "labels": labels,
    }




def compute_mcrmse(eval_pred):
    """
    Calculates mean columnwise root mean squared error
    https://www.kaggle.com/competitions/commonlit-evaluate-student-summaries/overview/evaluation
    """
    preds, labels = eval_pred

    col_rmse = np.sqrt(np.mean((preds - labels) ** 2, axis=0))
    mcrmse = np.mean(col_rmse)

    return {
        "content_rmse": col_rmse[0],
        "wording_rmse": col_rmse[1],
        "mcrmse": mcrmse,
    }


def main():
    parser = HfArgumentParser((Config, TrainingArguments))

    config, training_args = parser.parse_args_into_dataclasses()

    set_seed(training_args.seed)

    if "wandb" in training_args.report_to:
        import wandb

        try:
#             from kaggle_secrets import UserSecretsClient
#             user_secrets = UserSecretsClient()
#             key = user_secrets.get_secret("wandb")

#             wandb.login(key=key)
            wandb.login()
        except:
            print("Could not log in to WandB")

    tokenizer = AutoTokenizer.from_pretrained(config.model_name_or_path)
    model_config = AutoConfig.from_pretrained(config.model_name_or_path)

    model_config.update({
        "hidden_dropout_prob": config.dropout,
        "attention_probs_dropout_prob": config.dropout,
        "num_labels": 2,
        "problem_type": "regression",
        "max_position_embeddings": config.max_position_embeddings,
        "cfg": config.__dict__,
    })

    print(model_config)
    
    # Do not use pretrained model
#     model = AutoModelForSequenceClassification.from_pretrained(
#         config.model_name_or_path, config=model_config
#     )
    
    # use pretrained model
    print('use pretrained_model')
    model = AutoModelForSequenceClassification.from_pretrained('./input/pretrain/pretrained_model/', config = model_config)

    #pdf = pd.read_csv(f"{config.data_dir}/prompts_train.csv")
    pdf = pd.read_csv(f"./input/prompts_train.csv")
    #sdf = pd.read_csv(f"{config.data_dir}/summaries_train.csv")
    sdf = pd.read_csv(f"./input/summaries_train.csv")

    df = pdf.merge(sdf, on="prompt_id")

    # 4 prompt ids, 4 folds
    id2fold = {
        "814d6b": 0,
        "39c16e": 1,
        "3b9047": 2,
        "ebad26": 3,
    }

    df["fold"] = df["prompt_id"].map(id2fold)

    train_ds = Dataset.from_pandas(df[df["fold"] != config.fold])
    val_ds = Dataset.from_pandas(df[df["fold"] == config.fold])

    train_ds = train_ds.map(
        tokenize,
        batched=False,
        num_proc=config.num_proc,
        fn_kwargs={"tokenizer": tokenizer, "config": config},
    )

    val_ds = val_ds.map(
        tokenize,
        batched=False,
        num_proc=config.num_proc,
        fn_kwargs={"tokenizer": tokenizer, "config": config},
    )

    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        pad_to_multiple_of=16 if training_args.fp16 else None,
    )

    training_args.bf16 =True
    training_args.gradient_accumulation_steps = 1
    training_args.load_best_model_at_end = True
    training_args.greater_is_better = False
    training_args.metric_for_best_model = 'mcrmse'
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=data_collator,
        tokenizer=tokenizer,
        compute_metrics=compute_mcrmse,
    )

    trainer.train()

    model.config.best_metric = trainer.state.best_metric
    model.config.save_pretrained(training_args.output_dir)

    trainer.log({"eval_best_mcrmse": trainer.state.best_metric})


if __name__ == "__main__":
    main()

Overwriting train.py


In [15]:
from pathlib import Path

seed = 42

fold = 0

output = f"output_fold{fold}_seed{seed}_0309"

!python train.py \
  --model_name_or_path "microsoft/deberta-v3-large" \
  --add_prompt_question True \
  --fold $fold \
  --data_dir "./" \
  --output_dir $output \
  --fp16 \
  --num_train_epochs 4 \
  --dataloader_num_workers 4 \
  --learning_rate 2e-6 \
  --weight_decay 0.01 \
  --warmup_ratio 0 \
  --optim "adamw_torch" \
  --per_device_train_batch_size 2 \
  --per_device_eval_batch_size 2 \
  --evaluation_strategy "steps" \
  --eval_steps 150 \
  --save_strategy "steps" \
  --save_steps 150 \
  --save_total_limit 1 \
  --report_to "wandb" \
  --metric_for_best_model "mcrmse" \
  --greater_is_better False \
  --logging_steps 10 \
  --log_level "error" \
  --disable_tqdm True \
  --ddp_find_unused_parameters False \
  --dropout 0 \
  --seed $seed


output_dir = Path.cwd() / output
# add json files
for json_file in output_dir.glob("checkpoint*/*token*.json"):
    json_file.rename(output_dir/json_file.name)

# model files
for model_file in output_dir.glob("checkpoint*/*model*"):
    model_file.rename(output_dir/model_file.name)

# remove optimizer states and other files
to_delete = str(list(output_dir.glob("checkpoint*"))[0])
!rm -r $to_delete

[34m[1mwandb[0m: Currently logged in as: [33mpeng_sun[0m. Use [1m`wandb login --relogin`[0m to force relogin
DebertaV2Config {
  "_name_or_path": "microsoft/deberta-v3-large",
  "attention_probs_dropout_prob": 0.0,
  "cfg": {
    "add_prompt_question": true,
    "add_prompt_text": false,
    "data_dir": "./",
    "dropout": 0.0,
    "fold": 0,
    "max_position_embeddings": 512,
    "max_seq_length": 512,
    "model_name_or_path": "microsoft/deberta-v3-large",
    "num_proc": 4
  },
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-07,
  "max_position_embeddings": 512,
  "max_relative_positions": -1,
  "model_type": "deberta-v2",
  "norm_rel_ebd": "layer_norm",
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "pooler_dropout": 0,
  "pooler_hidden_act": "gelu",
  "pooler_hidden_size": 1024,
  "pos_att_type": [
    "p2c",
    "c2p"
  ],
  "posit

{'loss': 0.2375, 'learning_rate': 1.8804025074232926e-06, 'epoch': 0.24}
{'loss': 0.267, 'learning_rate': 1.8787528868360276e-06, 'epoch': 0.24}
{'loss': 0.485, 'learning_rate': 1.8771032662487628e-06, 'epoch': 0.25}
{'eval_loss': 0.6607676148414612, 'eval_content_rmse': 0.626237690448761, 'eval_wording_rmse': 0.9640341401100159, 'eval_mcrmse': 0.7951359152793884, 'eval_runtime': 38.9638, 'eval_samples_per_second': 28.308, 'eval_steps_per_second': 14.167, 'epoch': 0.25}
{'loss': 0.1576, 'learning_rate': 1.8754536456614978e-06, 'epoch': 0.25}
{'loss': 0.2829, 'learning_rate': 1.8738040250742327e-06, 'epoch': 0.25}
{'loss': 0.2829, 'learning_rate': 1.872154404486968e-06, 'epoch': 0.26}
{'loss': 0.3114, 'learning_rate': 1.870504783899703e-06, 'epoch': 0.26}
{'loss': 0.3271, 'learning_rate': 1.8688551633124382e-06, 'epoch': 0.26}
{'loss': 0.3631, 'learning_rate': 1.8672055427251732e-06, 'epoch': 0.27}
{'loss': 0.4725, 'learning_rate': 1.8655559221379081e-06, 'epoch': 0.27}
{'loss': 0.1651,

{'loss': 0.3496, 'learning_rate': 1.728637413394919e-06, 'epoch': 0.54}
{'eval_loss': 0.5881047248840332, 'eval_content_rmse': 0.5621434450149536, 'eval_wording_rmse': 0.927471935749054, 'eval_mcrmse': 0.7448077201843262, 'eval_runtime': 39.3792, 'eval_samples_per_second': 28.01, 'eval_steps_per_second': 14.018, 'epoch': 0.54}
{'loss': 0.21, 'learning_rate': 1.726987792807654e-06, 'epoch': 0.55}
{'loss': 0.2576, 'learning_rate': 1.7253381722203893e-06, 'epoch': 0.55}
{'loss': 0.2346, 'learning_rate': 1.7236885516331243e-06, 'epoch': 0.55}
{'loss': 0.2683, 'learning_rate': 1.7220389310458595e-06, 'epoch': 0.56}
{'loss': 0.325, 'learning_rate': 1.7203893104585945e-06, 'epoch': 0.56}
{'loss': 0.2188, 'learning_rate': 1.7187396898713295e-06, 'epoch': 0.56}
{'loss': 0.1907, 'learning_rate': 1.7170900692840647e-06, 'epoch': 0.57}
{'loss': 0.3022, 'learning_rate': 1.7154404486967997e-06, 'epoch': 0.57}
{'loss': 0.2356, 'learning_rate': 1.713790828109535e-06, 'epoch': 0.57}
{'loss': 0.2391, 'l

{'loss': 0.2594, 'learning_rate': 1.5785219399538106e-06, 'epoch': 0.84}
{'loss': 0.1814, 'learning_rate': 1.5768723193665456e-06, 'epoch': 0.85}
{'loss': 0.1727, 'learning_rate': 1.5752226987792806e-06, 'epoch': 0.85}
{'loss': 0.1017, 'learning_rate': 1.5735730781920158e-06, 'epoch': 0.85}
{'loss': 0.4497, 'learning_rate': 1.5719234576047508e-06, 'epoch': 0.86}
{'loss': 0.1606, 'learning_rate': 1.570273837017486e-06, 'epoch': 0.86}
{'loss': 0.2622, 'learning_rate': 1.568624216430221e-06, 'epoch': 0.86}
{'loss': 0.3792, 'learning_rate': 1.566974595842956e-06, 'epoch': 0.87}
{'loss': 0.2381, 'learning_rate': 1.5653249752556912e-06, 'epoch': 0.87}
{'loss': 0.2322, 'learning_rate': 1.5636753546684262e-06, 'epoch': 0.87}
{'loss': 0.2729, 'learning_rate': 1.5620257340811614e-06, 'epoch': 0.88}
{'loss': 0.2287, 'learning_rate': 1.5603761134938964e-06, 'epoch': 0.88}
{'loss': 0.2374, 'learning_rate': 1.5587264929066314e-06, 'epoch': 0.88}
{'loss': 0.198, 'learning_rate': 1.5570768723193666e-0

{'loss': 0.1173, 'learning_rate': 1.4267568459254372e-06, 'epoch': 1.15}
{'loss': 0.282, 'learning_rate': 1.4251072253381722e-06, 'epoch': 1.15}
{'loss': 0.2493, 'learning_rate': 1.4234576047509072e-06, 'epoch': 1.15}
{'loss': 0.2564, 'learning_rate': 1.4218079841636424e-06, 'epoch': 1.16}
{'loss': 0.3041, 'learning_rate': 1.4201583635763774e-06, 'epoch': 1.16}
{'loss': 0.1983, 'learning_rate': 1.4185087429891126e-06, 'epoch': 1.16}
{'loss': 0.1072, 'learning_rate': 1.4168591224018476e-06, 'epoch': 1.17}
{'loss': 0.2318, 'learning_rate': 1.4152095018145826e-06, 'epoch': 1.17}
{'loss': 0.1517, 'learning_rate': 1.4135598812273178e-06, 'epoch': 1.17}
{'loss': 0.1328, 'learning_rate': 1.4119102606400528e-06, 'epoch': 1.18}
{'loss': 0.2032, 'learning_rate': 1.410260640052788e-06, 'epoch': 1.18}
{'loss': 0.1586, 'learning_rate': 1.408611019465523e-06, 'epoch': 1.18}
{'loss': 0.2862, 'learning_rate': 1.406961398878258e-06, 'epoch': 1.19}
{'eval_loss': 0.5492208003997803, 'eval_content_rmse': 

{'loss': 0.1928, 'learning_rate': 1.2749917518970637e-06, 'epoch': 1.45}
{'loss': 0.3242, 'learning_rate': 1.2733421313097987e-06, 'epoch': 1.45}
{'loss': 0.2454, 'learning_rate': 1.271692510722534e-06, 'epoch': 1.46}
{'loss': 0.1818, 'learning_rate': 1.270042890135269e-06, 'epoch': 1.46}
{'loss': 0.213, 'learning_rate': 1.2683932695480039e-06, 'epoch': 1.46}
{'loss': 0.1885, 'learning_rate': 1.266743648960739e-06, 'epoch': 1.47}
{'loss': 0.1663, 'learning_rate': 1.265094028373474e-06, 'epoch': 1.47}
{'loss': 0.1522, 'learning_rate': 1.2634444077862093e-06, 'epoch': 1.47}
{'loss': 0.1744, 'learning_rate': 1.2617947871989443e-06, 'epoch': 1.48}
{'loss': 0.2179, 'learning_rate': 1.2601451666116793e-06, 'epoch': 1.48}
{'loss': 0.1473, 'learning_rate': 1.2584955460244145e-06, 'epoch': 1.48}
{'eval_loss': 0.6025023460388184, 'eval_content_rmse': 0.5816174149513245, 'eval_wording_rmse': 0.9309810400009155, 'eval_mcrmse': 0.7562992572784424, 'eval_runtime': 37.8909, 'eval_samples_per_second':

{'loss': 0.2157, 'learning_rate': 1.1232266578686902e-06, 'epoch': 1.76}
{'loss': 0.2897, 'learning_rate': 1.1217419993401517e-06, 'epoch': 1.76}
{'loss': 0.2405, 'learning_rate': 1.1200923787528869e-06, 'epoch': 1.76}
{'loss': 0.1497, 'learning_rate': 1.1184427581656219e-06, 'epoch': 1.77}
{'loss': 0.1269, 'learning_rate': 1.116793137578357e-06, 'epoch': 1.77}
{'loss': 0.3692, 'learning_rate': 1.115143516991092e-06, 'epoch': 1.77}
{'loss': 0.1942, 'learning_rate': 1.1134938964038269e-06, 'epoch': 1.77}
{'loss': 0.1449, 'learning_rate': 1.111844275816562e-06, 'epoch': 1.78}
{'loss': 0.14, 'learning_rate': 1.110194655229297e-06, 'epoch': 1.78}
{'eval_loss': 0.5464901924133301, 'eval_content_rmse': 0.5695561170578003, 'eval_wording_rmse': 0.8766903281211853, 'eval_mcrmse': 0.7231231927871704, 'eval_runtime': 37.9035, 'eval_samples_per_second': 29.1, 'eval_steps_per_second': 14.563, 'epoch': 1.78}
{'loss': 0.2178, 'learning_rate': 1.1085450346420323e-06, 'epoch': 1.78}
{'loss': 0.1681, 'l

{'loss': 0.1162, 'learning_rate': 9.716265258990432e-07, 'epoch': 2.06}
{'loss': 0.1421, 'learning_rate': 9.699769053117782e-07, 'epoch': 2.06}
{'loss': 0.1296, 'learning_rate': 9.683272847245132e-07, 'epoch': 2.07}
{'loss': 0.2153, 'learning_rate': 9.666776641372484e-07, 'epoch': 2.07}
{'loss': 0.1067, 'learning_rate': 9.650280435499834e-07, 'epoch': 2.07}
{'loss': 0.2275, 'learning_rate': 9.633784229627186e-07, 'epoch': 2.08}
{'loss': 0.1282, 'learning_rate': 9.617288023754536e-07, 'epoch': 2.08}
{'eval_loss': 0.4799107313156128, 'eval_content_rmse': 0.5621523857116699, 'eval_wording_rmse': 0.802375078201294, 'eval_mcrmse': 0.6822637319564819, 'eval_runtime': 37.5597, 'eval_samples_per_second': 29.367, 'eval_steps_per_second': 14.697, 'epoch': 2.08}
{'loss': 0.1282, 'learning_rate': 9.600791817881886e-07, 'epoch': 2.08}
{'loss': 0.1818, 'learning_rate': 9.584295612009238e-07, 'epoch': 2.09}
{'loss': 0.1671, 'learning_rate': 9.567799406136588e-07, 'epoch': 2.09}
{'loss': 0.1974, 'lear

{'loss': 0.0828, 'learning_rate': 8.182118112834048e-07, 'epoch': 2.37}
{'loss': 0.1906, 'learning_rate': 8.165621906961399e-07, 'epoch': 2.37}
{'loss': 0.1199, 'learning_rate': 8.14912570108875e-07, 'epoch': 2.37}
{'loss': 0.1624, 'learning_rate': 8.132629495216101e-07, 'epoch': 2.38}
{'eval_loss': 0.4165438711643219, 'eval_content_rmse': 0.49685874581336975, 'eval_wording_rmse': 0.7656492590904236, 'eval_mcrmse': 0.6312540173530579, 'eval_runtime': 37.5722, 'eval_samples_per_second': 29.357, 'eval_steps_per_second': 14.692, 'epoch': 2.38}
{'loss': 0.1747, 'learning_rate': 8.11613328934345e-07, 'epoch': 2.38}
{'loss': 0.1983, 'learning_rate': 8.099637083470802e-07, 'epoch': 2.38}
{'loss': 0.109, 'learning_rate': 8.083140877598153e-07, 'epoch': 2.39}
{'loss': 0.2406, 'learning_rate': 8.066644671725504e-07, 'epoch': 2.39}
{'loss': 0.1464, 'learning_rate': 8.050148465852854e-07, 'epoch': 2.39}
{'loss': 0.109, 'learning_rate': 8.033652259980203e-07, 'epoch': 2.4}
{'loss': 0.1351, 'learnin

{'loss': 0.1471, 'learning_rate': 6.64962058726493e-07, 'epoch': 2.67}
{'eval_loss': 0.4226114749908447, 'eval_content_rmse': 0.5139490365982056, 'eval_wording_rmse': 0.7622857093811035, 'eval_mcrmse': 0.6381173729896545, 'eval_runtime': 37.1106, 'eval_samples_per_second': 29.722, 'eval_steps_per_second': 14.874, 'epoch': 2.67}
{'loss': 0.1797, 'learning_rate': 6.63312438139228e-07, 'epoch': 2.68}
{'loss': 0.1638, 'learning_rate': 6.61662817551963e-07, 'epoch': 2.68}
{'loss': 0.1681, 'learning_rate': 6.60013196964698e-07, 'epoch': 2.68}
{'loss': 0.1751, 'learning_rate': 6.583635763774331e-07, 'epoch': 2.69}
{'loss': 0.2271, 'learning_rate': 6.567139557901682e-07, 'epoch': 2.69}
{'loss': 0.1228, 'learning_rate': 6.550643352029032e-07, 'epoch': 2.69}
{'loss': 0.1368, 'learning_rate': 6.534147146156383e-07, 'epoch': 2.7}
{'loss': 0.1591, 'learning_rate': 6.517650940283734e-07, 'epoch': 2.7}
{'loss': 0.1035, 'learning_rate': 6.501154734411085e-07, 'epoch': 2.7}
{'loss': 0.2088, 'learning_r

{'loss': 0.1663, 'learning_rate': 5.148465852853844e-07, 'epoch': 2.97}
{'loss': 0.1201, 'learning_rate': 5.131969646981195e-07, 'epoch': 2.98}
{'loss': 0.1163, 'learning_rate': 5.115473441108544e-07, 'epoch': 2.98}
{'loss': 0.1445, 'learning_rate': 5.098977235235895e-07, 'epoch': 2.98}
{'loss': 0.089, 'learning_rate': 5.082481029363246e-07, 'epoch': 2.99}
{'loss': 0.1181, 'learning_rate': 5.065984823490597e-07, 'epoch': 2.99}
{'loss': 0.1267, 'learning_rate': 5.049488617617948e-07, 'epoch': 2.99}
{'loss': 0.2224, 'learning_rate': 5.032992411745298e-07, 'epoch': 3.0}
{'loss': 0.1332, 'learning_rate': 5.016496205872649e-07, 'epoch': 3.0}
{'loss': 0.1558, 'learning_rate': 5e-07, 'epoch': 3.0}
{'loss': 0.1202, 'learning_rate': 4.983503794127351e-07, 'epoch': 3.01}
{'loss': 0.0953, 'learning_rate': 4.967007588254702e-07, 'epoch': 3.01}
{'loss': 0.0719, 'learning_rate': 4.950511382382052e-07, 'epoch': 3.01}
{'loss': 0.1419, 'learning_rate': 4.934015176509403e-07, 'epoch': 3.02}
{'loss': 0.0

{'loss': 0.2261, 'learning_rate': 3.6143187066974597e-07, 'epoch': 3.28}
{'loss': 0.1188, 'learning_rate': 3.59782250082481e-07, 'epoch': 3.28}
{'loss': 0.087, 'learning_rate': 3.581326294952161e-07, 'epoch': 3.29}
{'loss': 0.0665, 'learning_rate': 3.564830089079511e-07, 'epoch': 3.29}
{'loss': 0.1235, 'learning_rate': 3.548333883206862e-07, 'epoch': 3.29}
{'loss': 0.1159, 'learning_rate': 3.531837677334213e-07, 'epoch': 3.3}
{'loss': 0.0939, 'learning_rate': 3.5153414714615636e-07, 'epoch': 3.3}
{'loss': 0.1138, 'learning_rate': 3.4988452655889146e-07, 'epoch': 3.3}
{'loss': 0.1682, 'learning_rate': 3.482349059716265e-07, 'epoch': 3.31}
{'loss': 0.1007, 'learning_rate': 3.465852853843616e-07, 'epoch': 3.31}
{'loss': 0.1704, 'learning_rate': 3.449356647970966e-07, 'epoch': 3.31}
{'loss': 0.1717, 'learning_rate': 3.432860442098317e-07, 'epoch': 3.32}
{'eval_loss': 0.413396954536438, 'eval_content_rmse': 0.5012569427490234, 'eval_wording_rmse': 0.7586402297019958, 'eval_mcrmse': 0.629948

{'loss': 0.0887, 'learning_rate': 2.0983173870009896e-07, 'epoch': 3.58}
{'loss': 0.099, 'learning_rate': 2.0818211811283404e-07, 'epoch': 3.59}
{'loss': 0.1598, 'learning_rate': 2.065324975255691e-07, 'epoch': 3.59}
{'loss': 0.0924, 'learning_rate': 2.0488287693830418e-07, 'epoch': 3.59}
{'loss': 0.0947, 'learning_rate': 2.0323325635103923e-07, 'epoch': 3.6}
{'loss': 0.1677, 'learning_rate': 2.0158363576377433e-07, 'epoch': 3.6}
{'loss': 0.2286, 'learning_rate': 1.999340151765094e-07, 'epoch': 3.6}
{'loss': 0.2084, 'learning_rate': 1.9828439458924446e-07, 'epoch': 3.61}
{'loss': 0.1512, 'learning_rate': 1.9663477400197953e-07, 'epoch': 3.61}
{'loss': 0.1386, 'learning_rate': 1.949851534147146e-07, 'epoch': 3.61}
{'eval_loss': 0.4195033609867096, 'eval_content_rmse': 0.5284751057624817, 'eval_wording_rmse': 0.7481448650360107, 'eval_mcrmse': 0.6383099555969238, 'eval_runtime': 36.7162, 'eval_samples_per_second': 30.041, 'eval_steps_per_second': 15.034, 'epoch': 3.61}
{'loss': 0.1279, '

{'loss': 0.1764, 'learning_rate': 5.80666446717255e-08, 'epoch': 3.89}
{'loss': 0.1038, 'learning_rate': 5.641702408446057e-08, 'epoch': 3.89}
{'loss': 0.0691, 'learning_rate': 5.4767403497195644e-08, 'epoch': 3.89}
{'loss': 0.1136, 'learning_rate': 5.311778290993071e-08, 'epoch': 3.9}
{'loss': 0.0663, 'learning_rate': 5.1468162322665786e-08, 'epoch': 3.9}
{'loss': 0.1438, 'learning_rate': 4.9818541735400854e-08, 'epoch': 3.9}
{'loss': 0.2977, 'learning_rate': 4.816892114813593e-08, 'epoch': 3.91}
{'loss': 0.1439, 'learning_rate': 4.6519300560871e-08, 'epoch': 3.91}
{'eval_loss': 0.41671204566955566, 'eval_content_rmse': 0.5133568644523621, 'eval_wording_rmse': 0.7549096345901489, 'eval_mcrmse': 0.6341332197189331, 'eval_runtime': 37.343, 'eval_samples_per_second': 29.537, 'eval_steps_per_second': 14.782, 'epoch': 3.91}
{'loss': 0.1213, 'learning_rate': 4.486967997360607e-08, 'epoch': 3.91}
{'loss': 0.0847, 'learning_rate': 4.3220059386341145e-08, 'epoch': 3.92}
{'loss': 0.1039, 'learn

In [16]:
from spellchecker import SpellChecker
import re

def correct_spelling(input_text):
    #print('input_text: ', input_text)
    # Initialize the spell checker
    spell = SpellChecker()
    # Split the input text into words
    words = input_text.split()
    # Initialize an empty list to store the corrected words
    corrected_words = []
    for word in words:
        word = word.strip()
        if len(word) == 0:
            continue;
        if re.match('[a-zA-Z]', word) == None:
            continue;
        # Check if the word has any punctuation at the end
        if word[-1].isalpha():
            # Extract the punctuation
            punctuation_end = ""
        else:
            punctuation_end = word[-1]
            word = word[:-1]
        # check if the word has any punctuation at the start
        if word[0].isalpha():
           # Extract the punctuation
            punctuation_start = ""
        else:
            punctuation_start = word[0]
            word = word[1:]
        # Check the spelling of the word (case insensitive)
        corrected_word = spell.correction(word.lower())
        if corrected_word == None:
            corrected_words.append(word)
        # Preserve the original capitalization
        if word[0].isupper():
            corrected_word = corrected_word.capitalize()
        # Combine the corrected word and punctuation (if any)
        corrected_word = punctuation_start+corrected_word+punctuation_end
        # Append the corrected word to the list
        corrected_words.append(corrected_word)
    # Join the corrected words back into a single string
    corrected_text = " ".join(corrected_words)
    return corrected_text

In [17]:
#df.loc[0, 'prompt_text']