In [13]:
import pandas as pd
import numpy as np
import itertools
from scipy.optimize import minimize
import matplotlib.pyplot as plt

In [14]:
# define partition rules
def partition(condition):
    if condition == 1:
        # 单超平面分割，4种情况
        return [(i,) for i in range(4)]
    else:
        # 两个超平面分割，6种情况
        two_planes = list(itertools.combinations(range(4), 2))
        
        # 三个超平面分割，24种情况
        three_planes = []
        for m in range(4):
            other_dims = [i for i in range(4) if i != m]
            for n1, n2 in itertools.combinations(other_dims, 2):
                three_planes.append((m, n1, n2))
                three_planes.append((m, n2, n1))
        
        return two_planes + three_planes

# generate centers
def generate_centers():
    centers = {}
    
    # 条件1的中心点
    for rule in partition(1):
        dim = rule[0]
        centers[rule] = ([0.25 if i == dim else 0.5 for i in range(4)],
                         [0.75 if i == dim else 0.5 for i in range(4)])
    
    # 条件2的中心点
    for rule in partition(2):
        if len(rule) == 2:  # 两个超平面
            dim1, dim2 = rule
            centers[rule] = tuple([0.25 + 0.5*i if d in rule else 0.5 for d in range(4)]
                                  for i in range(2) for j in range(2))
        else:  # 三个超平面
            m, n1, n2 = rule
            centers[rule] = (
                [0.25 if d == m else (0.25 if d == n1 else 0.5) for d in range(4)],
                [0.25 if d == m else (0.75 if d == n1 else 0.5) for d in range(4)],
                [0.75 if d == m else (0.25 if d == n2 else 0.5) for d in range(4)],
                [0.75 if d == m else (0.75 if d == n2 else 0.5) for d in range(4)]
            )
    
    return centers

# generate all possible centers
all_centers = generate_centers()

In [15]:
# define model parameters
class ModelParams:
    def __init__(self, k, beta, phi):
        self.k = k
        self.beta = beta
        self.phi = phi

# get conters depending on k and condition
def get_centers(k, condition):
    rules = partition(condition)
    if 1 <= k <= len(rules):
        return all_centers[rules[k-1]]
    else:
        raise ValueError(f"Invalid k for condition {condition}. Must be between 1 and {len(rules)}.")

# define likelihood
def likelihood(params, data, condition):
    k, beta, phi = params.k, params.beta, params.phi
    
    x = data[['feature1', 'feature2', 'feature3', 'feature4']].values
    c = data['choice'].values
    r = data['feedback'].values

    # calculate distances between x and centers
    centers = get_centers(k, condition)
    distances = np.array([np.linalg.norm(x - np.array(center), axis=1) for center in centers])

    # calculate choosing probablity
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs, axis=0, keepdims=True)
    softmax_probs = np.exp(phi * probs) / np.sum(np.exp(phi * probs))
    p_c = softmax_probs[c - 1, np.arange(len(c))]
    
    return np.where(r == 1, p_c, 1 - p_c)

# define prior
def prior(params, condition):
    max_k = len(partition(condition))
    k_prior = 1/max_k if 1 <= params.k <= max_k else 0
    beta_prior = np.exp(-params.beta) if params.beta > 0 else 0
    phi_prior = np.exp(-params.phi) if params.phi > 0 else 0
    return k_prior * beta_prior * phi_prior

# define posterior
def posterior(params, data, condition):
    prior_value = prior(params, condition)
    log_prior = np.log(prior_value) if prior_value > 0 else -np.inf
    log_likelihood = np.sum(np.log(likelihood(params, data, condition)))
    return -(log_prior + log_likelihood)  # negative log posterior

# New function to predict choice
def predict_choice(params, x, condition):
    k, beta, phi = params.k, params.beta, params.phi
    centers = get_centers(k, condition)
    distances = np.array([np.linalg.norm(x - np.array(center)) for center in centers])
    probs = np.exp(-beta * distances)
    probs /= np.sum(probs)
    softmax_probs = np.exp(phi * probs) / np.sum(np.exp(phi * probs))
    return np.argmax(softmax_probs) + 1  # +1 because choices are 1-indexed

# fit model
def fit_model(data):
    condition = data['condition'].iloc[0]
    max_k = len(partition(condition))
    
    best_params = None  # 最佳的k和beta参数
    best_posterior = -np.inf  # 最佳参数对应的后验概率
    k_posteriors = {}  # 每个k值的边缘后验概率
    
    for k in range(1, max_k + 1):
        initial_beta, initial_phi = 1, 5
        result = minimize(
            lambda params: posterior(ModelParams(k, params[0], params[1]), data, condition),
            [initial_beta, initial_phi],
            bounds=[(0, 30), (0, 20)]
        )
           
        beta_opt, phi_opt = result.x
        posterior_opt = -result.fun
        k_posteriors[k] = posterior_opt

        if posterior_opt > best_posterior:
            best_posterior = posterior_opt
            best_params = ModelParams(k=k, beta=beta_opt, phi=phi_opt)
    
    # Normalize log probabilities
    max_log_posterior = max(k_posteriors.values())
    k_posteriors = {k: np.exp(log_p - max_log_posterior) for k, log_p in k_posteriors.items()}
    
    # Normalize to sum to 1
    total = sum(k_posteriors.values())
    k_posteriors = {k: p / total for k, p in k_posteriors.items()}
    
    return best_params, best_posterior, k_posteriors

# fit model trial by trial
def fit_model_for_steps(data):
    num_trials = len(data)
    step_results = []
    predictions = []
    cumulative_accuracy = []
    correct_count = 0
    
    for step in range(1, num_trials):
        trial_data = data.iloc[:step]
        fitted_params, best_posterior, k_posteriors = fit_model(trial_data)
        step_results.append({
            'k': fitted_params.k,
            'beta': fitted_params.beta,
            'phi': fitted_params.phi,
            'best_posterior': best_posterior,
            'k_posteriors': k_posteriors
        })
        
        # Predict next trial
        next_trial = data.iloc[step]
        x = next_trial[['feature1', 'feature2', 'feature3', 'feature4']].values
        predicted_choice = predict_choice(fitted_params, x, data['condition'].iloc[0])
        actual_choice = next_trial['choice']
        correct = predicted_choice == actual_choice
        correct_count += correct
        accuracy = correct_count / (step + 1)
        
        predictions.append({
            'trial': step + 1,
            'predicted_choice': predicted_choice,
            'actual_choice': actual_choice,
            'correct': correct
        })
        cumulative_accuracy.append(accuracy)
    
    return step_results, predictions, cumulative_accuracy

In [16]:
# extract data
data = pd.read_csv('Task2.csv')

# fit model by subjects
results = {}
for iSub, subject_data in data.groupby('iSub'):
    try:
        step_results, predictions, cumulative_accuracy = fit_model_for_steps(subject_data)
        condition = subject_data['condition'].iloc[0]
        results[iSub] = {
            'step_results': step_results, 
            'predictions': predictions,
            'cumulative_accuracy': cumulative_accuracy,
            'condition': condition
        }
    except Exception as e:
        print(f"Error fitting model for subject {iSub}: {str(e)}")
        continue

  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x) - f0
  df = fun(x

In [8]:
# print final fitted parameters
for iSub, subject_info in results.items():
    step_results = subject_info['step_results']
    final_result = step_results[-1]
    print(f"Subject {iSub}:")
    print(f"  Final Fitted k: {final_result['k']}")
    print(f"  Final Fitted beta: {final_result['beta']:.4f}")
    print(f"  Final Fitted phi: {final_result['phi']:.4f}")
    print(f"  Final log posterior: {final_result['best_posterior']:.4f}")
    print()

Subject 1:
  Final Fitted k: 1
  Final Fitted beta: 7.8061
  Final Fitted phi: 4.0017
  Final log posterior: -561.1046

Subject 4:
  Final Fitted k: 1
  Final Fitted beta: 11.5045
  Final Fitted phi: 3.8921
  Final log posterior: -605.3996

Subject 6:
  Final Fitted k: 7
  Final Fitted beta: 15.9559
  Final Fitted phi: 3.8856
  Final log posterior: -941.9268

Subject 11:
  Final Fitted k: 7
  Final Fitted beta: 17.2786
  Final Fitted phi: 4.2083
  Final log posterior: -837.0439

Subject 21:
  Final Fitted k: 7
  Final Fitted beta: 30.0000
  Final Fitted phi: 11.8013
  Final log posterior: -3571.9790

Subject 26:
  Final Fitted k: 7
  Final Fitted beta: 22.0774
  Final Fitted phi: 4.6936
  Final log posterior: -1200.9128

Subject 27:
  Final Fitted k: 7
  Final Fitted beta: 30.0000
  Final Fitted phi: 12.4630
  Final log posterior: -2981.8823



In [17]:
# plot parameters over trials
def plot_params_over_trials(step_results, iSub):
    num_steps = len(step_results)
    k_values = [result['k'] for result in step_results]
    beta_values = [result['beta'] for result in step_results]
    phi_values = [result['phi'] for result in step_results]
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 18))
    
    # Plot k values
    ax1.plot(range(1, num_steps + 1), k_values, marker='o')
    ax1.set_title(f'Subject {iSub} - k value over trials')
    ax1.set_xlabel('Number of trials')
    ax1.set_ylabel('k value')
    ax1.grid(True)
    
    # Plot beta values
    ax2.plot(range(1, num_steps + 1), beta_values, marker='o')
    ax2.set_title(f'Subject {iSub} - beta value over trials')
    ax2.set_xlabel('Number of trials')
    ax2.set_ylabel('beta value')
    ax2.grid(True)

    # Plot beta values
    ax3.plot(range(1, num_steps + 1), phi_values, marker='o')
    ax3.set_title(f'Subject {iSub} - phi value over trials')
    ax3.set_xlabel('Number of trials')
    ax3.set_ylabel('phi value')
    ax3.grid(True)

    plt.tight_layout()
    plt.savefig(f'M0201_params_over_trials_subject_{iSub}_.png')
    plt.close()

In [18]:
def plot_posterior_probabilities(step_results, condition, iSub):
    num_steps = len(step_results)
    max_k = max(k for result in step_results for k in result['k_posteriors'].keys())

    # Prepare data for plotting
    k_posteriors = {k: np.zeros(num_steps) for k in range(1, max_k + 1)}

    for step, result in enumerate(step_results):
        for k in range(1, max_k + 1):
            k_posteriors[k][step] = result['k_posteriors'].get(k, 0)

    # Create a figure
    fig, ax = plt.subplots(figsize=(12, 6))
    fig.suptitle(f'Posterior Probabilities for k (Subject {iSub}, Condition {condition})', fontsize=16)

    # Plot k posteriors
    for k in range(1, max_k + 1):
        if (condition == 1 and k == 1) or (condition != 1 and k == 7):
            ax.plot(range(1, num_steps + 1), k_posteriors[k], label=f'k={k}', linewidth=3, color='red')
        else:
            ax.plot(range(1, num_steps + 1), k_posteriors[k], label=f'k={k}')
    
    ax.set_xlabel('Trial')
    ax.set_ylabel('Posterior Probability')
    ax.set_title('Posterior Probabilities for k')
    ax.legend()

    plt.tight_layout()
    plt.savefig(f'M0201_posteriors_subject_{iSub}.png')
    plt.close()


In [19]:
for iSub, subject_info in results.items():
    step_results = subject_info['step_results']
    condition = subject_info['condition']
    plot_params_over_trials(step_results, iSub)
    plot_posterior_probabilities(step_results, condition, iSub)

In [20]:
# Plot cumulative accuracy for each subject
plt.figure(figsize=(12, 8))
for iSub, result in results.items():
    plt.plot(range(2, len(result['cumulative_accuracy']) + 2), result['cumulative_accuracy'], label=f'Subject {iSub}')

plt.xlabel('Trial')
plt.ylabel('Cumulative Prediction Accuracy')
plt.title('Prediction Accuracy Over Trials')
plt.legend()
plt.grid(True)
plt.savefig('M0201_prediction_accuracy.png')
plt.close()