In [14]:
from datasets import load_dataset

In [15]:
test_data_files = {
    "arxiv": "arxiv_test.csv",
    "blogs": "blogs_test.csv",
    "british": "british_test.csv",
    "darkreddit": "darkreddit_test.csv",
    "imdb": "imdb_test.csv",
    "pan11": "pan11_test.csv",
    "pan13": "pan13_test.csv",
    "pan14": "pan14_test.csv",
    "pan15": "pan15_test.csv",
    "pan20": "pan20_test.csv",
    "reuters": "reuters_test.csv",
    "victorian": "victorian_test.csv",
}

In [None]:
test_datasets = {name: load_dataset('swan07/authorship-verification', data_files={"test": file}, split='test') for name, file in test_data_files.items()}
train_dataset = load_dataset("swan07/authorship-verification", data_files="*_train.csv", download_mode="force_redownload")
val_dataset = load_dataset("swan07/authorship-verification", data_files="*_val.csv", download_mode="force_redownload")

In [19]:
train_subset = train_dataset
test_subset = test_datasets
val_subset = val_dataset

In [20]:
train_subset = train_subset.rename_column("same", "score")
val_subset = val_subset.rename_column("same", "score")

for name, dataset in test_subset.items():
    test_subset[name] = dataset.rename_column("same", "score")  


ValueError: Original column name same not in the dataset. Current columns in the dataset: ['text1', 'text2', 'score']

In [21]:
train_subset = train_subset['train']

In [22]:
val_subset = val_subset['train']

In [41]:
from sentence_transformers import SentenceTransformerTrainingArguments

args = SentenceTransformerTrainingArguments(
    output_dir="minilm",
    num_train_epochs=2,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    warmup_ratio=0.1,
    fp16=True,
    logging_dir='./logs', 
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="minilm-new",
    resume_from_checkpoint=True,  # Add this line to resume from the last checkpoint

)

In [26]:
import torch
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import ContrastiveLoss
from sentence_transformers.evaluation import BinaryClassificationEvaluator

In [27]:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L12-v2')




In [28]:
from sentence_transformers.losses import ContrastiveLoss

loss = ContrastiveLoss(model)


In [None]:
dev_evaluator = BinaryClassificationEvaluator(
    sentences1=val_subset["text1"],
    sentences2=val_subset["text2"],
    labels=val_subset["score"],
    name="all-nli-dev",
    show_progress_bar=True,
)
dev_evaluator(model)

In [None]:
# 7. Create a trainer & train
#epoch 2
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_subset,
    eval_dataset=val_subset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train(resume_from_checkpoint=True)

In [35]:
test_subset

{'arxiv': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 106
 }),
 'blogs': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 8840
 }),
 'british': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 173
 }),
 'darkreddit': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 412
 }),
 'imdb': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 4648
 }),
 'pan11': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 698
 }),
 'pan13': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 18
 }),
 'pan14': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 400
 }),
 'pan15': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 200
 }),
 'pan20': Dataset({
     features: ['score', 'text1', 'text2'],
     num_rows: 13704
 }),
 'reuters': Dataset({
     features: ['text1', 'text2', 'score'],
     num_rows: 181
 }),
 'victorian': Dataset

In [36]:
for name, test_dataset in test_datasets.items():
    test_evaluator = BinaryClassificationEvaluator(
        sentences1=test_dataset["text1"],
        sentences2=test_dataset["text2"],
        labels=test_dataset["score"],
        name=f"{name}-test",
        show_progress_bar=True,
    )
    evaluation_result = test_evaluator(model)
    print(f"Evaluation result for {name}: {evaluation_result}")

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Evaluation result for arxiv: {'arxiv-test_cosine_accuracy': 0.8301886792452831, 'arxiv-test_cosine_accuracy_threshold': 0.6947098970413208, 'arxiv-test_cosine_f1': 0.859375, 'arxiv-test_cosine_f1_threshold': 0.6901332139968872, 'arxiv-test_cosine_precision': 0.8088235294117647, 'arxiv-test_cosine_recall': 0.9166666666666666, 'arxiv-test_cosine_ap': 0.9059438222711578, 'arxiv-test_dot_accuracy': 0.8301886792452831, 'arxiv-test_dot_accuracy_threshold': 0.6947098970413208, 'arxiv-test_dot_f1': 0.859375, 'arxiv-test_dot_f1_threshold': 0.6901331543922424, 'arxiv-test_dot_precision': 0.8088235294117647, 'arxiv-test_dot_recall': 0.9166666666666666, 'arxiv-test_dot_ap': 0.9059438222711578, 'arxiv-test_manhattan_accuracy': 0.8301886792452831, 'arxiv-test_manhattan_accuracy_threshold': 11.346436500549316, 'arxiv-test_manhattan_f1': 0.8524590163934426, 'arxiv-test_manhattan_f1_threshold': 11.891170501708984, 'arxiv-test_manhattan_precision': 0.8387096774193549, 'arxiv-test_manhattan_recall': 0.86

Batches:   0%|          | 0/550 [00:00<?, ?it/s]

Evaluation result for blogs: {'blogs-test_cosine_accuracy': 0.6916289592760181, 'blogs-test_cosine_accuracy_threshold': 0.7527225613594055, 'blogs-test_cosine_f1': 0.713355944632872, 'blogs-test_cosine_f1_threshold': 0.6621318459510803, 'blogs-test_cosine_precision': 0.5979714153988013, 'blogs-test_cosine_recall': 0.8839164016356201, 'blogs-test_cosine_ap': 0.7724042084291309, 'blogs-test_dot_accuracy': 0.6916289592760181, 'blogs-test_dot_accuracy_threshold': 0.7527225613594055, 'blogs-test_dot_f1': 0.713355944632872, 'blogs-test_dot_f1_threshold': 0.6621319055557251, 'blogs-test_dot_precision': 0.5979714153988013, 'blogs-test_dot_recall': 0.8839164016356201, 'blogs-test_dot_ap': 0.7724044139768769, 'blogs-test_manhattan_accuracy': 0.6921945701357466, 'blogs-test_manhattan_accuracy_threshold': 11.203267097473145, 'blogs-test_manhattan_f1': 0.7121798415330755, 'blogs-test_manhattan_f1_threshold': 12.767942428588867, 'blogs-test_manhattan_precision': 0.5990390576565406, 'blogs-test_manha

Batches:   0%|          | 0/11 [00:00<?, ?it/s]

Evaluation result for british: {'british-test_cosine_accuracy': 0.7687861271676301, 'british-test_cosine_accuracy_threshold': 0.7162458300590515, 'british-test_cosine_f1': 0.8053097345132744, 'british-test_cosine_f1_threshold': 0.6518328189849854, 'british-test_cosine_precision': 0.7109375, 'british-test_cosine_recall': 0.9285714285714286, 'british-test_cosine_ap': 0.8657974477850441, 'british-test_dot_accuracy': 0.7687861271676301, 'british-test_dot_accuracy_threshold': 0.7162458896636963, 'british-test_dot_f1': 0.8053097345132744, 'british-test_dot_f1_threshold': 0.6518328785896301, 'british-test_dot_precision': 0.7109375, 'british-test_dot_recall': 0.9285714285714286, 'british-test_dot_ap': 0.8657974477850441, 'british-test_manhattan_accuracy': 0.7572254335260116, 'british-test_manhattan_accuracy_threshold': 11.92885971069336, 'british-test_manhattan_f1': 0.8071748878923766, 'british-test_manhattan_f1_threshold': 12.846293449401855, 'british-test_manhattan_precision': 0.72, 'british

Batches:   0%|          | 0/19 [00:00<?, ?it/s]

Evaluation result for darkreddit: {'darkreddit-test_cosine_accuracy': 0.6529126213592233, 'darkreddit-test_cosine_accuracy_threshold': 0.8751990795135498, 'darkreddit-test_cosine_f1': 0.6983546617915904, 'darkreddit-test_cosine_f1_threshold': 0.7424333691596985, 'darkreddit-test_cosine_precision': 0.5601173020527859, 'darkreddit-test_cosine_recall': 0.9271844660194175, 'darkreddit-test_cosine_ap': 0.7347184664311759, 'darkreddit-test_dot_accuracy': 0.6529126213592233, 'darkreddit-test_dot_accuracy_threshold': 0.8751991987228394, 'darkreddit-test_dot_f1': 0.6983546617915904, 'darkreddit-test_dot_f1_threshold': 0.7424334287643433, 'darkreddit-test_dot_precision': 0.5601173020527859, 'darkreddit-test_dot_recall': 0.9271844660194175, 'darkreddit-test_dot_ap': 0.7347184664311759, 'darkreddit-test_manhattan_accuracy': 0.6601941747572816, 'darkreddit-test_manhattan_accuracy_threshold': 8.842761993408203, 'darkreddit-test_manhattan_f1': 0.6998087954110899, 'darkreddit-test_manhattan_f1_thresho

Batches:   0%|          | 0/291 [00:00<?, ?it/s]

Evaluation result for imdb: {'imdb-test_cosine_accuracy': 0.7480636833046471, 'imdb-test_cosine_accuracy_threshold': 0.7538917064666748, 'imdb-test_cosine_f1': 0.7618308766485649, 'imdb-test_cosine_f1_threshold': 0.7126221656799316, 'imdb-test_cosine_precision': 0.6910626319493315, 'imdb-test_cosine_recall': 0.8487467588591184, 'imdb-test_cosine_ap': 0.8109513525642077, 'imdb-test_dot_accuracy': 0.7480636833046471, 'imdb-test_dot_accuracy_threshold': 0.7538917064666748, 'imdb-test_dot_f1': 0.7618308766485649, 'imdb-test_dot_f1_threshold': 0.7126221656799316, 'imdb-test_dot_precision': 0.6910626319493315, 'imdb-test_dot_recall': 0.8487467588591184, 'imdb-test_dot_ap': 0.8109513525642077, 'imdb-test_manhattan_accuracy': 0.7480636833046471, 'imdb-test_manhattan_accuracy_threshold': 10.824899673461914, 'imdb-test_manhattan_f1': 0.7636697247706423, 'imdb-test_manhattan_f1_threshold': 12.343918800354004, 'imdb-test_manhattan_precision': 0.6635841836734694, 'imdb-test_manhattan_recall': 0.899

Batches:   0%|          | 0/44 [00:00<?, ?it/s]

Evaluation result for pan11: {'pan11-test_cosine_accuracy': 0.6146131805157593, 'pan11-test_cosine_accuracy_threshold': 0.7716286182403564, 'pan11-test_cosine_f1': 0.692144373673036, 'pan11-test_cosine_f1_threshold': 0.6748872995376587, 'pan11-test_cosine_precision': 0.5659722222222222, 'pan11-test_cosine_recall': 0.8907103825136612, 'pan11-test_cosine_ap': 0.6618253164528178, 'pan11-test_dot_accuracy': 0.6146131805157593, 'pan11-test_dot_accuracy_threshold': 0.7716286778450012, 'pan11-test_dot_f1': 0.692144373673036, 'pan11-test_dot_f1_threshold': 0.6748872995376587, 'pan11-test_dot_precision': 0.5659722222222222, 'pan11-test_dot_recall': 0.8907103825136612, 'pan11-test_dot_ap': 0.6618253164528178, 'pan11-test_manhattan_accuracy': 0.6146131805157593, 'pan11-test_manhattan_accuracy_threshold': 10.506285667419434, 'pan11-test_manhattan_f1': 0.692063492063492, 'pan11-test_manhattan_f1_threshold': 12.561086654663086, 'pan11-test_manhattan_precision': 0.5647668393782384, 'pan11-test_manhat

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Evaluation result for pan13: {'pan13-test_cosine_accuracy': 0.6111111111111112, 'pan13-test_cosine_accuracy_threshold': 0.9473857283592224, 'pan13-test_cosine_f1': 0.6, 'pan13-test_cosine_f1_threshold': 0.8083878755569458, 'pan13-test_cosine_precision': 0.5, 'pan13-test_cosine_recall': 0.75, 'pan13-test_cosine_ap': 0.5628945707070707, 'pan13-test_dot_accuracy': 0.6111111111111112, 'pan13-test_dot_accuracy_threshold': 0.9473857879638672, 'pan13-test_dot_f1': 0.6, 'pan13-test_dot_f1_threshold': 0.8083878755569458, 'pan13-test_dot_precision': 0.5, 'pan13-test_dot_recall': 0.75, 'pan13-test_dot_ap': 0.5628945707070707, 'pan13-test_manhattan_accuracy': 0.6111111111111112, 'pan13-test_manhattan_accuracy_threshold': 5.042365074157715, 'pan13-test_manhattan_f1': 0.6, 'pan13-test_manhattan_f1_threshold': 9.648418426513672, 'pan13-test_manhattan_precision': 0.5, 'pan13-test_manhattan_recall': 0.75, 'pan13-test_manhattan_ap': 0.5420612373737375, 'pan13-test_euclidean_accuracy': 0.6111111111111112

Batches:   0%|          | 0/13 [00:00<?, ?it/s]

Evaluation result for pan14: {'pan14-test_cosine_accuracy': 0.625, 'pan14-test_cosine_accuracy_threshold': 0.7652945518493652, 'pan14-test_cosine_f1': 0.7008849557522123, 'pan14-test_cosine_f1_threshold': 0.39485129714012146, 'pan14-test_cosine_precision': 0.5424657534246575, 'pan14-test_cosine_recall': 0.99, 'pan14-test_cosine_ap': 0.5812960852351552, 'pan14-test_dot_accuracy': 0.625, 'pan14-test_dot_accuracy_threshold': 0.76529461145401, 'pan14-test_dot_f1': 0.7008849557522123, 'pan14-test_dot_f1_threshold': 0.39485129714012146, 'pan14-test_dot_precision': 0.5424657534246575, 'pan14-test_dot_recall': 0.99, 'pan14-test_dot_ap': 0.5812960852351552, 'pan14-test_manhattan_accuracy': 0.6225, 'pan14-test_manhattan_accuracy_threshold': 10.523223876953125, 'pan14-test_manhattan_f1': 0.7008849557522123, 'pan14-test_manhattan_f1_threshold': 17.252708435058594, 'pan14-test_manhattan_precision': 0.5424657534246575, 'pan14-test_manhattan_recall': 0.99, 'pan14-test_manhattan_ap': 0.583800424339302

Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Evaluation result for pan15: {'pan15-test_cosine_accuracy': 0.6, 'pan15-test_cosine_accuracy_threshold': 0.8211225271224976, 'pan15-test_cosine_f1': 0.695364238410596, 'pan15-test_cosine_f1_threshold': 0.5785237550735474, 'pan15-test_cosine_precision': 0.5357142857142857, 'pan15-test_cosine_recall': 0.9905660377358491, 'pan15-test_cosine_ap': 0.5971832064782955, 'pan15-test_dot_accuracy': 0.6, 'pan15-test_dot_accuracy_threshold': 0.8211225271224976, 'pan15-test_dot_f1': 0.695364238410596, 'pan15-test_dot_f1_threshold': 0.5785236358642578, 'pan15-test_dot_precision': 0.5357142857142857, 'pan15-test_dot_recall': 0.9905660377358491, 'pan15-test_dot_ap': 0.5971832064782955, 'pan15-test_manhattan_accuracy': 0.59, 'pan15-test_manhattan_accuracy_threshold': 9.681791305541992, 'pan15-test_manhattan_f1': 0.6976744186046512, 'pan15-test_manhattan_f1_threshold': 14.060454368591309, 'pan15-test_manhattan_precision': 0.5384615384615384, 'pan15-test_manhattan_recall': 0.9905660377358491, 'pan15-test

Batches:   0%|          | 0/755 [00:00<?, ?it/s]

KeyboardInterrupt: 