In [51]:
import numpy as np
from abc import ABC, abstractmethod
from typing import Union

class AbstractWatermarkCode(ABC):
    @classmethod
    @abstractmethod
    def from_random(cls,
                    rng: Union[np.random.Generator, list[np.random.Generator]],
                    vocab_size: int):
        pass

class AbstractReweight(ABC):
    watermark_code_type: type[AbstractWatermarkCode]

    @abstractmethod
    def reweight(self,
                 code: AbstractWatermarkCode,
                 p: np.ndarray) -> np.ndarray:
        pass

    @abstractmethod
    def get_la_score(self,
                     code: AbstractWatermarkCode) -> np.ndarray:
        pass

def get_gumbel_variables(rng: np.random.Generator,
                         vocab_size: int):
    u = rng.random((vocab_size,))  # ~ Unifom(0, 1)
    e = -np.log(u)  # ~ Exp(1)
    g = -np.log(e)  # ~ Gumbel(0, 1)
    return u, e, g


class DeltaGumbel_WatermarkCode(AbstractWatermarkCode):
    def __init__(self, g: np.ndarray):
        self.g = g

    @classmethod
    def from_random(
            cls,
            rng: Union[np.random.Generator, list[np.random.Generator]],
            vocab_size: int,
    ):
        if isinstance(rng, list):
            batch_size = len(rng)
            g = np.stack(
                [get_gumbel_variables(rng[i], vocab_size)[2] for i in range(batch_size)]
            )
        else:
            g = get_gumbel_variables(rng, vocab_size)[2]

        return cls(g)


class DeltaGumbel_Reweight(AbstractReweight):
    watermark_code_type = DeltaGumbel_WatermarkCode

    def __repr__(self):
        return f"DeltaGumbel_Reweight()"

    def reweight(
            self, code: AbstractWatermarkCode, p_logits: np.ndarray
    ) -> np.ndarray:
        assert isinstance(code, DeltaGumbel_WatermarkCode)

        index = np.argmax(p_logits + code.g, axis=-1)

        mask = np.arange(p_logits.shape[-1]) == index[..., None]

        modified_logits = np.where(
            mask,
            np.full_like(p_logits, 0),
            np.full_like(p_logits, float("-inf")),
        )
        return modified_logits

    def get_la_score(self, code):
        """likelihood agnostic score"""
        return np.array(np.log(2)) - np.exp(-code.g)

class WatermarkDetector:
    def __init__(
            self,
            private_key: any,
            reweight: AbstractReweight,
            context_code_length: int,
            vocab_size: int = 20,
            ignore_history: bool = False
    ):
        self.private_key = private_key
        self.cc_length = context_code_length
        self.vocab_size = vocab_size
        self.reweight = reweight
        self.ignore_history = ignore_history
        self.cc_history = set()

    def reset_history(self):
        self.cc_history = set()

    def get_rng_seed(self, context_code: any) -> any:
        if not self.ignore_history:
            self.cc_history.add(context_code)
        import hashlib

        m = hashlib.sha256()
        m.update(context_code)
        m.update(self.private_key)
        full_hash = m.digest()
        seed = int.from_bytes(full_hash, "big") % (2 ** 32 - 1)
        return seed

    def _get_codes(self, context):
        batch_size = len(context)

        context_codes = [
            context[i][-self.cc_length:].tobytes() for i in range(batch_size)
        ]
        # print(context[0][-self.cc_length:])
        mask, seeds = zip(
            *[
                (context_code in self.cc_history, self.get_rng_seed(context_code))
                for context_code in context_codes
            ]
        )
        return mask, seeds

    def detect(self,
               input_ids: np.ndarray):
        """
        :param input_ids: sequences after tokenization
        :return: scores, a higher score means a seq is likely to be watermarked
        """

        scores = []
        for i in range(input_ids.shape[1]):
            score = self.get_la_score(input_ids[:, :i], input_ids[:, i], self.vocab_size)
            
            ti = score-np.log(2)
            scores.append(ti)
        assert np.all(ti<=0), ti
        tis = np.array(scores)
        uis = np.exp(tis)
        Ubar=np.mean(uis)
        print("Ubar", Ubar)

        if np.mean(uis)==0:
            final_score=0
            final_p_value=1
            return final_score, final_p_value
        avgS = lambda Ubar, lamb: Ubar*lamb+np.log(lamb/np.expm1(lamb))
        import scipy.optimize
        sol=scipy.optimize.minimize(lambda l:-avgS(Ubar, l), 0.5,method='SLSQP', bounds=[(0,10)])
        final_score=-sol.fun*input_ids.shape[1]
        final_p_value = np.exp(-final_score)
        return final_score, final_p_value
        
        A = np.mean(tis)
        final_score = (-1-A-np.log(-A))*input_ids.shape[1]
        final_p_value = np.exp(-final_score)

        print("optimal score is", final_score)
        print("optimal p-value is", final_p_value)
        return final_score, final_p_value

    def get_la_score(
            self,
            input_ids: np.ndarray,
            labels: np.ndarray,
            vocab_size: int,
    ) -> np.ndarray:
        assert "get_la_score" in dir(
            self.reweight
        ), "Reweight does not support likelihood agnostic detection"
        mask, seeds = self._get_codes(input_ids)
        rng = [
            np.random.default_rng(seed) for seed in seeds
        ]
        
        mask = np.array(mask)
        watermark_code = self.reweight.watermark_code_type.from_random(rng, vocab_size)
        all_scores = self.reweight.get_la_score(watermark_code)
        scores = all_scores[np.arange(all_scores.shape[0]), labels]
        
        scores = np.logical_not(mask) * scores
        return scores

In [52]:
amino_acids = 'ARNDCQEGHILKMFPSTWYVX-'
# Create a dictionary mapping each amino acid to its index
aa_to_index = {aa: idx for idx, aa in enumerate(amino_acids)}
class AminoAcidTokenizer:
    def __init__(self, aa_to_index):
        self.aa_to_index = aa_to_index
        self.index_to_aa = {idx: aa for aa, idx in aa_to_index.items()}
        
    def encode(self, sequence):
        # Encode a sequence of amino acids to indices
        return np.array([self.aa_to_index.get(aa, self.aa_to_index.get('X')) for aa in sequence]).reshape(1,-1)
        
    def decode(self, indices):
        # Decode a list of indices back into an amino acid sequence
        return ''.join(self.index_to_aa.get(idx, 'X') for idx in indices)
tokenizer = AminoAcidTokenizer(aa_to_index)

In [53]:
# sequence = "STLETTPKGSEGVNVTKLECSLPSGVKTAQWSGPRSAPNGVFVGSDASISFDPGKTVTLMSAASGTGYKCLAMLKIQAYFNETADLPCQFANSQNQSLSELVVFWQDQENLVLNEVYLGKEKFDSVHSKYMGRTSFDSDSWTLRLHNLQIKDKGLYQCIIHHKKPTGMIRIHQMNSELSVLA"
# sequence = "KIQAYFNETADLPCQFANSQNQSLSELVVFWQDQENLVLNEVYLGSEKFDSVHSKYMGRTSFDSDSWTLRLHNLQIKDKGLYQCIIHHKKPTGMIRIHQMNSELSVLA"
# sequence = "STLEPTRKGSEGVNVTKLECSLPSGVKSAQWSGPGSGPQGVFVIGDAXISSDPGKTVTLKSAASGTGYKCLAML"
sequence = "NMSKEEIKEIMEKIDKIEKELEKISKGLTKEEREELLERIRKEINELFEISGKDFLPKEELEKLQKTLEELREGRNPDEKIEKFLEVLKKIRERLERALN"
detector = WatermarkDetector(b"private key",
                             DeltaGumbel_Reweight(),
                             context_code_length=5,
                             vocab_size=21)

print(detector.detect(tokenizer.encode(sequence)))

Ubar 0.4923574826583142
(array([1.36656653e-15]), array([1.]))


In [5]:
import re

sample_re = re.compile(r'Generating sample (\d+) \.\.\.')
entropy_re = re.compile(r'WaterMarked Entropy: ([\d.]+)')
timestep_re = re.compile(r'TIMESTEP \[(\d+)/(\d+)\].*current PLDDT: ([\d.]+)')

# List to hold extracted data
extracted_data = []
input_file_path = './secondary_structure_wm.log'
output_file_path = './secondary_structure_wm.csv'

with open(input_file_path, 'r') as file:
    sample_number = None
    entropy = None
    sequence = None
    expect_sequence = False
    for line in file:
        line = line.strip()
        if not line:  # Skip empty lines
            continue
        # Match each line against the regular expressions
        sample_match = sample_re.search(line)
        entropy_match = entropy_re.search(line)
        timestep_match = timestep_re.search(line)
        
        detector = WatermarkDetector(b"privatee key",
                                     DeltaGumbel_Reweight(),
                                     context_code_length=5,
                                     vocab_size=21)
        
        if sample_match:
            sample_number = sample_match.group(1)
        elif entropy_match:
            entropy = entropy_match.group(1)
            expect_sequence = True
        elif timestep_match:
            current_iteration, total_iterations, current_plddt = timestep_match.groups()
            modified_iteration = int(total_iterations) + 1 - int(current_iteration)
            extracted_data.append({
                'sample': sample_number,
                'iteration': int(modified_iteration),
                'entropy': float(entropy),
                'pLDDT': float(current_plddt),
                'watermark': ((detector.detect(tokenizer.encode(sequence), method='optimized')[0])/len(sequence))[0]
            })
            expect_sequence = False
        elif expect_sequence:
            sequence = line
            expect_sequence = False

# Create a DataFrame from the extracted data
import pandas as pd
df = pd.DataFrame(extracted_data)
print(df)
# df.to_csv(output_file_path)

  avgS = lambda Ubar, lamb: Ubar*lamb+np.log(lamb/np.expm1(lamb))


       sample  iteration   entropy   pLDDT     watermark
0      000000          1  2.722656  0.2717  2.080855e-04
1      000000          2  2.740234  0.3022  6.619150e-07
2      000000          3  2.490234  0.4558  5.477644e-05
3      000000          4  2.410156  0.4575  3.575793e-03
4      000000          5  2.314453  0.4785 -7.802704e-04
...       ...        ...       ...     ...           ...
12495  000499         21  0.508301  0.9092 -4.802248e-04
12496  000499         22  0.834961  0.9062 -4.171029e-03
12497  000499         23  0.807617  0.9058 -4.171029e-03
12498  000499         24  1.127930  0.9014 -4.171029e-03
12499  000499         25  0.757812  0.9043 -4.171029e-03

[12500 rows x 5 columns]
