# Per token inference

Now that we have a model that works well with the last token,
we want to explore how it behaves when used in a real-time environment,
so that we may eventually have an hallucination probability inferred **while generating**, per single token.

# Imports, installations and declarations from previous notebooks

This section can be skipped and collapsed.

In [255]:
#@title Install missing dependencies
!pip install wandb lightning



In [256]:
import os
try:
    import google.colab
    IN_COLAB = True
except ModuleNotFoundError:
    IN_COLAB = False

In [257]:
# If not in Colab, do some compatibility changes
if not IN_COLAB:
    DRIVE_PATH='.'
    os.environ['HF_TOKEN'] = open('.hf_token').read().strip()

In [258]:
#@title Mount Drive, if needed, and check the HF_TOKEN is set and accessible
if IN_COLAB:
    from google.colab import drive, userdata

    drive.mount('/content/drive', readonly=True)
    DRIVE_PATH: str = '/content/drive/MyDrive/Final_Project/'
    assert os.path.exists(DRIVE_PATH), 'Did you forget to create a shortcut in MyDrive named Final_Project this time as well? :('
    !cp -R {DRIVE_PATH}/publicDataset .
    !pwd
    !ls
    print()

    assert userdata.get('HF_TOKEN'), 'Set up HuggingFace login secret properly in Colab!'
    print('HF_TOKEN found')

    os.environ['WANDB_API_KEY'] = userdata.get('WANDB_API_KEY')
    print('WANDB_API_KEY found and set as env var')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content
artifacts  drive  hallucination_detector  publicDataset  sample_data  wandb

HF_TOKEN found
WANDB_API_KEY found and set as env var


In [259]:
#@title Clone the new updated Python files from GitHub, from master
if IN_COLAB:
  !mkdir -p /root/.ssh
  !touch /root/.ssh/id_ecdsa

  with open('/root/.ssh/id_ecdsa', 'w') as f:
    git_ssh_private_key = """
        -----BEGIN OPENSSH PRIVATE KEY-----
        b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
        QyNTUxOQAAACCB3clOafi6fZaBgQCN29TVyJKNW/eVRXT4/B4MB28VQAAAAJhAtW8YQLVv
        GAAAAAtzc2gtZWQyNTUxOQAAACCB3clOafi6fZaBgQCN29TVyJKNW/eVRXT4/B4MB28VQA
        AAAEA6ARNr020VevD7mkC4GFBVqlTcZP7hvn8B3xi5LDvzYIHdyU5p+Lp9loGBAI3b1NXI
        ko1b95VFdPj8HgwHbxVAAAAAEHNpbW9uZUBhcmNobGludXgBAgMEBQ==
        -----END OPENSSH PRIVATE KEY-----
    """
    f.write('\n'.join([line.strip() for line in git_ssh_private_key.split('\n') if line.strip() ]) + '\n')

  with open('/root/.ssh/known_hosts', 'w') as f:
    f.write("github.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl\n")
    f.write("github.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXkE2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMjA2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIEs4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+EjqoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/WnwH6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk=\n")
    f.write("github.com ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEmKSENjQEezOmxkZMy7opKgwFB9nkt5YRrYMjNuG5N87uRgg6CLrbo5wAdT/y6v0mKV0U2w0WZ2YB/++Tpockg=\n")

  !chmod 400 ~/.ssh/id_ecdsa ~/.ssh/known_hosts
  !ls ~/.ssh

  # Clone the repository
  !rm -rf /content/AML-project
  !git clone git@github.com:simonesestito/AML-project.git /content/AML-project
  assert os.path.exists('/content/AML-project/.git'), 'Error cloning the repository. See logs above for details'
  !rm -rf ./hallucination_detector && mv /content/AML-project/hallucination_detector .
  !rm -rf /content/AML-project  # We don't need the Git repo anymore

id_ecdsa  known_hosts
Cloning into '/content/AML-project'...
remote: Enumerating objects: 418, done.[K
remote: Counting objects: 100% (80/80), done.[K
remote: Compressing objects: 100% (49/49), done.[K
remote: Total 418 (delta 41), reused 53 (delta 30), pack-reused 338 (from 1)[K
Receiving objects: 100% (418/418), 2.12 MiB | 5.03 MiB/s, done.
Resolving deltas: 100% (226/226), done.


In [260]:
%load_ext autoreload
%autoreload 1
%aimport hallucination_detector
import hallucination_detector

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Initialize Llama

In [364]:
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import lightning as pl
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from hallucination_detector.llama import LlamaInstruct, LlamaPrompt
from hallucination_detector.dataset import StatementDataModule
from hallucination_detector.extractor import LlamaHiddenStatesExtractor, WeightedMeanReduction, AttentionAwareWeightedMeanReduction
from hallucination_detector.classifier import OriginalSAPLMAClassifier, LightningHiddenStateSAPLMA, EnhancedSAPLMAClassifier
from hallucination_detector.utils import try_to_overfit, plot_weight_matrix, classificator_evaluation
import wandb
import seaborn as sns
import matplotlib.pyplot as plt
import random

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [262]:
llama = LlamaInstruct()
assert not IN_COLAB or llama.device.type == 'cuda', 'The model should be running on a GPU. On CPU, it is impossible to run'

if llama.device.type == 'cpu':
    print('WARNING: You are running an LLM on the CPU. Beware of the long inference times! Use it ONLY FOR SMALL tests, like very small tests.', file=sys.stderr, flush=True)

# Initialize trained `SAPLMAClassifier`

In [263]:
saplma_artifact_id = 'aml-2324-project/llama-hallucination-detector/attention-aware-weighted-tokens-architecture-hc7ivucr:best'

run = wandb.init()
artifact = run.use_artifact(saplma_artifact_id, type='model')
artifact_dir = artifact.download()
artifact_dir

[34m[1mwandb[0m:   1 of 1 files downloaded.  


'/content/artifacts/attention-aware-weighted-tokens-architecture-hc7ivucr:v10'

In [264]:
!ls {artifact_dir}

model.ckpt


In [265]:
saplma = LightningHiddenStateSAPLMA.load_from_checkpoint(
    os.path.join(artifact_dir, 'model.ckpt'),
    llama=llama,
    saplma_classifier=OriginalSAPLMAClassifier(),
    reduction=AttentionAwareWeightedMeanReduction(),
).eval()

# Load dataset

In [266]:
from hallucination_detector.extractor.tokenizer import tokenize_prompts_fixed_length

In [267]:
batch_size = 1
datamodule = StatementDataModule(batch_size=batch_size, drive_path='publicDataset')
datamodule.prepare_data()
print(f'Found {len(datamodule.full_dataset)} samples')

Loading file: cities_true_false.csv
Loading file: facts_true_false.csv
Loading file: animals_true_false.csv
Loading file: elements_true_false.csv
Loading file: inventions_true_false.csv
Loading file: companies_true_false.csv
Loading file: generated_true_false.csv
Found 6330 samples


In [268]:
random_sample = datamodule.full_dataset[980]
random_sample

('Tokyo is a name of a country.', tensor(0), 'cities_true_false')

In [269]:
def remove_prefix_suffix_from_tokens(tokens: torch.Tensor) -> torch.Tensor:
  random_user_input = 'y6BabNgCyZf3A9XC3d1Qr'

  # Extract the strings for prefix and suffix, that are added by LlamaPrompt
  full_llama_prompt = str(LlamaPrompt(random_user_input))
  prefix_len = full_llama_prompt.index(random_user_input)
  suffix_len = len(full_llama_prompt) - prefix_len - len(random_user_input)
  prefix, suffix = full_llama_prompt[:prefix_len], full_llama_prompt[-suffix_len:]

  # Count how many tokens do they require
  prefix_len = llama.tokenizer([prefix], return_tensors='pt').input_ids.ravel().size(0)
  suffix_len = llama.tokenizer([suffix], return_tensors='pt').input_ids.ravel().size(0)

  # Remove the tokens that are not part of the user input we want to analyze
  tokens = tokens[prefix_len:-suffix_len+1]
  return tokens

In [340]:
@torch.no_grad()
def test_single_tokens_with_saplma_inference(statement: str) -> tuple[torch.Tensor, torch.Tensor]:
    tokenized_sample = tokenize_prompts_fixed_length(llama, statement)
    token_ids, attn_mask = tokenized_sample.input_ids.squeeze(), tokenized_sample.attention_mask.squeeze()
    real_token_ids = token_ids[attn_mask == 1]

    model_dtype = next(saplma.saplma_classifier.parameters()).dtype

    hidden_states = saplma.hidden_states_extractor.extract_input_hidden_states_for_layer(
        prompt=statement,
        for_layer=11,
    ).detach().to(model_dtype)[0]   # returns [70, 2048] = [TOKENS, EMBEDDING_DIM]
    assert hidden_states.shape == (70, 2048)

    # Consider each token as a sample in a batch
    # ignoring the ones that are not part of the real input statement
    real_hidden_states = hidden_states[attn_mask == 1]
    assert len(real_hidden_states.shape) == 2
    assert real_hidden_states.size(1) == 2048

    each_token_classification = saplma.saplma_classifier(real_hidden_states)

    # Remove the tokens that are not part of the user input we want to analyze
    each_token_classification = remove_prefix_suffix_from_tokens(each_token_classification)
    real_token_ids = remove_prefix_suffix_from_tokens(real_token_ids)

    return each_token_classification, real_token_ids


print(random_sample[0])
each_token_classification, real_token_ids = test_single_tokens_with_saplma_inference(random_sample[0])

for hallucination_probability, token in zip(each_token_classification, real_token_ids):
    print(f'{hallucination_probability.item():>6.1%}: {llama.tokenizer.decode(token)}')

Tokyo is a name of a country.
  7.2%: Tok
  1.0%: yo
  5.5%:  is
  0.0%:  a
  4.3%:  name
  0.2%:  of
  0.0%:  a
  0.0%:  country
  0.0%: .


# Try with gradients

A similar approach to gradcam

In [341]:
def extract_input_hidden_states_for_layers(prompt, for_layers: set[int]) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Given a batch of prompts, with length BATCH_SIZE,
    extract the hidden states for the L requested layers.

    The output tensor will have shape [BATCH_SIZE, L, SEQ_LEN, TOKEN_DIM].
    SEQ_LEN is the length of the input sequence, which is the same for all prompts in the batch, fixed to 70.
    TOKEN_DIM is the dimension of the hidden states, which is the same for all layers in the model, fixed to 2048.
    """
    if isinstance(for_layers, list) or isinstance(for_layers, tuple):
        for_layers = set(for_layers)
    assert isinstance(for_layers, set), f"Expected for_layers to be a set. Found: {type(for_layers)}"

    max_layers = len(llama.iter_layers())
    assert all(0 <= layer < max_layers for layer in for_layers), f"Expected all layers to be in range [0, {max_layers}). Found: {for_layers}"

    hidden_states = []

    def _collect_hidden_states(layer_idx: int):
        def _hook(module, inputs, outputs):
            assert isinstance(outputs, tuple), f"Expected outputs to be a tuple. Found: {type(outputs)}"
            assert len(outputs) >= 1, f"Expected outputs to have 1+ elements. Found: {len(outputs)}"

            hidden_state = outputs[0]
            assert isinstance(hidden_state, torch.Tensor), f"Expected hidden_state to be a torch.Tensor. Found: {type(hidden_state)}"
            assert hidden_state.size(1) == 70 and hidden_state.size(2) == 2048, f"Expected hidden_state to have shape (?, 70, 2048). Found: {hidden_state.shape}"
            hidden_states.append(hidden_state)
        return _hook

    llama.unregister_all_hooks()
    for layer_idx, decoder_layer in enumerate(llama.iter_layers()):
        if layer_idx in for_layers:
            llama.register_hook(decoder_layer, _collect_hidden_states(layer_idx))

    inputs = tokenize_prompts_fixed_length(llama, prompt)
    embedded_inputs = llama.model.get_input_embeddings()(inputs.input_ids)
    embedded_inputs = embedded_inputs.clone().detach().requires_grad_(True)
    embedded_inputs.retain_grad()
    _ = llama.model(
        inputs_embeds=embedded_inputs,
        attention_mask=inputs.attention_mask,
        **{
            "max_length": None,
            "max_new_tokens": 1,
            "num_return_sequences": 1,
            # We are collecting hidden_states in a more fine-grained way with hooks
            "output_attentions": False,
            "output_hidden_states": False,
            "return_dict_in_generate": False,
        }
    )
    llama.unregister_all_hooks()

    # Now, hidden_states are a list of tensors, each tensor representing the hidden_state for a layer we requested
    return embedded_inputs, torch.stack(hidden_states).transpose(0, 1)

In [402]:
def test_tokens_with_grad(statement: str) -> tuple[torch.Tensor, torch.Tensor]:
    tokenized_sample = tokenize_prompts_fixed_length(llama, statement)
    token_ids, attn_mask = tokenized_sample.input_ids.squeeze(), tokenized_sample.attention_mask.squeeze()
    real_token_ids = token_ids[attn_mask == 1]

    # Do a forward pass, with also returning the input embeddings (with requires_grad=True)
    embedded_inputs, hidden_states = extract_input_hidden_states_for_layers(
        statement,
        for_layers={11},
    )
    hidden_states = hidden_states.squeeze(0, 1).to(torch.float32)
    assert hidden_states.shape == (70, 2048)

    saplma_input = hidden_states[64].unsqueeze(0)
    assert saplma_input.shape == (1, 2048)
    prediction = saplma.saplma_classifier(saplma_input)
    print(f'Realistic probability (inferred): {prediction.item():.1%}')
    (5 * (prediction + 1)).sum().backward()  # Compute gradients on the input

    # Reduce the gradients on the input embeddings, summing up all dimensions of every token
    embedded_inputs_grads = embedded_inputs.grad[0].sum(dim=1)[attn_mask == 1]

    # Remove the tokens that are not part of the user input we want to analyze
    embedded_inputs_grads = remove_prefix_suffix_from_tokens(embedded_inputs_grads)
    real_token_ids = remove_prefix_suffix_from_tokens(real_token_ids)
    assert embedded_inputs_grads.shape == real_token_ids.shape
    return F.softmax(embedded_inputs_grads, dim=0), real_token_ids


print(random_sample[0])
embedded_inputs_grads, real_token_ids = test_tokens_with_grad(random_sample[0])

for hallucination_probability, token in zip(embedded_inputs_grads, real_token_ids):
    print(f'{hallucination_probability.item():>6.1%}: {llama.tokenizer.decode(token)}')

Tokyo is a name of a country.
Realistic probability (inferred): 0.0%
 11.3%: Tok
 10.8%: yo
 11.2%:  is
 11.1%:  a
 11.1%:  name
 11.1%:  of
 11.2%:  a
 11.1%:  country
 11.1%: .


## Do a few more tests

Compare these 2 approaches

In [407]:
def test_tokens(sample: tuple, strategy, verbose=True) -> tuple[torch.Tensor, torch.Tensor]:
    statement, is_hallucination, _ = sample
    if verbose:
      print('Using strategy:', strategy.__name__)
    probabilities, token_ids = strategy(statement)

    # Normalize the probabilities
    probs_mean = torch.mean(probabilities)
    probs_std = torch.std(probabilities)
    probabilities = (probabilities - probs_mean) / probs_std

    temperature = 0.9
    probabilities = F.softmax(probabilities / temperature, dim=0)

    color_thresholds = [
        (0.35, '\033[0m'),  # Neutral
        (0.65, '\033[33m'), # Yellow
        (1.01, '\033[31m'), # Red
    ]

    for hallucination_probability, token in zip(probabilities, token_ids):
        if hallucination_probability.isnan():
            hallucination_probability = 1.0
        color_for_probability = next(style for threshold, style in color_thresholds if hallucination_probability < threshold)
        print(f'{color_for_probability}{llama.tokenizer.decode(token)}\033[0m', end='')
    print(f'\nGround truth: {is_hallucination}\n')

In [408]:
random_true_sample = datamodule.full_dataset[1000]
random_false_sample = datamodule.full_dataset[2000]

test_tokens(random_true_sample, strategy=test_tokens_with_grad)
test_tokens(random_false_sample, strategy=test_tokens_with_grad)

test_tokens(random_true_sample, strategy=test_single_tokens_with_saplma_inference)
test_tokens(random_false_sample, strategy=test_single_tokens_with_saplma_inference)

Using strategy: test_tokens_with_grad
Realistic probability (inferred): 0.3%
[0mG[0m[33mren[0m[0mada[0m[0m is[0m[0m a[0m[0m name[0m[0m of[0m[0m a[0m[0m city[0m[0m.[0m
Ground truth: 0

Using strategy: test_tokens_with_grad
Realistic probability (inferred): 100.0%
[0mA[0m[0m group[0m[0m of[0m[0m wolves[0m[0m is[0m[0m called[0m[0m a[0m[31m pack[0m[0m.[0m
Ground truth: 1

Using strategy: test_single_tokens_with_saplma_inference
[0mG[0m[31mren[0m[0mada[0m[0m is[0m[0m a[0m[0m name[0m[0m of[0m[0m a[0m[0m city[0m[0m.[0m
Ground truth: 0

Using strategy: test_single_tokens_with_saplma_inference
[0mA[0m[0m group[0m[0m of[0m[0m wolves[0m[0m is[0m[0m called[0m[0m a[0m[0m pack[0m[0m.[0m
Ground truth: 1



In [409]:
random_true_sample = datamodule.full_dataset[1500]
random_false_sample = datamodule.full_dataset[1701]

test_tokens(random_true_sample, strategy=test_tokens_with_grad)
test_tokens(random_false_sample, strategy=test_tokens_with_grad)

test_tokens(random_true_sample, strategy=test_single_tokens_with_saplma_inference)
test_tokens(random_false_sample, strategy=test_single_tokens_with_saplma_inference)

Using strategy: test_tokens_with_grad
Realistic probability (inferred): 99.9%
[0mRain[0m[0mbows[0m[0m form[0m[0m when[0m[0m light[0m[0m refr[0m[0macts[0m[0m through[0m[0m water[0m[0m dro[0m[0mplets[0m[0m.[0m
Ground truth: 1

Using strategy: test_tokens_with_grad
Realistic probability (inferred): 0.8%
[0mThe[0m[0m Earth[0m[0m's[0m[0m t[0m[0mides[0m[0m are[0m[0m primarily[0m[0m caused[0m[0m by[0m[0m the[0m[0m rep[0m[0mulsive[0m[0m push[0m[0m of[0m[0m the[0m[31m sun[0m[0m.[0m
Ground truth: 0

Using strategy: test_single_tokens_with_saplma_inference
[0mRain[0m[0mbows[0m[0m form[0m[0m when[0m[0m light[0m[0m refr[0m[0macts[0m[0m through[0m[0m water[0m[0m dro[0m[0mplets[0m[0m.[0m
Ground truth: 1

Using strategy: test_single_tokens_with_saplma_inference
[0mThe[0m[0m Earth[0m[0m's[0m[0m t[0m[33mides[0m[0m are[0m[0m primarily[0m[0m caused[0m[0m by[0m[0m the[0m[0m rep[0m[0mulsive[0m[0m push

In [410]:
# Pick 10 random samples from full_dataset
false_sentences_indexes = [
    sample for sample in datamodule.full_dataset if not sample[1]
]
random_samples = random.sample(false_sentences_indexes, 10)

for sample in random_samples:
    test_tokens(sample, strategy=test_tokens_with_grad, verbose=False)

Realistic probability (inferred): 52.6%
[0mJohn[0m[0m Amb[0m[0mrose[0m[0m Fleming[0m[0m invented[0m[0m the[0m[31m Band[0m[0m-A[0m[0mid[0m[0m.[0m
Ground truth: 0

Realistic probability (inferred): 0.0%
[0mSh[0m[0manghai[0m[0m P[0m[0mud[0m[0mong[0m[0m Development[0m[0m Bank[0m[0m operates[0m[0m in[0m[0m the[0m[0m industry[0m[0m of[0m[0m Banking[0m[0m.[0m
Ground truth: 0

Realistic probability (inferred): 99.6%
[0mHar[0m[33mare[0m[0m is[0m[0m a[0m[0m city[0m[0m in[0m[0m Falk[0m[0mland[0m[0m Islands[0m
Ground truth: 0

Realistic probability (inferred): 0.0%
[0mThe[0m[0m ot[0m[0mter[0m[0m has[0m[0m long[0m[0m ears[0m[0m for[0m[0m detecting[0m[31m predators[0m[0m and[0m[0m strong[0m[0m hind[0m[0m legs[0m[0m for[0m[0m escaping[0m[0m.[0m
Ground truth: 0

Realistic probability (inferred): 0.0%
[0mAs[0m[0munc[0m[0mión[0m[0m is[0m[0m a[0m[0m name[0m[0m of[0m[0m a[0m[0m country[0m[0