In [None]:
import os
import json
import re
from tqdm import tqdm
from datasketch import MinHash, MinHashLSH
from multiprocessing import Pool, Manager, cpu_count

INPUT_DIR = os.path.expanduser("/home/blu-bridge004/Desktop/pytorch/preprocessed06_10")
OUTPUT_DIR = os.path.expanduser("/home/blu-bridge004/Desktop/pytorch/deduplicated_shards25_test3")
os.makedirs(OUTPUT_DIR, exist_ok=True)

num_perm = 128
threshold = 0.9

def normalize(text):
    text = text.lower()
    text = re.sub(r"\d+", "0", text)  # replace numbers with 0
    text = re.sub(r"\s+", " ", text)  # collapse whitespace
    text = re.sub(r"[^\w\s]", "", text)  # remove punctuation
    return text.strip()

def get_minhash(text):
    m = MinHash(num_perm=num_perm)
    for word in set(text.split()):
        m.update(word.encode('utf8'))
    return m

def dedup_worker(args):
    filename, global_seen = args
    input_path = os.path.join(INPUT_DIR, filename)
    output_path = os.path.join(OUTPUT_DIR, filename.replace(".jsonl", ".dedup.jsonl"))

    local_seen = set()
    documents = []
    duplicate_lines = []

    with open(input_path, "r", encoding="utf-8") as in_f:
        for line_number, line in enumerate(in_f, start=1):
            try:
                data = json.loads(line)
                text = data.get("text", "").strip()
                if not text:
                    continue

                norm_text = normalize(text)
                if norm_text in local_seen:
                    duplicate_lines.append(line_number)
                    continue

                mh = get_minhash(norm_text)
                is_duplicate = False

                with global_seen['lock']:
                    for key in global_seen['lsh'].query(mh):
                        is_duplicate = True
                        break

                    if not is_duplicate:
                        global_seen['lsh'].insert(f"{filename}_{line_number}", mh)

                if not is_duplicate:
                    local_seen.add(norm_text)
                    documents.append(data)
                else:
                    duplicate_lines.append(line_number)

            except Exception as e:
                print(f"Error in {filename} at line {line_number}: {e}")

    with open(output_path, "w", encoding="utf-8") as out_f:
        for doc in documents:
            out_f.write(json.dumps(doc, ensure_ascii=False) + "\n")

    return filename, duplicate_lines

def main():
    files = sorted(f for f in os.listdir(INPUT_DIR) if f.endswith(".jsonl"))

    manager = Manager()
    global_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
    global_seen = manager.dict()
    global_seen['lsh'] = global_lsh
    global_seen['lock'] = manager.Lock()

    with Pool(cpu_count()) as pool:
        results = list(tqdm(pool.imap(dedup_worker, [(f, global_seen) for f in files]), total=len(files)))

    for filename, duplicates in results:
        if duplicates:
            print(f"Duplicate lines in {filename}: {duplicates}")

if __name__ == "__main__":
    main()

  0%|          | 0/2 [00:00<?, ?it/s]