# process origin dataset

## load origin dataset

In [None]:
import datasets

In [None]:
pile_pubmed = datasets.Dataset.from_json("/dataset/pile/dedup-md5-pile-pubmed_abstracts.jsonl")

pile_pubmed

## sample dataset

In [4]:
seed = 2023
sample_size = 10_000

In [None]:
dataset = pile_pubmed.shuffle(seed=seed).select(range(sample_size))

dataset

## split sentence

In [6]:
open_list = ["[", "{", "("]
close_list = ["]", "}", ")"]

def is_balance(text):
    stack = []
    for c in text:
        if c in open_list:
            stack.append(c)
        elif c in close_list:
            pos = close_list.index(c)
            if stack and (open_list[pos] == stack[-1]):
                stack.pop()
            else:
                return False
    if len(stack) == 0:
        return True

In [7]:
import string

punctuation_tuple = tuple(string.punctuation)

def is_sentence_clean(sent: str, min_length: int=5, max_length: int=128):
    sent = sent.strip()
    
    length = len(sent.split())
    if length < min_length or length > max_length:
        return False
    if sent.startswith(tuple(string.punctuation)):
        return False
    if "\n" in sent:
        return False
    if not is_balance(sent):
        return False
    return True

In [8]:
from tqdm import tqdm
import spacy

In [None]:
spacy.prefer_gpu(0)
nlp = spacy.load("en_core_web_trf")

In [None]:
sentence_list = []
length_list = []
for text in tqdm(dataset["text"]):
    doc = nlp(text)
    for sent in doc.sents:
        sent = sent.text
        sent = sent.strip()
        if is_sentence_clean(sent, 8, 128):
            sentence_list.append(sent)
            length_list.append(len(sent.split()))

In [None]:
len(sentence_list)

In [None]:
sum(length_list) / len(length_list)

## shuffle & get new dataset list

In [13]:
import random
random.seed(seed)

In [14]:
random.shuffle(sentence_list)

In [15]:
train_size = 10_000
valid_size = 5_000
test_size = 5_000

sentence_list = sentence_list[:train_size + valid_size + test_size]

In [18]:
dataset = datasets.Dataset.from_dict(
    {
        "text": sentence_list,
    }
)

dataset.save_to_disk("./tmp/pile_pubmed")

                                                                                                  

# process new dataset

## get embedding & save

### sup-simcse-bert-base-uncased

In [None]:
!torchrun --nproc_per_node=8 ../embedding/sup-simcse-bert-base-uncased.py \
    --input_dataset "./tmp/pile_pubmed" \
    --output_dataset "your_output_dir" \
    --train_size 10000 \
    --valid_size 5000 \
    --test_size 5000 \

### e5-large-v2

In [None]:
!torchrun --nproc_per_node=8 ../embedding/e5-large-v2.py \
    --input_dataset "./tmp/pile_pubmed" \
    --output_dataset "your_output_dir" \
    --train_size 10000 \
    --valid_size 5000 \
    --test_size 5000 \

### bge-large-en

In [None]:
!torchrun --nproc_per_node=8 ../embedding/bge-large-en.py \
    --input_dataset "./tmp/pile_pubmed" \
    --output_dataset "your_output_dir" \
    --train_size 10000 \
    --valid_size 5000 \
    --test_size 5000 \