In [None]:
import sys
import pickle
import os
import torch
from tqdm.notebook import tqdm
from tqdm import tqdm
import random
from typing import Optional
from collections.abc import Iterator
from concurrent.futures import ProcessPoolExecutor
from torch.utils.data import Subset

sys.path.insert(0, '..')
sys.path.insert(0, '../..')
sys.path.insert(0, '../../..')
sys.path.insert(0, '../../../..')
sys.path.insert(0, '../../../../..')
sys.path.insert(0, '../../../../../..')

from stochasticLSTM.model import StochasticLSTMWeytjens

In [None]:
#load model
file_path_model = '../../training_variational_dropout/BPIC17/BPIC17_weytjens_rem_time_1_suffix_length5.pkl'

model = StochasticLSTMWeytjens.load(file_path_model, p_fix=0.05, device='cpu')

model_without_drop = StochasticLSTMWeytjens.load(file_path_model, p_fix=0, device='cpu')

# Load the dataset
file_path_data_set = '../../../../../../encoded_data/compare_weytjens/BPIC_2017_all_5_test.pkl'
bpic17_test_dataset = torch.load(file_path_data_set, weights_only=False)

# print(bpic17_test_dataset.all_categories)

# Create a Subset containing only index 0
#dup_indices = [122, 22]
#tiny_bpic17_test_dataset = Subset(bpic17_test_dataset, indices=dup_indices)

#print(tiny_bpic17_test_dataset[0])
#print(tiny_bpic17_test_dataset[1])

# Copy over the attributes your code relies on:
#tiny_bpic17_test_dataset.all_categories   = bpic17_test_dataset.all_categories
#tiny_bpic17_test_dataset.encoder_decoder  = bpic17_test_dataset.encoder_decoder


In [None]:
# Global placeholders for multiprocessing workers
global_model = None

global_model_without_drop = None

global_samples_per_case = None

global_act_categories = None

global_scaler_params = None

def init_worker(model, model_without_drop, samples_per_case, act_categories, scaler_params):
    """
    Initializer for each worker process, setting global variables.
    """
    global global_model, global_model_without_drop, global_samples_per_case, global_act_categories, global_scaler_params
    
    # Models have already been moved to CPU before forking
    model.eval()
    model_without_drop.eval()
    
    global_model = model
    global_model_without_drop = model_without_drop
    global_samples_per_case = samples_per_case
    global_act_categories = act_categories
    global_scaler_params = scaler_params

In [None]:
def iterate_case(case: tuple[list[torch.Tensor], list[torch.Tensor]], concept_name_id: int, min_suffix_size: int) -> Iterator[tuple[int, tuple[list[torch.Tensor], list[torch.Tensor]]]]:
    # Initialize prefix with zeros, matching the shape of the case tensors
    current_prefix = (
        [torch.zeros_like(cat_attribute).unsqueeze(0) for cat_attribute in case[0]],  # cats: one tensor for concept_name
        [torch.zeros_like(num_attribute).unsqueeze(0) for num_attribute in case[1]]   # nums: one tensor for case_elapsed_time
    )
    
    prefix_length = 0
    seq_len = case[0][0].shape[0]  # Sequence length from the first tensor
    
    # Iterate up to seq_len - min_suffix_size - 1
    for i in range(seq_len - min_suffix_size - 1):
        # Update categorical attribute (concept_name)
        for j in range(len(current_prefix[0])):  # j will be 0 since only one tensor
            current_prefix[0][j][0] = torch.roll(current_prefix[0][j][0], -1)
            current_prefix[0][j][0, -1] = case[0][j][i]
        
        # Update numerical attribute (case_elapsed_time)
        for j in range(len(current_prefix[1])):  # j will be 0 since only one tensor
            current_prefix[1][j][0] = torch.roll(current_prefix[1][j][0], -1)
            current_prefix[1][j][0, -1] = case[1][j][i]
        
        # Yield prefix if it’s non-padding or prefix has started
        if prefix_length or case[0][concept_name_id][i]:
            prefix_length += 1
            yield prefix_length, current_prefix

In [None]:

def _evaluate_case(case_name: str,
                   full_case: tuple[list[torch.Tensor], list[torch.Tensor], str],
                   concept_name_id: int,
                   min_suffix_size: int):
    """
    Process a single case, yielding results for each prefix length.
    - case_name: Name of the case
    - full_case: Tuple of (categorical tensors, numerical tensors, case_name)
    - param concept_name_id: Index of the concept name attribute
    - param min_suffix_size: Minimum suffix size from dataset
    
    Generator yielding result tuples per prefix
    """
    
    _, nums, _ = full_case
        
    # Target is the total elapsed time, same for all prefixes
    mean_s, std_s = global_scaler_params
    raw_target = nums[0][-1-min_suffix_size].item()
    # print(raw_target)
    target_val = raw_target * std_s + mean_s
    target = [{'case_elapsed_time': target_val}]
    
    # print(f"Starting case {case_name} with seq_len {full_case[0][0].shape[0]}")
    # sys.stdout.flush()
    
    # Iterate over prefixes and targets from iterate_case
    results = []
    for prefix_length, prefix in iterate_case(full_case, concept_name_id, min_suffix_size):
        
        # print(f"Prefix length: {prefix_length}")
        # sys.stdout.flush()
        
        # Monte Carlo samples for uncertainty
        mc_samples = []
        for _ in range(global_samples_per_case):
            # Get results of VI model:
            mean, logvar = global_model(input=prefix)
            mean = mean.squeeze(0)
            std = torch.exp(0.5 * logvar).squeeze(0)
            sample = torch.normal(mean=mean, std=std)
            sample = sample * std_s + mean_s
            sample = torch.clamp(sample, min=0.0)
            mc_samples.append([{'case_elapsed_time': sample.item()}])
        
        # Deterministic prediction
        # Get results from model with all activated neurons:
        mean_cet, _ = global_model_without_drop(input=prefix)
        mean_cet = mean_cet.squeeze(0)
        mean_cet = torch.clamp(mean_cet * std_s + mean_s, min=0.0)
        most_likely = [{'case_elapsed_time': mean_cet.item()}]
        
        # Prepare prefix in readable format (assuming first cat attribute is 'Activity')
        prefix_cat = prefix[0][0]  # Shape: (1, seq_len)
        act_categories = global_act_categories[0][2]
        prefix_prep = []
        for idx, cat in enumerate(prefix_cat[0].tolist()):
            if cat != 0:
                act = next(k for k, v in act_categories.items() if v == cat)
                num_val = prefix[1][0][0, idx].item() 
                # print(num_val)
                num_val = num_val * std_s + mean_s
                prefix_prep.append({'concept:name': act, 'case_elapsed_time': num_val})
        
        # print("\n")
        # print("Case name:", case_name)
        # print("Prefix length: ", prefix_length)
        # print("Prefix prepared: ", prefix_prep)
        # print("MC samples: ", mc_samples)
        # print("Target:", target)
        # print("Most likely: ", most_likely)
        
        results.append((case_name, prefix_length, prefix_prep, mc_samples, target, most_likely))
    
    return results
        


In [None]:
def evaluate_seq_processing(model: StochasticLSTMWeytjens,
                            model_without_drop: StochasticLSTMWeytjens,
                            dataset,
                            samples_per_case: Optional[int] = 1000,
                            random_order: Optional[bool]= False):
    """
    Sequential evaluation yielding tuples per case and prefix length.
    """
    # Move models to CPU
    model.to('cpu')
    model_without_drop.to('cpu')
    
    # 
    concept_name = 'concept:name'
    concept_name_id = [i for i, cat in enumerate(dataset.all_categories[0]) if cat[0] == concept_name][0]
    
    # Id of EOS token in activity
    eos_value = 'EOS'
    eos_id = [v for k, v in dataset.all_categories[0][concept_name_id][2].items() if k == eos_value][0]
    
    cases = {}
    for event in dataset:
        # Get suffix being the last 
        suffix = event[0][concept_name_id][-dataset.encoder_decoder.min_suffix_size:]
        if torch.all(suffix  == eos_id).item():
            cases[event[2]] = event
            
    case_items = list(cases.items())
    if random_order:
        case_items = random.sample(case_items, len(case_items))
    
    cat_categories, _ = model.data_set_categories
    scaler = dataset.encoder_decoder.continuous_encoders['case_elapsed_time']
    scaler_params = (scaler.mean_.item(), scaler.scale_.item())
    
    # Initialize globals for identical logic
    init_worker(model, model_without_drop, samples_per_case, cat_categories, scaler_params)
    
    # for cats, nums, case_name in tqdm(cases, total=len(cases)):
    for _, (case_name, full_case) in tqdm(enumerate(case_items), total=len(cases)):
        
        # Get a list with the results for all cases of one case:
        results = _evaluate_case(case_name, full_case, min_suffix_size=dataset.encoder_decoder.min_suffix_size, concept_name_id=concept_name_id)
        
        for res in results:
            yield res

In [None]:
def _evaluate_case_wrapper(args):
    return _evaluate_case(*args)

def evaluate_parallel_processing(model: StochasticLSTMWeytjens,
                                 model_without_drop: StochasticLSTMWeytjens,
                                 dataset,
                                 samples_per_case: Optional[int] = 1000,
                                 random_order: Optional[bool] = False,
                                 num_processes: Optional[int] = 4):

    # 1) Move models to CPU
    model.to('cpu')
    model_without_drop.to('cpu')

    # 2) Find your IDs (same logic as before)
    concept_name = 'concept:name'
    concept_name_id = next(i for i, cat in enumerate(dataset.all_categories[0]) if cat[0] == concept_name)
    
    eos_value = 'EOS'
    eos_id = next(v for k, v in dataset.all_categories[0][concept_name_id][2].items()if k == eos_value)

    # 3) Collect only “finished” cases
    cases = {}
    for event in dataset:
        suffix = event[0][concept_name_id][-dataset.encoder_decoder.min_suffix_size:]
        if torch.all(suffix == eos_id).item():
            cases[event[2]] = event

    case_items = list(cases.items())
    if random_order:
        random.shuffle(case_items)

    # 4) Extract constants for the workers
    min_suffix_size = dataset.encoder_decoder.min_suffix_size
    cat_categories, _ = model.data_set_categories
    
    scaler = dataset.encoder_decoder.continuous_encoders['case_elapsed_time']
    scaler_params = (scaler.mean_.item(), scaler.scale_.item())

    # Global variables each worker can use, that are not changed over time:
    init_args = (model, model_without_drop, samples_per_case, cat_categories, scaler_params)
    
    # Inputs for _evaluate_case method
    pool_inputs = [(case_name, full_case, concept_name_id, min_suffix_size) for case_name, full_case in case_items]          
      
    with ProcessPoolExecutor(max_workers=num_processes, initializer=init_worker, initargs=init_args) as executor:
        # 
        mapper = executor.map(_evaluate_case_wrapper, pool_inputs, chunksize=1)
        
        for case_results in tqdm(mapper,
                                 total=len(pool_inputs),
                                 desc="Parallel eval",
                                 unit="case"):
            
            for res in case_results:
                yield res


In [None]:
output_dir = '../../../../../../../../data/BPIC17/eval_weytjens_sl5/'

def save_chunk(results, i):
    chunk_number = (i + 1)
    filename = os.path.join(output_dir, f'results_part_{chunk_number:03d}.pkl')
    with open(filename, 'wb') as f:
        pickle.dump(results, f)
    print(f"Saved {len(results)} results to {filename}")
    sys.stdout.flush()

In [None]:
# Process to start in parallel:
num_processes = 32

save_every = 50

results = {}

# for i, (case_name, prefix_len, prefix, sampled_cets, target_cet, mean_cet) in enumerate(evaluate_seq_processing(model=model,
#                                                                                                                model_without_drop=model_without_drop,
#                                                                                                                dataset=bpic17_test_dataset,
#                                                                                                                )):
    
for i, (case_name, prefix_len, prefix, sampled_cets, target_cet, mean_cet) in enumerate(evaluate_parallel_processing(model=model,
                                                                                                                     model_without_drop=model_without_drop,
                                                                                                                     dataset=bpic17_test_dataset,
                                                                                                                     num_processes=num_processes)):    
    # print(case_name)
    # sys.stdout.flush()
    
    assert((case_name, prefix_len) not in results)
    
    results[(case_name, prefix_len)] = (prefix, target_cet, mean_cet, sampled_cets)
    
    if (i + 1) % save_every == 0:
        save_chunk(results, i)
        results = {}

if len(results):
    save_chunk(results, i)