#### DivideFromHere

# Experiment 3: Alpha-Beta Pruning Kane and Its randomization integrated with KAN

## Step 1: Overall Design
### Objective:
To integrate the Kolmogorov-Arnold Networks (KAN) model to analyze and interpret the equivalence between deterministic and randomized versions of an Alpha-Beta Pruning algorithm-based chess AI named Kane.

### Steps:
#### Alpha-Beta Pruning Implementation:

**Deterministic Alpha-Beta Pruning Implementation:**
- Implement the deterministic version of the Alpha-Beta Pruning algorithm with heuristic evaluation for Kane.
- Randomized Alpha-Beta Pruning Implementation:
- Introduce randomization to the evaluation function and move selection to create a randomized version.
#### Game Simulation:

**Simulate Games:**
- Simulate games using both the deterministic and randomized versions of Kane.
- Collect performance metrics during each game, including additional metrics like evaluation score, branching factor, depth of search, move diversity, and exploration vs. exploitation.
#### Data Collection and Analysis:

- Collect and Aggregate Performance Metrics:
- Collect evaluation consistency data.
- Collect move stability data.
- Collect search path data.
#### Aggregate and analyze the results from multiple games.
**Visualization:**

- Plot Comparison Metrics and Equivalence Curves:
- Plot comparison metrics to visualize the differences between the deterministic and randomized versions.
- Plot the equivalence curve to show the relationship between the two versions.
**Statistical Analysis:**

- Perform Statistical Tests:
- Perform statistical tests (t-test and F-test) to validate the results.
- Plot the results of the statistical analysis.
### KAN Model Integration:

#### Define and Train KAN Model:

- Define a custom KAN model architecture using PyTorch.
- Train the KAN model on the aggregated data from both deterministic and randomized versions.
- Evaluate the model's performance and track the equivalence score during training.
#### Visualize KAN Model Results:

- Visualize the dataset.
- Extract and visualize the symbolic formula from the trained KAN model.
- Plot the model's structure and equivalence data points.
- Visualize the weights and biases of the trained KAN model.
#### Verification and Conclusion:

**Simulate Multiple Games:**
- Run additional games to gather more data and plot the equivalence curve again.
- Collect and aggregate performance metrics from the additional games.
- Recalculate the means and standard deviations with combined data.
- Plot the equivalence curve again with the combined data.

## Step 2. Implement the deterministic version of the Alpha-Beta Pruning algorithm

In [None]:
import chess
import time
from collections import defaultdict
import pandas as pd
from IPython.display import clear_output, display, SVG
import chess.svg
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

# Deterministic Alpha-Beta Pruning Implementation
class KaneAlphaBetaDeterministic:
    def __init__(self, board):
        self.board = board
        self.evaluation_consistency = defaultdict(list)
        self.move_stability = defaultdict(int)
        self.search_paths = defaultdict(list)

    def heuristic_evaluation(self, board):
        # Enhanced heuristic evaluation function
        material_count = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
        mobility_count = len(list(board.legal_moves))
        piece_square_score = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
        center_control_count = sum(1 if square in [chess.D4, chess.E4, chess.D5, chess.E5] else 0 for square, piece in board.piece_map().items())
        score = material_count + mobility_count + piece_square_score + center_control_count
        return score

    def alpha_beta(self, board, depth, alpha, beta, is_maximizing_player):
        position_hash = hash(board.board_fen())
        if position_hash in self.evaluation_consistency and depth == 0:
            return self.evaluation_consistency[position_hash][-1]  # Return the last stored evaluation for consistency
        
        evaluation = self.heuristic_evaluation(board) if depth == 0 or board.is_game_over() else (
            self.alpha_beta_search(board, depth, alpha, beta, is_maximizing_player)
        )
        self.evaluation_consistency[position_hash].append(evaluation)
        return evaluation

    def alpha_beta_search(self, board, depth, alpha, beta, is_maximizing_player):
        if is_maximizing_player:
            max_eval = float('-inf')
            for move in board.legal_moves:
                board.push(move)
                eval = self.alpha_beta(board, depth - 1, alpha, beta, False)
                board.pop()
                max_eval = max(max_eval, eval)
                alpha = max(alpha, eval)
                if beta <= alpha:
                    break
            return max_eval
        else:
            min_eval = float('inf')
            for move in board.legal_moves:
                board.push(move)
                eval = self.alpha_beta(board, depth - 1, alpha, beta, True)
                board.pop()
                min_eval = min(min_eval, eval)
                beta = min(beta, eval)
                if beta <= alpha:
                    break
            return min_eval

    def find_best_move_alpha_beta(self, depth=3):
        position_hash = hash(self.board.board_fen())
        best_move = None
        best_value = float('-inf')
        for move in self.board.legal_moves:
            self.board.push(move)
            move_value = self.alpha_beta(self.board, depth, float('-inf'), float('inf'), False)
            self.board.pop()
            if move_value > best_value:
                best_value = move_value
                best_move = move
        self.move_stability[position_hash] += 1
        return best_move

    def track_search_path(self, board, move):
        position_hash = hash(board.board_fen())
        self.search_paths[position_hash].append(move.uci())


## 3. Introduce randomization to the Alpha-Beta Pruning algorithm

In [None]:
# Pseudorandom Generator (PRG)
class PseudoRandom:
    def __init__(self, seed):
        self.state = seed

    def random(self):
        self.state = (1103515245 * self.state + 12345) % (2**31)
        return self.state / (2**31)

# Randomized Evaluation Function
class KaneAlphaBetaRandomization:
    def __init__(self, board, seed):
        self.board = board
        self.prng = PseudoRandom(seed)
        self.evaluation_consistency = defaultdict(list)
        self.move_stability = defaultdict(int)
        self.search_paths = defaultdict(list)

    def heuristic_evaluation(self, board):
        # Enhanced heuristic evaluation function with random component
        material_count = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
        mobility_count = len(list(board.legal_moves))
        piece_square_score = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
        center_control_count = sum(1 if square in [chess.D4, chess.E4, chess.D5, chess.E5] else 0 for square, piece in board.piece_map().items())
        score = material_count + mobility_count + piece_square_score + center_control_count
        random_adjustment = int((self.prng.random() - 0.5) * 10)  # Random adjustment between -5 and +5
        return score + random_adjustment

    def alpha_beta(self, board, depth, alpha, beta, is_maximizing_player):
        position_hash = hash(board.board_fen())
        if position_hash in self.evaluation_consistency and depth == 0:
            return self.evaluation_consistency[position_hash][-1]  # Return the last stored evaluation for consistency
        
        evaluation = self.heuristic_evaluation(board) if depth == 0 or board.is_game_over() else (
            self.alpha_beta_search(board, depth, alpha, beta, is_maximizing_player)
        )
        self.evaluation_consistency[position_hash].append(evaluation)
        return evaluation

    def alpha_beta_search(self, board, depth, alpha, beta, is_maximizing_player):
        if is_maximizing_player:
            max_eval = float('-inf')
            for move in board.legal_moves:
                board.push(move)
                eval = self.alpha_beta(board, depth - 1, alpha, beta, False)
                board.pop()
                max_eval = max(max_eval, eval)
                alpha = max(alpha, eval)
                if beta <= alpha:
                    break
            return max_eval
        else:
            min_eval = float('inf')
            for move in board.legal_moves:
                board.push(move)
                eval = self.alpha_beta(board, depth - 1, alpha, beta, True)
                board.pop()
                min_eval = min(min_eval, eval)
                beta = min(beta, eval)
                if beta <= alpha:
                    break
            return min_eval

    def find_best_move_alpha_beta(self, depth=3):
        position_hash = hash(self.board.board_fen())
        best_move = None
        best_value = float('-inf')
        for move in self.board.legal_moves:
            self.board.push(move)
            move_value = self.alpha_beta(self.board, depth, float('-inf'), float('inf'), False)
            self.board.pop()
            if move_value > best_value:
                best_value = move_value
                best_move = move
        self.move_stability[position_hash] += 1
        return best_move

    def track_search_path(self, board, move):
        position_hash = hash(board.board_fen())
        self.search_paths[position_hash].append(move.uci())


## 4. Simulate games using both deterministic and randomized algorithms

In [None]:
# Function to calculate additional metrics
def calculate_metrics(board):
    material_count = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
    mobility_count = len(list(board.legal_moves))
    piece_square_score = sum(1 if piece.color == chess.WHITE else -1 for piece in board.piece_map().values())
    center_control_count = sum(1 if square in [chess.D4, chess.E4, chess.D5, chess.E5] else 0 for square, piece in board.piece_map().items())
    return material_count, mobility_count, piece_square_score, center_control_count

def calculate_additional_metrics(board, move_scores, current_depth, is_exploratory):
    evaluation_score = sum(move_scores) / len(move_scores) if move_scores else 0
    branching_factor = len(list(board.legal_moves))
    depth_of_search = current_depth
    move_diversity = np.var(move_scores) if move_scores else 0
    exploration_vs_exploitation = 1 if is_exploratory else 0
    return evaluation_score, branching_factor, depth_of_search, move_diversity, exploration_vs_exploitation

# Function to play the game with deterministic Alpha-Beta Pruning
def play_game_alpha_beta_deterministic(kane_alpha_beta, depth=3, max_moves=55, max_runtime=600):
    steps, times, material_counts, mobility_counts, piece_square_scores, center_control_counts = [], [], [], [], [], []
    move_list, evaluation_scores, branching_factors, depths_of_search, move_diversities, exploration_vs_exploitations = [], [], [], [], [], []
    step_number = 1

    start_time = time.time()
    while not kane_alpha_beta.board.is_game_over() and step_number <= max_moves and (time.time() - start_time) <= max_runtime:
        move_start_time = time.time()
        best_move = kane_alpha_beta.find_best_move_alpha_beta(depth)
        move_end_time = time.time()

        kane_alpha_beta.board.push(best_move)
        kane_alpha_beta.track_search_path(kane_alpha_beta.board, best_move)

        move_list.append(best_move.uci())
        steps.append(step_number)
        times.append(move_end_time - move_start_time)

        material_count, mobility_count, piece_square_score, center_control_count = calculate_metrics(kane_alpha_beta.board)
        material_counts.append(material_count)
        mobility_counts.append(mobility_count)
        piece_square_scores.append(piece_square_score)
        center_control_counts.append(center_control_count)

        move_scores = [kane_alpha_beta.alpha_beta(kane_alpha_beta.board, depth, float('-inf'), float('inf'), False) for move in kane_alpha_beta.board.legal_moves]
        evaluation_score, branching_factor, depth_of_search, move_diversity, exploration_vs_exploitation = calculate_additional_metrics(
            kane_alpha_beta.board, move_scores, depth, False)
        evaluation_scores.append(evaluation_score)
        branching_factors.append(branching_factor)
        depths_of_search.append(depth_of_search)
        move_diversities.append(move_diversity)
        exploration_vs_exploitations.append(exploration_vs_exploitation)

        step_number += 1

        clear_output(wait=True)
        display(SVG(chess.svg.board(board=kane_alpha_beta.board, size=350)))

        time.sleep(1)
        print(f"Move: {best_move}")
        print(f"Step: {step_number}, Time: {move_end_time - move_start_time}, Material: {material_count}, Mobility: {mobility_count}, Piece-Square: {piece_square_score}, Center Control: {center_control_count}")

    data = {
        'Step': steps,
        'Time': times,
        'Move': move_list,
        'Material Count': material_counts,
        'Mobility Count': mobility_counts,
        'Piece-Square Score': piece_square_scores,
        'Center Control Count': center_control_counts,
        'Evaluation Score': evaluation_scores,
        'Branching Factor': branching_factors,
        'Depth of Search': depths_of_search,
        'Move Diversity': move_diversities,
        'Exploration vs Exploitation': exploration_vs_exploitations
    }
    df = pd.DataFrame(data)

    print("Stop the game in advance!")
    print(f"Result: {kane_alpha_beta.board.result()}")
    print(df)
    return df

# Initialize the boards and engines
board_deterministic = chess.Board()
kane_deterministic = KaneAlphaBetaDeterministic(board_deterministic)

# Simulate and run the deterministic game
print("Running deterministic game...")
deterministic_results = [play_game_alpha_beta_deterministic(kane_deterministic)]

# Function to play the game with randomized Alpha-Beta Pruning
def play_game_alpha_beta_randomized(kane_alpha_beta, depth=3, max_moves=55, max_runtime=600):
    steps, times, material_counts, mobility_counts, piece_square_scores, center_control_counts = [], [], [], [], [], []
    move_list, evaluation_scores, branching_factors, depths_of_search, move_diversities, exploration_vs_exploitations = [], [], [], [], [], []
    step_number = 1

    start_time = time.time()
    while not kane_alpha_beta.board.is_game_over() and step_number <= max_moves and (time.time() - start_time) <= max_runtime:
        move_start_time = time.time()
        best_move = kane_alpha_beta.find_best_move_alpha_beta(depth)
        move_end_time = time.time()

        kane_alpha_beta.board.push(best_move)
        kane_alpha_beta.track_search_path(kane_alpha_beta.board, best_move)

        move_list.append(best_move.uci())
        steps.append(step_number)
        times.append(move_end_time - move_start_time)

        material_count, mobility_count, piece_square_score, center_control_count = calculate_metrics(kane_alpha_beta.board)
        material_counts.append(material_count)
        mobility_counts.append(mobility_count)
        piece_square_scores.append(piece_square_score)
        center_control_counts.append(center_control_count)

        move_scores = [kane_alpha_beta.alpha_beta(kane_alpha_beta.board, depth, float('-inf'), float('inf'), False) for move in kane_alpha_beta.board.legal_moves]
        evaluation_score, branching_factor, depth_of_search, move_diversity, exploration_vs_exploitation = calculate_additional_metrics(
            kane_alpha_beta.board, move_scores, depth, False)
        evaluation_scores.append(evaluation_score)
        branching_factors.append(branching_factor)
        depths_of_search.append(depth_of_search)
        move_diversities.append(move_diversity)
        exploration_vs_exploitations.append(exploration_vs_exploitation)

        step_number += 1

        clear_output(wait=True)
        display(SVG(chess.svg.board(board=kane_alpha_beta.board, size=350)))

        time.sleep(1)
        print(f"Move: {best_move}")
        print(f"Step: {step_number}, Time: {move_end_time - move_start_time}, Material: {material_count}, Mobility: {mobility_count}, Piece-Square: {piece_square_score}, Center Control: {center_control_count}")

    data = {
        'Step': steps,
        'Time': times,
        'Move': move_list,
        'Material Count': material_counts,
        'Mobility Count': mobility_counts,
        'Piece-Square Score': piece_square_scores,
        'Center Control Count': center_control_counts,
        'Evaluation Score': evaluation_scores,
        'Branching Factor': branching_factors,
        'Depth of Search': depths_of_search,
        'Move Diversity': move_diversities,
        'Exploration vs Exploitation': exploration_vs_exploitations
    }
    df = pd.DataFrame(data)

    print("Stop the game in advance!")
    print(f"Result: {kane_alpha_beta.board.result()}")
    print(df)
    return df

# Initialize the boards and engines
board_randomized = chess.Board()
kane_randomized = KaneAlphaBetaRandomization(board_randomized, seed=42)

# Simulate and run the randomized game
print("Running randomized game...")
randomized_results = [play_game_alpha_beta_randomized(kane_randomized)]


## 5. Collect and analyze performance metrics

In [None]:
# Function to ensure all results are DataFrames and handle errors
def ensure_dataframe(result):
    if isinstance(result, pd.DataFrame):
        return result
    try:
        return pd.DataFrame(result)
    except Exception as e:
        print("Error converting to DataFrame:", e)
        return pd.DataFrame()

# Convert results to DataFrames
deterministic_results = [ensure_dataframe(df) for df in deterministic_results]
randomized_results = [ensure_dataframe(df) for df in randomized_results]

# Aggregate and analyze the results
def aggregate_metrics(results):
    numeric_columns = ['Material Count', 'Mobility Count', 'Piece-Square Score', 'Center Control Count',
                       'Evaluation Score', 'Branching Factor', 'Depth of Search', 'Move Diversity',
                       'Exploration vs Exploitation']
    
    aggregated_data = pd.concat(results, ignore_index=True)
    
    mean_metrics = aggregated_data[numeric_columns].mean()
    std_metrics = aggregated_data[numeric_columns].std()
    
    return mean_metrics, std_metrics, aggregated_data

# Aggregate the results
deterministic_mean, deterministic_std, deterministic_data = aggregate_metrics(deterministic_results)
randomized_mean, randomized_std, randomized_data = aggregate_metrics(randomized_results)

# Display the aggregated metrics
print("Deterministic Mean Metrics:\n", deterministic_mean)
print("Deterministic Std Metrics:\n", deterministic_std)
print("\nRandomized Mean Metrics:\n", randomized_mean)
print("Randomized Std Metrics:\n", randomized_std)

# Display move sequences and non-numeric data
print("\nDeterministic Moves:\n", deterministic_data['Move'])
print("\nRandomized Moves:\n", randomized_data['Move'])


## 6. Plot comparison metrics and equivalence curves

In [None]:
# Plot comparison metrics
def plot_comparison_metrics(deterministic_mean, deterministic_std, randomized_mean, randomized_std):
    metrics = deterministic_mean.index
    x = range(len(metrics))

    fig, axs = plt.subplots(2, 1, figsize=(14, 10))

    # Plot means
    axs[0].bar(x, deterministic_mean, width=0.4, label='Deterministic Mean', align='center')
    axs[0].bar(x, randomized_mean, width=0.4, label='Randomized Mean', align='edge')
    axs[0].set_xticks(x)
    axs[0].set_xticklabels(metrics, rotation=45, ha='right')
    axs[0].set_ylabel('Mean Value')
    axs[0].set_title('Comparison of Mean Metrics')
    axs[0].legend()

    # Plot standard deviations
    axs[1].bar(x, deterministic_std, width=0.4, label='Deterministic Std', align='center')
    axs[1].bar(x, randomized_std, width=0.4, label='Randomized Std', align='edge')
    axs[1].set_xticks(x)
    axs[1].set_xticklabels(metrics, rotation=45, ha='right')
    axs[1].set_ylabel('Standard Deviation')
    axs[1].set_title('Comparison of Standard Deviation Metrics')
    axs[1].legend()

    plt.tight_layout()
    plt.show()

# Plot equivalence curve
def plot_equivalence_curve(deterministic_mean, deterministic_std, randomized_mean, randomized_std):
    metrics = deterministic_mean.index
    x = np.arange(len(metrics))

    # Create figure and axis
    fig, ax = plt.subplots(figsize=(14, 7))

    # Plot deterministic means with error bars for std
    ax.errorbar(x, deterministic_mean, yerr=deterministic_std, fmt='o-', label='Deterministic', color='blue', capsize=5)

    # Plot randomized means with error bars for std
    ax.errorbar(x, randomized_mean, yerr=randomized_std, fmt='o-', label='Randomized', color='green', capsize=5)

    # Fill between for deterministic std
    ax.fill_between(x, deterministic_mean - deterministic_std, deterministic_mean + deterministic_std, color='blue', alpha=0.2)

    # Fill between for randomized std
    ax.fill_between(x, randomized_mean - randomized_std, randomized_mean + randomized_std, color='green', alpha=0.2)

    # Add title and labels
    ax.set_title('Equivalence Curve for Deterministic and Randomized Alpha-Beta Algorithms')
    ax.set_xlabel('Metrics')
    ax.set_ylabel('Values')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics, rotation=45, ha='right')

    # Add legend
    ax.legend()

    # Show plot
    plt.tight_layout()
    plt.show()

# Plot the comparison metrics
plot_comparison_metrics(deterministic_mean, deterministic_std, randomized_mean, randomized_std)

# Plot the equivalence curve
plot_equivalence_curve(deterministic_mean, deterministic_std, randomized_mean, randomized_std)


## 7. Perform statistical tests

In [None]:
from scipy.stats import ttest_ind, f_oneway

# Function to perform statistical tests
def perform_statistical_tests(deterministic_metrics, randomized_metrics):
    results = {}
    for metric in deterministic_metrics.index:
        t_stat, p_value_t = ttest_ind(deterministic_data[metric], randomized_data[metric], equal_var=False)
        f_stat, p_value_f = f_oneway(deterministic_data[metric], randomized_data[metric])
        results[metric] = {
            't_stat': t_stat,
            'p_value_t': p_value_t,
            'f_stat': f_stat,
            'p_value_f': p_value_f
        }
    return results

# Perform statistical tests
statistical_results = perform_statistical_tests(deterministic_mean, randomized_mean)

# Display the results
for metric, result in statistical_results.items():
    print(f"{metric}: t-statistic = {result['t_stat']}, p-value (t-test) = {result['p_value_t']}")
    print(f"{metric}: f-statistic = {result['f_stat']}, p-value (F-test) = {result['p_value_f']}\n")

# Plot statistical analysis results
def plot_statistical_analysis(statistical_results):
    metrics = list(statistical_results.keys())
    t_stats = [result['t_stat'] for result in statistical_results.values()]
    p_values_t = [result['p_value_t'] for result in statistical_results.values()]

    fig, axs = plt.subplots(2, 1, figsize=(14, 10))

    # Plot t-statistics
    axs[0].bar(metrics, t_stats, color='blue')
    axs[0].set_xticks(metrics)
    axs[0].set_xticklabels(metrics, rotation=45, ha='right')
    axs[0].set_ylabel('t-statistic')
    axs[0].set_title('t-statistic of Each Metric')

    # Plot p-values (t-test)
    axs[1].bar(metrics, p_values_t, color='green')
    axs[1].set_xticks(metrics)
    axs[1].set_xticklabels(metrics, rotation=45, ha='right')
    axs[1].axhline(y=0.05, color='r', linestyle='--')
    axs[1].set_ylabel('p-value (t-test)')
    axs[1].set_title('p-value (t-test) of Each Metric')

    plt.tight_layout()
    plt.show()

# Plot statistical analysis results
plot_statistical_analysis(statistical_results)


## 8. Run additional games to gather more data and plot the equivalence curve again

In [None]:
# Run additional games to gather more data
def compare_alpha_beta_kane_versions(deterministic_kane, randomized_kane, depth=3, games=5, max_moves=55, max_runtime=600):
    deterministic_results = []
    randomized_results = []

    for _ in range(games):
        # Play game with deterministic Kane
        deterministic_kane.board.reset()
        deterministic_data = play_game_alpha_beta_deterministic(deterministic_kane, depth, max_moves, max_runtime)
        deterministic_results.append(deterministic_data)

        # Play game with randomized Kane
        randomized_kane.board.reset()
        randomized_data = play_game_alpha_beta_randomized(randomized_kane, depth, max_moves, max_runtime)
        randomized_results.append(randomized_data)

    return deterministic_results, randomized_results

# Initialize the boards and engines
board_deterministic = chess.Board()
kane_deterministic = KaneAlphaBetaDeterministic(board_deterministic)

board_randomized = chess.Board()
kane_randomized = KaneAlphaBetaRandomization(board_randomized, seed=42)

# Compare the two versions over multiple games
deterministic_results, randomized_results = compare_alpha_beta_kane_versions(kane_deterministic, kane_randomized)

# Aggregate the results
deterministic_mean, deterministic_std, deterministic_data = aggregate_metrics(deterministic_results)
randomized_mean, randomized_std, randomized_data = aggregate_metrics(randomized_results)

# Run additional games to gather more data
additional_games = 10
deterministic_results_additional, randomized_results_additional = compare_alpha_beta_kane_versions(kane_deterministic, kane_randomized, games=additional_games)

# Aggregate the additional data
deterministic_mean_additional, deterministic_std_additional, deterministic_data_additional = aggregate_metrics(deterministic_results_additional)
randomized_mean_additional, randomized_std_additional, randomized_data_additional = aggregate_metrics(randomized_results_additional)

# Combine the original and additional data
combined_deterministic_data = pd.concat([deterministic_data, deterministic_data_additional], ignore_index=True)
combined_randomized_data = pd.concat([randomized_data, randomized_data_additional], ignore_index=True)

# Recalculate the means and standard deviations
combined_deterministic_mean, combined_deterministic_std = combined_deterministic_data.mean(), combined_deterministic_data.std()
combined_randomized_mean, combined_randomized_std = combined_randomized_data.mean(), combined_randomized_data.std()

# Display the combined metrics
print("Combined Deterministic Mean Metrics:\n", combined_deterministic_mean)
print("Combined Deterministic Std Metrics:\n", combined_deterministic_std)
print("\nCombined Randomized Mean Metrics:\n", combined_randomized_mean)
print("Combined Randomized Std Metrics:\n", combined_randomized_std)

# Plot the equivalence curve again with combined data
plot_equivalence_curve(combined_deterministic_mean, combined_deterministic_std, combined_randomized_mean, combined_randomized_std)


## 9. KAN Integration for Alpha-Beta Pruning Kane
### 9.1. Define a custom KAN model architecture using PyTorch

In [None]:
class KANModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(KANModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out


### 9.2 Train the KAN model on the collected data

In [None]:
# Prepare the data for KAN model training
def prepare_data(deterministic_data, randomized_data):
    X = deterministic_data[['Material Count', 'Mobility Count', 'Piece-Square Score', 'Center Control Count']].values
    y = randomized_data[['Material Count', 'Mobility Count', 'Piece-Square Score', 'Center Control Count']].values
    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

X, y = prepare_data(combined_deterministic_data, combined_randomized_data)

# Initialize the KAN model
input_dim = X.shape[1]
hidden_dim = 128
output_dim = y.shape[1]
kan_model = KANModel(input_dim, hidden_dim, output_dim)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(kan_model.parameters(), lr=0.001)

# Train the KAN model
num_epochs = 500
for epoch in range(num_epochs):
    kan_model.train()
    optimizer.zero_grad()
    outputs = kan_model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch+1) % 50 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


### 9.3 3. Evaluate the model's performance and track the equivalence score during training

In [None]:
# Evaluate the KAN model
kan_model.eval()
with torch.no_grad():
    predictions = kan_model(X)
    equivalence_score = criterion(predictions, y).item()

print(f'Equivalence Score: {equivalence_score:.4f}')

# Visualization of dataset
plt.figure(figsize=(10, 6))
plt.scatter(X[:, 0], y[:, 0], label='Material Count')
plt.scatter(X[:, 1], y[:, 1], label='Mobility Count')
plt.scatter(X[:, 2], y[:, 2], label='Piece-Square Score')
plt.scatter(X[:, 3], y[:, 3], label='Center Control Count')
plt.legend()
plt.xlabel('Deterministic Data')
plt.ylabel('Randomized Data')
plt.title('Deterministic vs Randomized Metrics')
plt.show()

# Plot the equivalence curve to show the relationship between the deterministic and randomized versions
plt.figure(figsize=(10, 6))
plt.plot(range(len(X)), predictions[:, 0], label='Material Count Prediction', linestyle='--')
plt.plot(range(len(X)), y[:, 0], label='Material Count Actual')
plt.plot(range(len(X)), predictions[:, 1], label='Mobility Count Prediction', linestyle='--')
plt.plot(range(len(X)), y[:, 1], label='Mobility Count Actual')
plt.plot(range(len(X)), predictions[:, 2], label='Piece-Square Score Prediction', linestyle='--')
plt.plot(range(len(X)), y[:, 2], label='Piece-Square Score Actual')
plt.plot(range(len(X)), predictions[:, 3], label='Center Control Count Prediction', linestyle='--')
plt.plot(range(len(X)), y[:, 3], label='Center Control Count Actual')
plt.legend()
plt.xlabel('Data Points')
plt.ylabel('Metrics')
plt.title('Equivalence Curve for Deterministic and Randomized Alpha-Beta Metrics')
plt.show()

# Plot the model's structure
from torchviz import make_dot

kan_model_visual = make_dot(kan_model(X), params=dict(kan_model.named_parameters()))
kan_model_visual.format = 'png'
kan_model_visual.render('kan_model')


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=f6f51e1a-d40a-494a-8398-36807e7a81cb' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>