In [1]:
from lmm.benchmark_processors.benchmark_processor import BenchmarkConfig, EvaluationResult

benchmarks = ["gqa"]
models = ['llava-v1.5-7b', 'instructblip-vicuna-7b', 'prism-clip+7b', 'prism-dinosiglip+7b', 'prism-siglip+7b']
benchmark_configurations: dict[str, BenchmarkConfig] = {}

for bm in benchmarks:
    results: list[EvaluationResult] = []
    for model in models: 
        results.append(EvaluationResult(prediction_file=f"./data/{bm}/{model}/gqa-formatted-predictions.json", model=model))

    benchmark_configurations[bm] = BenchmarkConfig(name = bm,results=results, question_file = f"./data/{bm}/questions.json", subscenario_keyword="structural_type")

benchmark_configurations 

{'gqa': BenchmarkConfig(name='gqa', results=[EvaluationResult(prediction_file=PosixPath('data/gqa/llava-v1.5-7b/gqa-formatted-predictions.json'), model='llava-v1.5-7b'), EvaluationResult(prediction_file=PosixPath('data/gqa/instructblip-vicuna-7b/gqa-formatted-predictions.json'), model='instructblip-vicuna-7b'), EvaluationResult(prediction_file=PosixPath('data/gqa/prism-clip+7b/gqa-formatted-predictions.json'), model='prism-clip+7b'), EvaluationResult(prediction_file=PosixPath('data/gqa/prism-dinosiglip+7b/gqa-formatted-predictions.json'), model='prism-dinosiglip+7b'), EvaluationResult(prediction_file=PosixPath('data/gqa/prism-siglip+7b/gqa-formatted-predictions.json'), model='prism-siglip+7b')], question_file=PosixPath('data/gqa/questions.json'), subscenario_keyword='structural_type', models=['llava-v1.5-7b', 'instructblip-vicuna-7b', 'prism-clip+7b', 'prism-dinosiglip+7b', 'prism-siglip+7b'])}

In [8]:
from lmm.generate_irt_train_data import generate_irt_train_data


data = generate_irt_train_data(benchmark_configurations)
scenarios_position = data.scenarios_position
subscenarios_position = data.subscenarios_position
scenarios = data.scenarios
Y = data.train_data

Config: ['llava-v1.5-7b', 'instructblip-vicuna-7b', 'prism-clip+7b', 'prism-dinosiglip+7b', 'prism-siglip+7b']


100%|██████████| 12578/12578 [00:00<00:00, 2041244.18it/s]
100%|██████████| 12578/12578 [00:00<00:00, 2522156.89it/s]
100%|██████████| 12578/12578 [00:00<00:00, 1774980.01it/s]
100%|██████████| 12578/12578 [00:00<00:00, 2093158.06it/s]
100%|██████████| 12578/12578 [00:00<00:00, 2251737.41it/s]


In [10]:
import numpy as np 

bm = "gqa"
balance_weights = np.ones(Y.shape[1])

N = len(scenarios_position[bm])
n_sub = len(scenarios[bm])
for sub in scenarios[bm]:
    n_i = len(subscenarios_position[bm][sub])
    balance_weights[subscenarios_position[bm][sub]] = N/(n_sub*n_i)  


In [11]:
accs1 = np.mean([Y[:,subscenarios_position[bm][sub]].mean(axis=1) for sub in scenarios[bm]], axis=0)
accs2 = (balance_weights*Y)[:,scenarios_position[bm]].mean(axis=1)

np.abs(accs1 - accs2).mean()

3.295141937087465e-14

In [19]:
Y_bin_train = Y[:4]
Y_bin_test = Y[4:]

In [24]:
from tutorials.irt import * 
from tqdm import tqdm

scenarios = data.scenarios

Ds = [5,10] # Dimensions to try
device = 'cpu' # Either 'cuda' or 'cpu' 
epochs = 2000  # Number of epochs for IRT model training (py-irt default is 2000)
lr = .1  # Learning rate for IRT model training (py-irt default is .1)

val_ind = list(range(0,Y_bin_train.shape[0],5)) # Validation indices
train_ind = [i for i in range(Y_bin_train.shape[0]) if i not in val_ind]

# Saving the training dataset in the needed format
create_irt_dataset(Y_bin_train[train_ind], 'data/irt_val_dataset.jsonlines')

# Trying different Ds
errors = []  
errors2 = []

for D in tqdm(Ds):
    dataset_name = 'data/irt_val_dataset.jsonlines'
    model_name = 'data/irt_val_model/'
    
    # Load trained IRT model parameters
    train_irt_model(dataset_name, model_name, D, lr, epochs, device)
    A, B, Theta = load_irt_parameters(model_name)
    
    # Determine seen and unseen items for validation
    seen_items = list(range(0, Y_bin_train.shape[1], 2))
    unseen_items = list(range(1, Y_bin_train.shape[1], 2))

    # Estimate ability parameters for the validation set
    thetas = [estimate_ability_parameters(Y_bin_train[val_ind][j][seen_items], A[:, :, seen_items], B[:, :, seen_items]) for j in range(len(val_ind))]

    # Compute validation errors for each scenario and update the errors list (in the end, we give the same weight for all scenarios)
    errors2.append([])
    for scenario in scenarios.keys():
        ind = [u for u in unseen_items if u in scenarios_position[scenario]]
        errors2[-1].append(np.mean([abs((balance_weights*item_curve(thetas[j], A, B))[0,ind].mean()-Y_bin_train[val_ind][j,ind].mean()) for j in range(len(val_ind))]))
    errors.append(np.mean(errors2[-1]))

  0%|          | 0/2 [00:00<?, ?it/s]

[23:10:11] config: model_type='multidim_2pl' epochs=2000              cli.py:109
           priors='hierarchical' initializers=[] dims=5 lr=0.1                  
           lr_decay=0.9999 dropout=0.5 hidden=100 vocab_size=None               
           log_every=200 seed=42 deterministic=True                             
           data_path: data/irt_val_dataset.jsonlines                  cli.py:111
           output directory: data/irt_val_model/                      cli.py:112
[23:10:11] amortized: False                                       dataset.py:112
[23:10:11] Vocab size: None                                       training.py:90
           Training Model...                                          cli.py:116
           args: {'device': 'cpu', 'num_items': 12578,           training.py:134
           'num_subjects': 3}                                                   
           Parsed Model Args: {'device': 'cpu', 'num_items':     training.py:147
           12578, 'num_subje

 50%|█████     | 1/2 [00:22<00:22, 22.70s/it]

[23:10:34] config: model_type='multidim_2pl' epochs=2000              cli.py:109
           priors='hierarchical' initializers=[] dims=10 lr=0.1                 
           lr_decay=0.9999 dropout=0.5 hidden=100 vocab_size=None               
           log_every=200 seed=42 deterministic=True                             
           data_path: data/irt_val_dataset.jsonlines                  cli.py:111
           output directory: data/irt_val_model/                      cli.py:112
[23:10:34] amortized: False                                       dataset.py:112
[23:10:34] Vocab size: None                                       training.py:90
           Training Model...                                          cli.py:116
           args: {'device': 'cpu', 'num_items': 12578,           training.py:134
           'num_subjects': 3}                                                   
           Parsed Model Args: {'device': 'cpu', 'num_items':     training.py:147
           12578, 'num_subje

100%|██████████| 2/2 [00:52<00:00, 26.21s/it]


In [25]:
ind_D = np.argmin(np.array(errors))
D = Ds[ind_D]

In [26]:
create_irt_dataset(Y_bin_train, 'data/irt_dataset.jsonlines')

In [27]:
train_irt_model(dataset_name='data/irt_dataset.jsonlines', 
                model_name='data/irt_model', 
                D=D, lr=lr, epochs=epochs, device=device)               

[23:12:19] config: model_type='multidim_2pl' epochs=2000              cli.py:109
           priors='hierarchical' initializers=[] dims=5 lr=0.1                  
           lr_decay=0.9999 dropout=0.5 hidden=100 vocab_size=None               
           log_every=200 seed=42 deterministic=True                             
           data_path: data/irt_dataset.jsonlines                      cli.py:111
           output directory: data/irt_model                           cli.py:112
[23:12:19] amortized: False                                       dataset.py:112
[23:12:19] Vocab size: None                                       training.py:90
           Training Model...                                          cli.py:116
           args: {'device': 'cpu', 'num_items': 12578,           training.py:134
           'num_subjects': 4}                                                   
           Parsed Model Args: {'device': 'cpu', 'num_items':     training.py:147
           12578, 'num_subje

In [28]:
import pickle

def get_lambda(b, v):
    return (b**2)/(v+(b**2))

number_item = 100

lambds = {} 

for i,scenario in enumerate(scenarios.keys()):
    v = np.var(Y_bin_train[:,scenarios_position[scenario]], axis=1).mean()
    b = np.mean(errors2[ind_D][i]) 
    lambds[scenario] = get_lambda(b, v/(4*number_item))

with open('data/lambds.pickle', 'wb') as handle:
    pickle.dump(lambds, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Generate Anchor point

In [29]:
clustering = 'irt' # 'correct.' or 'irt'

In [30]:
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import pairwise_distances

random_state = 42

anchor_points = {}
anchor_weights = {}

for scenario in scenarios.keys():

    if clustering=='correct.':
        X = Y_bin_train[:,scenarios_position[scenario]].T
    elif clustering=='irt':
        A, B, _ = load_irt_parameters('data/irt_model/')
        X = np.vstack((A.squeeze(), B.squeeze().reshape((1,-1)))).T
        X = X[scenarios_position[scenario]]
    else:
        raise NotImplementedError 
        
    #Normalizing balance_weights, so their sum is one within each scenario
    norm_balance_weights = balance_weights[scenarios_position[scenario]]
    norm_balance_weights /= norm_balance_weights.sum()

    # Fitting the KMeans model
    kmeans = KMeans(n_clusters=number_item, n_init="auto", random_state=random_state)
    kmeans.fit(X, sample_weight=norm_balance_weights)

    # Calculating anchor points
    anchor_points[scenario] = pairwise_distances(kmeans.cluster_centers_, X, metric='euclidean').argmin(axis=1)

    # Calculating anchor weights
    anchor_weights[scenario] = np.array([np.sum(norm_balance_weights[kmeans.labels_==c]) for c in range(number_item)])

In [31]:
anchor = {'anchor_points':anchor_points,
          'anchor_weights':anchor_weights}

with open('data/anchor.pickle', 'wb') as handle:
    pickle.dump(anchor, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [32]:
for scenario in scenarios.keys():
    Y_anchor = Y_bin_test[:,scenarios_position[scenario]][:,anchor_points[scenario]]
    Y_hat = (Y_anchor*anchor_weights[scenario]).sum(axis=1)
    Y_true = (balance_weights*Y_bin_test)[:,scenarios_position[scenario]].mean(axis=1)

    print(f"scenario: {scenario}, avg. error: {np.abs(Y_hat-Y_true).mean():.3f}")

scenario: gqa, avg. error: 0.021


# Estimate Performance

In [34]:
A, B, _ = load_irt_parameters('data/irt_model/')
seen_items = np.hstack([np.array(scenarios_position[scenario])[anchor_points[scenario]] for scenario in scenarios.keys()]).tolist()
unseen_items = [i for i in range(Y_bin_train.shape[1]) if i not in seen_items]

In [35]:
thetas = [estimate_ability_parameters(Y_bin_test[j][seen_items], A[:, :, seen_items], B[:, :, seen_items]) for j in tqdm(range(Y_bin_test.shape[0]))]

100%|██████████| 1/1 [00:00<00:00, 83.07it/s]


In [36]:
pirt_preds = {}
for scenario in scenarios.keys():

    ind_seen = [u for u in seen_items if u in scenarios_position[scenario]]
    ind_unseen = [u for u in unseen_items if u in scenarios_position[scenario]]
    pirt_lambd = Y_anchor.shape[1]/len(scenarios_position[scenario])

    pirt_pred = []
    
    for j in range(Y_bin_test.shape[0]):
        data_part = (balance_weights*Y_bin_test)[j,ind_seen].mean()
        irt_part = (balance_weights*item_curve(thetas[j], A, B))[0,ind_unseen].mean()
        pirt_pred.append(pirt_lambd*data_part + (1-pirt_lambd)*irt_part) 
        
    pirt_preds[scenario] = np.array(pirt_pred) # Predictions
    true = (balance_weights*Y_bin_test)[:,scenarios_position[scenario]].mean(axis=1) # True performance
    
    print(f"scenario: {scenario}, avg. error: {np.abs(pirt_preds[scenario]-true).mean():.3f}")

scenario: gqa, avg. error: 0.010


In [37]:
with open('data/lambds.pickle', 'rb') as handle:
    lambds = pickle.load(handle)

In [38]:
preds = {}
for scenario in scenarios.keys():
    Y_anchor = Y_bin_test[:,scenarios_position[scenario]][:,anchor_points[scenario]]
    preds[scenario] = (Y_anchor*anchor_weights[scenario]).sum(axis=1) # Predictions
    true = (balance_weights*Y_bin_test)[:,scenarios_position[scenario]].mean(axis=1) # True performance

    print(f"scenario: {scenario}, avg. error: {np.abs(preds[scenario]-true).mean():.3f}")

scenario: gqa, avg. error: 0.021


In [39]:
gpirt_preds = {}
for scenario in scenarios.keys():
    gpirt_preds[scenario] = lambds[scenario]*preds[scenario]  + (1-lambds[scenario])*pirt_preds[scenario]
    true = (balance_weights*Y_bin_test)[:,scenarios_position[scenario]].mean(axis=1) # True performance
    
    print(f"Prediction: {gpirt_preds[scenario]} vs True: {true}")
    print(f"scenario: {scenario}, avg. error: {np.abs(gpirt_preds[scenario]-true).mean():.3f}")

Prediction: [0.7159501] vs True: [0.73548393]
scenario: gqa, avg. error: 0.020
