# Knowledge Mixing

## Overview
This notebook combines the four types of knowledge tuning example datasets (Extractive summaries, Detailed Summaries, Key Facts Q&A, Document-based Q&A) into training-ready datasets. It prepares different cut sizes and mixes (upsampling and downsampling) that are consumed by model training workflows.

This notebook generates a consolidated dataset, based on the input `CUT` size provided. It evaluates the feasible CUT sizes and generates a consolidated dataset for every feasible CUT size.

It generates `/step_04/training_mix/combined_cut_{N}x.jsonl` where N is the feasible cut size.

## Prerequisites

Datasets generated from the [knowledge generation notebook](../03_Knowledge_Generation/Knowledge_Generation.ipynb) are in the `output/step_03/` directory.
- Extractive summaries
- Detailed summaries
- Key Facts Q&A pairs
- Document Based Q&A pairs 




## Install dependencies

In [None]:
!pip install -qqU .

In [None]:
import os
from pathlib import Path

from datasets import Dataset, concatenate_datasets, load_dataset
from dotenv import load_dotenv
from tabulate import tabulate
from transformers import AutoTokenizer
from utils.knowledge_utils import (
    count_len_in_tokens,
    generate_knowledge_qa_dataset,
    get_avg_summaries_per_raw_doc,
    sample_doc_qa,
)

## Set up paths and directories

In [None]:
load_dotenv()  # Load environment variables from .env file

WORKSPACE = Path.cwd().parent  # Path to the workspace directory

OUTPUT_DIR = WORKSPACE / "output" / "step_04"

OUTPUT_DIR.mkdir(
    parents=True, exist_ok=True
)  # Create the output directory if it doesn't exist

In [None]:
# Load configuration from environment variables
KNOWLEDGE_OUTPUT_DIR = WORKSPACE / "output" / "step_03"
TOKENIZER_MODEL = os.getenv("TOKENIZER_MODEL_NAME", "RedHatAI/Llama-3.1-8B-Instruct")
SAVE_GPT_OSS_FORMAT = os.getenv("SAVE_GPT_OSS_FORMAT", "false").lower() == "true"

# Parse cut sizes from environment variable
cut_sizes_str = os.getenv("CUT_SIZES", "5,50")
cuts = [int(x.strip()) for x in cut_sizes_str.split(",")]

# Get Q&A pairs per document
QA_PER_DOC = int(os.getenv("QA_PER_DOC", "3"))


exp_folder = str(KNOWLEDGE_OUTPUT_DIR)
# Define input and output paths relative to exp_folder
input_data_dir = os.path.join(exp_folder)
output_dir = os.path.join(OUTPUT_DIR, "training_mix")

print(f"Experiment folder: {exp_folder}")
print(f"Student model: {TOKENIZER_MODEL}")
print(f"GPT OSS format: {SAVE_GPT_OSS_FORMAT}")
print(f"Cut sizes: {cuts}")
print(f"Q&A pairs per document: {QA_PER_DOC}")
print(f"Input data directory: {input_data_dir}")
print(f"Output directory: {output_dir}")

## Load the datasets

### Define utility functions

In [None]:
def load_tokenizer(student_model):
    """Initialize and return tokenizer."""
    print(f"Loading tokenizer: {student_model}")
    return AutoTokenizer.from_pretrained(student_model, trust_remote_code=True)


def filter_gpt_oss_dataset(ds):
    """Apply GPT OSS format filtering to dataset."""
    original_size = len(ds)

    # Filter out problematic questions
    ds = ds.filter(
        lambda x: (
            "..." not in x["question"]
            and "<question>" not in x["question"]
            and "<Insert question here>" not in x["question"]
        )
    )

    # Clean response text
    ds = ds.map(
        lambda x: {
            "response": x["response"]
            .replace("[ANSWER]", "")
            .replace("[END]", "")
            .strip()
        }
    )

    filtered_size = len(ds)
    print(f"  Filtered {original_size - filtered_size} samples (kept {filtered_size})")
    return ds


def load_summary_dataset(summary_type):
    """Load a single summary dataset."""
    file_path = os.path.join(input_data_dir, f"{summary_type}")

    # Check if file exists
    if not Path(file_path).exists():
        print(f"‚ö†Ô∏è  Warning: File not found: {file_path}")
        return None

    print(f"Loading {summary_type} from: {file_path}")
    ds = load_dataset("json", data_dir=file_path, split="train")

    if summary_type == "document_based_qa":
        ds = ds.rename_column("base_document", "raw_document")
    # Apply filtering if needed
    if SAVE_GPT_OSS_FORMAT:
        ds = filter_gpt_oss_dataset(ds)

    print(f"  Loaded {summary_type}: {len(ds)} samples")
    return ds.to_polars()


def load_all_summary_datasets():
    """Load all summary type datasets."""
    summary_types = [
        "extractive_summary",
        "detailed_summary",
        "key_facts_to_qa",
        "document_based_qa",
    ]

    summary_datasets = {}

    for summary_type in summary_types:
        dataset = load_summary_dataset(summary_type)
        if dataset is not None:
            summary_datasets[summary_type] = dataset

    if not summary_datasets:
        raise ValueError("No datasets were successfully loaded!")

    return summary_datasets

### Load all the datasets

Use the utility functions to load all datasets.

In [None]:
# Load tokenizer and datasets
try:
    tokenizer = load_tokenizer(TOKENIZER_MODEL)
    summary_datasets = load_all_summary_datasets()
    # After loading each dataset

    for _summary_type, dataset in summary_datasets.items():
        print(f" Columns: {list(dataset.columns)}")
        for column in list(dataset.columns):
            print(f"          - {column}")

        print(f" Sample record keys: {list(dataset.head(1).to_dicts()[0].keys())}")
    print(f"\n‚úÖ Successfully loaded {len(summary_datasets)} summary datasets")
except Exception as e:
    print(f"‚ùå Error during initialization: {e}")
    raise

## Combine the datasets and save to a file

### Define utility functions

In [None]:
def validate_cuts_for_datasets(summary_datasets, cuts):
    """Validate which cut sizes are feasible for each dataset."""
    feasible_cuts = set(cuts)

    print("üîç Validating cut sizes against available data...")
    for summary_type, df in summary_datasets.items():
        if summary_type in ["key_facts_to_qa", "document_based_qa"]:
            print(f"\nüìä Skipping {summary_type}:")
            continue
        print(f"\nüìä Checking {summary_type}:")

        for cut in cuts:
            avg_summaries = get_avg_summaries_per_raw_doc(df)
            is_feasible = avg_summaries >= cut
            status = "‚úÖ Feasible" if is_feasible else "‚ùå Too large"
            print(
                f"  Cut {cut}: {status} (avg summaries per raw doc: {avg_summaries:.1f})"
            )

            if not is_feasible:
                feasible_cuts.discard(cut)

    final_cuts = sorted(list(feasible_cuts))
    if len(final_cuts) < len(cuts):
        removed_cuts = set(cuts) - feasible_cuts
        print(f"\n‚ö†Ô∏è  Removing infeasible cuts: {sorted(list(removed_cuts))}")

    print(f"\n‚úÖ Final feasible cuts: {final_cuts}")
    return final_cuts


def process_single_summary_type(summary_type, df, cut, tokenizer, qa_per_doc):
    """Process a single summary type dataset."""
    try:
        print(f"  Processing {summary_type}...")
        if summary_type == "key_facts_to_qa":
            # Skip the sampling step for keys facts QA dataset as we discard the generated summary and only keep the qa pairs
            # Generate knowledge Q&A dataset
            generated_dataset = generate_knowledge_qa_dataset(
                df,
                keep_columns=[
                    "question",
                    "document_outline",
                    "raw_document",
                    "document",
                ],
                pre_training=True,
                keep_document_in_context=False,
            )
        else:
            if summary_type != "document_based_qa":
                # Sample documents and Q&A pairs (validation already done)
                df_cut = sample_doc_qa(df, n_docs_per_raw=cut, qa_per_doc=qa_per_doc)
            else:
                df_cut = df

            # Generate knowledge Q&A dataset
            generated_dataset = generate_knowledge_qa_dataset(
                df_cut,
                keep_columns=[
                    "question",
                    "document_outline",
                    "raw_document",
                    "document",
                ],
                pre_training=True,
                keep_document_in_context=True,
            )

        # Count tokens
        generated_dataset = count_len_in_tokens(generated_dataset, tokenizer)

        # Convert back to HuggingFace dataset
        generated_dataset = Dataset.from_polars(generated_dataset)

        # Calculate statistics
        unique_docs = len(set(generated_dataset["document"]))
        unique_raw_docs = len(set(generated_dataset["raw_document"]))
        generated_cut_size = unique_docs / unique_raw_docs if unique_raw_docs > 0 else 0

        stats = {
            "samples": len(generated_dataset),
            "unique_docs": unique_docs,
            "unique_raw_docs": unique_raw_docs,
            "avg_docs_per_raw": generated_cut_size,
            "total_tokens": sum(generated_dataset["token_length"]),
        }

        print(
            f"    ‚úÖ Processed {len(generated_dataset)} samples ({generated_cut_size:.1f} summaries per raw doc)"
        )
        return generated_dataset, stats

    except Exception as e:
        print(f"    ‚ùå Error processing {summary_type}: {e}")
        return None, None


def combine_and_save_datasets(all_datasets, cut_stats, cut, output_dir):
    """Combine datasets and save to file."""
    if not all_datasets:
        print(f"  ‚ùå No datasets processed for cut size {cut}")
        return None

    try:
        # Combine all summary types for this cut
        combined_dataset = concatenate_datasets(all_datasets)
        total_tokens = sum(combined_dataset["token_length"])

        # Save combined dataset
        output_path = os.path.join(output_dir, f"combined_cut_{cut}x.jsonl")
        combined_dataset.to_json(output_path, orient="records", lines=True)

        # Print results
        print(f"  üíæ Saved to: {output_path}")
        print(f"  üìà Total samples: {len(combined_dataset)}")
        print(f"  üî¢ Total tokens: {total_tokens:,}")

        # Print detailed statistics
        print("  üìã Summary statistics:")
        for summary_type, stats in cut_stats.items():
            print(
                f"    {summary_type}: {stats['samples']} samples, {stats['total_tokens']:,} tokens"
            )

        return (cut, total_tokens, len(combined_dataset))

    except Exception as e:
        print(f"  ‚ùå Error combining datasets for cut {cut}: {e}")
        return None


def process_single_cut(cut, summary_datasets, tokenizer, output_dir, qa_per_doc):
    """Process all summary types for a single cut size."""
    print(f"\nüìä Processing cut size: {cut}")
    all_datasets = []
    cut_stats = {}

    for summary_type, df in summary_datasets.items():
        dataset, stats = process_single_summary_type(
            summary_type, df, cut, tokenizer, qa_per_doc
        )

        if dataset is not None and stats is not None:
            all_datasets.append(dataset)
            cut_stats[summary_type] = stats

    return combine_and_save_datasets(all_datasets, cut_stats, cut, output_dir)


def process_and_mix_datasets(cuts, summary_datasets, tokenizer, output_dir, qa_per_doc):
    """Process and mix datasets with different cut sizes."""
    # First validate which cuts are feasible
    feasible_cuts = validate_cuts_for_datasets(summary_datasets, cuts)

    if not feasible_cuts:
        print("\n‚ùå No feasible cuts found! Check your data or reduce cut sizes.")
        return []

    token_count = []

    print(f"\nProcessing {len(feasible_cuts)} feasible cut sizes...")
    for cut in feasible_cuts:
        result = process_single_cut(
            cut, summary_datasets, tokenizer, output_dir, QA_PER_DOC
        )
        if result is not None:
            token_count.append(result)

    return token_count


def print_final_summary(token_count):
    """Print final summary table."""
    if token_count:
        print("\n" + "=" * 50)
        print("üìä FINAL SUMMARY")
        print("=" * 50)
        print(
            tabulate(
                token_count,
                headers=["Cut Size", "Total Tokens", "Total Samples"],
                tablefmt="github",
                numalign="right",
            )
        )
    else:
        print("\n‚ùå No datasets were successfully processed!")

### Mix the datasets

Use the utility functions to process and mix the datasets.

In [None]:
# Process datasets
token_count = process_and_mix_datasets(
    cuts, summary_datasets, tokenizer, output_dir, QA_PER_DOC
)

# Print final summary
print_final_summary(token_count)

## Next Step

In this Knowledge Tuning example, the output training dataset is small enough for training and you can proceed to [Module 5: Model Training](../05_Model_Training/05_Model_Training_README.md).

**NOTE:** In a use case where the output dataset is too large for training (for example, 1 million samples), your next step would be to identify representative samples subsets of the data as illustrated in [the example Subset Selection for Dataset Diversity notebook](https://github.com/opendatahub-io/data-processing/blob/main/notebooks/use-cases/subset-selection.ipynb).