In [22]:
import pandas as pd
import numpy as np
import torch
import pickle
import os
import json
import gc
from torch.distributions import Bernoulli
from torch.optim import LBFGS
from tqdm import tqdm
from scipy.stats import pearsonr
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager
import multiprocessing as mp

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from tueplots import bundles
bundles.icml2024()

from torchmetrics import AUROC
auroc = AUROC(task="binary")

import warnings
warnings.filterwarnings("ignore")

torch.manual_seed(0)

device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

def visualize_response_matrix(results, value, filename):
    # Extract the groups labels in the order of the columns
    group_values = results.columns.get_level_values("scenario")

    # Identify the boundaries where the group changes
    boundaries = []
    for i in range(1, len(group_values)):
        if group_values[i] != group_values[i - 1]:
            boundaries.append(i - 0.5)  # using 0.5 to place the line between columns

    # Visualize the results with a matrix: red is 0, white is -1 and blue is 1
    cmap = mcolors.ListedColormap(["white", "red", "blue"])
    bounds = [-1.5, -0.5, 0.5, 1.5]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # Calculate midpoints for each group label
    groups_list = list(group_values)
    group_names = []
    group_midpoints = []
    current_group = groups_list[0]
    start_index = 0
    for i, grp in enumerate(groups_list):
        if grp != current_group:
            midpoint = (start_index + i - 1) / 2.0
            group_names.append(current_group)
            group_midpoints.append(midpoint)
            current_group = grp
            start_index = i
    # Add the last group
    midpoint = (start_index + len(groups_list) - 1) / 2.0
    group_names.append(current_group)
    group_midpoints.append(midpoint)

    # Define the minimum spacing between labels (e.g., 100 units)
    min_spacing = 100
    last_label_pos = -float("inf")
    # Plot the matrix
    with plt.rc_context(bundles.icml2024(usetex=True, family="serif")):
        fig, ax = plt.subplots(figsize=(20, 10))
        cax = ax.matshow(value, aspect="auto", cmap=cmap, norm=norm)

        # Add vertical lines at each boundary
        for b in boundaries:
            ax.axvline(x=b, color="black", linewidth=0.25, linestyle="--", alpha=0.5)
        
        # Add group labels above the matrix, only if they're spaced enough apart
        for name, pos in zip(group_names, group_midpoints):
            if pos - last_label_pos >= min_spacing:
                ax.text(pos, -5, name, ha='center', va='bottom', rotation=90, fontsize=3)
                last_label_pos = pos

        # Add model labels on the y-axis
        ax.set_yticks(range(len(results.index)))
        ax.set_yticklabels(results.index, fontsize=3)

        # Add a colorbar
        cbar = plt.colorbar(cax)
        cbar.set_ticks([-1, 0, 1])
        cbar.set_ticklabels(["-1", "0", "1"])
        plt.savefig(filename, dpi=600, bbox_inches="tight")
        plt.close()

def trainer(parameters, optim, closure, n_iter=100, verbose=True):
    pbar = tqdm(range(n_iter)) if verbose else range(n_iter)
    for iteration in pbar:
        if iteration > 0:
            previous_parameters = [p.clone() for p in parameters]
            previous_loss = loss.clone()
        
        loss = optim.step(closure)
        
        if iteration > 0:
            d_loss = (previous_loss - loss).item()
            d_parameters = sum(
                torch.norm(prev - curr, p=2).item()
                for prev, curr in zip(previous_parameters, parameters)
            )
            grad_norm = sum(torch.norm(p.grad, p=2).item() for p in parameters if p.grad is not None)
            if verbose:
                pbar.set_postfix({"grad_norm": grad_norm, "d_parameter": d_parameters, "d_loss": d_loss})
            
            if d_loss < 1e-5 and d_parameters < 1e-5 and grad_norm < 1e-5:
                break
    return parameters

def compute_auc(probs, data, train_idtor, test_idtor):
    train_probs = probs[train_idtor.bool()]
    test_probs = probs[test_idtor.bool()]
    train_labels = data[train_idtor.bool()]
    test_labels = data[test_idtor.bool()]
    train_auc = auroc(train_probs, train_labels)
    test_auc = auroc(test_probs, test_labels)
    print(f"train auc: {train_auc}")
    print(f"test auc: {test_auc}")
    
    return train_auc, test_auc

def compute_cttcorr(probs, data, train_idtor, test_idtor):
    train_probs  = probs.clone()
    test_probs   = probs.clone()
    train_labels = data.clone()
    test_labels  = data.clone()

    train_mask = ~train_idtor.bool()
    train_probs[train_mask]  = float('nan')
    train_labels[train_mask] = float('nan')

    test_mask = ~test_idtor.bool()
    test_probs[test_mask]   = float('nan')
    test_labels[test_mask]  = float('nan')
    
    train_prob_ctt = torch.nanmean(train_probs, dim=1).detach().cpu().numpy()
    train_label_ctt = torch.nanmean(train_labels, dim=1).detach().cpu().numpy()
    train_mask = ~np.isnan(train_prob_ctt) & ~np.isnan(train_label_ctt)
    train_cttcorr = pearsonr(train_prob_ctt[train_mask], train_label_ctt[train_mask]).statistic
    
    test_prob_ctt = torch.nanmean(test_probs, dim=1).detach().cpu().numpy()
    test_label_ctt = torch.nanmean(test_labels, dim=1).detach().cpu().numpy()
    test_mask = ~np.isnan(test_prob_ctt) & ~np.isnan(test_label_ctt)
    test_cttcorr = pearsonr(test_prob_ctt[test_mask], test_label_ctt[test_mask]).statistic
    
    print(f"train cttcorr: {train_cttcorr}")
    print(f"test cttcorr: {test_cttcorr}")

    return train_cttcorr, test_cttcorr

In [23]:
from torch import float32


with open(f"../data/resmat.pkl", "rb") as f:
    results = pickle.load(f)

dtype = torch.float64 if device.startswith("cuda") else torch.float32

# data_withnan, missing=nan
# data_withneg1, missing=-1
# data_with0, missing=0
data_withnan = torch.tensor(results.values, dtype=dtype, device=device)
data_idtor = (~torch.isnan(data_withnan)).to(dtype)
data_withneg1 = data_withnan.nan_to_num(nan=-1.0)
data_with0 = data_withneg1 * data_idtor
data_with0 = data_with0.nan_to_num(nan=0.0)
n_test_takers, n_items = data_with0.shape
scenarios = results.columns.get_level_values("scenario").unique()

In [24]:
data_with0.unique()

tensor([0., 1.], device='cuda:0', dtype=torch.float64)

In [25]:
import torch
import numpy as np
import pandas as pd
from torch.distributions import Bernoulli
from torch.optim import LBFGS
from tqdm import tqdm
from torch.optim import Adam
import gc

# Assume prerequisite functions like trainer() and data tensors are loaded
# For this example, we will simulate the necessary data structures.

# ====================================================================================
# 1. SETUP & DATA SIMULATION (Replace with your actual data loading)
# ====================================================================================

n_test_takers = 183
n_items = 78712  # <-- FIX 1: Set n_items to the total number of questions
n_dimensions = 3

# --- Create a map from each question to its scenario ---
# CRITICAL STEP: You need to generate this map from your data.
# We will simulate it here. Assume you have a DataFrame or list that
# tells you which scenario each of the 78,712 questions belongs to.
# For example, the first 5236 questions are from 'air_bench_2024' (scenario index 0),
# the next 9558 are from 'babi_qa' (scenario index 1), etc.

question_to_scenario_map = np.fromfile("../data/scenario_map.npy", dtype=np.int32)


# ====================================================================================
# 2. EXPAND THE Q-MATRIX
# ====================================================================================

# Your original (22, 3) Q-Matrix for scenarios
q_matrix_scenario = np.array([
    [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0],
    [1, 0, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [1, 0, 0],
    [1, 0, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 1],
    [1, 0, 0], [1, 1, 0], [1, 1, 1], [1, 0, 0], [1, 0, 0],
    [1, 0, 0], [1, 0, 0]
])

# FIX 2: Expand the Q-Matrix to the question level using the map
print("Expanding Q-Matrix to question level...")
q_matrix_np_expanded = q_matrix_scenario[question_to_scenario_map]
Q_matrix = torch.tensor(q_matrix_np_expanded, device=device, dtype=torch.float32)

print(f"Shape of expanded Q-Matrix: {Q_matrix.shape}") # Should be (78712, 3)
import torch
import numpy as np
import pandas as pd
from torch.distributions import Bernoulli
from torch.optim import LBFGS
from tqdm import tqdm
import gc

# Assume prerequisite data tensors are loaded...
# device, data_with0, data_idtor, Q_matrix, n_test_takers, n_items, n_dimensions

# ====================================================================================
# 3. MIRT MODEL TRAINING (Corrected for LBFGS)
# ====================================================================================

# Apply random train/test mask to the matrix
valid_condition = False
trial = 0
while not valid_condition:
    train_idtor = torch.bernoulli(data_idtor * 0.8).int()
    test_idtor = data_idtor - train_idtor
    valid_condition = (train_idtor.sum(axis=1) != 0).all() and (train_idtor.sum(axis=0) != 0).all()
    print(f"trial {trial} valid condition: {valid_condition}")
    trial += 1

# --- STAGE 1: Fit Item Parameters (a_params and ds) ---
print("Starting Stage 1: Fitting item parameters (a_params and ds)...")

# Use a fixed, random set of thetas to approximate the expectation (Monte Carlo EM)
n_mc_samples = 150
thetas_nuisance = torch.randn(n_mc_samples, n_test_takers, n_dimensions, device=device)

# Initialize parameters
a_params = torch.randn(n_items, n_dimensions, requires_grad=True, device=device)
ds = torch.randn(n_items, requires_grad=True, device=device)

# Process items in batches to make LBFGS memory-efficient
B = 50000
for i in tqdm(range(0, n_items, B), desc="Stage 1: Calibrating Item Batches"):
    
    # <<< FIX: Create new, independent LEAF tensors for each batch
    # We clone the data from the main tensor, detach it from the graph,
    # and then set requires_grad=True to make it a new leaf.
    current_B = min(B, n_items - i)
    a_params_batch = a_params[i:i+current_B].clone().detach().requires_grad_(True)
    ds_batch = ds[i:i+current_B].clone().detach().requires_grad_(True)

    # Select corresponding data and Q-matrix for the batch
    data_batch = data_with0[:, i:i+current_B]
    train_idtor_batch = train_idtor[:, i:i+current_B]
    Q_matrix_batch = Q_matrix[i:i+current_B, :]

    # Define the optimizer for the current batch (now with valid leaf tensors)
    optim_items = LBFGS([a_params_batch, ds_batch], lr=0.1, max_iter=20, history_size=10, line_search_fn="strong_wolfe")

    # Define the closure function required by LBFGS
    def closure_items():
        optim_items.zero_grad()
        
        a_params_masked = a_params_batch * Q_matrix_batch
        a_params_constrained = torch.clamp(a_params_masked, min=0)
        
        logits = torch.matmul(thetas_nuisance, a_params_constrained.T) - ds_batch[None, None, :]
        probs = torch.sigmoid(logits)
        
        log_likelihoods = Bernoulli(probs=probs).log_prob(data_batch[None, :, :]) * train_idtor_batch[None, :, :]
        loss = -log_likelihoods.sum() / (train_idtor_batch.sum() * n_mc_samples)
        
        loss.backward()
        return loss

    # The step function now calls the closure
    optim_items.step(closure_items)
    
    # <<< FIX: Copy the optimized data from the batch tensor back to the main tensor
    # We use .data to do this without affecting the gradient history.
    a_params.data[i:i+current_B] = a_params_batch.data
    ds.data[i:i+current_B] = ds_batch.data


# Detach the now-calibrated item parameters to fix them for the next stage
a_params_calibrated = a_params.detach()
ds_calibrated = ds.detach()
print("Stage 1 Finished.")


# --- STAGE 2: Fit Person Parameters (thetas) ---
print("Starting Stage 2: Fitting person parameters (thetas)...")

# Initialize the real thetas we want to learn
thetas = torch.randn(n_test_takers, n_dimensions, requires_grad=True, device=device)
optim_thetas = LBFGS([thetas], lr=0.1, max_iter=20, history_size=10, line_search_fn="strong_wolfe")

# <<< CHANGE: Define the closure for the theta optimization
def closure_thetas():
    optim_thetas.zero_grad()
    
    thetas_constrained = torch.clamp(thetas, min=0) 
    a_params_constrained = torch.clamp(a_params_calibrated * Q_matrix, min=0)
    
    logits = torch.matmul(thetas_constrained, a_params_constrained.T) - ds_calibrated[None, :]
    probs = torch.sigmoid(logits)
    
    log_likelihoods = Bernoulli(probs=probs).log_prob(data_with0) * train_idtor
    loss = -log_likelihoods.sum() / train_idtor.sum()
    
    loss.backward()
    return loss

# <<< CHANGE: A single call to step() is needed, as max_iter is handled internally
optim_thetas.step(closure_thetas)

# Detach final parameters
thetas_final = thetas.detach()
a_params_final = a_params_calibrated
ds_final = ds_calibrated
print("Stage 2 Finished. Training complete.")

# --- 4. EVALUATION ---
# Apply the final constraints before calculating metrics
thetas_constrained_final = torch.clamp(thetas_final, min=0)
a_params_constrained_final = torch.clamp(a_params_final * Q_matrix, min=0)

# Use the correctly constrained parameters for prediction
logits = torch.matmul(thetas_constrained_final, a_params_constrained_final.T) - ds_final[None, :]
probs = torch.sigmoid(logits)

train_auc, test_auc = compute_auc(probs, data_with0, train_idtor, test_idtor)
train_cttcorr, test_cttcorr = compute_cttcorr(probs, data_with0, train_idtor, test_idtor)

# --- 5. POST-HOC ANALYSIS ---
print("\n--- Post-Hoc Analysis ---")

# Ensure the final thetas are on the CPU and converted to a NumPy array
thetas_np = thetas_final.cpu().numpy()

# Calculate the correlation matrix
# rowvar=False is important: it tells numpy that your variables are columns, not rows.
theta_correlation_matrix = np.corrcoef(thetas_np, rowvar=False)

print("Theta Correlation Matrix:")
print(theta_correlation_matrix)
print("\nInterpretation:")
print("- Values close to 0 mean the dimensions are distinct.")
print("- Values close to 1 or -1 mean the dimensions are highly related.")
import torch

# --- 6. MODEL FIT CALCULATION ---
print("\n--- Model Fit Indices ---")

# 1. Define Model Complexity (k = number of parameters)
num_item_params = n_items * (n_dimensions + 1)  # (a * K) + d
num_person_params = n_test_takers * n_dimensions # theta * K
k = num_item_params + num_person_params
print(f"Number of parameters (k): {k}")

# 2. Define Number of Observations (n = number of responses in training set)
n = train_idtor.sum().item()
print(f"Number of observations (n): {n}")

# 3. Calculate Final Log-Likelihood on the Training Data
with torch.no_grad(): # We don't need to compute gradients here
    # Use the final, constrained parameters
    thetas_constrained_final = torch.clamp(thetas_final, min=0)
    a_params_constrained_final = torch.clamp(a_params_final * Q_matrix, min=0)
    
    # Calculate probabilities for the training data
    logits = torch.matmul(thetas_constrained_final, a_params_constrained_final.T) - ds_final[None, :]
    probs = torch.sigmoid(logits)
    
    # Calculate the log-likelihood only on the training responses
    log_likelihood = Bernoulli(probs=probs).log_prob(data_with0) * train_idtor
    total_log_likelihood = log_likelihood.sum()

# 4. Calculate AIC and BIC
aic = 2 * k - 2 * total_log_likelihood
bic = k * torch.log(torch.tensor(n, dtype=torch.float32)) - 2 * total_log_likelihood

print(f"\nTotal Log-Likelihood: {total_log_likelihood.item():.2f}")
print(f"AIC: {aic.item():.2f}")
print(f"BIC: {bic.item():.2f}")
print("\nReminder: Lower AIC/BIC values indicate a better model fit.")
# Clean up memory
del thetas, a_params, ds, a_params_calibrated, ds_calibrated, thetas_final, a_params_final, ds_final, train_auc, test_auc, train_cttcorr, test_cttcorr
gc.collect()
torch.cuda.empty_cache()

Expanding Q-Matrix to question level...
Shape of expanded Q-Matrix: torch.Size([78712, 3])
trial 0 valid condition: True
Starting Stage 1: Fitting item parameters (a_params and ds)...


Stage 1: Calibrating Item Batches:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         | 0/2 [00:00<?, ?it/s]

Stage 1: Calibrating Item Batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.69s/it]


Stage 1 Finished.
Starting Stage 2: Fitting person parameters (thetas)...
Stage 2 Finished. Training complete.
train auc: 0.8100907802581787
test auc: 0.7944523096084595
train cttcorr: 0.703482087615883
test cttcorr: 0.6817904579591438

--- Post-Hoc Analysis ---
Theta Correlation Matrix:
[[1.         0.10632631 0.0273906 ]
 [0.10632631 1.         0.23484477]
 [0.0273906  0.23484477 1.        ]]

Interpretation:
- Values close to 0 mean the dimensions are distinct.
- Values close to 1 or -1 mean the dimensions are highly related.

--- Model Fit Indices ---
Number of parameters (k): 315397
Number of observations (n): 4267308

Total Log-Likelihood: -2215706.13
AIC: 5062206.25
BIC: 9246418.75

Reminder: Lower AIC/BIC values indicate a better model fit.
