In [None]:
from typing import List, Tuple

import polars as pl
import torch

from clinical_zeroshot_labeler.labeler import SequenceLabeler


class SimpleGenerativeModel:
    """A simple mock generative model that returns predefined sequences."""

    def __init__(self, sequences: List[List[Tuple[int, float, float]]]):
        """
        Args:
            sequences: List of sequences, where each sequence is a list of
                      (token, time, value) tuples
        """
        self.sequences = sequences
        self.current_positions = [0] * len(sequences)
        self.batch_size = len(sequences)

    def generate_next_token(self, prompt: torch.Tensor):
        """Simulate generating the next token for each sequence in the batch."""
        tokens = []
        times = []
        values = []

        for i, seq in enumerate(self.sequences):
            if self.current_positions[i] < len(seq):
                token, time, value = seq[self.current_positions[i]]
                self.current_positions[i] += 1
            else:
                # For finished sequences, repeat last values
                token, time, value = seq[-1]
            tokens.append(token)
            times.append(time)
            values.append(value)

        return (
            torch.tensor(tokens, dtype=torch.long),
            torch.tensor(times, dtype=torch.float),
            torch.tensor(values, dtype=torch.float),
        )

    def is_finished(self):
        """Check if all sequences have been fully processed."""
        return all(pos >= len(seq) for pos, seq in zip(self.current_positions, self.sequences))


def main():
    # 1. Define your task configuration
    task_config = """
    predicates:
        hospital_discharge:
            code: {regex: "HOSPITAL_DISCHARGE//.*"}
        lab:
            code: {regex: "LAB//.*"}
        abnormal_lab:
            code: {regex: "LAB//.*"}
            value_min: 2.0
            value_min_inclusive: True

    trigger: hospital_discharge

    windows:
        input:
            start: NULL
            end: trigger
            start_inclusive: True
            end_inclusive: True
            index_timestamp: end
        target:
            start: input.end
            end: start + 4d
            start_inclusive: False
            end_inclusive: True
            has:
                lab: (1, None)
            label: abnormal_lab
    """

    # 2. Set up metadata
    metadata_df = pl.DataFrame(
        {
            "code": [
                "PAD",
                "HOSPITAL_DISCHARGE//MEDICAL",
                "LAB//NORMAL",
                "LAB//HIGH",
            ]
        }
    ).with_row_index("code/vocab_index")

    # 3. Set up example sequences
    sequences = [
        # Sequence 1: Has abnormal lab
        [
            (1, 0.0, 0.0),  # Hospital discharge at t=0
            (2, 1.0, 1.5),  # Normal lab at t=1
            (3, 2.0, 2.5),  # High lab at t=2
        ],
        # Sequence 2: Only normal labs
        [
            (1, 0.0, 0.0),  # Hospital discharge at t=0
            (2, 1.0, 1.5),  # Normal lab at t=1
            (2, 2.0, 1.8),  # Another normal lab at t=2
        ],
    ]
    model = SimpleGenerativeModel(sequences)

    # 4. Create labeler
    batch_size = len(sequences)
    labeler = SequenceLabeler.from_yaml_str(task_config, metadata_df, batch_size=batch_size)

    # 5. Process sequences
    prompts = torch.zeros(batch_size, dtype=torch.long)

    while not labeler.is_finished() and not model.is_finished():
        tokens, times, values = model.generate_next_token(prompts)
        status = labeler.process_step(tokens, times, values)
        print(f"Step status: {status}")
        prompts = tokens

    # 6. Get and print final labels
    labels = labeler.get_labels()
    print(f"\nFinal labels: {labels}")

    for i, label in enumerate(labels):
        print(f"Sequence {i+1}: {'Abnormal lab detected' if label else 'No abnormal labs'}")


if __name__ == "__main__":
    main()