In [1]:
import numpy as np

In [2]:
from typing import Union
import numpy as np

from abc import ABC, abstractmethod


class AbstractWatermarkCode(ABC):
    @classmethod
    @abstractmethod
    def from_random(cls,
                    rng: Union[np.random.Generator, list[np.random.Generator]],
                    vocab_size: int):
        pass


class AbstractReweight(ABC):
    watermark_code_type: type[AbstractWatermarkCode]

    @abstractmethod
    def reweight(self,
                 code: AbstractWatermarkCode,
                 p: np.ndarray) -> np.ndarray:
        pass


def get_gumbel_variables(rng: np.random.Generator,
                         vocab_size: int):
    u = rng.random((vocab_size,))  # ~ Unifom(0, 1)
    e = -np.log(u)  # ~ Exp(1)
    g = -np.log(e)  # ~ Gumbel(0, 1)
    return u, e, g


class DeltaGumbel_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, g: np.ndarray):
        self.g = g

    @classmethod
    def from_random(
        cls,
        rng: Union[np.random.Generator, list[np.random.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            g = np.stack(
                [get_gumbel_variables(rng[i], vocab_size)[2] for i in range(batch_size)]
            )
        else:
            g = get_gumbel_variables(rng, vocab_size)[2]

        return cls(g)


class DeltaGumbel_Reweight(AbstractReweight):
    watermark_code_type = DeltaGumbel_WatermarkCode

    def __repr__(self):
        return f"DeltaGumbel_Reweight()"

    def reweight(
        self, code: AbstractWatermarkCode, p_logits: np.ndarray
    ) -> np.ndarray:

        assert isinstance(code, DeltaGumbel_WatermarkCode)

        index = np.argmax(p_logits + code.g, axis=-1)

        mask = np.arange(p_logits.shape[-1]) == index[..., None]

        modified_logits = np.where(
            mask,
            np.full_like(p_logits, 0),
            np.full_like(p_logits, float("-inf")),
        )
        return modified_logits

    def get_la_score(self, code):
        """likelihood agnostic score"""
        return np.array(np.log(2)) - np.exp(-code.g)


class WatermarkProcessor:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            ignore_history=False,
    ):

        self.private_key = private_key
        self.reweight = reweight
        self.cc_length = context_code_length
        self.ignore_history = ignore_history
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def reset_history(self):
        self.cc_history = set()

    def _get_codes(self, context, current_pos):
        batch_size = len(context)

        if current_pos == 0:
            context_codes = [
                context[i][-self.cc_length:].tobytes() for i in range(batch_size)
            ]

        else:
            cc_pos = current_pos - self.cc_length
            
            if cc_pos < 0:
                cc_pos = 0
            else:
                cc_pos = cc_pos
            
            context_codes = [
                context[i][cc_pos:current_pos][~np.isnan(context[i][cc_pos:current_pos])].tobytes() for i in range(batch_size)
            ]

        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def __call__(self, 
                 mode: str = 'normal',
                 context: np.ndarray = None,
                 logits: np.ndarray = None,
                 current_pos: int = None):

        if mode == 'normal':
            current_pos = 0
        elif mode == 'order_agnoistic':
            current_pos = current_pos
        else:
            raise NotImplementedError('Current watermark processor does not support this mode')

        mask, seeds = self._get_codes(context, current_pos=current_pos)

        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]

        mask = np.array(mask)

        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, logits.shape[1]
        )

        reweighted_logits = self.reweight.reweight(watermark_code, logits)

        if self.ignore_history:
            return reweighted_logits
        else:
            return np.where(mask[:, None], logits, reweighted_logits)


class WatermarkDetector:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            vocab_size: int = 20,
            ignore_history: bool = False
    ):
        self.private_key = private_key
        self.cc_length = context_code_length
        self.vocab_size = vocab_size
        self.reweight = reweight
        self.ignore_history = ignore_history
        self.cc_history = set()

    def reset_history(self):
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def _get_codes(self, context):
        batch_size = len(context)

        context_codes = [
            context[i][-self.cc_length:].tobytes() for i in range(batch_size)
        ]

        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def detect(self,
               input_ids: np.ndarray):
        """
        :param input_ids: sequences after tokenization
        :return:
        """
        
        scores = []
        for i in range(input_ids.shape[1]):

            score = self.get_la_score(input_ids[:,:i], input_ids[:,i], self.vocab_size)
            scores.append(score)
            
        scores = np.array(scores) 
        return np.sum(scores, axis=0)

    def get_la_score(
        self,
        input_ids: np.ndarray,
        labels: np.ndarray,
        vocab_size: int,
    ) -> np.ndarray:
        assert "get_la_score" in dir(
            self.reweight
        ), "Reweight does not support likelihood agnostic detection"
        mask, seeds = self._get_codes(input_ids)
        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]
        mask = np.array(mask)
        watermark_code = self.reweight.watermark_code_type.from_random(rng, vocab_size)
        all_scores = self.reweight.get_la_score(watermark_code)
        scores = all_scores[np.arange(all_scores.shape[0]), labels]
        scores = np.logical_not(mask) * scores
        return scores

In [3]:
input_id = np.array([[ 4, np.nan, np.nan, 15, 20, 18],
                     [ 5, np.nan, np.nan,  2,  0, 15],
                     [ 5, np.nan, np.nan, 21, 10, 17],
                     [ 6, np.nan, np.nan, 10, 15, 8],
                     [ 6, np.nan, np.nan, 9, 25, 9]])
logits = np.random.random((5, 20))

In [4]:
watermark_processor = WatermarkProcessor(private_key=b'private',
                                         reweight=DeltaGumbel_Reweight(),
                                         context_code_length=3,)

In [5]:
logits

array([[0.36316354, 0.59542443, 0.24230033, 0.63879487, 0.67896723,
        0.82720807, 0.88676025, 0.15824899, 0.00889607, 0.72545931,
        0.93644745, 0.44424959, 0.36290092, 0.08732449, 0.1200061 ,
        0.34865266, 0.99849804, 0.84025134, 0.36790185, 0.53772777],
       [0.95197296, 0.70819715, 0.07972865, 0.83036348, 0.61822959,
        0.16988685, 0.40061496, 0.68030451, 0.41058003, 0.26533909,
        0.86764167, 0.02757654, 0.61631643, 0.05568713, 0.02220829,
        0.70892512, 0.95568788, 0.99658983, 0.63780764, 0.6710573 ],
       [0.20165967, 0.39573435, 0.02946278, 0.2102426 , 0.01116232,
        0.45755235, 0.05935861, 0.44202719, 0.24094968, 0.67877363,
        0.73265155, 0.00550702, 0.32155552, 0.16857977, 0.09366355,
        0.2702031 , 0.09740327, 0.90906146, 0.67867739, 0.8113312 ],
       [0.50224427, 0.75140448, 0.14995572, 0.16796419, 0.1662573 ,
        0.42600325, 0.34660815, 0.99811681, 0.01562505, 0.38494506,
        0.73954064, 0.08743451, 0.93793213, 0

In [6]:
watermark_processor('order_agnoistic', input_id, logits, 1) # add watermark, two samples are not added

array([[      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf, 0.        ,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf, 0.        ,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [0.20165967, 0.39573435, 0.02946278, 0.2102426 , 0.01116232,
        0.45755235, 0.05935861, 0.44202719, 0.24094968, 0.67877363,
        0.73265155, 0.00550702, 0.32155552, 0.16857977, 0.09366355,
        0.2702031 , 0.09740327, 0.90906146, 0.67867739, 0.8113312 ],
       [      -inf,       -inf,       -inf,       -inf, 0.        ,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,  

In [7]:
watermark_processor('order_agnoistic', input_id, logits, 2) # no watermark should be added

array([[0.36316354, 0.59542443, 0.24230033, 0.63879487, 0.67896723,
        0.82720807, 0.88676025, 0.15824899, 0.00889607, 0.72545931,
        0.93644745, 0.44424959, 0.36290092, 0.08732449, 0.1200061 ,
        0.34865266, 0.99849804, 0.84025134, 0.36790185, 0.53772777],
       [0.95197296, 0.70819715, 0.07972865, 0.83036348, 0.61822959,
        0.16988685, 0.40061496, 0.68030451, 0.41058003, 0.26533909,
        0.86764167, 0.02757654, 0.61631643, 0.05568713, 0.02220829,
        0.70892512, 0.95568788, 0.99658983, 0.63780764, 0.6710573 ],
       [0.20165967, 0.39573435, 0.02946278, 0.2102426 , 0.01116232,
        0.45755235, 0.05935861, 0.44202719, 0.24094968, 0.67877363,
        0.73265155, 0.00550702, 0.32155552, 0.16857977, 0.09366355,
        0.2702031 , 0.09740327, 0.90906146, 0.67867739, 0.8113312 ],
       [0.50224427, 0.75140448, 0.14995572, 0.16796419, 0.1662573 ,
        0.42600325, 0.34660815, 0.99811681, 0.01562505, 0.38494506,
        0.73954064, 0.08743451, 0.93793213, 0

In [8]:
watermark_processor('order_agnoistic', input_id, logits, 3) # no watermark should be added

array([[0.36316354, 0.59542443, 0.24230033, 0.63879487, 0.67896723,
        0.82720807, 0.88676025, 0.15824899, 0.00889607, 0.72545931,
        0.93644745, 0.44424959, 0.36290092, 0.08732449, 0.1200061 ,
        0.34865266, 0.99849804, 0.84025134, 0.36790185, 0.53772777],
       [0.95197296, 0.70819715, 0.07972865, 0.83036348, 0.61822959,
        0.16988685, 0.40061496, 0.68030451, 0.41058003, 0.26533909,
        0.86764167, 0.02757654, 0.61631643, 0.05568713, 0.02220829,
        0.70892512, 0.95568788, 0.99658983, 0.63780764, 0.6710573 ],
       [0.20165967, 0.39573435, 0.02946278, 0.2102426 , 0.01116232,
        0.45755235, 0.05935861, 0.44202719, 0.24094968, 0.67877363,
        0.73265155, 0.00550702, 0.32155552, 0.16857977, 0.09366355,
        0.2702031 , 0.09740327, 0.90906146, 0.67867739, 0.8113312 ],
       [0.50224427, 0.75140448, 0.14995572, 0.16796419, 0.1662573 ,
        0.42600325, 0.34660815, 0.99811681, 0.01562505, 0.38494506,
        0.73954064, 0.08743451, 0.93793213, 0

### Naive tokenizer construction

In [9]:
amino_acids = 'ACDEFGHIKLMNPQRSTVWYX'  # List of amino acids
# Create a dictionary mapping each amino acid to its index
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
class AminoAcidTokenizer:
    def __init__(self, aa_to_index):
        self.aa_to_index = aa_to_index
        self.index_to_aa = {idx: aa for aa, idx in aa_to_index.items()}
        
    def encode(self, sequence):
        # Encode a sequence of amino acids to indices
        return [self.aa_to_index.get(aa, self.aa_to_index.get('X')) for aa in sequence]
        
    def decode(self, indices):
        # Decode a list of indices back into an amino acid sequence
        return ''.join(self.index_to_aa.get(idx, 'X') for idx in indices)
tokenizer = AminoAcidTokenizer(aa_to_index)
encoded_input = tokenizer.encode("MVDADTQKALDFI")
print(encoded_input)

[10, 17, 2, 0, 2, 16, 13, 8, 0, 9, 2, 4, 7]


### Read Tokenizer File to construct original tokenizer

In [10]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tokenizer.json")
# generated with cmd line:
# python3 sample_wm.py --model progen2-large --t 0.8 --p 0.9 --max-length 100 --num-samples 1 --context "2" --rng-seed 0
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDAGGVVRLLTERLERAAAEPALRVPLALRASRGRREIRLQADRVGPLRALTERLVELSAPERVAVRLAPGDEGGVEVLPRVEALVLAVLDPDREVEVAVLAGKARSEISLVAGRRLLERRVASEPLPAAAVLAFLGEALREELLVPLDALRGLARRVAEGGGAELRLTLLSRARAEVVL")
# encoded_input = tokenizer.encode("2GGPPVVRLD")
print(encoded_input)

  from .autonotebook import tqdm as notebook_tqdm


[11, 11, 19, 19, 25, 25, 21, 15, 8, 9, 13, 21, 11, 21, 11, 13, 21, 25, 9, 10, 8, 5, 11, 11, 25, 25, 21, 15, 15, 23, 9, 21, 15, 9, 21, 5, 5, 5, 9, 19, 5, 15, 21, 25, 19, 15, 5, 15, 21, 5, 22, 21, 11, 21, 21, 9, 13, 21, 15, 20, 5, 8, 21, 25, 11, 19, 15, 21, 5, 15, 23, 9, 21, 15, 25, 9, 15, 22, 5, 19, 9, 21, 25, 5, 25, 21, 15, 5, 19, 11, 8, 9, 11, 11, 25, 9, 25, 15, 19, 21, 25, 9, 5, 15, 25, 15, 5, 25, 15, 8, 19, 8, 21, 9, 25, 9, 25, 5, 25, 15, 5, 11, 14, 5, 21, 22, 9, 13, 22, 15, 25, 5, 11, 21, 21, 15, 15, 9, 21, 21, 25, 5, 22, 9, 19, 15, 19, 5, 5, 5, 25, 15, 5, 10, 15, 11, 9, 5, 15, 21, 9, 9, 15, 15, 25, 19, 15, 8, 5, 15, 21, 11, 15, 5, 21, 21, 25, 5, 9, 11, 11, 11, 5, 9, 15, 21, 15, 23, 15, 15, 22, 21, 5, 21, 5, 9, 25, 25, 15]


In [11]:
len(tokenizer)

30

### Detection with Private Key

In [12]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)

In [13]:
input_ids = np.array(encoded_input).reshape(1,-1)

In [14]:
detector.detect(input_ids)

array([63.62662829])

In [15]:
# np.log(2) - (1/30) # theoretically maximum expectation score

### Without watermark in the model, we can not detect

In [16]:
# generated with cmd line:
# python3 sample.py --model progen2-large --t 0.8 --p 0.9 --max-length 100 --num-samples 1 --context "2" --rng-seed 0
encoded_input = tokenizer.encode("DTAEEETDAALEADSLEAAVRKAEDIDLEGKGLAALLAEIEPFGSETALRELGTLMGEARSVALSSVTGHATKAVSGLKALTTAVSQAEAGASIYLPKQ")
print(encoded_input)

[8, 23, 5, 9, 9, 9, 23, 8, 5, 5, 15, 9, 5, 8, 22, 15, 9, 5, 5, 25, 21, 14, 5, 9, 8, 13, 8, 15, 9, 11, 14, 11, 15, 5, 5, 15, 15, 5, 9, 13, 9, 19, 10, 11, 22, 9, 23, 5, 15, 21, 9, 15, 11, 23, 15, 16, 11, 9, 5, 21, 22, 25, 5, 15, 22, 22, 25, 23, 11, 12, 5, 23, 14, 5, 25, 22, 11, 15, 14, 5, 15, 23, 23, 5, 25, 22, 20, 5, 9, 5, 11, 5, 22, 13, 28, 15, 19, 14, 20]


In [17]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
input_ids = np.array(encoded_input).reshape(1,-1)

In [18]:
detector.detect(input_ids)

array([-26.23357289])

### Without Correct Key, Score Drops

In [19]:
# generated with cmd line:
# python3 sample_wm.py --model progen2-large --t 0.8 --p 0.9 --max-length 100 --num-samples 1 --context "2" --rng-seed 0
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDAGGVVRLLTERLERAAAEPALRVPLALRASRGRREIRLQADRVGPLRALTERLVELSAPERVAVRLAPGDEGGVEVLPRVEALVLAVLDPDREVEVAVLAGKARSEISLVAGRRLLERRVASEPLPAAAVLAFLGEALREELLVPLDALRGLARRVAEGGGAELRLTLLSRARAEVVL")
input_ids = np.array(encoded_input).reshape(1,-1)
detector = WatermarkDetector(b"private keey",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids)

array([-69.73725497])

### Watermark is Relatively Robust to Modification

In [20]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="./tokenizer.json")
# select the head of the protein
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDAGGVVRLLTERLERAAAEPALRVPLAL")
input_ids = np.array(encoded_input).reshape(1,-1)
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids)

array([23.63254586])

In [21]:
# modify the sequence by changing A to G
encoded_input = tokenizer.encode("GGPPVVRLDEIRGRGIRVEFDGGGVVRLLTERLERGGGEPGLRVPLGL")
print(encoded_input)
input_ids = np.array(encoded_input).reshape(1,-1)
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(input_ids) # we still have some watermark in it!

[11, 11, 19, 19, 25, 25, 21, 15, 8, 9, 13, 21, 11, 21, 11, 13, 21, 25, 9, 10, 8, 11, 11, 11, 25, 25, 21, 15, 15, 23, 9, 21, 15, 9, 21, 11, 11, 11, 9, 19, 11, 15, 21, 25, 19, 15, 11, 15]


array([5.64354916])

### Detect multiple sequences at once

In [22]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(np.concatenate([input_ids, input_ids])) # we still have some watermark in it!


array([5.64354916, 0.        ])