In [None]:
import pandas as pd
from transformers import AutoTokenizer
import matplotlib.pyplot as plt

# ======================
# CONFIG
# ======================
FILES = {
    "rond": {
        "train": "./ROND_train.csv",
        "val": "./ROND_val.csv",
        "test": "./ROND_test.csv"
    },
    "bluescrubs": {
        "train": "./bluescrubs_train_clean.csv",
        "val": "./bluescrubs_val_clean.csv",
        "test": "./bluescrubs_test_clean.csv"
    }
}

# Tokenizers for BERT and Longformer
bert_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
longformer_tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")


# ======================
# FUNCTION: Truncate + Chunk
# ======================
def truncate_and_chunk(text, tokenizer, max_tokens=512, stride=50):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    chunks = []

    for i in range(0, len(tokens), max_tokens - stride):
        chunk_tokens = tokens[i:i + max_tokens]
        chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
        chunks.append(chunk_text)

    return chunks


# ======================
# PIPELINE for BlueScrubs (classification)
# ======================
def preprocess_bluescrubs(file_path, tokenizer, max_tokens, name=""):
    df = pd.read_csv(file_path)

    new_texts = []
    new_labels = []
    chunk_counts = []

    for _, row in df.iterrows():
        text = str(row["input"])      # text column
        label = row["output"]         # classification label (0/1)

        chunks = truncate_and_chunk(text, tokenizer, max_tokens=max_tokens)
        chunk_counts.append(len(chunks))

        for chunk in chunks:
            new_texts.append(chunk)
            new_labels.append(label)

    new_df = pd.DataFrame({"text": new_texts, "label": new_labels})

    # Validation
    print(f"\nðŸ“Š Validation for {name}:")
    print(f"  Total original docs: {len(df)}")
    print(f"  Total new docs (after chunking): {len(new_df)}")
    print(f"  Avg chunks per doc: {sum(chunk_counts)/len(chunk_counts):.2f}")
    print(f"  Max chunks for a single doc: {max(chunk_counts)}")

    plt.hist(chunk_counts, bins=50, color='skyblue', edgecolor='black')
    plt.title(f"Chunk distribution per document ({name})")
    plt.xlabel("Chunks per document")
    plt.ylabel("Frequency")
    plt.show()

    return new_df


# ======================
# ROND (summarization) - No chunking needed
# ======================
def load_rond(file_path):
    df = pd.read_csv(file_path)
    return df[["input", "output"]]   # just keep text + summary


# ======================
# Run preprocessing
# ======================

# ROND (no chunking)
rond_train = load_rond(FILES["rond"]["train"])
rond_val   = load_rond(FILES["rond"]["val"])
rond_test  = load_rond(FILES["rond"]["test"])

# BlueScrubs (chunking for BERT + Longformer)
bluescrubs_train_bert = preprocess_bluescrubs(FILES["bluescrubs"]["train"], bert_tokenizer, 512, "BERT - Train")
bluescrubs_val_bert   = preprocess_bluescrubs(FILES["bluescrubs"]["val"], bert_tokenizer, 512, "BERT - Val")
bluescrubs_test_bert  = preprocess_bluescrubs(FILES["bluescrubs"]["test"], bert_tokenizer, 512, "BERT - Test")

bluescrubs_train_long = preprocess_bluescrubs(FILES["bluescrubs"]["train"], longformer_tokenizer, 4096, "Longformer - Train")
bluescrubs_val_long   = preprocess_bluescrubs(FILES["bluescrubs"]["val"], longformer_tokenizer, 4096, "Longformer - Val")
bluescrubs_test_long  = preprocess_bluescrubs(FILES["bluescrubs"]["test"], longformer_tokenizer, 4096, "Longformer - Test")


# ======================
# Save outputs
# ======================
# ROND
rond_train.to_csv("./rond_train_processed.csv", index=False)
rond_val.to_csv("./rond_val_processed.csv", index=False)
rond_test.to_csv("./rond_test_processed.csv", index=False)

# BlueScrubs
bluescrubs_train_bert.to_csv("./bluescrubs_train_chunked_bert.csv", index=False)
bluescrubs_val_bert.to_csv("./bluescrubs_val_chunked_bert.csv", index=False)
bluescrubs_test_bert.to_csv("./bluescrubs_test_chunked_bert.csv", index=False)

bluescrubs_train_long.to_csv("./bluescrubs_train_chunked_longformer.csv", index=False)
bluescrubs_val_long.to_csv("./bluescrubs_val_chunked_longformer.csv", index=False)
bluescrubs_test_long.to_csv("./bluescrubs_test_chunked_longformer.csv", index=False)
