In [1]:
import numpy as np

In [2]:
from typing import Union
import numpy as np

from abc import ABC, abstractmethod


class AbstractWatermarkCode(ABC):
    @classmethod
    @abstractmethod
    def from_random(cls,
                    rng: Union[np.random.Generator, list[np.random.Generator]],
                    vocab_size: int):
        pass


class AbstractReweight(ABC):
    watermark_code_type: type[AbstractWatermarkCode]

    @abstractmethod
    def reweight(self,
                 code: AbstractWatermarkCode,
                 p: np.ndarray) -> np.ndarray:
        pass


def get_gumbel_variables(rng: np.random.Generator,
                         vocab_size: int):
    u = rng.random((vocab_size,))  # ~ Unifom(0, 1)
    e = -np.log(u)  # ~ Exp(1)
    g = -np.log(e)  # ~ Gumbel(0, 1)
    return u, e, g


class DeltaGumbel_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, g: np.ndarray):
        self.g = g

    @classmethod
    def from_random(
        cls,
        rng: Union[np.random.Generator, list[np.random.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            g = np.stack(
                [get_gumbel_variables(rng[i], vocab_size)[2] for i in range(batch_size)]
            )
        else:
            g = get_gumbel_variables(rng, vocab_size)[2]

        return cls(g)


class DeltaGumbel_Reweight(AbstractReweight):
    watermark_code_type = DeltaGumbel_WatermarkCode

    def __repr__(self):
        return f"DeltaGumbel_Reweight()"

    def reweight(
        self, code: AbstractWatermarkCode, p_logits: np.ndarray
    ) -> np.ndarray:

        assert isinstance(code, DeltaGumbel_WatermarkCode)

        index = np.argmax(p_logits + code.g, axis=-1)

        mask = np.arange(p_logits.shape[-1]) == index[..., None]

        modified_logits = np.where(
            mask,
            np.full_like(p_logits, 0),
            np.full_like(p_logits, float("-inf")),
        )
        return modified_logits

    def get_la_score(self, code):
        """likelihood agnostic score"""
        return np.array(np.log(2)) - np.exp(-code.g)


class WatermarkProcessor:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            ignore_history=False,
    ):

        self.private_key = private_key
        self.reweight = reweight
        self.cc_length = context_code_length
        self.ignore_history = ignore_history
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def reset_history(self):
        self.cc_history = set()

    def _get_codes(self, context, current_pos):
        batch_size = len(context)

        if current_pos == 0:
            context_codes = [
                context[i][-self.cc_length:].tobytes() for i in range(batch_size)
            ]

        else:
            cc_pos = current_pos - self.cc_length
            
            if cc_pos < 0:
                cc_pos = 0
            else:
                cc_pos = cc_pos
            
            context_codes = [
                context[i][cc_pos:current_pos][~np.isnan(context[i][cc_pos:current_pos])].tobytes() for i in range(batch_size)
            ]

        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def __call__(self, 
                 mode: str = 'normal',
                 context: np.ndarray = None,
                 logits: np.ndarray = None,
                 current_pos: int = None):

        if mode == 'normal':
            current_pos = 0
        elif mode == 'order_agnoistic':
            current_pos = current_pos
        else:
            raise NotImplementedError('Current watermark processor does not support this mode')

        mask, seeds = self._get_codes(context, current_pos=current_pos)

        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]

        mask = np.array(mask)

        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, logits.shape[1]
        )

        reweighted_logits = self.reweight.reweight(watermark_code, logits)

        if self.ignore_history:
            return reweighted_logits
        else:
            return np.where(mask[:, None], logits, reweighted_logits)


class WatermarkDetector:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            vocab_size: int = 20,
            ignore_history: bool = False
    ):
        self.private_key = private_key
        self.cc_length = context_code_length
        self.vocab_size = vocab_size
        self.reweight = reweight
        self.ignore_history = ignore_history
        self.cc_history = set()

    def reset_history(self):
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def _get_codes(self, context):
        batch_size = len(context)

        context_codes = [
            context[i][-self.cc_length:].tobytes() for i in range(batch_size)
        ]

        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def detect(self,
               input_ids: np.ndarray):
        """
        :param input_ids: sequences after tokenization
        :return:
        """
        
        scores = []
        for i in range(input_ids.shape[1]):

            score = self.get_la_score(input_ids[:,:i], input_ids[:,i], self.vocab_size)
            scores.append(score)
            
        scores = np.array(scores) 
        return np.sum(scores, axis=0)

    def get_la_score(
        self,
        input_ids: np.ndarray,
        labels: np.ndarray,
        vocab_size: int,
    ) -> np.ndarray:
        assert "get_la_score" in dir(
            self.reweight
        ), "Reweight does not support likelihood agnostic detection"
        mask, seeds = self._get_codes(input_ids)
        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]
        mask = np.array(mask)
        watermark_code = self.reweight.watermark_code_type.from_random(rng, vocab_size)
        all_scores = self.reweight.get_la_score(watermark_code)
        scores = all_scores[np.arange(all_scores.shape[0]), labels]
        scores = np.logical_not(mask) * scores
        return scores

In [3]:
input_id = np.array([[ 4, np.nan, np.nan, 15, 20, 18],
                     [ 5, np.nan, np.nan,  2,  0, 15],
                     [ 5, np.nan, np.nan, 21, 10, 17],
                     [ 6, np.nan, np.nan, 10, 15, 8],
                     [ 6, np.nan, np.nan, 9, 25, 9]])
logits = np.random.random((5, 20))

In [4]:
watermark_processor = WatermarkProcessor(private_key=b'private',
                                         reweight=DeltaGumbel_Reweight(),
                                         context_code_length=3,)

In [5]:
logits

array([[0.27345002, 0.46073045, 0.96927645, 0.75624814, 0.69989047,
        0.50814657, 0.43346905, 0.74601261, 0.96998015, 0.37814076,
        0.39418337, 0.58310508, 0.20646198, 0.64546119, 0.24977696,
        0.79664081, 0.18938096, 0.47139425, 0.0873729 , 0.28474509],
       [0.32283926, 0.08903497, 0.21233506, 0.10430039, 0.3206153 ,
        0.17804705, 0.04651974, 0.53754152, 0.10692629, 0.80250703,
        0.55071335, 0.25694895, 0.93563718, 0.49728107, 0.65539614,
        0.53114406, 0.51073647, 0.95689727, 0.95227567, 0.62795639],
       [0.13787495, 0.03095741, 0.10545715, 0.56647594, 0.74173664,
        0.22204499, 0.1169336 , 0.9389509 , 0.87800631, 0.11468302,
        0.20173521, 0.49373343, 0.28704077, 0.73549827, 0.86923247,
        0.02852061, 0.45394992, 0.2598513 , 0.35724781, 0.03543314],
       [0.71464493, 0.13757047, 0.4150354 , 0.42264515, 0.06106554,
        0.98146016, 0.36787342, 0.03697618, 0.2226742 , 0.38606002,
        0.88420054, 0.43256975, 0.43924154, 0

In [6]:
watermark_processor('order_agnoistic', input_id, logits, 1) # add watermark, two samples are not added

array([[      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf, 0.        ,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf, 0.        ,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [0.13787495, 0.03095741, 0.10545715, 0.56647594, 0.74173664,
        0.22204499, 0.1169336 , 0.9389509 , 0.87800631, 0.11468302,
        0.20173521, 0.49373343, 0.28704077, 0.73549827, 0.86923247,
        0.02852061, 0.45394992, 0.2598513 , 0.35724781, 0.03543314],
       [      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,  

In [7]:
watermark_processor('order_agnoistic', input_id, logits, 2) # no watermark should be added

array([[0.27345002, 0.46073045, 0.96927645, 0.75624814, 0.69989047,
        0.50814657, 0.43346905, 0.74601261, 0.96998015, 0.37814076,
        0.39418337, 0.58310508, 0.20646198, 0.64546119, 0.24977696,
        0.79664081, 0.18938096, 0.47139425, 0.0873729 , 0.28474509],
       [0.32283926, 0.08903497, 0.21233506, 0.10430039, 0.3206153 ,
        0.17804705, 0.04651974, 0.53754152, 0.10692629, 0.80250703,
        0.55071335, 0.25694895, 0.93563718, 0.49728107, 0.65539614,
        0.53114406, 0.51073647, 0.95689727, 0.95227567, 0.62795639],
       [0.13787495, 0.03095741, 0.10545715, 0.56647594, 0.74173664,
        0.22204499, 0.1169336 , 0.9389509 , 0.87800631, 0.11468302,
        0.20173521, 0.49373343, 0.28704077, 0.73549827, 0.86923247,
        0.02852061, 0.45394992, 0.2598513 , 0.35724781, 0.03543314],
       [0.71464493, 0.13757047, 0.4150354 , 0.42264515, 0.06106554,
        0.98146016, 0.36787342, 0.03697618, 0.2226742 , 0.38606002,
        0.88420054, 0.43256975, 0.43924154, 0

In [8]:
watermark_processor('order_agnoistic', input_id, logits, 3) # no watermark should be added

array([[0.27345002, 0.46073045, 0.96927645, 0.75624814, 0.69989047,
        0.50814657, 0.43346905, 0.74601261, 0.96998015, 0.37814076,
        0.39418337, 0.58310508, 0.20646198, 0.64546119, 0.24977696,
        0.79664081, 0.18938096, 0.47139425, 0.0873729 , 0.28474509],
       [0.32283926, 0.08903497, 0.21233506, 0.10430039, 0.3206153 ,
        0.17804705, 0.04651974, 0.53754152, 0.10692629, 0.80250703,
        0.55071335, 0.25694895, 0.93563718, 0.49728107, 0.65539614,
        0.53114406, 0.51073647, 0.95689727, 0.95227567, 0.62795639],
       [0.13787495, 0.03095741, 0.10545715, 0.56647594, 0.74173664,
        0.22204499, 0.1169336 , 0.9389509 , 0.87800631, 0.11468302,
        0.20173521, 0.49373343, 0.28704077, 0.73549827, 0.86923247,
        0.02852061, 0.45394992, 0.2598513 , 0.35724781, 0.03543314],
       [0.71464493, 0.13757047, 0.4150354 , 0.42264515, 0.06106554,
        0.98146016, 0.36787342, 0.03697618, 0.2226742 , 0.38606002,
        0.88420054, 0.43256975, 0.43924154, 0

### Naive tokenizer construction

In [9]:
amino_acids = 'ACDEFGHIKLMNPQRSTVWYX'  # List of amino acids
# Create a dictionary mapping each amino acid to its index
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
class AminoAcidTokenizer:
    def __init__(self, aa_to_index):
        self.aa_to_index = aa_to_index
        self.index_to_aa = {idx: aa for aa, idx in aa_to_index.items()}
        
    def encode(self, sequence):
        # Encode a sequence of amino acids to indices
        return [self.aa_to_index.get(aa, self.aa_to_index.get('X')) for aa in sequence]
        
    def decode(self, indices):
        # Decode a list of indices back into an amino acid sequence
        return ''.join(self.index_to_aa.get(idx, 'X') for idx in indices)
tokenizer = AminoAcidTokenizer(aa_to_index)
encoded_input = tokenizer.encode("MVDADTQKALDFI")
print(encoded_input)

[10, 17, 2, 0, 2, 16, 13, 8, 0, 9, 2, 4, 7]


### Read Tokenizer File to construct original tokenizer

In [10]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tokenizer.json")
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDAGGVVRLLTERLERAAAEPALRVPLALRASRGRREIRLQADRVGPLRALTERLVELSAPERVAVRLAPGDEGGVEVLPRVEALVLAVLDPDREVEVAVLAGKARSEISLVAGRRLLERRVASEPLPAAAVLAFLGEALREELLVPLDALRGLARRVAEGGGAELRLTLLSRARAEVVL")
# encoded_input = tokenizer.encode("2GGPPVVRLD")
print(encoded_input)

  from .autonotebook import tqdm as notebook_tqdm


[11, 11, 19, 19, 25, 25, 21, 15, 8, 9, 13, 21, 11, 21, 11, 13, 21, 25, 9, 10, 8, 5, 11, 11, 25, 25, 21, 15, 15, 23, 9, 21, 15, 9, 21, 5, 5, 5, 9, 19, 5, 15, 21, 25, 19, 15, 5, 15, 21, 5, 22, 21, 11, 21, 21, 9, 13, 21, 15, 20, 5, 8, 21, 25, 11, 19, 15, 21, 5, 15, 23, 9, 21, 15, 25, 9, 15, 22, 5, 19, 9, 21, 25, 5, 25, 21, 15, 5, 19, 11, 8, 9, 11, 11, 25, 9, 25, 15, 19, 21, 25, 9, 5, 15, 25, 15, 5, 25, 15, 8, 19, 8, 21, 9, 25, 9, 25, 5, 25, 15, 5, 11, 14, 5, 21, 22, 9, 13, 22, 15, 25, 5, 11, 21, 21, 15, 15, 9, 21, 21, 25, 5, 22, 9, 19, 15, 19, 5, 5, 5, 25, 15, 5, 10, 15, 11, 9, 5, 15, 21, 9, 9, 15, 15, 25, 19, 15, 8, 5, 15, 21, 11, 15, 5, 21, 21, 25, 5, 9, 11, 11, 11, 5, 9, 15, 21, 15, 23, 15, 15, 22, 21, 5, 21, 5, 9, 25, 25, 15]


In [11]:
len(tokenizer)

30

### Detection with Private Key

In [12]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)

In [13]:
input_ids = np.array(encoded_input).reshape(1,-1)

In [14]:
detector.detect(input_ids)

array([63.62662829])

In [15]:
# np.log(2) - (1/30) # theoretically maximum score

### Without Correct Key, Score Drops

In [16]:
detector = WatermarkDetector(b"private keey",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids)

array([-69.73725497])

### Watermark is Relatively Robust to Modification

In [17]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tokenizer.json")
# select the head of the protein
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDAGGVVRLLTERLERAAAEPALRVPLAL")
input_ids = np.array(encoded_input).reshape(1,-1)
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids)

array([23.63254586])

In [18]:
# modify the sequence by changing A to G
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDGGGVVRLLTERLERGGGEPGLRVPLGL")
print(encoded_input)
input_ids = np.array(encoded_input).reshape(1,-1)
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids) # we still have some watermark in it!

[11, 11, 19, 19, 25, 25, 21, 15, 8, 9, 13, 21, 11, 21, 11, 13, 21, 25, 9, 10, 8, 11, 11, 11, 25, 25, 21, 15, 15, 23, 9, 21, 15, 9, 21, 11, 11, 11, 9, 19, 11, 15, 21, 25, 19, 15, 11, 15]


array([5.64354916])

### Detect multiple sequences at once

In [19]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(np.concatenate([input_ids, input_ids])) # we still have some watermark in it!


array([5.64354916, 0.        ])