In [1]:
from datasets import load_dataset, Dataset, concatenate_datasets
import polars as pl
from transformers import AutoTokenizer
from tabulate import tabulate
from dotenv import load_dotenv
import os
from pathlib import Path
from knowledge_mixing_utils import sample_doc_qa, generate_knowledge_qa_dataset, count_len_in_tokens, get_avg_summaries_per_raw_doc

  from .autonotebook import tqdm as notebook_tqdm
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
# Load environment variables from .env file
load_dotenv()

# Load configuration from environment variables
exp_folder = os.getenv('OUTPUT_DATA_FOLDER', 'generated_output_data')
student_model = os.getenv('STUDENT_MODEL', 'meta-llama/Llama-3.1-8B-Instruct')
save_gpt_oss_format = os.getenv('SAVE_GPT_OSS_FORMAT', 'false').lower() == 'true'

# Parse cut sizes from environment variable
cut_sizes_str = os.getenv('CUT_SIZES', '10,20')
cuts = [int(x.strip()) for x in cut_sizes_str.split(',')]

# Get Q&A pairs per document
qa_per_doc = int(os.getenv('QA_PER_DOC', '3'))

# Define input and output paths relative to exp_folder
input_data_dir = os.path.join(exp_folder)
output_dir = os.path.join(exp_folder, 'training_mix')

print(f"Experiment folder: {exp_folder}")
print(f"Student model: {student_model}")
print(f"GPT OSS format: {save_gpt_oss_format}")
print(f"Cut sizes: {cuts}")
print(f"Q&A pairs per document: {qa_per_doc}")
print(f"Input data directory: {input_data_dir}")
print(f"Output directory: {output_dir}")

# Create output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)

Experiment folder: output_data
Student model: meta-llama/Llama-3.1-8B-Instruct
GPT OSS format: False
Cut sizes: [3, 5, 7]
Q&A pairs per document: 3
Input data directory: output_data
Output directory: output_data/training_mix


In [None]:
def load_tokenizer(student_model):
    """Initialize and return tokenizer."""
    print(f"Loading tokenizer: {student_model}")
    return AutoTokenizer.from_pretrained(student_model, trust_remote_code=True)


def filter_gpt_oss_dataset(ds):
    """Apply GPT OSS format filtering to dataset."""
    original_size = len(ds)
    
    # Filter out problematic questions
    ds = ds.filter(
        lambda x: '...' not in x['question'] 
        and '<question>' not in x['question'] 
        and '<Insert question here>' not in x['question']
    )
    
    # Clean response text
    ds = ds.map(
        lambda x: {
            'response': x['response'].replace('[ANSWER]', '').replace('[END]', '').strip()
        }
    )
    
    filtered_size = len(ds)
    print(f"  Filtered {original_size - filtered_size} samples (kept {filtered_size})")
    return ds


def load_summary_dataset(summary_type):
    """Load a single summary dataset."""
    file_path = os.path.join(input_data_dir, f"{summary_type}")
    
    # Check if file exists
    if not Path(file_path).exists():
        print(f"⚠️  Warning: File not found: {file_path}")
        return None
    
    print(f"Loading {summary_type} from: {file_path}")
    ds = load_dataset('json', data_dir=file_path, split="train")
    
    # Apply filtering if needed
    if save_gpt_oss_format:
        ds = filter_gpt_oss_dataset(ds)
    
    print(f"  Loaded {summary_type}: {len(ds)} samples")
    return ds.to_polars()


def load_all_summary_datasets():
    """Load all summary type datasets."""
    summary_types = [
        "extractive_summary",
        "detailed_summary", 
        "key_facts_to_qa"
    ]
    
    summary_datasets = {}
    
    for summary_type in summary_types:
        dataset = load_summary_dataset(summary_type)
        if dataset is not None:
            summary_datasets[summary_type] = dataset
    
    if not summary_datasets:
        raise ValueError("No datasets were successfully loaded!")
    
    return summary_datasets


# Load tokenizer and datasets
try:
    tokenizer = load_tokenizer(student_model)
    summary_datasets = load_all_summary_datasets()
     # After loading each dataset

    for summary_type, dataset in summary_datasets.items():

        print(f" Columns: {list(dataset.columns)}")

        print(f" Sample record keys: {list(dataset.head(1).to_dicts()[0].keys())}")
    print(f"\n✅ Successfully loaded {len(summary_datasets)} summary datasets")
except Exception as e:
    print(f"❌ Error during initialization: {e}")
    raise

Loading tokenizer: meta-llama/Llama-3.1-8B-Instruct
Loading extractive_summary from: output_data/extractive_summary
  Loaded extractive_summary: 451 samples
Loading detailed_summary from: output_data/detailed_summary
  Loaded detailed_summary: 508 samples
Loading key_facts_summary from: output_data/key_facts_summary
  Loaded key_facts_summary: 418 samples

✅ Successfully loaded 3 summary datasets


In [None]:
def validate_cuts_for_datasets(summary_datasets, cuts):
    """Validate which cut sizes are feasible for each dataset."""
    feasible_cuts = set(cuts)

    print("🔍 Validating cut sizes against available data...")
    for summary_type, df in summary_datasets.items():
        if summary_type == "key_facts_to_qa":
            print(f"\n📊 Skipping {summary_type}:")
            continue
        print(f"\n📊 Checking {summary_type}:")

        for cut in cuts:
            avg_summaries = get_avg_summaries_per_raw_doc(df)
            is_feasible = avg_summaries >= cut
            status = "✅ Feasible" if is_feasible else "❌ Too large"
            print(f"  Cut {cut}: {status} (avg summaries per raw doc: {avg_summaries:.1f})")

            if not is_feasible:
                feasible_cuts.discard(cut)

    final_cuts = sorted(list(feasible_cuts))
    if len(final_cuts) < len(cuts):
        removed_cuts = set(cuts) - feasible_cuts
        print(f"\n⚠️  Removing infeasible cuts: {sorted(list(removed_cuts))}")

    print(f"\n✅ Final feasible cuts: {final_cuts}")
    return final_cuts


def process_single_summary_type(summary_type, df, cut, tokenizer, qa_per_doc):
    """Process a single summary type dataset."""
    try:
        print(f"  Processing {summary_type}...")
        if summary_type == "key_facts_to_qa":
            # Skip the sampling step for keys facts QA dataset as we discard the generated summary and only keep the qa pairs
            # Generate knowledge Q&A dataset
            generated_dataset = generate_knowledge_qa_dataset(
                df,
                keep_columns=["question", "document_outline", 'raw_document', 'document'],
                pre_training=True,
                keep_document_in_context=False
            )
        else:
            # Sample documents and Q&A pairs (validation already done)
            df_cut = sample_doc_qa(df, n_docs_per_raw=cut, qa_per_doc=qa_per_doc)

            # Generate knowledge Q&A dataset
            generated_dataset = generate_knowledge_qa_dataset(
                df_cut,
                keep_columns=["question", "document_outline", 'raw_document', 'document'],
                pre_training=True,
                keep_document_in_context=True
            )

        # Count tokens
        generated_dataset = count_len_in_tokens(generated_dataset, tokenizer)

        # Convert back to HuggingFace dataset
        generated_dataset = Dataset.from_polars(generated_dataset)

        # Calculate statistics
        unique_docs = len(set(generated_dataset['document']))
        unique_raw_docs = len(set(generated_dataset['raw_document']))
        generated_cut_size = unique_docs / unique_raw_docs if unique_raw_docs > 0 else 0

        stats = {
            'samples': len(generated_dataset),
            'unique_docs': unique_docs,
            'unique_raw_docs': unique_raw_docs,
            'avg_docs_per_raw': generated_cut_size,
            'total_tokens': sum(generated_dataset['token_length'])
        }

        print(f"    ✅ Processed {len(generated_dataset)} samples ({generated_cut_size:.1f} summaries per raw doc)")
        return generated_dataset, stats

    except Exception as e:
        print(f"    ❌ Error processing {summary_type}: {e}")
        return None, None


def combine_and_save_datasets(all_datasets, cut_stats, cut, output_dir):
    """Combine datasets and save to file."""
    if not all_datasets:
        print(f"  ❌ No datasets processed for cut size {cut}")
        return None

    try:
        # Combine all summary types for this cut
        combined_dataset = concatenate_datasets(all_datasets)
        total_tokens = sum(combined_dataset['token_length'])

        # Save combined dataset
        output_path = os.path.join(output_dir, f"combined_cut_{cut}x.jsonl")
        combined_dataset.to_json(output_path, orient="records", lines=True)

        # Print results
        print(f"  💾 Saved to: {output_path}")
        print(f"  📈 Total samples: {len(combined_dataset)}")
        print(f"  🔢 Total tokens: {total_tokens:,}")

        # Print detailed statistics
        print(f"  📋 Summary statistics:")
        for summary_type, stats in cut_stats.items():
            print(f"    {summary_type}: {stats['samples']} samples, {stats['total_tokens']:,} tokens")

        return (cut, total_tokens, len(combined_dataset))

    except Exception as e:
        print(f"  ❌ Error combining datasets for cut {cut}: {e}")
        return None


def process_single_cut(cut, summary_datasets, tokenizer, output_dir, qa_per_doc):
    """Process all summary types for a single cut size."""
    print(f"\n📊 Processing cut size: {cut}")
    all_datasets = []
    cut_stats = {}

    for summary_type, df in summary_datasets.items():
        dataset, stats = process_single_summary_type(summary_type, df, cut, tokenizer, qa_per_doc)

        if dataset is not None and stats is not None:
            all_datasets.append(dataset)
            cut_stats[summary_type] = stats

    return combine_and_save_datasets(all_datasets, cut_stats, cut, output_dir)


def process_and_mix_datasets(cuts, summary_datasets, tokenizer, output_dir, qa_per_doc):
    """Process and mix datasets with different cut sizes."""
    # First validate which cuts are feasible
    feasible_cuts = validate_cuts_for_datasets(summary_datasets, cuts)

    if not feasible_cuts:
        print("\n❌ No feasible cuts found! Check your data or reduce cut sizes.")
        return []

    token_count = []

    print(f"\nProcessing {len(feasible_cuts)} feasible cut sizes...")
    for cut in feasible_cuts:
        result = process_single_cut(cut, summary_datasets, tokenizer, output_dir, qa_per_doc)
        if result is not None:
            token_count.append(result)

    return token_count


def print_final_summary(token_count):
    """Print final summary table."""
    if token_count:
        print("\n" + "="*50)
        print("📊 FINAL SUMMARY")
        print("="*50)
        print(tabulate(
            token_count,
            headers=["Cut Size", "Total Tokens", "Total Samples"],
            tablefmt="github",
            numalign="right"
        ))
    else:
        print("\n❌ No datasets were successfully processed!")


# Process datasets
token_count = process_and_mix_datasets(cuts, summary_datasets, tokenizer, output_dir, qa_per_doc)

# Print final summary
print_final_summary(token_count)

🔍 Validating cut sizes against available data...

📊 Checking extractive_summary:
  Cut 3: ✅ Feasible (avg summaries per raw doc: 7.0)
  Cut 5: ✅ Feasible (avg summaries per raw doc: 7.0)
  Cut 7: ✅ Feasible (avg summaries per raw doc: 7.0)

📊 Checking detailed_summary:
  Cut 3: ✅ Feasible (avg summaries per raw doc: 7.0)
  Cut 5: ✅ Feasible (avg summaries per raw doc: 7.0)
  Cut 7: ✅ Feasible (avg summaries per raw doc: 7.0)

📊 Checking key_facts_summary:
  Cut 3: ✅ Feasible (avg summaries per raw doc: 6.0)
  Cut 5: ✅ Feasible (avg summaries per raw doc: 6.0)
  Cut 7: ❌ Too large (avg summaries per raw doc: 6.0)

⚠️  Removing infeasible cuts: [7]

✅ Final feasible cuts: [3, 5]

Processing 2 feasible cut sizes...

📊 Processing cut size: 3
  Processing extractive_summary...



A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)


    ✅ Processed 45 samples (3.0 summaries per raw doc)
  Processing detailed_summary...
    ✅ Processed 45 samples (3.0 summaries per raw doc)
  Processing key_facts_summary...



A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)

A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)


    ✅ Processed 45 samples (3.0 summaries per raw doc)


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 34.30ba/s]

  💾 Saved to: output_data/training_mix/combined_cut_3x.jsonl
  📈 Total samples: 135
  🔢 Total tokens: 393,042
  📋 Summary statistics:
    extractive_summary: 45 samples, 134,210 tokens
    detailed_summary: 45 samples, 107,682 tokens
    key_facts_summary: 45 samples, 151,150 tokens

📊 Processing cut size: 5
  Processing extractive_summary...




A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)


    ✅ Processed 75 samples (5.0 summaries per raw doc)
  Processing detailed_summary...



A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)


    ✅ Processed 75 samples (5.0 summaries per raw doc)
  Processing key_facts_summary...



A later expression might fail because the output type is not known. Set return_dtype=pl.self_dtype() if the type is unchanged, or set the proper output data type.
  knowledge_ds = generated_dataset.with_columns(base_columns)


    ✅ Processed 75 samples (5.0 summaries per raw doc)


Creating json from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 20.11ba/s]

  💾 Saved to: output_data/training_mix/combined_cut_5x.jsonl
  📈 Total samples: 225
  🔢 Total tokens: 653,808
  📋 Summary statistics:
    extractive_summary: 75 samples, 226,191 tokens
    detailed_summary: 75 samples, 176,539 tokens
    key_facts_summary: 75 samples, 251,078 tokens

📊 FINAL SUMMARY
|   Cut Size |   Total Tokens |   Total Samples |
|------------|----------------|-----------------|
|          3 |         393042 |             135 |
|          5 |         653808 |             225 |



