In [2]:
import numpy as np
import pickle
import os
import sys
from mindreadingautobots.sequence_generators import data_io


In [3]:
def generate_multitask_sparse_parity(n_tasks, n_bits, k, n_data, p_bitflip=0.0, seed=None):
    """
    Generate a multitask sparse parity dataset.
    
    Args:
        n_tasks: Number of subtasks (distinct versions of sparse parity)
        n_bits: Total length of task bits
        k: Size of the fixed subset for each parity calculation
        n_data: Number of data points to generate
        p_bitflip: Probability of flipping bits in the task bits (not control bits)
        seed: Random seed for reproducibility
        
    Returns:
        X: Array of shape (n_data, n_tasks + n_bits + 1) containing noiseless data:
           - First n_tasks bits are control bits (one-hot encoding of task)
           - Next n_bits are task bits
           - Last bit is the output (parity of relevant task bits)
        Z: Array of same shape as X but with noise in the task bits (if p_bitflip > 0)
        task_subsets: List of k indices for each task indicating which bits to use for parity
    """
    if seed is not None:
        np.random.seed(seed)
    
    # Initialize the dataset
    total_bits = n_tasks + n_bits + 1  # control bits + task bits + output bit
    X = np.zeros((n_data, total_bits), dtype=np.int32)
    
    # Generate random task subsets (each task uses a different subset of k indices)
    task_subsets = []
    for i in range(n_tasks):
        # Generate a random subset of k indices from the task bits
        subset = np.sort(np.random.choice(n_bits, k, replace=False))
        task_subsets.append(subset)
    
    # Generate data for each example
    for i in range(n_data):
        # Randomly select a task (which control bit to activate)
        active_task = np.random.randint(0, n_tasks)
        
        # Set the control bit (one-hot encoding)
        X[i, active_task] = 1
        
        # Generate random task bits
        task_bits = np.random.randint(0, 2, n_bits)
        X[i, n_tasks:n_tasks+n_bits] = task_bits
        
        # Compute the parity of the subset corresponding to the active task
        relevant_subset = task_subsets[active_task]
        relevant_bits = task_bits[relevant_subset]
        parity = np.sum(relevant_bits) % 2
        
        # Set the output bit
        X[i, -1] = parity
    
    # Apply noise to task bits if specified
    if p_bitflip > 0:
        # Create a copy of X
        Z = np.copy(X)
        
        # Generate noise mask for task bits only
        flips = np.random.binomial(1, p_bitflip, size=(n_data, n_bits))
        
        # Apply noise to task bits only
        Z[:, n_tasks:n_tasks+n_bits] = np.logical_xor(
            X[:, n_tasks:n_tasks+n_bits], 
            flips
        ).astype(np.int32)
        
        # Recompute output bit based on noisy task bits
        for i in range(n_data):
            active_task = np.argmax(Z[i, :n_tasks])
            relevant_subset = task_subsets[active_task]
            relevant_bits = Z[i, n_tasks:n_tasks+n_bits][relevant_subset]
            Z[i, -1] = np.sum(relevant_bits) % 2

        print(X, Z, task_subsets)            
        return X, Z, task_subsets

    else:
        return X, X, task_subsets


In [4]:
def verify_examples(X, task_subsets, n_tasks, num_examples=5):
    """Verify correctness of several examples in the dataset"""
    for i in range(min(num_examples, len(X))):
        example = X[i]
        
        # Determine which task is active
        active_task = np.argmax(example[:n_tasks])
        
        # Get the task bits
        task_bits = example[n_tasks:-1]
        
        # Get the subset for the active task
        relevant_subset = task_subsets[active_task]
        relevant_bits = task_bits[relevant_subset]
        
        # Compute expected parity
        expected_parity = np.sum(relevant_bits) % 2
        actual_parity = example[-1]
        
        # Display the example
        print(f"\nExample {i+1}:")
        print(f"Full string: {''.join([str(x) for x in example[:-1]])}")
        print(f"Control bits: {''.join([str(x) for x in example[:n_tasks]])}")
        print(f"Task bits: {''.join([str(x) for x in task_bits])}")
        print(f"Active task: {active_task+1}")
        print(f"Relevant subset indices: {relevant_subset}")
        print(f"Relevant bits: {relevant_bits}")
        print(f"Expected answer: {expected_parity}")
        print(f"Actual answer: {actual_parity}")
        print(f"Correct: {expected_parity == actual_parity}")

In [20]:
# sys.path.append("../src")

n_tasks = 2  # Number of tasks
task_bits_length = 8  # Length of the task bits portion
k = 3  # Size of subset for parity calculation
n_train = 10  # Number of training examples
n_val = 4  # Number of validation examples
p_bitflips = [0.0, 0.1, 0.2, 0.3]  # Array of bit flip probabilities
seed = 42  # Random seed

for p_bitflip in p_bitflips:
    # Generate the dataset
    print(f"\nGenerating multitask_sparse_parity with {n_tasks} tasks, {task_bits_length} task bits, k={k}, noise={p_bitflip}")
    X, Z, task_subsets = generate_multitask_sparse_parity(
        n_tasks=n_tasks,
        n_bits=task_bits_length,
        k=k,
        n_data=n_train + n_val,
        p_bitflip=p_bitflip,
        seed=seed
    )

    # Print the task subsets (which bits are used for each task)
    for i, subset in enumerate(task_subsets):
        print(f"Task {i+1} uses bits {subset} for parity calculation")

    # Split into train and validation sets
    X_train = X[:n_train]
    X_val = X[n_train:]
    Z_train = Z[:n_train]
    Z_val = Z[n_train:]

    # Verify some examples
    print("\nVerifying noiseless training examples:")
    verify_examples(X_train, task_subsets, n_tasks, 2)

    if p_bitflip > 0:
        print("\nVerifying noisy training examples:")
        verify_examples(Z_train, task_subsets, n_tasks, 2)

    # Compare original and noisy examples
    if p_bitflip > 0:
        print("\nComparing original and noisy examples:")
        for i in range(2):  # Show 2 examples
            print(f"Example {i+1}:")
            print(f"Original task bits: {X_train[i, n_tasks:-1]}")
            print(f"Noisy task bits: {Z_train[i, n_tasks:-1]}")
            print(f"Bits flipped: {np.sum(X_train[i, n_tasks:-1] != Z_train[i, n_tasks:-1])}")
            print(f"Original answer: {X_train[i, -1]}")
            print(f"Noisy answer: {Z_train[i, -1]}\n")

    # Create directory for this bit flip rate
    bf_str = '0' if p_bitflip == 0 else str(int(p_bitflip * 100))
    output_dir = f'./multitask_sparse_parity_ntasks{task_bits_length}_ncontrol{n_tasks}_k{k}_ndata{n_train}_bf{bf_str}_seed{seed}'
    os.makedirs(output_dir, exist_ok=True)

    # Save the datasets
    data_io.save_numpy_as_dict(X_train, f'{output_dir}/noiseless_train.pkl')
    data_io.save_numpy_as_dict(X_val, f'{output_dir}/noiseless_val.pkl')
    if p_bitflip > 0:
        data_io.save_numpy_as_dict(Z_train, f'{output_dir}/train.pkl')
        data_io.save_numpy_as_dict(Z_val, f'{output_dir}/val.pkl')

    # Print saved data
    print("\nPrinting saved data:")
    for filename in ['noiseless_train.pkl', 'noiseless_val.pkl', 'train.pkl', 'val.pkl']:
        filepath = f'{output_dir}/{filename}'
        if os.path.exists(filepath):
            with open(filepath, 'rb') as f:
                data = pickle.load(f)
                print(f"\nContents of {filename}:")
                print(f"Number of examples: {len(data['line'])}")
                print("First example:")
                print(f"Line: {data['line'][0]}")
                print(f"Label: {data['label'][0]}")





Generating multitask_sparse_parity with 2 tasks, 8 task bits, k=3, noise=0.0
Task 1 uses bits [0 1 5] for parity calculation
Task 2 uses bits [0 3 7] for parity calculation

Verifying noiseless training examples:

Example 1:
Full string: 0110101111
Control bits: 01
Task bits: 10101111
Active task: 2
Relevant subset indices: [0 3 7]
Relevant bits: [1 0 1]
Expected answer: 0
Actual answer: 0
Correct: True

Example 2:
Full string: 0111100111
Control bits: 01
Task bits: 11100111
Active task: 2
Relevant subset indices: [0 3 7]
Relevant bits: [1 0 1]
Expected answer: 0
Actual answer: 0
Correct: True

Printing saved data:

Contents of noiseless_train.pkl:
Number of examples: 10
First example:
Line: 0110101111
Label: 0

Contents of noiseless_val.pkl:
Number of examples: 4
First example:
Line: 0110101101
Label: 0

Generating multitask_sparse_parity with 2 tasks, 8 task bits, k=3, noise=0.1
[[0 1 1 0 1 0 1 1 1 1 0]
 [0 1 1 1 1 0 0 1 1 1 0]
 [1 0 1 0 0 0 0 0 1 1 1]
 [0 1 1 1 0 1 1 0 1 0 0]
 [0 1