In [1]:
import pandas as pd
import json
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer, losses
import numpy as np
import cv2 as cv2
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from huggingface_hub import HfFolder
import datasets

  from tqdm.autonotebook import tqdm, trange


In [2]:
comparable_data = pd.read_csv("comparable_data.csv")
comparable_data.head(3)

Unnamed: 0,title,price,cat_1,cat_2,cat_3,caracteristics,img_ref,target,dealer
0,Беговая дорожка UNIXFIT MX-990X,120890,Беговые дорожки,UNIXFIT,UNIXFIT MX-990X,Тип электрическая Уровень базовый Габариты (...,images/begovye_dorozhki/1_begdorozhki_1349.jpeg,begovye_dorozhki,begdorozhki
1,"Беговая дорожка Proxima Ivetta HRC, Арт. PROT-219",139990,Беговые дорожки,Proxima,"Proxima Ivetta HRC, Арт. PROT-219",Тип электрическая Уровень базовый Габариты (...,images/begovye_dorozhki/2_begdorozhki_1463.jpeg,begovye_dorozhki,begdorozhki
2,"Беговая дорожка UNIXFIT MX-990 AC (10,1"" TFT)",159890,Беговые дорожки,UNIXFIT,"UNIXFIT MX-990 AC (10,1"" TFT)",Тип электрическая Уровень полупрофессиональны...,images/begovye_dorozhki/3_begdorozhki_1638.jpeg,begovye_dorozhki,begdorozhki


In [3]:
def prepare_data(path_to_df, info_used = 'title_only'):
    df = pd.read_csv(path_to_df)
    updated_df = pd.DataFrame()
    for _, row in tqdm(df.iterrows(), total=df.shape[0]):
        if info_used == 'title_only':
            text = row["title"]
        elif info_used == 'title_cat':
            text = (
                str(row["title"])
                + " "
                + str(row["cat_1"])
                + " "
                + str(row["cat_2"])
                + " "
                + str(row["cat_3"])
            )
        else:
            text = (
                str(row["title"])
                + " "
                + str(row["cat_1"])
                + " "
                + str(row["cat_2"])
                + " "
                + str(row["cat_3"])
                + " "
                + str(row["caracteristics"])
            )

        label = row["target"]
        new_row = {"label_string": label, "text": text}
        updated_df = pd.concat([updated_df, pd.DataFrame([new_row])], ignore_index=True)
    data_classes = list(set(updated_df["label_string"].tolist()))
    updated_df["label"] = updated_df["label_string"].apply(data_classes.index)
    updated_df = updated_df.drop(["label_string"], axis=1)
    updated_df = updated_df[["label", "text"]]

    return updated_df

In [4]:
df = prepare_data('comparable_data.csv')

  0%|          | 0/13718 [00:00<?, ?it/s]

In [5]:
df.head(5)

Unnamed: 0,label,text
0,6,Беговая дорожка UNIXFIT MX-990X
1,6,"Беговая дорожка Proxima Ivetta HRC, Арт. PROT-219"
2,6,"Беговая дорожка UNIXFIT MX-990 AC (10,1"" TFT)"
3,6,Беговая дорожка Titanium Masters Physiotech TLF
4,6,Беговая дорожка Laufstein Commercial


In [6]:
def create_cos_sim_data(data_df, use_all_combos=False, combos_mult=1024):
    labels_positive = {}
    labels_negative = {}

    # for each label create a set of same label images.
    for i in list(data_df.label.unique()):
        labels_positive[i] = data_df[data_df.label == i]["text"].to_numpy()
    # for each label create a set of image of different label.
    for i in list(data_df.label.unique()):
        labels_negative[i] = data_df[data_df.label != i]["text"].to_numpy()
    cosine_loss_dataset = pd.DataFrame()
    for i, row in tqdm(data_df.iterrows(), total=len(data_df)):
        # label and image of the index for each row in df
        label = row["label"]

        if use_all_combos:
            # probability of same label image == 0.5
            if np.random.randint(0, 2) == 0:
                for i in range(int(len(labels_positive[label]) / combos_mult)):
                    second = labels_positive[label][i]
                    dis = 1.0
                    first = row["text"]
                    new_line = {'sentence1':first,
                                'sentence2': second,
                                'score': dis}
                    cosine_loss_dataset = pd.concat([cosine_loss_dataset, pd.DataFrame([new_line])])
            else:
                for i in range(int(len(labels_positive[label]) / combos_mult)):
                    second = labels_negative[label][
                        np.random.randint(0, len(labels_negative[label]))
                    ]
                    first = row["text"]
                    dis = 0.0
                    new_line = {"sentence1": first, "sentence2": second, "score": dis}
                    cosine_loss_dataset = pd.concat(
                        [cosine_loss_dataset, pd.DataFrame([new_line])]
                    )
        else:
            if np.random.randint(0, 2) == 0:
                second = labels_positive[label][
                    np.random.randint(0, len(labels_positive[label]))
                ]
                dis = 1.0
                first = row["text"]
                new_line = {'sentence1':first,
                            'sentence2': second,
                            'score': dis}
                cosine_loss_dataset = pd.concat([cosine_loss_dataset, pd.DataFrame([new_line])])
            else:
                second = labels_negative[label][
                    np.random.randint(0, len(labels_negative[label]))
                ]
                first = row["text"]
                dis = 0.0
                new_line = {'sentence1':first,
                            'sentence2': second,
                            'score': dis}
                cosine_loss_dataset = pd.concat([cosine_loss_dataset, pd.DataFrame([new_line])])
    return cosine_loss_dataset

In [7]:
cosine_loss_dataset = create_cos_sim_data(df)

  0%|          | 0/13718 [00:00<?, ?it/s]

In [8]:
train, test = train_test_split(cosine_loss_dataset, test_size=0.002, random_state=2012)
train_dataset = datasets.Dataset.from_pandas(train, preserve_index=False)

test_dataset = datasets.Dataset.from_pandas(test, preserve_index=False)

In [9]:
train_dataset

Dataset({
    features: ['sentence1', 'sentence2', 'score'],
    num_rows: 13690
})

In [10]:
model_name = "cointegrated/rubert-tiny2"
max_seq_length = 512
num_epochs = 10
train_batch_size = 32

- **Contrastive loss.** Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
- **CoSENT (Cosine Sentence) loss.** It expects that each of the InputExamples consists of a pair of texts and a float valued label, representing the expected similarity score between the pair.
- **CosineSimilarityLoss** expects that the InputExamples consists of two texts and a float label. It computes the vectors u = model(sentence_A) and v = model(sentence_B) and measures the cosine-similarity between the two. By default, it minimizes the following loss: ||input_label - cos_score_transformation(cosine_sim(u,v))||_2.

In [11]:
model = SentenceTransformer(model_name)
# loss = losses.CoSENTLoss(model)
loss = losses.ContrastiveLoss(model)
#loss = losses.CosineSimilarityLoss(model)



In [12]:
binary_acc_evaluator = BinaryClassificationEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    labels=test_dataset["score"],
    name="cv",
)
results = binary_acc_evaluator(model)
results

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'cv_cosine_accuracy': 0.6071428571428571,
 'cv_cosine_accuracy_threshold': 0.7078050374984741,
 'cv_cosine_f1': 0.6976744186046512,
 'cv_cosine_f1_threshold': 0.5503430366516113,
 'cv_cosine_precision': 0.5555555555555556,
 'cv_cosine_recall': 0.9375,
 'cv_cosine_ap': 0.6246544669099017,
 'cv_dot_accuracy': 0.6071428571428571,
 'cv_dot_accuracy_threshold': 0.7078051567077637,
 'cv_dot_f1': 0.6976744186046512,
 'cv_dot_f1_threshold': 0.5503429174423218,
 'cv_dot_precision': 0.5555555555555556,
 'cv_dot_recall': 0.9375,
 'cv_dot_ap': 0.6246544669099017,
 'cv_manhattan_accuracy': 0.6071428571428571,
 'cv_manhattan_accuracy_threshold': 10.729952812194824,
 'cv_manhattan_f1': 0.6976744186046512,
 'cv_manhattan_f1_threshold': 13.22286605834961,
 'cv_manhattan_precision': 0.5555555555555556,
 'cv_manhattan_recall': 0.9375,
 'cv_manhattan_ap': 0.6185276802875487,
 'cv_euclidean_accuracy': 0.6071428571428571,
 'cv_euclidean_accuracy_threshold': 0.7644381523132324,
 'cv_euclidean_f1': 0.6976744

In [13]:
output_dir = 'tiny_sent_transformer'

In [14]:
with open("config.json", "r") as f:
    json_config = json.load(f)
TOKEN = json_config["token"]

In [15]:
HfFolder.save_token(TOKEN)

In [16]:
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    evaluation_strategy="epoch",
    save_strategy="no",
    hub_token=HfFolder.get_token(),
)



In [17]:
# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    loss=loss,
    evaluator=binary_acc_evaluator,
)
trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkatya_shakhova[0m ([33mshakhova[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/4280 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.012141643092036247, 'eval_cv_cosine_accuracy': 0.9285714285714286, 'eval_cv_cosine_accuracy_threshold': 0.8569946885108948, 'eval_cv_cosine_f1': 0.9411764705882353, 'eval_cv_cosine_f1_threshold': 0.8569946885108948, 'eval_cv_cosine_precision': 0.8888888888888888, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9365436136575842, 'eval_cv_dot_accuracy': 0.9285714285714286, 'eval_cv_dot_accuracy_threshold': 0.8569947481155396, 'eval_cv_dot_f1': 0.9411764705882353, 'eval_cv_dot_f1_threshold': 0.8569947481155396, 'eval_cv_dot_precision': 0.8888888888888888, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9365436136575842, 'eval_cv_manhattan_accuracy': 0.9285714285714286, 'eval_cv_manhattan_accuracy_threshold': 7.489104270935059, 'eval_cv_manhattan_f1': 0.9411764705882353, 'eval_cv_manhattan_f1_threshold': 7.489104270935059, 'eval_cv_manhattan_precision': 0.8888888888888888, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9406644927784634, 'eval_cv_euclidean_ac

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.01045686099678278, 'eval_cv_cosine_accuracy': 0.9285714285714286, 'eval_cv_cosine_accuracy_threshold': 0.8671985864639282, 'eval_cv_cosine_f1': 0.9411764705882353, 'eval_cv_cosine_f1_threshold': 0.8307956457138062, 'eval_cv_cosine_precision': 0.8888888888888888, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9805183531746032, 'eval_cv_dot_accuracy': 0.9285714285714286, 'eval_cv_dot_accuracy_threshold': 0.8671985268592834, 'eval_cv_dot_f1': 0.9411764705882353, 'eval_cv_dot_f1_threshold': 0.8307956457138062, 'eval_cv_dot_precision': 0.8888888888888888, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9805183531746032, 'eval_cv_manhattan_accuracy': 0.9285714285714286, 'eval_cv_manhattan_accuracy_threshold': 6.973616600036621, 'eval_cv_manhattan_f1': 0.9411764705882353, 'eval_cv_manhattan_f1_threshold': 8.109447479248047, 'eval_cv_manhattan_precision': 0.8888888888888888, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9805183531746032, 'eval_cv_euclidean_acc

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.008501644246280193, 'eval_cv_cosine_accuracy': 0.9285714285714286, 'eval_cv_cosine_accuracy_threshold': 0.8695433139801025, 'eval_cv_cosine_f1': 0.9375, 'eval_cv_cosine_f1_threshold': 0.8695433139801025, 'eval_cv_cosine_precision': 0.9375, 'eval_cv_cosine_recall': 0.9375, 'eval_cv_cosine_ap': 0.9820586622807017, 'eval_cv_dot_accuracy': 0.9285714285714286, 'eval_cv_dot_accuracy_threshold': 0.8695434331893921, 'eval_cv_dot_f1': 0.9375, 'eval_cv_dot_f1_threshold': 0.8695434331893921, 'eval_cv_dot_precision': 0.9375, 'eval_cv_dot_recall': 0.9375, 'eval_cv_dot_ap': 0.9820586622807017, 'eval_cv_manhattan_accuracy': 0.9285714285714286, 'eval_cv_manhattan_accuracy_threshold': 6.930814743041992, 'eval_cv_manhattan_f1': 0.9375, 'eval_cv_manhattan_f1_threshold': 6.930814743041992, 'eval_cv_manhattan_precision': 0.9375, 'eval_cv_manhattan_recall': 0.9375, 'eval_cv_manhattan_ap': 0.9775943765664161, 'eval_cv_euclidean_accuracy': 0.9285714285714286, 'eval_cv_euclidean_accuracy_thresh

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.007320765405893326, 'eval_cv_cosine_accuracy': 0.9285714285714286, 'eval_cv_cosine_accuracy_threshold': 0.9186586141586304, 'eval_cv_cosine_f1': 0.9411764705882353, 'eval_cv_cosine_f1_threshold': 0.7524685859680176, 'eval_cv_cosine_precision': 0.8888888888888888, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9891493055555556, 'eval_cv_dot_accuracy': 0.9285714285714286, 'eval_cv_dot_accuracy_threshold': 0.9186586141586304, 'eval_cv_dot_f1': 0.9411764705882353, 'eval_cv_dot_f1_threshold': 0.7524685859680176, 'eval_cv_dot_precision': 0.8888888888888888, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9891493055555556, 'eval_cv_manhattan_accuracy': 0.9285714285714286, 'eval_cv_manhattan_accuracy_threshold': 5.53184700012207, 'eval_cv_manhattan_f1': 0.9411764705882353, 'eval_cv_manhattan_f1_threshold': 9.608419418334961, 'eval_cv_manhattan_precision': 0.8888888888888888, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9891493055555556, 'eval_cv_euclidean_acc

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.006469670683145523, 'eval_cv_cosine_accuracy': 0.9642857142857143, 'eval_cv_cosine_accuracy_threshold': 0.7788258790969849, 'eval_cv_cosine_f1': 0.9696969696969697, 'eval_cv_cosine_f1_threshold': 0.7788258790969849, 'eval_cv_cosine_precision': 0.9411764705882353, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9924172794117647, 'eval_cv_dot_accuracy': 0.9642857142857143, 'eval_cv_dot_accuracy_threshold': 0.7788258790969849, 'eval_cv_dot_f1': 0.9696969696969697, 'eval_cv_dot_f1_threshold': 0.7788258790969849, 'eval_cv_dot_precision': 0.9411764705882353, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9924172794117647, 'eval_cv_manhattan_accuracy': 0.9642857142857143, 'eval_cv_manhattan_accuracy_threshold': 9.076764106750488, 'eval_cv_manhattan_f1': 0.9696969696969697, 'eval_cv_manhattan_f1_threshold': 9.076764106750488, 'eval_cv_manhattan_precision': 0.9411764705882353, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9882506127450981, 'eval_cv_euclidean_ac

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.005286471452564001, 'eval_cv_cosine_accuracy': 0.9642857142857143, 'eval_cv_cosine_accuracy_threshold': 0.8943095207214355, 'eval_cv_cosine_f1': 0.9696969696969697, 'eval_cv_cosine_f1_threshold': 0.7802018523216248, 'eval_cv_cosine_precision': 0.9411764705882353, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9963235294117647, 'eval_cv_dot_accuracy': 0.9642857142857143, 'eval_cv_dot_accuracy_threshold': 0.894309401512146, 'eval_cv_dot_f1': 0.9696969696969697, 'eval_cv_dot_f1_threshold': 0.7802018523216248, 'eval_cv_dot_precision': 0.9411764705882353, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9963235294117647, 'eval_cv_manhattan_accuracy': 0.9642857142857143, 'eval_cv_manhattan_accuracy_threshold': 6.236815452575684, 'eval_cv_manhattan_f1': 0.9696969696969697, 'eval_cv_manhattan_f1_threshold': 9.04854679107666, 'eval_cv_manhattan_precision': 0.9411764705882353, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9963235294117647, 'eval_cv_euclidean_accu

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.005540285725146532, 'eval_cv_cosine_accuracy': 0.9642857142857143, 'eval_cv_cosine_accuracy_threshold': 0.8650959730148315, 'eval_cv_cosine_f1': 0.9696969696969697, 'eval_cv_cosine_f1_threshold': 0.8056637644767761, 'eval_cv_cosine_precision': 0.9411764705882353, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9963235294117647, 'eval_cv_dot_accuracy': 0.9642857142857143, 'eval_cv_dot_accuracy_threshold': 0.8650959730148315, 'eval_cv_dot_f1': 0.9696969696969697, 'eval_cv_dot_f1_threshold': 0.8056638240814209, 'eval_cv_dot_precision': 0.9411764705882353, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9963235294117647, 'eval_cv_manhattan_accuracy': 0.9642857142857143, 'eval_cv_manhattan_accuracy_threshold': 7.103158950805664, 'eval_cv_manhattan_f1': 0.9696969696969697, 'eval_cv_manhattan_f1_threshold': 8.540844917297363, 'eval_cv_manhattan_precision': 0.9411764705882353, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 0.9963235294117647, 'eval_cv_euclidean_ac

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.0041037509217858315, 'eval_cv_cosine_accuracy': 1.0, 'eval_cv_cosine_accuracy_threshold': 0.7808363437652588, 'eval_cv_cosine_f1': 1.0, 'eval_cv_cosine_f1_threshold': 0.7808363437652588, 'eval_cv_cosine_precision': 1.0, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 1.0, 'eval_cv_dot_accuracy': 1.0, 'eval_cv_dot_accuracy_threshold': 0.7808363437652588, 'eval_cv_dot_f1': 1.0, 'eval_cv_dot_f1_threshold': 0.7808363437652588, 'eval_cv_dot_precision': 1.0, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 1.0, 'eval_cv_manhattan_accuracy': 1.0, 'eval_cv_manhattan_accuracy_threshold': 9.01974868774414, 'eval_cv_manhattan_f1': 1.0, 'eval_cv_manhattan_f1_threshold': 9.01974868774414, 'eval_cv_manhattan_precision': 1.0, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 1.0, 'eval_cv_euclidean_accuracy': 1.0, 'eval_cv_euclidean_accuracy_threshold': 0.6620033979415894, 'eval_cv_euclidean_f1': 1.0, 'eval_cv_euclidean_f1_threshold': 0.6620033979415894, 'eval_cv_euclidean_preci

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.003982707858085632, 'eval_cv_cosine_accuracy': 0.9642857142857143, 'eval_cv_cosine_accuracy_threshold': 0.8336308598518372, 'eval_cv_cosine_f1': 0.9696969696969697, 'eval_cv_cosine_f1_threshold': 0.7609013915061951, 'eval_cv_cosine_precision': 0.9411764705882353, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 0.9963235294117647, 'eval_cv_dot_accuracy': 0.9642857142857143, 'eval_cv_dot_accuracy_threshold': 0.8336309194564819, 'eval_cv_dot_f1': 0.9696969696969697, 'eval_cv_dot_f1_threshold': 0.7609014511108398, 'eval_cv_dot_precision': 0.9411764705882353, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 0.9963235294117647, 'eval_cv_manhattan_accuracy': 1.0, 'eval_cv_manhattan_accuracy_threshold': 9.115335464477539, 'eval_cv_manhattan_f1': 1.0, 'eval_cv_manhattan_f1_threshold': 9.115335464477539, 'eval_cv_manhattan_precision': 1.0, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 1.0, 'eval_cv_euclidean_accuracy': 0.9642857142857143, 'eval_cv_euclidean_accuracy_thr

  0%|          | 0/1 [00:00<?, ?it/s]

{'eval_loss': 0.003640834940597415, 'eval_cv_cosine_accuracy': 1.0, 'eval_cv_cosine_accuracy_threshold': 0.7653387784957886, 'eval_cv_cosine_f1': 1.0, 'eval_cv_cosine_f1_threshold': 0.7653387784957886, 'eval_cv_cosine_precision': 1.0, 'eval_cv_cosine_recall': 1.0, 'eval_cv_cosine_ap': 1.0, 'eval_cv_dot_accuracy': 1.0, 'eval_cv_dot_accuracy_threshold': 0.7653387784957886, 'eval_cv_dot_f1': 1.0, 'eval_cv_dot_f1_threshold': 0.7653387784957886, 'eval_cv_dot_precision': 1.0, 'eval_cv_dot_recall': 1.0, 'eval_cv_dot_ap': 1.0, 'eval_cv_manhattan_accuracy': 1.0, 'eval_cv_manhattan_accuracy_threshold': 9.330949783325195, 'eval_cv_manhattan_f1': 1.0, 'eval_cv_manhattan_f1_threshold': 9.330949783325195, 'eval_cv_manhattan_precision': 1.0, 'eval_cv_manhattan_recall': 1.0, 'eval_cv_manhattan_ap': 1.0, 'eval_cv_euclidean_accuracy': 1.0, 'eval_cv_euclidean_accuracy_threshold': 0.6849288940429688, 'eval_cv_euclidean_f1': 1.0, 'eval_cv_euclidean_f1_threshold': 0.6849288940429688, 'eval_cv_euclidean_prec

TrainOutput(global_step=4280, training_loss=0.0033133900332673687, metrics={'train_runtime': 107.0996, 'train_samples_per_second': 1278.249, 'train_steps_per_second': 39.963, 'total_flos': 0.0, 'train_loss': 0.0033133900332673687, 'epoch': 10.0})

In [18]:
trainer.create_model_card()

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [19]:
trainer.push_to_hub()

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/117M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.50k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Shakhovak/tiny_sent_transformer/commit/92c1369f00979ecc144dced4b1788e551ac7e3df', commit_message='End of training', commit_description='', oid='92c1369f00979ecc144dced4b1788e551ac7e3df', pr_url=None, pr_revision=None, pr_num=None)

In [20]:
model.save_pretrained(output_dir)