# SynthID Text Integration Tests

This colab provides integration tests to ensure that the SynthID Text library
works correctly with the Hugging Face Transformers library. Run this notebook 
end-to-end to execute the integration tests on CPU and GPU using a GPT-2 model.

## Test setup

In [None]:
# @title Install and import the required Python libraries
#
# @markdown _This may require you to restart your session._

! pip install -q synthid-text

from collections.abc import Mapping, Sequence
from typing import Any

from synthid_text import g_value_expectations
from synthid_text import logits_processing
from synthid_text import synthid_mixin
import torch
import tqdm
import transformers

In [None]:
# @title Define integration test logic

def initiate_test_model(
    *,
    ngram_len: int,
    keys: Sequence[int],
    sampling_table_size: int,
    sampling_table_seed: int,
    context_history_size: int,
    device: torch.device,
    vocab_size: int,
    pretrained_model_name_or_path: str,
) -> 'TestModel':
  config = {
      'ngram_len': ngram_len,
      'keys': keys,
      'sampling_table_size': sampling_table_size,
      'sampling_table_seed': sampling_table_seed,
      'context_history_size': context_history_size,
      'device': device,
  }


  class TestLogitsProcessor(transformers.LogitsProcessor):
    """Test logits processor used for bias testing later in the colab."""

    @torch.no_grad
    def __call__(
        self,
        input_ids: torch.LongTensor,
        scores: torch.FloatTensor,
    ) -> torch.FloatTensor:
      """Create and return uniform scores for bias testing."""
      return torch.ones_like(scores, device=scores.device)


  class TestModelMixin(synthid_mixin.SynthIDSparseTopKMixin):

    def _construct_warper_list(
        self,
        extra_params: Mapping[str, Any],
        watermark_config: Mapping[str, Any] = config,
    ) -> transformers.LogitsProcessorList:
      """Instantiate warpers list."""
      warpers = transformers.LogitsProcessorList()
      # Add a logits warper that converts logits to uniform for testing.
      warpers.append(TestLogitsProcessor())
      extra_params['top_k'] = vocab_size
      warpers.append(
          logits_processing.SynthIDLogitsProcessor(
              **watermark_config, **extra_params
          )
      )
      return warpers


  class TestModel(TestModelMixin, transformers.GPT2LMHeadModel):
    pass


  return TestModel.from_pretrained(
      pretrained_model_name_or_path,
      device_map=device,
  )

def test_mean_g_value_matches_theoretical_integrated(
    *,
    vocab_size: int,
    ngram_len: int,
    keys: Sequence[int],
    device: torch.device,
    pretrained_model_name_or_path: str,
    batch_size: int,
    num_repeats: int,
    outputs_len: int,
    atol: float,
    sampling_table_size: int = 2**16,
    sampling_table_seed: int = 0,
    context_history_size: int = 1024,
) -> tuple[float, float]:
  """Tests the value of the mean g-value in the sampling loop."""
  model = initiate_test_model(
      ngram_len=ngram_len,
      keys=keys,
      sampling_table_size=sampling_table_size,
      sampling_table_seed=sampling_table_seed,
      context_history_size=context_history_size,
      device=device,
      vocab_size=vocab_size,
      pretrained_model_name_or_path=pretrained_model_name_or_path,
  )
  # Turn off EOS token stopping.
  generation_config = transformers.GenerationConfig.from_model_config(
      model.config
  )
  generation_config.eos_token_id = None
  generation_config.stop_strings = None

  logits_processor = logits_processing.SynthIDLogitsProcessor(
      ngram_len=ngram_len,
      keys=keys,
      sampling_table_size=sampling_table_size,
      sampling_table_seed=sampling_table_seed,
      context_history_size=context_history_size,
      device=device,
      top_k=vocab_size,
      temperature=0.7,
  )

  inputs_len = ngram_len - 1
  input_ids = torch.zeros(
      (batch_size, inputs_len),
      dtype=torch.int64,
      device=device,
  )
  inputs = {
      'input_ids': input_ids,
      'attention_mask': torch.ones_like(input_ids, device=device),
  }

  expected_mean_g_value = g_value_expectations.expected_mean_g_value(
      vocab_size=vocab_size,
  )
  mean_g_values_repeats = []

  torch.manual_seed(0)
  for i in tqdm.tqdm(range(num_repeats)):
    outputs = model.generate(
        **inputs,
        do_sample=True,
        temperature=1.0,
        top_k=vocab_size,
        max_length=inputs_len + outputs_len,
        generation_config=generation_config,
    )
    g_values = logits_processor.compute_g_values(
        input_ids=outputs[:, inputs_len:],
    )
    context_repetition_mask = logits_processor.compute_context_repetition_mask(
        input_ids=outputs[:, inputs_len:],
    ).unsqueeze(dim=2)
    mean_g_values = torch.masked.mean(
        g_values,
        mask=context_repetition_mask,
        dim=0,
        keepdim=True,
        dtype=torch.float64,
    )
    mean_g_values_repeats.append(mean_g_values)

  mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0)
  is_close = torch.isclose(
      mean_g_values,
      torch.tensor(expected_mean_g_value, dtype=torch.float64),
      atol=atol,
      rtol=0,
  )

  return mean_g_values, expected_mean_g_value, is_close

In [None]:
# @title Define all test parameters

COMMON_CONFIG = dict(
    batch_size=400,
    outputs_len=30,
    keys=[38],
    atol=0.01,
    pretrained_model_name_or_path=(
        'trl-internal-testing/tiny-random-GPT2LMHeadModel'
    ),
)

VOCAB_15_NGRAM_6_CONFIG = dict(
    vocab_size=15,
    ngram_len=6,
    num_repeats=100,
    **COMMON_CONFIG
)

VOCAB_1000_NGRAM_10_CONFIG = dict(
    vocab_size=1000,
    ngram_len=10,
    num_repeats=50,
    **COMMON_CONFIG
)

## Test Invocations

In [None]:
# @title Run parameterized tests on CPU

DEVICE = torch.device('cpu')

for test_config in (VOCAB_15_NGRAM_6_CONFIG, VOCAB_1000_NGRAM_10_CONFIG):
  result = test_mean_g_value_matches_theoretical_integrated(
      device=DEVICE, **test_config
  )
  print(result)
  del result

In [None]:
# @title  Run parameterized tests on GPU

if not torch.cuda.is_available():
  raise RuntimeError(
      'Attempted to run tests on a GPU when no GPU is available.'
  )

DEVICE = torch.device('cuda:0')

for test_config in (VOCAB_15_NGRAM_6_CONFIG, VOCAB_1000_NGRAM_10_CONFIG):
  result = test_mean_g_value_matches_theoretical_integrated(
      device=DEVICE, **test_config
  )
  print(result)
  del result