In [8]:
import pandas as pd
import numpy as np
import torch
import pickle
import os
import json
import gc
from torch.distributions import Bernoulli
from torch.optim import LBFGS
from tqdm import tqdm
from scipy.stats import pearsonr
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager
import multiprocessing as mp

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from tueplots import bundles
bundles.icml2024()

from torchmetrics import AUROC
auroc = AUROC(task="binary")

import warnings
warnings.filterwarnings("ignore")

torch.manual_seed(0)

device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

def visualize_response_matrix(results, value, filename):
    # Extract the groups labels in the order of the columns
    group_values = results.columns.get_level_values("scenario")

    # Identify the boundaries where the group changes
    boundaries = []
    for i in range(1, len(group_values)):
        if group_values[i] != group_values[i - 1]:
            boundaries.append(i - 0.5)  # using 0.5 to place the line between columns

    # Visualize the results with a matrix: red is 0, white is -1 and blue is 1
    cmap = mcolors.ListedColormap(["white", "red", "blue"])
    bounds = [-1.5, -0.5, 0.5, 1.5]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # Calculate midpoints for each group label
    groups_list = list(group_values)
    group_names = []
    group_midpoints = []
    current_group = groups_list[0]
    start_index = 0
    for i, grp in enumerate(groups_list):
        if grp != current_group:
            midpoint = (start_index + i - 1) / 2.0
            group_names.append(current_group)
            group_midpoints.append(midpoint)
            current_group = grp
            start_index = i
    # Add the last group
    midpoint = (start_index + len(groups_list) - 1) / 2.0
    group_names.append(current_group)
    group_midpoints.append(midpoint)

    # Define the minimum spacing between labels (e.g., 100 units)
    min_spacing = 100
    last_label_pos = -float("inf")
    # Plot the matrix
    with plt.rc_context(bundles.icml2024(usetex=True, family="serif")):
        fig, ax = plt.subplots(figsize=(20, 10))
        cax = ax.matshow(value, aspect="auto", cmap=cmap, norm=norm)

        # Add vertical lines at each boundary
        for b in boundaries:
            ax.axvline(x=b, color="black", linewidth=0.25, linestyle="--", alpha=0.5)
        
        # Add group labels above the matrix, only if they're spaced enough apart
        for name, pos in zip(group_names, group_midpoints):
            if pos - last_label_pos >= min_spacing:
                ax.text(pos, -5, name, ha='center', va='bottom', rotation=90, fontsize=3)
                last_label_pos = pos

        # Add model labels on the y-axis
        ax.set_yticks(range(len(results.index)))
        ax.set_yticklabels(results.index, fontsize=3)

        # Add a colorbar
        cbar = plt.colorbar(cax)
        cbar.set_ticks([-1, 0, 1])
        cbar.set_ticklabels(["-1", "0", "1"])
        plt.savefig(filename, dpi=600, bbox_inches="tight")
        plt.close()

def trainer(parameters, optim, closure, n_iter=100, verbose=True):
    pbar = tqdm(range(n_iter)) if verbose else range(n_iter)
    for iteration in pbar:
        if iteration > 0:
            previous_parameters = [p.clone() for p in parameters]
            previous_loss = loss.clone()
        
        loss = optim.step(closure)
        
        if iteration > 0:
            d_loss = (previous_loss - loss).item()
            d_parameters = sum(
                torch.norm(prev - curr, p=2).item()
                for prev, curr in zip(previous_parameters, parameters)
            )
            grad_norm = sum(torch.norm(p.grad, p=2).item() for p in parameters if p.grad is not None)
            if verbose:
                pbar.set_postfix({"grad_norm": grad_norm, "d_parameter": d_parameters, "d_loss": d_loss})
            
            if d_loss < 1e-5 and d_parameters < 1e-5 and grad_norm < 1e-5:
                break
    return parameters

def compute_auc(probs, data, train_idtor, test_idtor):
    train_probs = probs[train_idtor.bool()]
    test_probs = probs[test_idtor.bool()]
    train_labels = data[train_idtor.bool()]
    test_labels = data[test_idtor.bool()]
    train_auc = auroc(train_probs, train_labels)
    test_auc = auroc(test_probs, test_labels)
    print(f"train auc: {train_auc}")
    print(f"test auc: {test_auc}")
    
    return train_auc, test_auc

def compute_cttcorr(probs, data, train_idtor, test_idtor):
    train_probs  = probs.clone()
    test_probs   = probs.clone()
    train_labels = data.clone()
    test_labels  = data.clone()

    train_mask = ~train_idtor.bool()
    train_probs[train_mask]  = float('nan')
    train_labels[train_mask] = float('nan')

    test_mask = ~test_idtor.bool()
    test_probs[test_mask]   = float('nan')
    test_labels[test_mask]  = float('nan')
    
    train_prob_ctt = torch.nanmean(train_probs, dim=1).detach().cpu().numpy()
    train_label_ctt = torch.nanmean(train_labels, dim=1).detach().cpu().numpy()
    train_mask = ~np.isnan(train_prob_ctt) & ~np.isnan(train_label_ctt)
    train_cttcorr = pearsonr(train_prob_ctt[train_mask], train_label_ctt[train_mask]).statistic
    
    test_prob_ctt = torch.nanmean(test_probs, dim=1).detach().cpu().numpy()
    test_label_ctt = torch.nanmean(test_labels, dim=1).detach().cpu().numpy()
    test_mask = ~np.isnan(test_prob_ctt) & ~np.isnan(test_label_ctt)
    test_cttcorr = pearsonr(test_prob_ctt[test_mask], test_label_ctt[test_mask]).statistic
    
    print(f"train cttcorr: {train_cttcorr}")
    print(f"test cttcorr: {test_cttcorr}")

    return train_cttcorr, test_cttcorr

In [9]:
from torch import float32


with open(f"../data/resmat.pkl", "rb") as f:
    results = pickle.load(f)

dtype = torch.float64 if device.startswith("cuda") else torch.float32

# data_withnan, missing=nan
# data_withneg1, missing=-1
# data_with0, missing=0
data_withnan = torch.tensor(results.values, dtype=dtype, device=device)
data_idtor = (~torch.isnan(data_withnan)).to(dtype)
data_withneg1 = data_withnan.nan_to_num(nan=-1.0)
data_with0 = data_withneg1 * data_idtor
data_with0 = data_with0.nan_to_num(nan=0.0)
n_test_takers, n_items = data_with0.shape
scenarios = results.columns.get_level_values("scenario").unique()

In [10]:
data_with0.unique()

tensor([0., 1.], device='cuda:0', dtype=torch.float64)

In [None]:
import torch
import numpy as np
import pandas as pd
from torch.distributions import Bernoulli
from torch.optim import LBFGS
from tqdm import tqdm
from torch.optim import Adam
import gc

# Assume prerequisite functions like trainer() and data tensors are loaded
# For this example, we will simulate the necessary data structures.

# ====================================================================================
# 1. SETUP & DATA SIMULATION (Replace with your actual data loading)
# ====================================================================================

n_test_takers = 183
n_items = 78712  # <-- FIX 1: Set n_items to the total number of questions
n_dimensions = 3

# --- Create a map from each question to its scenario ---
# CRITICAL STEP: You need to generate this map from your data.
# We will simulate it here. Assume you have a DataFrame or list that
# tells you which scenario each of the 78,712 questions belongs to.
# For example, the first 5236 questions are from 'air_bench_2024' (scenario index 0),
# the next 9558 are from 'babi_qa' (scenario index 1), etc.

question_to_scenario_map = np.fromfile("../data/scenario_map.npy", dtype=np.int32)


# ====================================================================================
# 2. EXPAND THE Q-MATRIX
# ====================================================================================

# Your original (22, 3) Q-Matrix for scenarios
q_matrix_scenario = np.array([
    [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0],
    [1, 0, 0], [1, 1, 0], [1, 0, 0], [1, 0, 1], [1, 0, 0],
    [1, 0, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 1],
    [1, 0, 0], [1, 1, 0], [1, 1, 1], [1, 0, 0], [1, 0, 0],
    [1, 0, 0], [1, 0, 0]
])

# FIX 2: Expand the Q-Matrix to the question level using the map
print("Expanding Q-Matrix to question level...")
q_matrix_np_expanded = q_matrix_scenario[question_to_scenario_map]
Q_matrix = torch.tensor(q_matrix_np_expanded, device=device, dtype=torch.float32)

print(f"Shape of expanded Q-Matrix: {Q_matrix.shape}") # Should be (78712, 3)

# ====================================================================================
# 3. MIRT MODEL TRAINING (Now with correct shapes)
# ====================================================================================

# Apply random train/test mask to the matrix
valid_condition = False
trial = 0
while not valid_condition:
    train_idtor = torch.bernoulli(data_idtor * 0.8).int()
    test_idtor = data_idtor - train_idtor
    valid_condition = (train_idtor.sum(axis=1) != 0).all() and (train_idtor.sum(axis=0) != 0).all()
    print(f"trial {trial} valid condition: {valid_condition}")
    trial += 1

# Hyperparameters for the two stages
n_item_iterations = 100   # More iterations might be needed for the large number of items
n_theta_iterations = 100 # Thetas can often be fit with more iterations
lr_items = 0.01
lr_thetas = 0.01


# --- STAGE 1: Fit Item Parameters (a_params and ds) ---
print("Starting Stage 1: Fitting item parameters (a_params and ds)...")

# We are only training these two parameters in this stage
a_params = torch.randn(n_items, n_dimensions, requires_grad=True, device=device)
ds = torch.randn(n_items, requires_grad=True, device=device)
optim_items = Adam([a_params, ds], lr=lr_items)

# Use a fixed, random set of thetas to approximate the expectation (Monte Carlo EM)
n_mc_samples = 50 # Number of samples for the approximation. Can be reduced for speed.
thetas_nuisance = torch.randn(n_mc_samples, n_test_takers, n_dimensions, device=device)

for i in tqdm(range(n_item_iterations), desc="Stage 1: Calibrating Items"):
    optim_items.zero_grad()
    
    a_params_masked = a_params * Q_matrix
    a_params_constrained = torch.clamp(a_params_masked, min=0)
    
    # Calculate logits using the fixed 'nuisance' thetas
    # Shapes: (S, M, K) @ (K, N) -> (S, M, N), where S=n_mc_samples
    logits = torch.matmul(thetas_nuisance, a_params_constrained.T) - ds[None, None, :]
    probs = torch.sigmoid(logits)
    
    # Expand data and mask to match the nuisance sample dimension for loss calculation
    log_likelihoods = Bernoulli(probs=probs).log_prob(data_with0[None, :, :]) * train_idtor[None, :, :]
    loss = -log_likelihoods.sum() / (train_idtor.sum() * n_mc_samples)
    
    loss.backward()
    optim_items.step()

# Detach the now-calibrated item parameters to fix them for the next stage
a_params_calibrated = a_params.detach()
ds_calibrated = ds.detach()
print("Stage 1 Finished.")


# --- STAGE 2: Fit Person Parameters (thetas) ---
print("Starting Stage 2: Fitting person parameters (thetas)...")

# Initialize the real thetas we want to learn
thetas = torch.randn(n_test_takers, n_dimensions, requires_grad=True, device=device)
optim_thetas = Adam([thetas], lr=lr_thetas)

for i in tqdm(range(n_theta_iterations), desc="Stage 2: Fitting Thetas"):
    optim_thetas.zero_grad()
    
    # Use the calibrated, fixed item parameters from Stage 1
    # Note: No need for clamp here if a_params_calibrated is already positive, but it's a safe check
    a_params_constrained = torch.clamp(a_params_calibrated * Q_matrix, min=0)
    
    logits = torch.matmul(thetas, a_params_constrained.T) - ds_calibrated[None, :]
    probs = torch.sigmoid(logits)
    
    log_likelihoods = Bernoulli(probs=probs).log_prob(data_with0) * train_idtor
    loss = -log_likelihoods.sum() / train_idtor.sum()
    
    loss.backward()
    optim_thetas.step()

thetas, a_params, ds = thetas.detach(), a_params_calibrated, ds_calibrated
print("Stage 2 Finished. Training complete.")

# calculate metrics
a_params_constrained = a_params * Q_matrix
logits = torch.matmul(thetas, a_params_constrained.T) - ds[None, :]
probs = torch.sigmoid(logits)

train_auc, test_auc = compute_auc(probs, data_with0, train_idtor, test_idtor)

train_cttcorr, test_cttcorr = compute_cttcorr(probs, data_with0, train_idtor, test_idtor)

# del thetas, a_params, ds, a_params_calibrated, ds_calibrated
gc.collect()
torch.cuda.empty_cache()


Expanding Q-Matrix to question level...
Shape of expanded Q-Matrix: torch.Size([78712, 3])
trial 0 valid condition: True
Starting Stage 1: Fitting item parameters (a_params and ds)...


Stage 1: Calibrating Items:   0%|                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              | 0/200 [00:00<?, ?it/s]

Stage 1: Calibrating Items: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:41<00:00,  4.83it/s]


Stage 1 Finished.
Starting Stage 2: Fitting person parameters (thetas)...


Stage 2: Fitting Thetas: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:04<00:00, 203.23it/s]


Stage 2 Finished. Training complete.
train auc: 0.6158990263938904
test auc: 0.6071275472640991
train cttcorr: -0.7281585080721156
test cttcorr: -0.7226807794288592
