# process origin dataset

## load origin dataset

In [None]:
import datasets

In [None]:
wikipedia = datasets.Dataset.from_json("/dataset/wikipedia.jsonl")

wikipedia

## sample dataset

In [5]:
seed = 2023
sample_size = 250_000 + 50_000

In [None]:
dataset = wikipedia.shuffle(seed=seed).select(range(250_000, sample_size))

dataset

## split sentence 

In [43]:
open_list = ["[", "{", "("]
close_list = ["]", "}", ")"]

def is_balance(text):
    stack = []
    for c in text:
        if c in open_list:
            stack.append(c)
        elif c in close_list:
            pos = close_list.index(c)
            if stack and (open_list[pos] == stack[-1]):
                stack.pop()
            else:
                return False
    if len(stack) == 0:
        return True

In [44]:
import string

punctuation_tuple = tuple(string.punctuation)

def is_sentence_clean(sent: str, min_length: int=5, max_length: int=128):
    sent = sent.strip()
    
    length = len(sent.split())
    if length < min_length or length > max_length:
        return False
    if sent.startswith(tuple(string.punctuation)):
        return False
    if "\n" in sent:
        return False
    if not is_balance(sent):
        return False
    return True

In [45]:
from tqdm import tqdm
import spacy

In [None]:
spacy.prefer_gpu(4)
nlp = spacy.load("en_core_web_trf")

In [None]:
sentence_list = []
length_list = []
for text in tqdm(dataset["text"]):
    doc = nlp(text)
    for sent in doc.sents:
        sent = sent.text
        sent = sent.strip()
        if is_sentence_clean(sent, 8, 128):
            sentence_list.append(sent)
            length_list.append(len(sent.split()))

In [None]:
len(sentence_list)

In [None]:
sum(length_list) / len(length_list)

In [None]:
import datasets

In [None]:
dataset = datasets.load_from_disk("./tmp/wikipedia_test_medium/")

dataset

In [None]:
dataset = dataset.map(lambda example: {"length": len(example["text"].split())}, num_proc=32)

length_list = dataset["length"]

sum(length_list) / len(length_list)

## shuffle & get new dataset list

In [6]:
import random
random.seed(seed)

In [12]:
random.shuffle(sentence_list)

In [13]:
train_size = 4_000_000
valid_size = 5_000
test_size = 5_000

dataset_list = sentence_list[:train_size + valid_size + test_size]

In [14]:
medium_test_list = [s for s in sentence_list[train_size + valid_size + test_size:] if 50 > len(s.split()) > 35][:test_size]
long_test_list = [s for s in sentence_list[train_size + valid_size + test_size:] if len(s.split()) > 65][:test_size]

In [15]:
dataset = datasets.Dataset.from_dict(
    {
        "text": dataset_list,
    }
)

dataset.save_to_disk("./tmp/wikipedia")

                                                                                                      

In [16]:
# test medium
dataset = datasets.Dataset.from_dict(
    {
        "text": medium_test_list,
    }
)

dataset.save_to_disk("./tmp/wikipedia_test_medium")

                                                                                                

In [17]:
# test long
dataset = datasets.Dataset.from_dict(
    {
        "text": long_test_list,
    }
)

dataset.save_to_disk("./tmp/wikipedia_test_long")

                                                                                               

# process new dataset

## get embedding & save

### sup-simcse-bert-base-uncased

In [None]:
!torchrun --nproc_per_node=8 ../embedding/sup-simcse-bert-base-uncased.py \
    --input_dataset "./tmp/wikipedia" \
    --output_dataset "your_output_dir" \
    --train_size 4000000 \
    --valid_size 5000 \
    --test_size 5000 \

In [None]:
# wikipedia_test_medium
!torchrun --nproc_per_node=8 ../embedding/sup-simcse-bert-base-uncased.py \
    --input_dataset "./tmp/wikipedia_test_medium" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \

In [None]:
# wikipedia_test_long
!torchrun --nproc_per_node=8 ../embedding/sup-simcse-bert-base-uncased.py \
    --input_dataset "./tmp/wikipedia_test_long" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \

### e5-large-v2

In [None]:
!torchrun --nproc_per_node=8 ../embedding/e5-large-v2.py \
    --input_dataset "./tmp/wikipedia" \
    --output_dataset "your_output_dir" \
    --train_size 4000000 \
    --valid_size 5000 \
    --test_size 5000 \

In [None]:
# wikipedia_test_medium
!torchrun --nproc_per_node=8 ../embedding/e5-large-v2.py \
    --input_dataset "./tmp/wikipedia_test_medium" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \

In [None]:
# wikipedia_test_long
!torchrun --nproc_per_node=8 ../embedding/e5-large-v2.py \
    --input_dataset "./tmp/wikipedia_test_long" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \

### bge-large-en

In [None]:
!torchrun --nproc_per_node=8 ../embedding/bge-large-en.py \
    --input_dataset "./tmp/wikipedia" \
    --output_dataset "your_output_dir" \
    --train_size 4000000 \
    --valid_size 5000 \
    --test_size 5000 \

In [None]:
# wikipedia_test_medium
!torchrun --nproc_per_node=8 ../embedding/bge-large-en.py \
    --input_dataset "./tmp/wikipedia_test_medium" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \

In [None]:
# wikipedia_test_long
!torchrun --nproc_per_node=8 ../embedding/bge-large-en.py \
    --input_dataset "./tmp/wikipedia_test_long" \
    --output_dataset "your_output_dir" \
    --train_size 0 \
    --valid_size 0 \
    --test_size 5000 \