In [None]:
import itertools
import random
import numpy as np
import pandas as pd
import pickle
from collections import Counter
from itertools import combinations, product
from pyeda.inter import *
import torch
import networkx as nx
from torch_geometric.data import Data
import re

In [None]:
# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

In [None]:
class BiasControlledCircuitGenerator:
    """Generate Boolean circuits with precise bias control for classification"""

    def __init__(self, num_inputs=8):
        self.num_inputs = num_inputs
        self.variables = [exprvar(f'x{i}') for i in range(num_inputs)]

    def calculate_bias(self, expr):
        """Calculate the bias (fraction of inputs that evaluate to True)"""
        true_count = 0
        total_count = 2 ** self.num_inputs

        for assignment in product([0, 1], repeat=self.num_inputs):
            point = {var: val for var, val in zip(self.variables, assignment)}
            if expr.restrict(point):
                true_count += 1

        return true_count / total_count

    def generate_ultra_sparse_circuit(self, target_bias=0.005):
        """Generate ultra-sparse circuits (Class 0: bias < 0.01)"""
        circuits = []

        # Method 1: Deep AND chains
        def create_deep_and_chain(depth):
            if depth <= self.num_inputs:
                selected_vars = random.sample(self.variables, depth)
                return And(*selected_vars)
            else:
                # Create nested AND structures
                return And(*random.sample(self.variables, self.num_inputs))

        # Method 2: Majority functions with high threshold
        def create_high_threshold_majority():
            threshold = max(7, int(0.85 * self.num_inputs))
            selected_vars = random.sample(self.variables, threshold)
            return And(*selected_vars)

        # Method 3: Conjunctive normal form with many clauses
        def create_restrictive_cnf():
            num_clauses = random.randint(3, 5)
            clauses = []
            for _ in range(num_clauses):
                clause_size = random.randint(2, 4)
                clause_vars = random.sample(self.variables, clause_size)
                # Mix of positive and negative literals
                literals = []
                for var in clause_vars:
                    if random.random() < 0.7:  # Bias towards positive literals
                        literals.append(var)
                    else:
                        literals.append(~var)
                clauses.append(Or(*literals))
            return And(*clauses)

        # Generate circuits using different methods
        methods = [create_deep_and_chain, create_high_threshold_majority, create_restrictive_cnf]

        for method in methods:
            for _ in range(20):  # Generate multiple candidates
                if method == create_deep_and_chain:
                    circuit = method(random.randint(7, self.num_inputs))
                else:
                    circuit = method()

                bias = self.calculate_bias(circuit)
                if bias < 0.01:
                    circuits.append((circuit, bias))

        return circuits

    def generate_sparse_circuit(self, target_range=(0.01, 0.06)):
        """Generate sparse circuits (Class 1: bias 0.01-0.06)"""
        circuits = []

        # Method 1: XOR patterns with AND modifiers
        def create_xor_and_hybrid():
            # Start with XOR pattern (bias ≈ 0.5)
            xor_vars = random.sample(self.variables, random.randint(3, 5))
            xor_part = Xor(*xor_vars)

            # Add AND constraint to reduce bias
            and_vars = random.sample(self.variables, random.randint(2, 4))
            and_part = And(*and_vars)

            return And(xor_part, and_part)

        # Method 2: Threshold functions with moderate threshold
        def create_moderate_threshold():
            threshold = random.randint(4, 6)
            selected_vars = random.sample(self.variables, threshold)
            return And(*selected_vars)

        # Method 3: Mixed AND-OR with careful balance
        def create_balanced_mixed():
            # Create small OR groups
            or_groups = []
            remaining_vars = self.variables.copy()

            for _ in range(random.randint(2, 3)):
                if len(remaining_vars) >= 2:
                    group_vars = random.sample(remaining_vars, 2)
                    or_groups.append(Or(*group_vars))
                    for var in group_vars:
                        remaining_vars.remove(var)

            # AND all OR groups together
            if or_groups:
                return And(*or_groups)
            else:
                return And(*random.sample(self.variables, 3))

        methods = [create_xor_and_hybrid, create_moderate_threshold, create_balanced_mixed]

        for method in methods:
            for _ in range(30):  # More candidates for this class
                circuit = method()
                bias = self.calculate_bias(circuit)
                if target_range[0] <= bias <= target_range[1]:
                    circuits.append((circuit, bias))

        return circuits

    def generate_dense_circuit(self, target_bias=0.06):
        """Generate dense circuits (Class 2: bias ≥ 0.06)"""
        circuits = []

        # Method 1: OR-heavy structures
        def create_or_heavy():
            # Create multiple OR branches
            or_branches = []
            for _ in range(random.randint(2, 4)):
                branch_vars = random.sample(self.variables, random.randint(2, 3))
                or_branches.append(And(*branch_vars))

            return Or(*or_branches)

        # Method 2: Low threshold majority functions
        def create_low_threshold():
            threshold = random.randint(2, 4)
            # Create all possible combinations of threshold size
            var_combinations = list(combinations(self.variables, threshold))
            selected_combinations = random.sample(var_combinations,
                                                min(len(var_combinations),
                                                    random.randint(3, 8)))

            terms = [And(*combo) for combo in selected_combinations]
            return Or(*terms)

        # Method 3: Disjunctive normal form with many terms
        def create_rich_dnf():
            num_terms = random.randint(4, 8)
            terms = []

            for _ in range(num_terms):
                term_size = random.randint(2, 4)
                term_vars = random.sample(self.variables, term_size)
                # Mix positive and negative literals
                literals = []
                for var in term_vars:
                    if random.random() < 0.6:  # Bias towards positive
                        literals.append(var)
                    else:
                        literals.append(~var)
                terms.append(And(*literals))

            return Or(*terms)

        methods = [create_or_heavy, create_low_threshold, create_rich_dnf]

        for method in methods:
            for _ in range(25):
                circuit = method()
                bias = self.calculate_bias(circuit)
                if bias >= target_bias:
                    circuits.append((circuit, bias))

        return circuits


In [None]:
def generate_balanced_dataset(num_circuits_per_class=1000, num_inputs=8):
    """Generate balanced dataset with tight bias control"""

    generator = BiasControlledCircuitGenerator(num_inputs)
    dataset = []

    # Class 0: Ultra-sparse (bias < 0.01)
    print("Generating ultra-sparse circuits...")
    ultra_sparse_circuits = []
    attempts = 0
    max_attempts = 10000

    while len(ultra_sparse_circuits) < num_circuits_per_class and attempts < max_attempts:
        candidates = generator.generate_ultra_sparse_circuit()
        ultra_sparse_circuits.extend(candidates)
        attempts += 1

        if attempts % 1000 == 0:
            print(f"Ultra-sparse progress: {len(ultra_sparse_circuits)}/{num_circuits_per_class}")

    # Take best candidates (lowest bias)
    ultra_sparse_circuits.sort(key=lambda x: x[1])
    selected_ultra_sparse = ultra_sparse_circuits[:num_circuits_per_class]

    for circuit, bias in selected_ultra_sparse:
        dataset.append({
            'circuit': circuit,
            'bias': bias,
            'class': 0,
            'features': extract_circuit_features(circuit)
        })

    # Class 1: Sparse (0.01 <= bias < 0.06)
    print("Generating sparse circuits...")
    sparse_circuits = []
    attempts = 0

    while len(sparse_circuits) < num_circuits_per_class and attempts < max_attempts:
        candidates = generator.generate_sparse_circuit()
        sparse_circuits.extend(candidates)
        attempts += 1

        if attempts % 1000 == 0:
            print(f"Sparse progress: {len(sparse_circuits)}/{num_circuits_per_class}")

    # Select circuits distributed across the range
    sparse_circuits.sort(key=lambda x: x[1])
    selected_sparse = sparse_circuits[:num_circuits_per_class]

    for circuit, bias in selected_sparse:
        dataset.append({
            'circuit': circuit,
            'bias': bias,
            'class': 1,
            'features': extract_circuit_features(circuit)
        })

    # Class 2: Dense (bias >= 0.06)
    print("Generating dense circuits...")
    dense_circuits = []
    attempts = 0

    while len(dense_circuits) < num_circuits_per_class and attempts < max_attempts:
        candidates = generator.generate_dense_circuit()
        dense_circuits.extend(candidates)
        attempts += 1

        if attempts % 1000 == 0:
            print(f"Dense progress: {len(dense_circuits)}/{num_circuits_per_class}")

    # Select circuits with good distribution
    dense_circuits.sort(key=lambda x: x[1], reverse=True)
    selected_dense = dense_circuits[:num_circuits_per_class]

    for circuit, bias in selected_dense:
        dataset.append({
            'circuit': circuit,
            'bias': bias,
            'class': 2,
            'features': extract_circuit_features(circuit)
        })

    return dataset

def extract_circuit_features(circuit):
    """Extract 13-dimensional features from circuit structure"""
    # This is a placeholder - implement based on your existing feature extraction
    # Features might include: depth, width, gate counts, connectivity patterns, etc.
    return np.random.rand(13)  # Replace with actual feature extraction

def validate_class_separation(dataset):
    """Validate that circuits fall within correct bias ranges"""
    class_counts = {0: 0, 1: 0, 2: 0}
    correct_counts = {0: 0, 1: 0, 2: 0}

    for sample in dataset:
        bias = sample['bias']
        true_class = sample['class']
        class_counts[true_class] += 1

        # Check if bias falls in correct range
        if true_class == 0 and bias < 0.01:
            correct_counts[0] += 1
        elif true_class == 1 and 0.01 <= bias < 0.06:
            correct_counts[1] += 1
        elif true_class == 2 and bias >= 0.06:
            correct_counts[2] += 1

    print("Class Separation Validation:")
    for class_id in [0, 1, 2]:
        accuracy = correct_counts[class_id] / class_counts[class_id] * 100
        print(f"Class {class_id}: {correct_counts[class_id]}/{class_counts[class_id]} ({accuracy:.1f}%)")

    overall_accuracy = sum(correct_counts.values()) / sum(class_counts.values()) * 100
    print(f"Overall: {overall_accuracy:.1f}%")

    return overall_accuracy >= 95.0

In [None]:
%%time
# Generate the dataset
dataset_8_inputs = generate_balanced_dataset(num_circuits_per_class=1000, num_inputs=8)

In [None]:
dataset_10_inputs = generate_balanced_dataset(num_circuits_per_class=250, num_inputs=10)

In [None]:
dataset = dataset_10_inputs + dataset_8_inputs