In [None]:
import glob
import itertools
import json
import os
import pickle
import random
import shutil
from itertools import zip_longest
from os.path import join
from pprint import pprint
from typing import Dict, List, Tuple, Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tqdm.notebook import tqdm
from transformers import Mamba2ForCausalLM

sns.set()

SEED = 23
ROOT = "/path/to/project"
ID2LABEL_PATH = '/path/to/codes_dict'
DATA_ROOT = '/path/to/data'
FORECAST_INPUT_PATH = f'{DATA_ROOT}/prompt_data'
FORECAST_OUTPUT_PATH = f'{DATA_ROOT}/forecast_data'
SYPHILIS_PATH = f'{DATA_ROOT}/syphilis_data/*.parquet'
MEMBERSHIP_INFERENCE_RESULTS_FILE = "/h/afallah/odyssey/odyssey/membership_detection_results.parquet"
os.chdir(ROOT)

AUGMENTED_TOKEN_MAP = {
    "type_ids": "type_tokens",
    "ages": "age_tokens",
    "time_stamps": "time_tokens",
    "visit_orders": "position_tokens",
    "visit_segments": "visit_tokens",
}

ADDITIONAL_TOKEN_TYPES = list(AUGMENTED_TOKEN_MAP.keys())
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from odyssey.data.tokenizer import ConceptTokenizer
from odyssey.data.dataset import PretrainDataset, PretrainDatasetDecoder, FinetuneDatasetDecoder
from odyssey.models.model_utils import load_pretrain_data, load_finetune_data
from odyssey.evals.prediction import load_pretrained_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class args:
    """Save the configuration arguments."""
    data_dir = "odyssey/data/meds_data"
    vocab_dir = f"{data_dir}/vocab"
    sequence_file = "patient_sequences_2048.parquet"
    id_file = "dataset_2048.pkl"
    valid_scheme = "few_shot"
    num_finetune_patients = "all"
    tasks = ['mortality_1month', 'readmission_1month', 'los_1week', 'c0', 'c1', 'c2']
    checkpoint_dir = "checkpoints/mamba2_pretrain_2048"
    model_path = f"{checkpoint_dir}/best.ckpt"
    model_huggingface_dir = f"{checkpoint_dir}/huggingface"
    max_len = 512
    batch_size = 64
    num_return_sequences = 100
    chunk_size = 2000
    max_patients = 50_000

args.syphilis_sequences = glob.glob(SYPHILIS_PATH)
args.forecast_inputs = glob.glob(os.path.join(FORECAST_INPUT_PATH, "*.parquet"))

# **Plot Token Distribution**

In [None]:
meta_vocab = json.load(open(os.path.join(args.vocab_dir, "metadata/meta_vocab.json")))

token2freq = {}
for token, data in meta_vocab.items():
    token2freq[token] = data['frequency']

token2freq = dict(sorted(token2freq.items(), key=lambda x: x[1], reverse=True))

# ---

random_df = pd.read_parquet(f"{DATA_ROOT}/forecast_december_20/ehrmamba2_cls_prompt.parquet")
random_token2freq = {}
for sequence in random_df['predicted_tokens']:
    for token in sequence:
        random_token2freq[token] = random_token2freq.get(token, 0) + 1

random_token2freq = dict(sorted(random_token2freq.items(), key=lambda x: x[1], reverse=True))

In [None]:
# Get tokens above threshold frequency and their frequencies from both distributions
threshold = 25
filtered_tokens = [t for t, f in token2freq.items() if f >= threshold and random_token2freq.get(t, 0) >= threshold]
token_types = [t.split('//')[0].split('_')[0] for t in filtered_tokens]
unique_types = list(set(token_types))

# Custom high-contrast color palette (manually chosen for clarity)
custom_palette = [
    "#1f77b4",  # blue
    "#ff7f0e",  # orange
    "#2ca02c",  # green
    "#d62728",  # red
    "#9467bd",  # purple
    "#8c564b",  # brown
    "#e377c2",  # pink
    "#7f7f7f",  # gray
    "#bcbd22",  # olive
    "#17becf",  # cyan
]
# If more types than colors, cycle through
type_to_color = dict(zip(unique_types, itertools.cycle(custom_palette)))
colors = [type_to_color[t] for t in token_types]

original_freqs = [token2freq[t] for t in filtered_tokens]
random_freqs = [random_token2freq[t] for t in filtered_tokens]

# Create dataframe for plotting
plot_df = pd.DataFrame({
    'Original Frequency': original_freqs,
    'Generated Frequency': random_freqs,
    'Type': token_types
})

# Create scatter plot
plt.figure(figsize=(10, 10))
plt.rcParams['font.family'] = 'sans-serif'
plt.xticks(fontsize=int(12 * 1.3 * 1.2))
plt.yticks(fontsize=int(12 * 1.3 * 1.2))

scatter = sns.scatterplot(
    data=plot_df,
    x='Original Frequency',
    y='Generated Frequency',
    hue='Type',
    alpha=0.5,
    palette=type_to_color
)
plt.xscale('log')
plt.yscale('log')
plt.title(
    f'Token Frequency Distribution by Type (both freq >= {threshold:,})',
    fontsize=int(16 * 1.3 * 1.2),
    pad=30,
    loc='center'  # Center the title
)
plt.xlabel('Original Frequency', fontsize=int(16 * 1.3 * 1.2))
plt.ylabel('Generated Frequency', fontsize=int(16 * 1.3 * 1.2))

# Make legend font larger and add background
legend = plt.legend(fontsize=int(12 * 1.3 * 1.2), title='Type', title_fontsize=int(13 * 1.3 * 1.2), frameon=True)
frame = legend.get_frame()
frame.set_facecolor('white')
frame.set_edgecolor('black')
frame.set_alpha(0.9)

plt.tight_layout()
plt.show()

In [None]:
# Get tokens above threshold frequency and their frequencies
threshold = 5
filtered_tokens = [t for t, f in token2freq.items() if f >= threshold and random_token2freq.get(t, 0) >= threshold]
token_types = [t.split('//')[0].split('_')[0] for t in filtered_tokens]
unique_types = list(set(token_types))
type_to_color = dict(zip(unique_types, sns.color_palette("husl", len(unique_types))))
colors = [type_to_color[t] for t in token_types]

original_freqs = [token2freq[t] for t in filtered_tokens]
random_freqs = [random_token2freq[t] for t in filtered_tokens]

# Normalize frequencies
total_original = sum(original_freqs)
total_random = sum(random_freqs)
original_freqs_norm = [f/total_original for f in original_freqs]
random_freqs_norm = [f/total_random for f in random_freqs]

# Create dataframe for plotting
plot_df = pd.DataFrame({
    'Token': range(len(filtered_tokens)),
    'Original': original_freqs_norm,
    'Generated': random_freqs_norm,
    'Type': token_types
}).melt(id_vars=['Token', 'Type'], var_name='Distribution', value_name='Frequency')

# Create line plot
plt.figure(figsize=(10, 10))

# Plot Original in black first
original_data = plot_df[plot_df['Distribution'] == 'Original']
sns.lineplot(data=original_data, x='Token', y='Frequency', color='black', 
             linewidth=2, label='Original')

# Plot Generated with colors by type
generated_data = plot_df[plot_df['Distribution'] == 'Generated']
sns.lineplot(data=generated_data, x='Token', y='Frequency', hue='Type',
             linewidth=0.7, palette=type_to_color)

plt.xscale('log')
plt.yscale('log')
plt.title(f'Normalized Token Frequency Distribution by Type (both freq >= {threshold:,})')
plt.xlabel('Token Rank')
plt.ylabel('Normalized Frequency')
plt.tight_layout()
plt.show()

# **Load Model**

In [None]:
tokenizer = ConceptTokenizer(
    data_dir=args.vocab_dir,
    start_token="[BOS]",
    end_token="[EOS]",
    time_tokens=None,
    padding_side="right"
)
tokenizer.fit_on_vocab()

In [None]:
model = Mamba2ForCausalLM.from_pretrained(
    args.model_huggingface_dir,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
model.eval()

# **Forecasting**

In [None]:
def get_dataset_and_loader(data, tokenizer, max_len, batch_size, num_workers=4):
    """Create dataset and dataloader for inference
    
    Args:
        data: DataFrame containing patient data
        tokenizer: Tokenizer instance
        max_len: Max sequence length to use
        batch_size: Batch size for dataloader
        num_workers: Number of workers for dataloader
        
    Returns:
        dataset: PretrainDatasetDecoder instance
        dataloader: DataLoader instance
    """
    dataset = PretrainDatasetDecoder(
        data=data,
        tokenizer=tokenizer,
        max_len=max_len,
        additional_token_types=None,
        padding_side="right",
        return_attention_mask=False,
        return_labels=False
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        persistent_workers=True
    )
    
    return dataset, dataloader

In [None]:
def load_test_data(
    data_dir: str,
    sequence_file: str,
    id_file: str,
) -> pd.DataFrame:
    sequence_path = join(data_dir, sequence_file)
    id_path = join(data_dir, id_file)

    if not os.path.exists(sequence_path):
        raise FileNotFoundError(f"Sequence file not found: {sequence_path}")

    if not os.path.exists(id_path):
        raise FileNotFoundError(f"ID file not found: {id_path}")

    data = pl.read_parquet(sequence_path).to_pandas()
    with open(id_path, "rb") as file:
        patient_ids = pickle.load(file)

    return data.loc[data["patient_id"].isin(patient_ids["test"])]

# Load pretrain data
pretrain = load_pretrain_data(
    args.data_dir,
    'patient_sequences/'+args.sequence_file,
    'patient_id_dict/'+args.id_file,
)
test = load_test_data(
    args.data_dir,
    'patient_sequences/'+args.sequence_file,
    'patient_id_dict/'+args.id_file,
)

# Or load a sample subset of the data
# sample = train_dataset[1000]
# sample['concept_ids'] = sample['concept_ids'][:26]
# sample['attention_mask'] = sample['attention_mask'][:26]
# sample = {k: v.unsqueeze(0).to(device) for k, v in sample.items()}
# concept_ids = sample['concept_ids'].unsqueeze(0).to(device)
# sample['concept_ids']

# generated = model.generate(
    # input_ids=[sample['concept_ids']],
    # max_new_tokens=10,
    # attention_mask=sample['attention_mask'],
    # labels=sample['labels']
# )

# cutoff = 1
# selected_dataset = pretrain.loc[pretrain['event_tokens'].transform(len) > 1]  #100
# selected_dataset.loc[:, 'event_tokens'] = selected_dataset['event_tokens'].transform(lambda x: x[:cutoff])

# random.seed(SEED)
# patient_ids = selected_dataset['patient_id'].unique().tolist()
# random.shuffle(patient_ids)

# # patient_ids = patient_ids[:10_000]
# selected_dataset = selected_dataset.loc[selected_dataset['patient_id'].isin(patient_ids)]
# selected_dataset = selected_dataset.set_index('patient_id').loc[patient_ids].reset_index()

In [None]:
selected_dataset = pretrain.loc[pretrain["event_tokens"].transform(len) > 100]

# Option 1
random.seed(SEED)
patient_ids = selected_dataset['patient_id'].unique().tolist()
random.shuffle(patient_ids)

cutoff = 2
patient_ids = patient_ids[:50_000]
selected_dataset = selected_dataset.loc[selected_dataset['patient_id'].isin(patient_ids)]
selected_dataset.loc[:, 'event_tokens'] = selected_dataset['event_tokens'].transform(lambda x: x[:cutoff])
selected_dataset = selected_dataset.set_index('patient_id').loc[patient_ids].reset_index()

# ---

selected_dataset_test = test.loc[test["event_tokens"].transform(len) > 100]

# Option 1
random.seed(SEED)
patient_ids = selected_dataset_test['patient_id'].unique().tolist()
random.shuffle(patient_ids)

patient_ids = patient_ids[:10_000]
selected_dataset_test = selected_dataset_test.loc[selected_dataset_test['patient_id'].isin(patient_ids)]
selected_dataset_test = selected_dataset_test.set_index('patient_id').loc[patient_ids].reset_index()


# Option 2
# selected_dataset = selected_dataset[
#     selected_dataset["event_tokens"].apply(
#         lambda x: any("DIAG" in str(x[i]) for i in range(4, min(7, len(x))))
#     )
# ]

# selected_dataset.loc[:, "diag_loc"] = selected_dataset["event_tokens"].apply(
#     lambda tokens: next(
#         (i for i in range(len(tokens)) if "DIAG" in str(tokens[i]) and i > 3), None
#     )
# )
# diag_locs = selected_dataset["diag_loc"].unique()

# for loc in diag_locs:
#     subset = selected_dataset[selected_dataset["diag_loc"] == loc].copy()

#     subset["event_tokens"] = subset.apply(
#         lambda row: row["event_tokens"][: int(row["diag_loc"]) + 1], axis=1
#     )

#     filename = f"prompt_diag_loc_{int(loc)}.parquet"
#     subset.to_parquet(filename, index=False)

#     print(f"Saved {len(subset)} rows to {filename}")

---
# **Membership Inference Test**
---

In [None]:
class MembershipDetector:
    def __init__(self, k_percent: float = 20.0):
        self.k_percent = k_percent

    def get_batch_probs(
        self,
        model: torch.nn.Module,
        input_ids: torch.Tensor,
        device: torch.device
    ) -> torch.Tensor:
        """
        Get token probabilities for a batch of sequences efficiently
        
        Args:
            model: The model
            input_ids: Tensor of shape [batch_size, seq_len]
            device: torch device
            
        Returns:
            Tensor of shape [batch_size, seq_len-1] containing probabilities
            for each token in each sequence
        """
        batch_size, seq_len = input_ids.size()
        all_probs = []
        
        with torch.no_grad():
            # Process sequence positions in parallel for the batch
            for pos in range(seq_len - 1):
                # Get inputs up to current position for all sequences
                input_slice = input_ids[:, :pos+1]
                
                # Get model outputs for entire batch
                outputs = model(input_slice)
                logits = outputs.logits  # [batch_size, pos+1, vocab_size]
                
                # Get probabilities for next token for all sequences
                next_token_probs = torch.softmax(logits[:, -1, :], dim=-1)  # [batch_size, vocab_size]
                
                # Get probability of actual next tokens for batch
                actual_next_tokens = input_ids[:, pos+1]  # [batch_size]
                batch_probs = next_token_probs.gather(1, actual_next_tokens.unsqueeze(1)).squeeze(1)  # [batch_size]
                
                all_probs.append(batch_probs)
        
        # Stack probabilities for all positions
        return torch.stack(all_probs, dim=1)  # [batch_size, seq_len-1]

    def compute_batch_scores(self, token_probs: torch.Tensor) -> torch.Tensor:
        """
        Compute MIN-K% PROB scores for a batch of sequences
        
        Args:
            token_probs: Tensor of shape [batch_size, seq_len] 
            
        Returns:
            Tensor of shape [batch_size] containing detection scores
        """
        # Convert to numpy for operations
        probs_np = token_probs.cpu().numpy()
        batch_size, seq_len = probs_np.shape
        
        # Calculate number of tokens to select (k%)
        k = max(1, int(seq_len * self.k_percent / 100))
        
        # Initialize scores array
        scores = np.zeros(batch_size)
        
        # Process each sequence in batch
        for i in range(batch_size):
            seq_probs = probs_np[i]
            log_probs = np.log(seq_probs + 1e-10)
            
            # Get indices of k tokens with lowest probabilities
            lowest_k_indices = np.argpartition(seq_probs, k)[:k]
            
            # Calculate score for this sequence
            scores[i] = np.mean(log_probs[lowest_k_indices])
        
        return torch.tensor(scores, device=token_probs.device)

    def process_dataloader(
        self,
        model: torch.nn.Module,
        dataloader: torch.utils.data.DataLoader,
        device: torch.device,
        is_member: bool
    ) -> List[dict]:
        """Process entire dataloader in batches"""
        results = []
        
        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Processing {'member' if is_member else 'non-member'} data")):
            # Get batch of sequences
            input_ids = batch["concept_ids"].to(device)
            
            # Get probabilities for entire batch
            batch_probs = self.get_batch_probs(model, input_ids, device)
            
            # Compute scores for batch
            batch_scores = self.compute_batch_scores(batch_probs)
            
            # Store results
            for seq_idx, score in enumerate(batch_scores):
                results.append({
                    'batch_idx': batch_idx,
                    'seq_idx': seq_idx,
                    'is_member': is_member,
                    'detection_score': score.item(),
                    'min_prob': torch.min(batch_probs[seq_idx]).item(),
                    'mean_prob': torch.mean(batch_probs[seq_idx]).item()
                })
        
        return results

    def detect_membership(
        self,
        model: torch.nn.Module,
        member_dataloader: torch.utils.data.DataLoader,
        nonmember_dataloader: torch.utils.data.DataLoader,
        device: torch.device
    ) -> Tuple[float, pd.DataFrame]:
        """Run detection on both dataloaders"""
        model.eval()
        
        with torch.no_grad():
            # Process both dataloaders
            member_results = self.process_dataloader(model, member_dataloader, device, True)
            nonmember_results = self.process_dataloader(model, nonmember_dataloader, device, False)
        
        # Combine and analyze results
        results_df = pd.DataFrame(member_results + nonmember_results)
        auc = roc_auc_score(results_df['is_member'], results_df['detection_score'])
        
        return auc, results_df


def run_membership_detection(
    model,
    member_dataloader,
    nonmember_dataloader,
    device,
    k_percent=20.0,
    results=None,
    auc=0
) -> Tuple[float, pd.DataFrame]:
    """Run membership detection with visualizations"""
    if not all(results):
        detector = MembershipDetector(k_percent=k_percent)
        auc, results = detector.detect_membership(
            model=model,
            member_dataloader=member_dataloader,
            nonmember_dataloader=nonmember_dataloader,
            device=device
        )
    
    print(f"\nDetection Results:")
    print(f"AUC Score: {auc:.3f}")
    
    # Compute statistics
    member_scores = results[results['is_member']]['detection_score']
    non_member_scores = results[~results['is_member']]['detection_score']
    
    print("\nMember sequences:")
    print(f"Mean score: {member_scores.mean():.3f}")
    print(f"Min score: {member_scores.min():.3f}")
    print(f"Max score: {member_scores.max():.3f}")
    
    print("\nNon-member sequences:")
    print(f"Mean score: {non_member_scores.mean():.3f}")
    print(f"Min score: {non_member_scores.min():.3f}")
    print(f"Max score: {non_member_scores.max():.3f}")
    
    # Plot score distributions with custom font sizes and add padding between title and graph
    plt.figure(figsize=(10, 6))
    ax = sns.histplot(data=results, x='detection_score', hue='is_member', bins=50, alpha=0.5)
    plt.title(f'Distribution of Detection Scores (AUC = {auc:.3f})', fontsize=22, pad=30)
    plt.xlabel('Detection Score', fontsize=18)
    plt.ylabel('Count', fontsize=18)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    legend = ax.get_legend()
    if legend is not None:
        legend.set_frame_on(True)
        legend.get_frame().set_facecolor('white')
        for text in legend.get_texts():
            text.set_fontsize(16)
        legend.set_title(legend.get_title().get_text(), prop={'size': 16})
        legend.set_bbox_to_anchor((0, 1))  # Move legend to top left
        legend.set_loc("upper left")
    plt.show()
    return auc, results



# sequence_length = 512
# max_sequences = 10_000

# def get_random_sequence_slice(tokens, seq_len):
#     if len(tokens) <= seq_len:
#         return tokens
#     start = random.randint(0, len(tokens) - seq_len)
#     return tokens[start:start + seq_len]

# # Randomly sample sequences
# # Filter for sequences that are long enough
# valid_pretrain = pretrain[pretrain['event_tokens'].str.len() >= sequence_length]
# valid_test = test[test['event_tokens'].str.len() >= sequence_length]

# random.seed(23)
# member_indices = random.sample(range(len(valid_pretrain)), min(max_sequences, len(valid_pretrain)))
# nonmember_indices = random.sample(range(len(valid_test)), min(max_sequences, len(valid_test)))

# member_dataset = valid_pretrain.iloc[member_indices].copy()
# nonmember_dataset = valid_test.iloc[nonmember_indices].copy()

# member_dataset['event_tokens'] = member_dataset['event_tokens'].transform(
#     lambda x: get_random_sequence_slice(x, sequence_length)
# )
# nonmember_dataset['event_tokens'] = nonmember_dataset['event_tokens'].transform(
#     lambda x: get_random_sequence_slice(x, sequence_length)
# )

# _, member_dataloader = get_dataset_and_loader(
#     data=member_dataset,
#     tokenizer=tokenizer,
#     max_len=sequence_length,
#     batch_size=args.batch_size // 8,
# )

# _, nonmember_dataloader = get_dataset_and_loader(
#     data=nonmember_dataset,
#     tokenizer=tokenizer,
#     max_len=sequence_length,
#     batch_size=args.batch_size // 8,
# )


df = pd.read_parquet(MEMBERSHIP_INFERENCE_RESULTS_FILE)
auc, results = run_membership_detection(
    model=None,#model,
    member_dataloader=None,#member_dataloader,
    nonmember_dataloader=None,#nonmember_dataloader,
    device=None,#device,
    k_percent=20.0,
    results=df,
    auc=0.568
)

In [None]:
def write_buffer_to_file(buffer: List[dict], output_file: str) -> None:
    """Write buffer contents to JSONL file."""
    mode = "a" if os.path.exists(output_file) else "w"
    with open(output_file, mode) as f:
        for item in buffer:
            f.write(json.dumps(item) + "\n")


def write_buffer_to_parquet(buffer: List[dict], temp_dir: str, chunk_idx: int) -> None:
    """Write buffer contents to parquet file."""
    df = pd.DataFrame(buffer)
    chunk_path = os.path.join(temp_dir, f"chunk_{chunk_idx}.parquet")
    df.to_parquet(chunk_path)


def forecast(
    model: torch.nn.Module,
    tokenizer: ConceptTokenizer,
    dataloader: DataLoader,
    patient_ids: List[str],
    args: Any,
    device: torch.device,
    output_file: str,
    temperature: float = 1.0,
    top_p: float = 0.95,
    do_sample: bool = True,
    experiment: str = "static_prompt",
):
    """Generate predictions using the model and save to JSONL file.

    Args:
        model: The model to use for generation
        tokenizer: Tokenizer for decoding predictions
        dataloader: DataLoader containing batches
        patient_ids: List of patient IDs
        args: Arguments containing batch_size, num_return_sequences, max_len
        device: Device to run model on
        output_file: Path to output JSONL file
        temperature: Temperature for sampling (default: 1.0)
        top_p: Top-p sampling parameter (default: 0.95)
        do_sample: Whether to sample or use greedy decoding (default: True)
        experiment: Experiment name (default: "static_prompt")
    """
    buffer = []
    buffer_size = args.batch_size * args.num_return_sequences * 10

    # Create temp directory for chunks
    temp_dir = os.path.join(
        os.path.dirname(output_file),
        f"temp_{os.path.basename(output_file).split('.')[0]}",
    )
    shutil.rmtree(temp_dir, ignore_errors=True)
    os.makedirs(temp_dir)
    chunk_idx = 0

    with torch.no_grad():
        for batch_idx, batch in enumerate(
            tqdm(
                dataloader,
                desc=f"Generating predictions for {experiment}",
                total=len(dataloader),
            )
        ):

            input_ids = batch["concept_ids"].to(device)

            # TODO: Investigage why this is needed
            # input_ids = input_ids[:, :cutoff]
            # input_ids = input_ids.reshape(-1, 1)

            outputs = model.generate(
                input_ids=input_ids,
                max_length=args.max_len,
                temperature=temperature,
                top_p=top_p,
                num_return_sequences=args.num_return_sequences,
                do_sample=do_sample,
                num_beams=1,
                pad_token_id=tokenizer.get_pad_token_id(),
                eos_token_id=tokenizer.get_eos_token_id(),
                bos_token_id=tokenizer.get_class_token_id(),
                use_cache=True,
            )

            start_idx = batch_idx * args.batch_size
            end_idx = min((batch_idx + 1) * args.batch_size, len(patient_ids))
            batch_patient_ids = patient_ids[start_idx:end_idx]
            batch_patient_ids = [
                pid
                for pid in batch_patient_ids
                for _ in range(args.num_return_sequences)
            ]

            for i, (patient_id, sequence) in enumerate(zip(batch_patient_ids, outputs)):
                sequence = sequence.detach().cpu().numpy()
                sequence = sequence[: np.argmax(sequence == 0) or len(sequence)]
                predicted_tokens = tokenizer.decode(sequence).split(" ")

                buffer.append(
                    {
                        "split": "pretrain",
                        "experiment": experiment,
                        "patient_id": patient_id,
                        "trajectory": i % args.num_return_sequences,
                        "predicted_tokens": predicted_tokens,
                    }
                )

            if len(buffer) >= buffer_size:
                write_buffer_to_parquet(buffer, temp_dir, chunk_idx)
                chunk_idx += 1
                buffer = []

        if buffer:
            write_buffer_to_parquet(buffer, temp_dir, chunk_idx)

    # Combine all chunks
    chunk_files = sorted(glob.glob(os.path.join(temp_dir, "chunk_*.parquet")))
    combined_df = pd.concat([pd.read_parquet(f) for f in chunk_files])
    combined_df.to_parquet(output_file)
    shutil.rmtree(temp_dir)


def membership_inference_batched(
    model: torch.nn.Module,
    tokenizer: ConceptTokenizer,
    dataloader: DataLoader,
    patient_ids: List[str],
    args: Any,
    device: torch.device,
    output_file: str,
    mask_percent: float = 10.0,
    n_runs: int = 100,
    experiment: str = "membership_inference",
):
    """
    Perform membership inference using efficient batching, masking tokens
    and analyzing embedding variance.

    Args:
        model: The model to use for inference
        tokenizer: Tokenizer for processing tokens
        dataloader: DataLoader containing batches
        patient_ids: List of patient IDs
        args: Arguments containing batch_size, etc.
        device: Device to run model on
        output_file: Path to output parquet file
        mask_percent: Percentage of tokens to mask (default: 10.0)
        n_runs: Number of masking runs per sequence (default: 100)
        experiment: Experiment name (default: "membership_inference")
    """
    buffer = []
    buffer_size = args.batch_size * 10

    # Create temp directory for chunks
    temp_dir = os.path.join(
        os.path.dirname(output_file),
        f"temp_{os.path.basename(output_file).split('.')[0]}",
    )
    shutil.rmtree(temp_dir, ignore_errors=True)
    os.makedirs(temp_dir)
    chunk_idx = 0

    # Set up random token choices for masking
    vocab_size = tokenizer.get_vocab_size()

    with torch.no_grad():
        for batch_idx, batch in enumerate(
            tqdm(
                dataloader,
                desc=f"Running membership inference with {mask_percent}% masking",
                total=len(dataloader),
            )
        ):
            input_ids = batch["concept_ids"].to(device)
            attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)).to(
                device
            )

            batch_size, seq_length = input_ids.size()
            start_idx = batch_idx * args.batch_size
            end_idx = min((batch_idx + 1) * args.batch_size, len(patient_ids))
            batch_patient_ids = patient_ids[start_idx:end_idx][:batch_size]

            # Calculate actual sequence lengths
            seq_lengths = torch.sum(attention_mask, dim=1).cpu().numpy().astype(int)

            # Store embeddings for all sequences and runs
            all_embeddings = [[] for _ in range(batch_size)]

            # Run multiple passes with different masking patterns
            for run in range(n_runs):
                # Create a masked version of the input batch
                masked_input = input_ids.clone()

                # Apply masking for each sequence in the batch
                for seq_idx in range(batch_size):
                    seq_len = seq_lengths[seq_idx]

                    # Number of tokens to mask for this sequence
                    num_masked = max(1, int((mask_percent / 100) * seq_len))

                    # Randomly select positions to mask (avoid first token)
                    mask_indices = (
                        np.random.choice(seq_len - 1, num_masked, replace=False) + 1
                    )

                    # Apply masking with random tokens
                    for idx in mask_indices:
                        random_token = torch.randint(
                            10, min(vocab_size - 1, 5000), (1,)
                        ).item()
                        masked_input[seq_idx, idx] = random_token

                # Forward pass through the model
                outputs = model(
                    masked_input,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                )

                # Extract hidden states from the last layer
                hidden_states = outputs.hidden_states[-1]

                # Extract final token embeddings for each sequence
                for seq_idx in range(batch_size):
                    final_token_pos = seq_lengths[seq_idx] - 1
                    embedding = hidden_states[seq_idx, final_token_pos].cpu().numpy()
                    all_embeddings[seq_idx].append(embedding)

            # Compute statistics and store results
            for i, patient_id in enumerate(batch_patient_ids):
                if i >= len(all_embeddings):
                    continue

                # Stack embeddings for this sequence
                embeddings = np.stack(all_embeddings[i])

                # Compute variance across runs
                var_embedding = np.var(embeddings, axis=0)

                # Compute membership inference scores
                avg_variance = float(np.mean(var_embedding))
                max_variance = float(np.max(var_embedding))
                total_variance = float(np.sum(var_embedding))

                buffer.append(
                    {
                        "patient_id": patient_id,
                        "experiment": experiment,
                        "mask_percent": mask_percent,
                        "n_runs": n_runs,
                        "avg_variance": avg_variance,
                        "max_variance": max_variance,
                        "total_variance": total_variance,
                        "sequence_length": int(seq_lengths[i]),
                    }
                )

            if len(buffer) >= buffer_size:
                write_buffer_to_parquet(buffer, temp_dir, chunk_idx)
                chunk_idx += 1
                buffer = []

    if buffer:
        write_buffer_to_parquet(buffer, temp_dir, chunk_idx)

    # Combine all chunks
    chunk_files = sorted(glob.glob(os.path.join(temp_dir, "chunk_*.parquet")))
    combined_df = pd.concat([pd.read_parquet(f) for f in chunk_files])
    combined_df.to_parquet(output_file)
    shutil.rmtree(temp_dir)

In [None]:
# Run membership inference on pretrain data
membership_results_file = os.path.join(FORECAST_OUTPUT_PATH, "membership_inference_pretrain.parquet")

# Use selected_dataset that was already defined in the notebook
dataset, dataloader = get_dataset_and_loader(
    data=selected_dataset,
    tokenizer=tokenizer,
    max_len=args.max_len,
    batch_size=args.batch_size,
)
patient_ids = selected_dataset["patient_id"].tolist()

# Run membership inference
print("Running membership inference on pretrain data...")
membership_inference_batched(
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader,
    patient_ids=patient_ids,
    args=args,
    device=device,
    output_file=membership_results_file,
    mask_percent=10,
    n_runs=100,
    experiment="pretrain_membership",
)

print(f"Membership inference results saved to {membership_results_file}")

In [None]:
# Run membership inference on test data
test_membership_results_file = os.path.join(FORECAST_OUTPUT_PATH, "membership_inference_test.parquet")

# Create dataset and dataloader for test data using selected_dataset_test
test_dataset, test_dataloader = get_dataset_and_loader(
    data=selected_dataset_test,
    tokenizer=tokenizer,
    max_len=args.max_len,
    batch_size=args.batch_size,
)
test_patient_ids = selected_dataset_test["patient_id"].tolist()

# Run membership inference on test data
print("Running membership inference on test data...")
membership_inference_batched(
    model=model,
    tokenizer=tokenizer,
    dataloader=test_dataloader,
    patient_ids=test_patient_ids,
    args=args,
    device=device,
    output_file=test_membership_results_file,
    mask_percent=10,
    n_runs=100,
    experiment="test_membership",
)

print(f"Test membership inference results saved to {test_membership_results_file}")

In [None]:
# Run forecasting
dataset, dataloader = get_dataset_and_loader(
    data=selected_dataset,
    tokenizer=tokenizer,
    max_len=args.max_len,
    batch_size=args.batch_size,
)
patient_ids = selected_dataset["patient_id"].tolist()

forecast(
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader,
    patient_ids=patient_ids,
    args=args,
    device=device,
    output_file="ehrmamba2_cls_age_april20.parquet",
    experiment="cls_age",
)

In [None]:
# Run forecasting for each input file
for input_file in tqdm(
    args.forecast_inputs,
    desc="Processing files",
    total=len(args.forecast_inputs),
    leave=False,
):
    # Create output filename
    print(f"Processing {input_file}")
    
    base_name = os.path.splitext(os.path.basename(input_file))[0]
    output_file = os.path.join(FORECAST_OUTPUT_PATH, f"{base_name}_forecast.parquet")

    # Load dataset and dataloader
    selected_dataset = pd.read_parquet(input_file)
    # selected_dataset = selected_dataset.iloc[:args.max_patients]
    selected_dataset = selected_dataset.rename(columns={"event_tokens": "event_tokens"})   # event_token
    prompt_length = selected_dataset.event_tokens.transform(len).unique()[0]

    dataset, dataloader = get_dataset_and_loader(
        data=selected_dataset,
        tokenizer=tokenizer,
        max_len=prompt_length,
        batch_size=args.batch_size,
        num_workers=4,
    )
    patient_ids = selected_dataset["patient_id"].tolist()
    # patient_ids = list(range(len(selected_dataset)))

    # Run forecast
    forecast(
        model=model,
        tokenizer=tokenizer,
        dataloader=dataloader,
        patient_ids=patient_ids,
        args=args,
        device=device,
        output_file=str(output_file),
        experiment=os.path.basename(input_file).split(".")[0],
    )
    print(f"Finished processing {os.path.basename(input_file)} to {os.path.basename(output_file)}\n")

In [None]:
for input_file in args.forecast_inputs:

    if "10" in input_file:
        print("SKIPPING")
        continue

    base_name = os.path.splitext(os.path.basename(input_file))[0]
    forecast_file = os.path.join(FORECAST_OUTPUT_PATH, f"{base_name}_forecast.parquet")
    
    if os.path.exists(forecast_file):
        # Load and prep dataframes
        prompt_df = pd.read_parquet(input_file)
        age_cols = [col for col in prompt_df.columns if 'age' in col.lower()]
        prompt_df = prompt_df[['patient_id'] + age_cols]
        forecast_df = pd.read_parquet(forecast_file)
        
        # Reset indexes and repeat prompt rows to match forecast
        prompt_df = prompt_df.reset_index(drop=True)
        prompt_df = pd.concat([prompt_df] * args.num_return_sequences).reset_index(drop=True)
        forecast_df = forecast_df.reset_index(drop=True)
        
        # Simple concat since indexes now align
        merged = pd.concat([forecast_df, prompt_df[age_cols]], axis=1)
        merged.to_parquet(forecast_file.replace('.parquet', '_with_age.parquet'))

# **Generate Patient Representation**

In [None]:
# Visualize Syphilis Data
dfs = []
for f in args.syphilis_sequences:

    if "10" in f:
        print("SKIPPING")
        continue

    df = pl.read_parquet(f).to_pandas()
    print(f"\nFile: {f}, Length: {len(df)}")
    display(df.head())
    dfs.append(df)

In [None]:
sequence2target = {
    f: "event_token" for f in args.syphilis_sequences
}

In [None]:
for patient_sequence, target_column in list(sequence2target.items()): #sequence2target.items():
    print(
        f"\nPatient Sequence: {patient_sequence.split('/')[-1]}, Target Column: {target_column}\n"
    )

    # Create temp directory next to input parquet
    temp_dir = os.path.join(
        os.path.dirname(patient_sequence),
        f"temp_{os.path.basename(patient_sequence).split('.')[0]}",
    )
    shutil.rmtree(temp_dir, ignore_errors=True)
    os.makedirs(temp_dir)

    # Load and prep dataset
    df = pl.read_parquet(patient_sequence).to_pandas()
    df = df.rename(columns={target_column: "event_tokens"})

    # Create dataset and dataloader (unchanged)
    dataset = PretrainDatasetDecoder(
        data=df,
        tokenizer=tokenizer,
        max_len=args.max_len,
        additional_token_types=None,
        padding_side="right",
        return_attention_mask=True,
        return_labels=False,
    )
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
    )

    # Process in chunks
    buffer = []
    chunk_idx = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(
            tqdm(
                dataloader,
                desc=f"Generating embeddings for {patient_sequence.split('/')[-1]}",
                total=len(dataloader),
            )
        ):
            # Get embeddings for batch
            outputs = model(
                input_ids=batch["concept_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                output_hidden_states=True,
            )
            hidden_states = outputs.hidden_states[-1]

            # Process each sequence
            for i, seq in enumerate(batch["concept_ids"]):
                non_pad_pos = (seq != tokenizer.get_pad_token_id()).nonzero()[-1][0] # -1
                emb = hidden_states[i, non_pad_pos].cpu().numpy()
                buffer.append(emb)

            # Write buffer when full
            if len(buffer) >= args.chunk_size or batch_idx == len(dataloader) - 1:
                start_idx = chunk_idx * args.chunk_size
                end_idx = start_idx + len(buffer)

                chunk_df = df.iloc[start_idx:end_idx].copy()
                chunk_df[f"{target_column}_embeddings"] = buffer
                chunk_df = chunk_df.rename(columns={"event_tokens": target_column})

                chunk_path = os.path.join(temp_dir, f"chunk_{chunk_idx}.parquet")
                pl.from_pandas(chunk_df).write_parquet(chunk_path)

                buffer = []
                chunk_idx += 1

    # Combine chunks and save final output
    chunk_files = sorted(glob.glob(os.path.join(temp_dir, "chunk_*.parquet")))
    combined_df = pl.concat([pl.read_parquet(f) for f in chunk_files])
    combined_df.write_parquet(patient_sequence.split('/')[-1])  # CHANGE LATER ON!

    # Cleanup
    shutil.rmtree(temp_dir)