In [None]:
! pip install datasets > /dev/null

In [None]:
import warnings
warnings.filterwarnings("ignore")

import pandas as pd

from transformers import AutoModelForMaskedLM, AutoTokenizer,TrainingArguments, Trainer, LineByLineTextDataset, DataCollatorForLanguageModeling

In [None]:
# create target text corpus
train = pd.read_csv("../input/commonlitreadabilityprize/train.csv")
test = pd.read_csv("../input/commonlitreadabilityprize/test.csv")

data = pd.concat([train, test], axis=0)

data.loc[:, "excerpt"] = data.excerpt.apply(lambda x: x.replace("\n", ""))
excerpts = "\n".join(data.excerpt.values.tolist())

with open("./pretrain_data.txt", "w") as f:
    f.write(excerpts)
    
print("Pretrain data created")

In [None]:
# create model
checkpoint = "roberta-base"
model = AutoModelForMaskedLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


#create dataset
dtrain = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path=f"./pretrain_data.txt",
    block_size=128
)

dvalid = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path=f"./pretrain_data.txt",
    block_size=128
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

In [None]:
training_args = TrainingArguments(
    output_dir=f"./{checkpoint}_chk", #select model path for checkpoint
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    evaluation_strategy= 'steps',
    save_total_limit=2,
    eval_steps=200,
    metric_for_best_model='eval_loss',
    greater_is_better=False,
    load_best_model_at_end =True,
    prediction_loss_only=True,
    report_to = "none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dtrain,
    eval_dataset=dvalid
)

In [None]:
print("Starting with training...")
trainer.train()
trainer.save_model(f"./clrp-itpt-model-{checkpoint}")
tokenizer.save_pretrained(f"./clrp-itpt-tokenizer-{checkpoint}")