In [1]:
import numpy as np
from abc import ABC, abstractmethod
from typing import Union

class AbstractWatermarkCode(ABC):
    @classmethod
    @abstractmethod
    def from_random(cls,
                    rng: Union[np.random.Generator, list[np.random.Generator]],
                    vocab_size: int):
        pass

class AbstractReweight(ABC):
    watermark_code_type: type[AbstractWatermarkCode]

    @abstractmethod
    def reweight(self,
                 code: AbstractWatermarkCode,
                 p: np.ndarray) -> np.ndarray:
        pass

    @abstractmethod
    def get_la_score(self,
                     code: AbstractWatermarkCode) -> np.ndarray:
        pass

def get_gumbel_variables(rng: np.random.Generator,
                         vocab_size: int):
    u = rng.random((vocab_size,))  # ~ Unifom(0, 1)
    e = -np.log(u)  # ~ Exp(1)
    g = -np.log(e)  # ~ Gumbel(0, 1)
    return u, e, g


class DeltaGumbel_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, g: np.ndarray):
        self.g = g

    @classmethod
    def from_random(
            cls,
            rng: Union[np.random.Generator, list[np.random.Generator]],
            vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            g = np.stack(
                [get_gumbel_variables(rng[i], vocab_size)[2] for i in range(batch_size)]
            )
        else:
            g = get_gumbel_variables(rng, vocab_size)[2]

        return cls(g)


class DeltaGumbel_Reweight(AbstractReweight):
    watermark_code_type = DeltaGumbel_WatermarkCode

    def __repr__(self):
        return f"DeltaGumbel_Reweight()"

    def reweight(
            self, code: AbstractWatermarkCode, p_logits: np.ndarray
    ) -> np.ndarray:
        assert isinstance(code, DeltaGumbel_WatermarkCode)

        index = np.argmax(p_logits + code.g, axis=-1)

        mask = np.arange(p_logits.shape[-1]) == index[..., None]

        modified_logits = np.where(
            mask,
            np.full_like(p_logits, 0),
            np.full_like(p_logits, float("-inf")),
        )
        return modified_logits

    def get_la_score(self, code):
        """likelihood agnostic score"""
        return np.array(np.log(2)) - np.exp(-code.g)

class WatermarkDetector:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            vocab_size: int = 20,
            ignore_history: bool = False
    ):
        self.private_key = private_key
        self.cc_length = context_code_length
        self.vocab_size = vocab_size
        self.reweight = reweight
        self.ignore_history = ignore_history
        self.cc_history = set()

    def reset_history(self):
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def _get_codes(self, context):
        batch_size = len(context)

        context_codes = [
            context[i][-self.cc_length:].tobytes() for i in range(batch_size)
        ]
        # print(context[0][-self.cc_length:])
        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def detect(self,
               input_ids: np.ndarray):
        """
        :param input_ids: sequences after tokenization
        :return: scores, a higher score means a seq is likely to be watermarked
        """

        scores = []
        for i in range(input_ids.shape[1]):
            score = self.get_la_score(input_ids[:, :i], input_ids[:, i], self.vocab_size)
            
            ti = score-np.log(2)
            scores.append(ti)
        assert np.all(ti<=0), ti
        tis = np.array(scores)
        uis = np.exp(tis)
        Ubar=np.mean(uis)
        print("Ubar", Ubar)

        if np.mean(uis)==0:
            final_score=0
            final_p_value=1
            return final_score, final_p_value
        avgS = lambda Ubar, lamb: Ubar*lamb+np.log(lamb/np.expm1(lamb))
        import scipy.optimize
        sol=scipy.optimize.minimize(lambda l:-avgS(Ubar, l), 0.5, bounds=[(0,10)])
        final_score=-sol.fun*input_ids.shape[1]
        final_p_value = np.exp(-final_score)
        return final_score, final_p_value
        
        A = np.mean(tis)
        final_score = (-1-A-np.log(-A))*input_ids.shape[1]
        final_p_value = np.exp(-final_score)

        print("optimal score is", final_score)
        print("optimal p-value is", final_p_value)
        return final_score, final_p_value

    def get_la_score(
            self,
            input_ids: np.ndarray,
            labels: np.ndarray,
            vocab_size: int,
    ) -> np.ndarray:
        assert "get_la_score" in dir(
            self.reweight
        ), "Reweight does not support likelihood agnostic detection"
        mask, seeds = self._get_codes(input_ids)
        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]
        
        mask = np.array(mask)
        watermark_code = self.reweight.watermark_code_type.from_random(rng, vocab_size)
        all_scores = self.reweight.get_la_score(watermark_code)
        scores = all_scores[np.arange(all_scores.shape[0]), labels]
        
        scores = np.logical_not(mask) * scores
        return scores

In [2]:
amino_acids = 'ACDEFGHIKLMNPQRSTVWYX'  # List of amino acids
# Create a dictionary mapping each amino acid to its index
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
class AminoAcidTokenizer:
    def __init__(self, aa_to_index):
        self.aa_to_index = aa_to_index
        self.index_to_aa = {idx: aa for aa, idx in aa_to_index.items()}
        
    def encode(self, sequence):
        # Encode a sequence of amino acids to indices
        return np.array([self.aa_to_index.get(aa, self.aa_to_index.get('X')) for aa in sequence]).reshape(1,-1)
        
    def decode(self, indices):
        # Decode a list of indices back into an amino acid sequence
        return ''.join(self.index_to_aa.get(idx, 'X') for idx in indices)
tokenizer = AminoAcidTokenizer(aa_to_index)

In [3]:
# 7ZEE design t=0.1
seq_1 = "MSLEEETATLDHPNVRIADPARVAEILAALRAGGADALRVVSDFDGTLSLVTKDGVPQPSLDDVLYNSPYISEEAKAKLDALDAEYTPIFNDPNLTVEQKLPYAKEYKTKKLEILTTENIKKSQIKEAVEKSGVKLREGAKRFFTLLEEHGVPLVIFSDGIGDIVEELIKSNNLLYPNIKIVANFFKYDENGNLVGFEGKLVTRFNKNATLEXXXXXXXXXXSHTHVILLGDSLSDINMTDGLPGITTELRIGFLNSDIEKNLEKFLATFDIVLVQDESLYVVNGILEEILG"
# 7ZEE design t=0.3
seq_2 = "LPVSAEKASLEHPHVRIADPARVAAILAALRAGGADALRVVSDFDGTLSLAKKNGVPQPSLNDVLKNSDVVSDEAKAKLKEIDEKYLPILNDPNLSKEEKLPYAKEYTTEKLEILKTENIKKSQIKEVVEKSGVKLREGAKRFFTLLEEHGVPLVIFSSGIGDIAEELIKSNNLLHSNITIVANFFKYDENGNLVGFEGKLVNKLNKNAKNLXXXXXXXXXXAHTHVILLGDSLSDIEMTEGLPGVTTELRIGFLNDSIEEKLEEFLARFDIVLVDDESLFVVNGILDDVLG"
# 7ZEE design t=0.5
seq_3 = "LPVSAEKASLEHPHVRIADPKRVADILEQLREGGSDRLAIVSDFDRTLSASFKDGVPQPSMDDVLKNSDVVSDEAKAEFAKLDAEYTPIFDDPNLTVAEKIPFAQKYYAEKLAILTKEEIKESQIAEMVRKSNVRLKEGAKRFFNLANEHKIPLYIFSAGIGDIKKELIRENGLYHDNIHLISNFFKFNEEGKLVGFEGALVTRFNKNMNNYXXXXXXXXXXNRTHVILIGDSLDDLDMHKGMEGITTLLSIGFLRSDIETNLKKFLDSFDIVLVQDESLYVVNGILDYITG"

In [4]:
original_1 = [[10, 15, 9, 3, 3, 3, 16, 0, 16, 9, 2, 6, 12, 11, 17, 14, 7, 0, 2, 12, 0, 14, 17, 0,
             3, 7, 9, 0, 0, 9, 14, 0, 5, 5, 0, 2, 0, 9, 14, 17, 17, 15, 2, 4, 2, 5, 16, 9,
             15, 9, 17, 16, 8, 2, 5, 17, 12, 13, 12, 15, 9, 2, 2, 17, 9, 19, 11, 15, 12, 19,
             7, 15, 3, 3, 0, 8, 0, 8, 9, 2, 0, 9, 2, 0, 3, 19, 16, 12, 7, 4, 11, 2, 12, 11,
             9, 16, 17, 3, 13, 8, 9, 12, 19, 0, 8, 3, 19, 8, 16, 8, 8, 9, 3, 7, 9, 16, 16,
             3, 11, 7, 8, 8, 15, 13, 7, 8, 3, 0, 17, 3, 8, 15, 5, 17, 8, 9, 14, 3, 5, 0, 8,
             14, 4, 4, 16, 9, 9, 3, 3, 6, 5, 17, 12, 9, 17, 7, 4, 15, 2, 5, 7, 5, 2, 7, 17,
             3, 3, 9, 7, 8, 15, 11, 11, 9, 9, 19, 12, 11, 7, 8, 7, 17, 0, 11, 4, 4, 8, 19,
             2, 3, 11, 5, 11, 9, 17, 5, 4, 3, 5, 8, 9, 17, 16, 14, 4, 11, 8, 11, 0, 16, 9,
             3, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 15, 6, 16, 6, 17, 7, 9, 9, 5, 2, 15,
             9, 15, 2, 7, 11, 10, 16, 2, 5, 9, 12, 5, 7, 16, 16, 3, 9, 14, 7, 5, 4, 9, 11,
             15, 2, 7, 3, 8, 11, 9, 3, 8, 4, 9, 0, 16, 4, 2, 7, 17, 9, 17, 13, 2, 3, 15, 9,
             19, 17, 17, 11, 5, 7, 9, 3, 3, 7, 9, 5]]

# Convert the list to a NumPy array
original_1 = np.array(original_1)
print(np.allclose(tokenizer.encode(seq_1), original_1))

True


In [5]:
original_2 = [[9, 12, 17, 15, 0, 3, 8, 0, 15, 9, 3, 6, 12, 6, 17, 14, 7, 0, 2, 12, 0, 14, 17, 0,
            0, 7, 9, 0, 0, 9, 14, 0, 5, 5, 0, 2, 0, 9, 14, 17, 17, 15, 2, 4, 2, 5, 16, 9,
            15, 9, 0, 8, 8, 11, 5, 17, 12, 13, 12, 15, 9, 11, 2, 17, 9, 8, 11, 15, 2, 17, 17,
            15, 2, 3, 0, 8, 0, 8, 9, 8, 3, 7, 2, 3, 8, 19, 9, 12, 7, 9, 11, 2, 12, 11, 9, 15,
            8, 3, 3, 8, 9, 12, 19, 0, 8, 3, 19, 16, 16, 3, 8, 9, 3, 7, 9, 8, 16, 3, 11, 7,
            8, 8, 15, 13, 7, 8, 3, 17, 17, 3, 8, 15, 5, 17, 8, 9, 14, 3, 5, 0, 8, 14, 4, 4,
            16, 9, 9, 3, 3, 6, 5, 17, 12, 9, 17, 7, 4, 15, 15, 5, 7, 5, 2, 7, 0, 3, 3, 9,
            7, 8, 15, 11, 11, 9, 9, 6, 15, 11, 7, 16, 7, 17, 0, 11, 4, 4, 8, 19, 2, 3, 11,
            5, 11, 9, 17, 5, 4, 3, 5, 8, 9, 17, 11, 8, 9, 11, 8, 11, 0, 8, 11, 9, 20, 20, 20,
            20, 20, 20, 20, 20, 20, 20, 0, 6, 16, 6, 17, 7, 9, 9, 5, 2, 15, 9, 15, 2, 7, 3,
            10, 16, 3, 5, 9, 12, 5, 17, 16, 16, 3, 9, 14, 7, 5, 4, 9, 11, 2, 15, 7, 3, 3,
            8, 9, 3, 3, 4, 9, 0, 14, 4, 2, 7, 17, 9, 17, 2, 2, 3, 15, 9, 4, 17, 17, 11, 5,
            7, 9, 2, 2, 17, 9, 5]]

# Convert the list to a NumPy array
original_2 = np.array(original_2)
print(np.allclose(tokenizer.encode(seq_2), original_2))

True


In [6]:
original_3 = [[9, 12, 17, 15, 0, 3, 8, 0, 15, 9, 3, 6, 12, 6, 17, 14, 7, 0, 2, 12, 8, 14, 17, 0,
            2, 7, 9, 3, 13, 9, 14, 3, 5, 5, 15, 2, 14, 9, 0, 7, 17, 15, 2, 4, 2, 14, 16, 9,
            15, 0, 15, 4, 8, 2, 5, 17, 12, 13, 12, 15, 10, 2, 2, 17, 9, 8, 11, 15, 2, 17, 17,
            15, 2, 3, 0, 8, 0, 3, 4, 0, 8, 9, 2, 0, 3, 19, 16, 12, 7, 4, 2, 2, 12, 11, 9, 16,
            17, 0, 3, 8, 7, 12, 4, 0, 13, 8, 19, 19, 0, 3, 8, 9, 0, 7, 9, 16, 8, 3, 3, 7, 8, 3,
            15, 13, 7, 0, 3, 10, 17, 14, 8, 15, 11, 17, 14, 9, 8, 3, 5, 0, 8, 14, 4, 4, 11, 9, 0,
            11, 3, 6, 8, 7, 12, 9, 19, 7, 4, 15, 0, 5, 7, 5, 2, 7, 8, 8, 3, 9, 7, 14, 3, 11, 5, 9,
            19, 6, 2, 11, 7, 6, 9, 7, 15, 11, 4, 4, 8, 4, 11, 3, 3, 5, 8, 9, 17, 5, 4, 3, 5, 0, 9,
            17, 16, 14, 4, 11, 8, 11, 10, 11, 11, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 11,
            14, 16, 6, 17, 7, 9, 7, 5, 2, 15, 9, 2, 2, 9, 2, 10, 6, 8, 5, 10, 3, 5, 7, 16, 16, 9,
            9, 15, 7, 5, 4, 9, 14, 15, 2, 7, 3, 16, 11, 9, 8, 8, 4, 9, 2, 15, 4, 2, 7, 17, 9, 17,
            13, 2, 3, 15, 9, 19, 17, 17, 11, 5, 7, 9, 2, 19, 7, 16, 5]]

# Convert the list to a NumPy array
original_3 = np.array(original_3)
print(np.allclose(tokenizer.encode(seq_3), original_3))

True


In [7]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(tokenizer.encode(seq_1))

Ubar 0.5067641594786917


(0.08016515144748818, 0.9229639049741365)

In [8]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(tokenizer.encode(seq_2))

Ubar 0.5747777000532985


(9.86334274150776, 5.204807529922428e-05)

In [9]:
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=30)
detector.detect(tokenizer.encode(seq_3))

Ubar 0.6532691784403173


(42.391178039335486, 3.8881704573624797e-19)