In [1]:
import pandas as pd
from torch.utils.data import Dataset
import numpy as np

In [2]:
df = pd.read_csv("parsed_metalurgs.csv", nrows=100_000)
df.shape

(100000, 3)

In [15]:
class SentencesDataset(Dataset):
    def __init__(self, dataset, data_size=None, close_sent_dist=3, far_sent_dist=20):
        self.dataset = dataset
        self.data_size = data_size if data_size is not None else len(self.dataset)
        self.close_sent_dist = close_sent_dist
        self.far_sent_dist = far_sent_dist

    def __len__(self):
        return self.data_size

    def __getitem__(self, idx: int):
        first, second = 0, 0
        label = 0

        if idx % 2:
            first, second = self.get_positive_example()
            label = 1
        else:
            first, second = self.get_negative_example()
            label = 0

        first_sentence, second_sentence = self.dataset.line.iloc[[first, second]]
        examples = {
            "sentence1": first_sentence,
            "sentence2": second_sentence,
            "label": label,
        }

        return examples

    def get_positive_example(self):
        while True:
            first_id = np.random.randint(len(self.dataset))
            first_id_path = self.dataset.path.iloc[first_id]
            min_second_id = max(first_id - self.close_sent_dist, 0)
            max_second_id = min(first_id + self.close_sent_dist, len(self.dataset) - 1)
            while self.dataset.path.iat[min_second_id] != first_id_path:
                min_second_id += 1
            while self.dataset.path.iat[max_second_id] != first_id_path:
                max_second_id -= 1
            if max_second_id - min_second_id > 0:
                second_id = np.random.randint(min_second_id, max_second_id)

                # assert first_id != second_id
                # assert self.dataset.path.iat[first_id] == self.dataset.path.iat[second_id]
                return first_id, second_id

    def get_negative_example(self):
        first_id = np.random.randint(len(self.dataset))
        second_id = np.random.randint(len(self.dataset) - 2 * self.far_sent_dist - 1)
        if second_id >= first_id - self.far_sent_dist:
            second_id += 2 * self.far_sent_dist + 1
        return first_id, second_id

In [16]:
test = SentencesDataset(df)

In [21]:
test[1]

{'sentence1': '"Переходные опоры ЛЭП  кВ"',
 'sentence2': '"Угловые опоры ЛЭП  кВ"',
 'label': 1}