In [1]:
import numpy as np

In [2]:
from typing import Union
import numpy as np

from abc import ABC, abstractmethod


class AbstractWatermarkCode(ABC):
    @classmethod
    @abstractmethod
    def from_random(cls,
                    rng: Union[np.random.Generator, list[np.random.Generator]],
                    vocab_size: int):
        pass


class AbstractReweight(ABC):
    watermark_code_type: type[AbstractWatermarkCode]

    @abstractmethod
    def reweight(self,
                 code: AbstractWatermarkCode,
                 p: np.ndarray) -> np.ndarray:
        pass


def get_gumbel_variables(rng: np.random.Generator,
                         vocab_size: int):
    u = rng.random((vocab_size,))  # ~ Unifom(0, 1)
    e = -np.log(u)  # ~ Exp(1)
    g = -np.log(e)  # ~ Gumbel(0, 1)
    return u, e, g


class DeltaGumbel_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, g: np.ndarray):
        self.g = g

    @classmethod
    def from_random(
        cls,
        rng: Union[np.random.Generator, list[np.random.Generator]],
        vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            g = np.stack(
                [get_gumbel_variables(rng[i], vocab_size)[2] for i in range(batch_size)]
            )
        else:
            g = get_gumbel_variables(rng, vocab_size)[2]

        return cls(g)


class DeltaGumbel_Reweight(AbstractReweight):
    watermark_code_type = DeltaGumbel_WatermarkCode

    def __repr__(self):
        return f"DeltaGumbel_Reweight()"

    def reweight(
        self, code: AbstractWatermarkCode, p_logits: np.ndarray
    ) -> np.ndarray:

        assert isinstance(code, DeltaGumbel_WatermarkCode)

        index = np.argmax(p_logits + code.g, axis=-1)

        mask = np.arange(p_logits.shape[-1]) == index[..., None]

        modified_logits = np.where(
            mask,
            np.full_like(p_logits, 0),
            np.full_like(p_logits, float("-inf")),
        )
        return modified_logits

    def get_la_score(self, code):
        """likelihood agnostic score"""
        return np.array(np.log(2)) - np.exp(-code.g)


class WatermarkProcessor:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            ignore_history=False,
    ):

        self.private_key = private_key
        self.reweight = reweight
        self.cc_length = context_code_length
        self.ignore_history = ignore_history
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def reset_history(self):
        self.cc_history = set()

    def _get_codes(self, context, current_pos):
        batch_size = len(context)

        if current_pos == 0:
            context_codes = [
                context[i][-self.cc_length:].tobytes() for i in range(batch_size)
            ]

        else:
            cc_pos = current_pos - self.cc_length
            
            if cc_pos < 0:
                cc_pos = 0
            else:
                cc_pos = cc_pos
            
            context_codes = [
                context[i][cc_pos:current_pos][~np.isnan(context[i][cc_pos:current_pos])].tobytes() for i in range(batch_size)
            ]

        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def __call__(self, 
                 mode: str = 'normal',
                 context: np.ndarray = None,
                 logits: np.ndarray = None,
                 current_pos: int = None):

        if mode == 'normal':
            current_pos = 0
        elif mode == 'order_agnoistic':
            current_pos = current_pos
        else:
            raise NotImplementedError('Current watermark processor does not support this mode')

        mask, seeds = self._get_codes(context, current_pos=current_pos)

        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]

        mask = np.array(mask)

        watermark_code = self.reweight.watermark_code_type.from_random(
            rng, logits.shape[1]
        )

        reweighted_logits = self.reweight.reweight(watermark_code, logits)

        if self.ignore_history:
            return reweighted_logits
        else:
            return np.where(mask[:, None], logits, reweighted_logits)


In [3]:
input_id = np.array([[ 4, np.nan, np.nan, 15, 20, 18],
                     [ 5, np.nan, np.nan,  2,  0, 15],
                     [ 5, np.nan, np.nan, 21, 10, 17],
                     [ 6, np.nan, np.nan, 10, 15, 8],
                     [ 6, np.nan, np.nan, 9, 25, 9]])
logits = np.random.random((5, 20))

In [4]:
watermark_processor = WatermarkProcessor(private_key=b'private',
                                         reweight=DeltaGumbel_Reweight(),
                                         context_code_length=3,)

In [5]:
watermark_processor('order_agnoistic', input_id, logits, 1) # add watermark, two samples are not added

array([[      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf, 0.        ,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [      -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf, 0.        ,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,       -inf,       -inf],
       [0.93234704, 0.07242028, 0.51555684, 0.60030021, 0.74821633,
        0.90119845, 0.1461733 , 0.6156435 , 0.17875809, 0.57320531,
        0.40367377, 0.65405878, 0.40946104, 0.38813083, 0.44697785,
        0.75028194, 0.4269854 , 0.07208373, 0.05735176, 0.18283482],
       [      -inf,       -inf,       -inf,       -inf, 0.        ,
              -inf,       -inf,       -inf,       -inf,       -inf,
              -inf,       -inf,       -inf,  

In [6]:
watermark_processor('order_agnoistic', input_id, logits, 2) # no watermark should be added

array([[0.04635725, 0.59822131, 0.83203407, 0.60843493, 0.77099384,
        0.92256869, 0.88273538, 0.26352314, 0.91150715, 0.62588056,
        0.53473634, 0.93476759, 0.90233474, 0.10348885, 0.02519254,
        0.46255228, 0.58711394, 0.76478158, 0.12957853, 0.88370082],
       [0.54289593, 0.35740699, 0.54748061, 0.28296673, 0.52860427,
        0.56110425, 0.09788825, 0.94614535, 0.53121644, 0.95074177,
        0.0382386 , 0.26421784, 0.86080055, 0.26335671, 0.34633426,
        0.89990684, 0.18708494, 0.99547631, 0.25116478, 0.02393811],
       [0.93234704, 0.07242028, 0.51555684, 0.60030021, 0.74821633,
        0.90119845, 0.1461733 , 0.6156435 , 0.17875809, 0.57320531,
        0.40367377, 0.65405878, 0.40946104, 0.38813083, 0.44697785,
        0.75028194, 0.4269854 , 0.07208373, 0.05735176, 0.18283482],
       [0.55259268, 0.31150994, 0.14288953, 0.15248225, 0.96405439,
        0.17515568, 0.91402861, 0.91575409, 0.68602625, 0.21078319,
        0.74252832, 0.85478058, 0.03305952, 0

In [7]:
watermark_processor('order_agnoistic', input_id, logits, 3) # no watermark should be added

array([[0.04635725, 0.59822131, 0.83203407, 0.60843493, 0.77099384,
        0.92256869, 0.88273538, 0.26352314, 0.91150715, 0.62588056,
        0.53473634, 0.93476759, 0.90233474, 0.10348885, 0.02519254,
        0.46255228, 0.58711394, 0.76478158, 0.12957853, 0.88370082],
       [0.54289593, 0.35740699, 0.54748061, 0.28296673, 0.52860427,
        0.56110425, 0.09788825, 0.94614535, 0.53121644, 0.95074177,
        0.0382386 , 0.26421784, 0.86080055, 0.26335671, 0.34633426,
        0.89990684, 0.18708494, 0.99547631, 0.25116478, 0.02393811],
       [0.93234704, 0.07242028, 0.51555684, 0.60030021, 0.74821633,
        0.90119845, 0.1461733 , 0.6156435 , 0.17875809, 0.57320531,
        0.40367377, 0.65405878, 0.40946104, 0.38813083, 0.44697785,
        0.75028194, 0.4269854 , 0.07208373, 0.05735176, 0.18283482],
       [0.55259268, 0.31150994, 0.14288953, 0.15248225, 0.96405439,
        0.17515568, 0.91402861, 0.91575409, 0.68602625, 0.21078319,
        0.74252832, 0.85478058, 0.03305952, 0