In [2]:
import os
import json

def collect_chains(base_path, label, chain_source):
    """
    Collect all chains from a given directory and assign a label and source.
    :param base_path: Path to the directory containing JSON files.
    :param label: Valid or invalid chain type (e.g., "valid", "invalid_type1").
    :param chain_source: Chain source (e.g., "latest" or "influential").
    :return: A list of dictionaries containing chain metadata and content.
    """
    chains = []
    for root, _, files in os.walk(base_path):
        for file in files:
            if file.endswith(".json"):
                file_path = os.path.join(root, file)
                with open(file_path, "r") as f:
                    content = json.load(f)
                chains.append({
                    "file_name": file,
                    "chain_label": label,
                    "chain_source": chain_source,
                    "file_path": file_path,
                    "content": content
                })
    return chains

# Define all paths with updated chain sources
data_paths = [
    {"path": "../ground_truth_path/result_chains", "label": "valid", "source": "latest"},
    {"path": "../ground_truth_path/invalid_chains_type1", "label": "invalid_type1", "source": "latest"},
    {"path": "../ground_truth_path/invalid_chains_type2", "label": "invalid_type2", "source": "latest"},
    {"path": "../ground_truth_path_fred/result_chains", "label": "valid", "source": "influential"},
    {"path": "../ground_truth_path_fred/invalid_chains_type1", "label": "invalid_type1", "source": "influential"},
    {"path": "../ground_truth_path_fred/invalid_chains_type2", "label": "invalid_type2", "source": "influential"},
]

# Collect all chains
all_chains = []
for data_path in data_paths:
    chains = collect_chains(data_path["path"], data_path["label"], data_path["source"])
    all_chains.extend(chains)

# Save the collected data as a JSON object
output_path = "./collected_chains_summary.json"
with open(output_path, "w") as f:
    json.dump(all_chains, f, indent=4)

print(f"Chains summary saved to {output_path}")


Chains summary saved to ./collected_chains_summary.json


In [1]:
import json
from collections import Counter, defaultdict

# Load the collected chains summary
input_path = "./collected_chains_summary.json"
with open(input_path, "r") as f:
    data = json.load(f)

# Initialize counters and containers for statistics
total_chains = len(data)
chain_labels = Counter()
chain_sources = Counter()
review_id_counts = Counter()
content_lengths = []
chains_ending_2023_2024 = Counter()  # Tracks chains ending in 2023/2024 by chain_label

# Extract statistics
for chain in data:
    chain_labels[chain["chain_label"]] += 1
    chain_sources[chain["chain_source"]] += 1
    
    # Extract review_id from file_name (e.g., "CD000017")
    review_id = chain["file_name"].split("_")[2]
    review_id_counts[review_id] += 1
    
    # Calculate content length
    content_length = len(chain["content"])
    content_lengths.append(content_length)
    
    # Check if chain ends in 2023 or 2024
    last_paper = chain["content"][-1] if chain["content"] else None
    if last_paper and isinstance(last_paper, dict):
        year = last_paper.get("year")  # Adjust key based on your data
        if year in {2023, 2024}:
            chains_ending_2023_2024[chain["chain_label"]] += 1

# Compute additional stats
unique_review_ids = len(review_id_counts)
avg_content_length = sum(content_lengths) / total_chains if total_chains > 0 else 0
total_chains_ending_2023_2024 = sum(chains_ending_2023_2024.values())

# Display the statistics
print("### Collected Data Summary ###")
print(f"Total Chains: {total_chains}")
print(f"Chains Ending in 2023 or 2024 by Label: {dict(chains_ending_2023_2024)}")
print(f"Total Chains Ending in 2023 or 2024: {total_chains_ending_2023_2024}")
print(f"Chain Labels Distribution: {dict(chain_labels)}")
print(f"Chain Sources Distribution: {dict(chain_sources)}")
print(f"Unique Review IDs: {unique_review_ids}")
print(f"Average Content Length: {avg_content_length:.2f}")
print("\nTop 5 Review IDs by Chain Count:")
for review_id, count in review_id_counts.most_common(5):
    print(f"  {review_id}: {count} chains")


### Collected Data Summary ###
Total Chains: 2018
Chains Ending in 2023 or 2024 by Label: {'valid': 201, 'invalid_type1': 814, 'invalid_type2': 269}
Total Chains Ending in 2023 or 2024: 1284
Chain Labels Distribution: {'valid': 379, 'invalid_type1': 1184, 'invalid_type2': 455}
Chain Sources Distribution: {'latest': 927, 'influential': 1091}
Unique Review IDs: 198
Average Content Length: 11.29

Top 5 Review IDs by Chain Count:
  CD001461: 27 chains
  CD002800: 25 chains
  CD000951: 21 chains
  CD006000: 21 chains
  CD002048: 20 chains


In [31]:
import json
import random
from collections import Counter, defaultdict

def extract_review_id(file_name):
    """
    Extract the review_id from the file name.
    Assumes the file name has the format 'temporal_chain_<REVIEW_ID>_...'
    """
    return file_name.split("_")[2]  # Adjust index based on your naming convention

def filter_short_chains(data):
    """
    Filter out chains with a length of 1.
    :param data: List of chains (e.g., JSON list).
    :return: Filtered list of chains.
    """
    return [chain for chain in data if len(chain["content"]) > 1]

def stratified_split(data, train_ratio=0.7, valid_ratio=0.15, test_ratio=0.15, seed=4321):
    """
    Stratified and balanced split of chains into train, validation, and test sets,
    ensuring that chains with the same review_id do not appear in multiple splits.
    :param data: List of data items (e.g., JSON list).
    :param train_ratio: Proportion of data to use for training.
    :param valid_ratio: Proportion of data to use for validation.
    :param test_ratio: Proportion of data to use for testing.
    :param seed: Random seed for reproducibility.
    :return: train, validation, and test splits.
    """
    # Group data by review_id
    review_groups = defaultdict(list)
    for chain in data:
        review_id = extract_review_id(chain["file_name"])
        review_groups[review_id].append(chain)
    
    # Further group by chain_label
    label_groups = defaultdict(list)
    for review_id, chains in review_groups.items():
        chain_label = chains[0]["chain_label"]  # Assume all chains under a review_id have the same label
        label_groups[chain_label].append((review_id, chains))
    
    # Shuffle each label group
    random.seed(seed)
    for label in label_groups:
        random.shuffle(label_groups[label])
    
    # Create splits
    train, valid, test = [], [], []
    for label, review_data in label_groups.items():
        total = len(review_data)
        train_size = int(total * train_ratio)
        valid_size = int(total * valid_ratio)
        
        train_groups = review_data[:train_size]
        valid_groups = review_data[train_size:train_size + valid_size]
        test_groups = review_data[train_size + valid_size:]
        
        # Flatten the groups back into lists
        train.extend([chain for _, chains in train_groups for chain in chains])
        valid.extend([chain for _, chains in valid_groups for chain in chains])
        test.extend([chain for _, chains in test_groups for chain in chains])
    
    return train, valid, test

def calculate_statistics(split_data, split_name):
    """
    Calculate and print statistics for a given split, including chains ending in 2023 or 2024
    and chain length stats.
    :param split_data: The data of the split.
    :param split_name: Name of the split (e.g., "train").
    """
    total_chains = len(split_data)
    chain_labels = Counter(chain["chain_label"] for chain in split_data)
    chain_sources = Counter(chain["chain_source"] for chain in split_data)
    review_ids = {extract_review_id(chain["file_name"]) for chain in split_data}
    unique_review_ids = len(review_ids)
    
    # Chains ending in 2023 or 2024
    chains_ending_2023_2024 = sum(
        1 for chain in split_data
        if chain["content"] and isinstance(chain["content"][-1], dict) and chain["content"][-1].get("year") in {2023, 2024}
    )
    
    # Chain length stats
    chain_lengths = [len(chain["content"]) for chain in split_data if chain["content"]]
    avg_chain_length = sum(chain_lengths) / len(chain_lengths) if chain_lengths else 0
    
    # Chain length by type
    chain_length_by_label = defaultdict(list)
    for chain in split_data:
        if chain["content"]:
            chain_length_by_label[chain["chain_label"]].append(len(chain["content"]))
    
    print(f"### Statistics for {split_name} ###")
    print(f"Total Chains: {total_chains}")
    print(f"Chain Labels Distribution: {dict(chain_labels)}")
    print(f"Chain Sources Distribution: {dict(chain_sources)}")
    print(f"Unique Review IDs: {unique_review_ids}")
    print(f"Chains Ending in 2023 or 2024: {chains_ending_2023_2024}")
    print(f"Average Chain Length: {avg_chain_length:.2f}")
    print("Chain Length by Type:")
    for label, lengths in chain_length_by_label.items():
        print(f"  {label}: Min={min(lengths)}, Max={max(lengths)}, Avg={sum(lengths)/len(lengths):.2f}")
    print("")

# Load the collected chains summary
input_path = "./collected_chains_summary.json"
with open(input_path, "r") as f:
    data = json.load(f)
data = filter_short_chains(data)

# Perform stratified split
train_data, valid_data, test_data = stratified_split(data)

# Calculate and display statistics
calculate_statistics(train_data, "Train")
calculate_statistics(valid_data, "Validation")
calculate_statistics(test_data, "Test")


### Statistics for Train ###
Total Chains: 1364
Chain Labels Distribution: {'valid': 259, 'invalid_type1': 793, 'invalid_type2': 312}
Chain Sources Distribution: {'latest': 634, 'influential': 730}
Unique Review IDs: 138
Chains Ending in 2023 or 2024: 850
Average Chain Length: 11.23
Chain Length by Type:
  valid: Min=2, Max=27, Avg=8.95
  invalid_type1: Min=3, Max=27, Avg=12.46
  invalid_type2: Min=2, Max=25, Avg=9.98

### Statistics for Validation ###
Total Chains: 326
Chain Labels Distribution: {'valid': 52, 'invalid_type1': 203, 'invalid_type2': 71}
Chain Sources Distribution: {'latest': 133, 'influential': 193}
Unique Review IDs: 29
Chains Ending in 2023 or 2024: 240
Average Chain Length: 12.03
Chain Length by Type:
  valid: Min=2, Max=20, Avg=10.63
  invalid_type1: Min=3, Max=20, Avg=12.87
  invalid_type2: Min=4, Max=28, Avg=10.65

### Statistics for Test ###
Total Chains: 320
Chain Labels Distribution: {'valid': 60, 'invalid_type1': 188, 'invalid_type2': 72}
Chain Sources Distrib

In [32]:

def calculate_statistics(data, dataset_name="Dataset"):
    """
    Calculate and display statistics for a dataset.
    :param data: List of chains.
    :param dataset_name: Name of the dataset (e.g., "Updated Dataset").
    """
    total_chains = len(data)
    chain_labels = Counter(chain["chain_label"] for chain in data)
    chain_sources = Counter(chain["chain_source"] for chain in data)
    chain_lengths = [len(chain["content"]) for chain in data]
    avg_chain_length = sum(chain_lengths) / len(chain_lengths) if chain_lengths else 0

    # Chains Ending in 2023 or 2024
    chains_ending_2023_2024 = sum(
        1 for chain in data
        if chain["content"] and isinstance(chain["content"][-1], dict) and chain["content"][-1].get("year") in {2023, 2024}
    )

    # Chain Length by Type
    chain_length_by_label = defaultdict(list)
    for chain in data:
        chain_length_by_label[chain["chain_label"]].append(len(chain["content"]))

    print(f"### Statistics for {dataset_name} ###")
    print(f"Total Chains: {total_chains}")
    print(f"Chain Labels Distribution: {dict(chain_labels)}")
    print(f"Chain Sources Distribution: {dict(chain_sources)}")
    print(f"Chains Ending in 2023 or 2024: {chains_ending_2023_2024}")
    print(f"Average Chain Length: {avg_chain_length:.2f}")
    print("Chain Length by Type:")
    for label, lengths in chain_length_by_label.items():
        print(f"  {label}: Min={min(lengths)}, Max={max(lengths)}, Avg={sum(lengths)/len(lengths):.2f}")
    print("")

def balance_valid_and_invalid_chains(data, target_ratio=1.0, window_size=5, step_size=1, min_chain_length=3):
    """
    Dynamically split valid chains and balance them with invalid chains until the desired ratio is achieved.
    Marks split chains with a flag for distinction.
    :param data: List of chains.
    :param target_ratio: Desired ratio of valid to negative chains.
    :param window_size: Maximum length of each subchain.
    :param step_size: Step size for overlapping subchains.
    :param min_chain_length: Minimum chain length to include in splitting.
    :return: Updated dataset with balanced valid and negative chains.
    """
    updated_data = []
    valid_chains = []
    negative_count = sum(1 for chain in data if chain["chain_label"] in {"invalid_type1", "invalid_type2"})
    valid_count = 0

    for chain in data:
        if chain["chain_label"] == "valid" and len(chain["content"]) >= min_chain_length:
            # Mark original valid chain
            chain["generated_from_split"] = False
            valid_chains.append(chain)
            valid_count += 1

            # Generate overlapping subchains
            content = chain["content"]
            for start_idx in range(0, len(content) - window_size + 1, step_size):
                subchain_content = content[start_idx : start_idx + window_size]
                subchain = {
                    "file_name": f"{chain['file_name']}_split_{start_idx}",
                    "chain_label": chain["chain_label"],
                    "chain_source": chain["chain_source"],
                    "file_path": chain["file_path"],  # Keep original path for reference
                    "content": subchain_content,
                    "generated_from_split": True,  # Mark as split
                }
                valid_chains.append(subchain)
                valid_count += 1

                # Stop splitting if valid chains exceed the target ratio
                if valid_count >= target_ratio * negative_count:
                    break
        else:
            # Keep the chain as is
            chain["generated_from_split"] = False  # Negative chains are not split
            updated_data.append(chain)

    # Combine valid and negative chains
    updated_data.extend(valid_chains)
    return updated_data


# Load the original dataset
input_path = "./collected_chains_summary.json"
with open(input_path, "r") as f:
    data = json.load(f)
data = filter_short_chains(data)

# Adjusted function for balancing chains
data_with_balanced_valids = balance_valid_and_invalid_chains(
    data, target_ratio=1.0, window_size=5, step_size=1, min_chain_length=5
)

# Calculate and display statistics for the updated dataset
calculate_statistics(data_with_balanced_valids, "Balanced Dataset with 50-50 Valid and Invalid Chains")



### Statistics for Balanced Dataset with 50-50 Valid and Invalid Chains ###
Total Chains: 3523
Chain Labels Distribution: {'valid': 1884, 'invalid_type1': 1184, 'invalid_type2': 455}
Chain Sources Distribution: {'latest': 1772, 'influential': 1751}
Chains Ending in 2023 or 2024: 1482
Average Chain Length: 8.61
Chain Length by Type:
  valid: Min=2, Max=27, Avg=5.83
  invalid_type1: Min=3, Max=27, Avg=12.52
  invalid_type2: Min=2, Max=28, Avg=9.97



In [33]:
# Save the updated dataset
output_path = "./balanced_chains_50_50.json"
with open(output_path, "w") as f:
    json.dump(data_with_balanced_valids, f, indent=4)

print(f"Balanced dataset saved to {output_path}.")

Balanced dataset saved to ./balanced_chains_50_50.json.


In [34]:
import json

# Load the balanced dataset
input_path = "./balanced_chains_50_50.json"
with open(input_path, "r") as f:
    data = json.load(f)

# Filter out short chains (length <= 1)
data = filter_short_chains(data)

# Perform stratified split
train_data, valid_data, test_data = stratified_split(data, train_ratio=0.7, valid_ratio=0.15, test_ratio=0.15)

# Calculate and display statistics for each split
calculate_statistics(train_data, "Train")
calculate_statistics(valid_data, "Validation")
calculate_statistics(test_data, "Test")

### Statistics for Train ###
Total Chains: 2490
Chain Labels Distribution: {'valid': 1331, 'invalid_type1': 834, 'invalid_type2': 325}
Chain Sources Distribution: {'latest': 1199, 'influential': 1291}
Chains Ending in 2023 or 2024: 1047
Average Chain Length: 8.65
Chain Length by Type:
  valid: Min=2, Max=27, Avg=5.84
  invalid_type1: Min=3, Max=27, Avg=12.62
  invalid_type2: Min=2, Max=28, Avg=9.98

### Statistics for Validation ###
Total Chains: 523
Chain Labels Distribution: {'valid': 296, 'invalid_type1': 165, 'invalid_type2': 62}
Chain Sources Distribution: {'latest': 276, 'influential': 247}
Chains Ending in 2023 or 2024: 215
Average Chain Length: 8.07
Chain Length by Type:
  valid: Min=2, Max=21, Avg=5.71
  invalid_type1: Min=3, Max=21, Avg=11.94
  invalid_type2: Min=2, Max=21, Avg=9.02

### Statistics for Test ###
Total Chains: 510
Chain Labels Distribution: {'valid': 257, 'invalid_type1': 185, 'invalid_type2': 68}
Chain Sources Distribution: {'latest': 297, 'influential': 213}


In [35]:
# Save the splits
output_dir = "./balanced_splits/"
import os
os.makedirs(output_dir, exist_ok=True)

with open(f"{output_dir}/train.json", "w") as f:
    json.dump(train_data, f, indent=4)

with open(f"{output_dir}/valid.json", "w") as f:
    json.dump(valid_data, f, indent=4)

with open(f"{output_dir}/test.json", "w") as f:
    json.dump(test_data, f, indent=4)

print(f"Balanced splits saved to {output_dir}.")

Balanced splits saved to ./balanced_splits/.


#### alpaca format

In [1]:
from datasets import Dataset, DatasetDict

def format_reasoning_chain(content):
    """
    Formats the reasoning chain as a numbered list with paper title and abstract.
    """
    formatted_chain = []
    for i, paper in enumerate(content, start=1):
        title = paper.get("title", "No Title Provided")
        abstract = paper.get("abstract", "No Abstract Provided")
        formatted_chain.append(f"{i}. PAPER TITLE: {title}\n   ABSTRACT: {abstract}")
    return "\n".join(formatted_chain)


def convert_to_alpaca_format(entry):
    """
    Converts a dataset entry into Alpaca format with 'instruction', 'input', and 'output'.
    """
    label = "invalid" if entry["chain_label"] in {"invalid_type1", "invalid_type2"} else entry["chain_label"]
    formatted_chain = format_reasoning_chain(entry["content"])  # Format reasoning chain

    return {
        "instruction": f"Hypotheses are frequently the starting point when undertaking the empirical portion of the scientific process. They guide the types of data collected, analyses conducted, and inferences drawn. You are a scientist. Your task is to evaluate the relevance of a reasoning chain derived from a series of related papers. Each paper in the chain builds on the previous one, starting from a source paper and ending with a final paper. A reasoning chain is considered 'valid' if the logical connections between all the papers in the chain are coherent, and the final hypothesis logically inspired from or depends on the source paper, either directly or through intermediate papers. A reasoning chain is 'invalid' if any logical break exists in the chain, such as unrelated intermediate papers, or  if the final hypothesis does not logically depend on the previous papers. Your task is to classify the given reasoning chain as either 'valid' or 'invalid'. Your evaluation should be based on the logical relevance and progression between the papers.  Answer should be only 'valid' or 'invalid'.",
        "input": f"Reasoning Chain:\n{formatted_chain}",
        "output": label,  # Expected answer ('valid' or 'invalid')
    }
    # sometimes you will be given the target hypothesis - add this to the prompt
    # 1. extract target hypothesis
    # 2. 50 percent of data with conditioned on hypothesis and fine tune the model
    # 3. 


In [2]:
import json

def remove_2024_papers(data):
    """
    Removes papers published in 2024 from the reasoning chains in the dataset.
    :param data: List of chains (each chain is a dictionary with a 'content' field).
    :return: Updated dataset with 2024 papers removed from the chains.
    """
    updated_data = []
    for chain in data:
        # Filter out entries in the content field with year == 2024
        chain["content"] = [paper for paper in chain["content"] if paper.get("year") != 2024]
        
        # Add the chain to the updated dataset only if it still contains papers
        if chain["content"]:
            updated_data.append(chain)
    return updated_data


# Load split datasets
with open("./balanced_splits/train.json", "r") as f:
    train_data = json.load(f)

with open("./balanced_splits/valid.json", "r") as f:
    valid_data = json.load(f)

with open("./balanced_splits/test.json", "r") as f:
    test_data = json.load(f)

# Remove 2024 papers from each split
train_data = remove_2024_papers(train_data)
valid_data = remove_2024_papers(valid_data)
test_data = remove_2024_papers(test_data)

def count_long_chains(data, length_threshold=25):
    """
    Counts the number of chains with a length above the specified threshold.
    :param data: List of chains (each chain is a dictionary with a 'content' field).
    :param length_threshold: Minimum length to qualify as a "long chain."
    :return: Count of chains with a length above the threshold.
    """
    return sum(1 for chain in data if len(chain["content"]) > length_threshold)

# Count long chains in each dataset
train_long_chains = count_long_chains(train_data)
valid_long_chains = count_long_chains(valid_data)
test_long_chains = count_long_chains(test_data)

print(f"Number of chains with length above 25:")
print(f"Train: {train_long_chains}")
print(f"Valid: {valid_long_chains}")
print(f"Test: {test_long_chains}")

def remove_long_chains(data, length_threshold=25):
    """
    Removes chains with a length above the specified threshold from the dataset.
    :param data: List of chains (each chain is a dictionary with a 'content' field).
    :param length_threshold: Maximum allowed length for a chain to be included.
    :return: Updated dataset with long chains removed.
    """
    return [chain for chain in data if len(chain["content"]) <= length_threshold]

# Remove long chains from the train dataset
#train_data = remove_long_chains(train_data)

# Convert split datasets into Hugging Face datasets
train_dataset = Dataset.from_list(train_data)
valid_dataset = Dataset.from_list(valid_data)
test_dataset = Dataset.from_list(test_data)


Number of chains with length above 25:
Train: 14
Valid: 0
Test: 0


In [3]:
# Apply the conversion to Alpaca format
alpaca_train_dataset = train_dataset.map(
    lambda x: convert_to_alpaca_format(x),
    remove_columns=train_dataset.column_names
)

alpaca_valid_dataset = valid_dataset.map(
    lambda x: convert_to_alpaca_format(x),
    remove_columns=valid_dataset.column_names
)

alpaca_test_dataset = test_dataset.map(
    lambda x: convert_to_alpaca_format(x),
    remove_columns=test_dataset.column_names
)

# Combine into a DatasetDict
alpaca_dataset = DatasetDict({
    "train": alpaca_train_dataset,
    "valid": alpaca_valid_dataset,
    "test": alpaca_test_dataset,
})



Map:   0%|          | 0/2490 [00:00<?, ? examples/s]

Map:   0%|          | 0/523 [00:00<?, ? examples/s]

Map:   0%|          | 0/510 [00:00<?, ? examples/s]

In [4]:
output_dir = "./alpaca_formatted_splits/"
import os
os.makedirs(output_dir, exist_ok=True)

alpaca_dataset["train"].to_json(f"{output_dir}/alpaca_train.jsonl")
alpaca_dataset["valid"].to_json(f"{output_dir}/alpaca_valid.jsonl")
alpaca_dataset["test"].to_json(f"{output_dir}/alpaca_test.jsonl")

print(f"Alpaca datasets saved to {output_dir}")


Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Alpaca datasets saved to ./alpaca_formatted_splits/


### alapca input to mistralLite format

In [1]:
from datasets import load_dataset
import os
import json

# Input and output directories
input_dir = "./alpaca_multi_task3/"
output_dir = "./mistral_format_multi_task3/"
os.makedirs(output_dir, exist_ok=True)

def convert_to_mistral_format(entry):
    """
    Converts an Alpaca-formatted entry to the MistralLite format.
    """
    instruction = entry["instruction"]
    input_text = entry["input"]
    output_text = entry["output"]

    # Create the MistralLite prompt format
    if input_text.strip():
        prompt = f"<|prompter|>{instruction}\n{input_text}</s><|assistant|>{output_text}"
    else:
        prompt = f"<|prompter|>{instruction}</s><|assistant|>{output_text}"
    
    return prompt

# Process and save datasets
splits = ["train", "valid", "test"]
for split in splits:
    # Load the dataset
    alpaca_dataset = load_dataset("json", data_files={split: f"{input_dir}/alpaca_{split}.jsonl"})
    
    mistral_data = []
    for entry in alpaca_dataset[split]:
        mistral_entry = convert_to_mistral_format(entry)
        mistral_data.append(mistral_entry)

    # Save to new directory in MistralLite format
    mistral_output_file = os.path.join(output_dir, f"mistral_{split}.jsonl")
    with open(mistral_output_file, "w") as f:
        for mistral_entry in mistral_data:
            json_line = json.dumps({"text": mistral_entry})  # Wrap each line in a JSON object
            f.write(json_line + "\n")  # Write each JSON object as a new line


    print(f"MistralLite {split} dataset saved to {mistral_output_file}")


Generating train split: 0 examples [00:00, ? examples/s]

MistralLite train dataset saved to ./mistral_format_multi_task3/mistral_train.jsonl


Generating valid split: 0 examples [00:00, ? examples/s]

MistralLite valid dataset saved to ./mistral_format_multi_task3/mistral_valid.jsonl


Generating test split: 0 examples [00:00, ? examples/s]

MistralLite test dataset saved to ./mistral_format_multi_task3/mistral_test.jsonl
