In [None]:
%%capture
!pip3 install scikit-image
!pip3 install diffusers
!pip3 install  spacy
!pip3 install ftfy
!pip3 install transformers
!pip3 install numba
!pip3 install nltk
!pip3 install emoji
!pip3 install inflect
!pip3 install joblib
!pip3 install accelerate

In [None]:
import os
import shutil
import json
import torch
from diffusers import StableDiffusionPipeline
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import re
import math
import random
import emoji
import difflib
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline
from torch import autocast
from PIL import Image
from matplotlib import pyplot as plt

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [None]:
def get_embeddings(model, input_ids):
  return model.text_model.embeddings(input_ids)

def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0):
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
    inverted_mask = 1.0 - expanded_mask
    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

def get_grad(model, adv_input_tokens, target, avoid, suff_start, suff_end, object_key_mask=None):
  embedding_weight = model.get_input_embeddings().weight

  one_hot = torch.zeros(
    adv_input_tokens[suff_start:suff_end].shape[0],
    embedding_weight.shape[0],
    device=model.device,
    dtype=embedding_weight.dtype
  )

  one_hot.scatter_(
    1,
    adv_input_tokens[suff_start:suff_end].unsqueeze(1),
    torch.ones(one_hot.shape[0], 1, device=model.device, dtype=embedding_weight.dtype)
  )
  one_hot.requires_grad_()

  suffix_embeds = (one_hot @ embedding_weight).unsqueeze(0)
  embeds = get_embeddings(model, adv_input_tokens.unsqueeze(0)).detach()
  new_embeds = torch.cat(
      [
          embeds[:,:suff_start,:],
          suffix_embeds,
          embeds[:,suff_end:,:]
      ],
      dim=1)

  hidden_states = new_embeds
  attention_mask = None

  input_ids = adv_input_tokens.unsqueeze(0)
  input_shape = input_ids.shape

  causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)

  if attention_mask is not None:
      attention_mask = _expand_mask(attention_mask, hidden_states.dtype)

  encoder_outputs = model.text_model.encoder(
      inputs_embeds=hidden_states,
      attention_mask=attention_mask,
      causal_attention_mask=causal_attention_mask,
      return_dict=True,
  )


  last_hidden_state = encoder_outputs[0]
  last_hidden_state = model.text_model.final_layer_norm(last_hidden_state)

  cosine_sim1 = F.cosine_similarity(last_hidden_state.view(-1), target, dim=0)
  cosine_sim2 = F.cosine_similarity(last_hidden_state.view(-1), avoid, dim=0)

  loss1 = 1 - cosine_sim1
  loss2 = 1 - cosine_sim2

  loss = loss1 - loss2
  loss.backward()

  return one_hot.grad.clone()

def get_allowed_characters():
    allowed_characters=['·','~','!','@','#','$','%','^','&','*','(',')','=','-','*','+','.','<','>','?',',','\'',';',':','|','\\','/']
    for i in range(ord('A'),ord('Z')+1):
        allowed_characters.append(chr(i))
    for i in range(0,10):
        allowed_characters.append(str(i))
    return allowed_characters

def find_mismatches(list1, list2):
    mismatches = []
    for i, (elem1, elem2) in enumerate(zip(list1, list2)):
        if elem1 != elem2:
            mismatches.append((i, elem1, elem2))
    return mismatches

def check_encode_decode(tokenizer, tokens, MAX_LENGTH):
    text = tokenizer.decode(tokens,skip_special_tokens=True)
    new_tokens = tokenizer(text,return_tensors="pt", padding="max_length",max_length=MAX_LENGTH, truncation=True)["input_ids"][0]
    if tokens.tolist()==new_tokens.tolist():
        return True
    return False

def get_all_substrings(input_string):
  s=input_string
  start = s[0]
  word_list = sorted(s[i:j] for i, x in enumerate(s) for j in range(i + 1, len(s) + 1) if x in start)
  if len(word_list) > 1:
    filtered_list = [word for word in word_list if len(word) >= len(s)//2]
  else:
    filtered_list = word_list
  return filtered_list

def text_has_emoji(text):
  for character in text:
      if character in emoji.EMOJI_DATA:
          return True
  return False

def check_target_word(tokenizer, tokens, substrings, TARGET_WORD_RESTRICTION):
  text = tokenizer.decode(tokens,skip_special_tokens=True)
  words = text.split()
  if TARGET_WORD_RESTRICTION:
    matching_words = [word for word in words if word in substrings]
    contains_target = any(any(word in word_item for word in substrings) for word_item in words)
    return not (matching_words or contains_target) # or text_has_emoji(text))
  else:
    return True #not text_has_emoji(text)

def gradient_greedy_search(T,k,B,model, input_tokens, target, avoid,
                           suff_start, suff_end, allowed_char_idxes,
                           tried_adv_tokens, START_THRESH,
                           END_THRESH, SKIP_TRIED_SUFFIXES, tokenizer,
                           MAX_LENGTH, substrings, TARGET_WORD_RESTRICTION):
    max_sim = -np.infty #MAX
    adv_input_tokens = input_tokens.clone()

    cos_dim1 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    for t in range(T):
      g = get_grad(model, adv_input_tokens, target, avoid, suff_start, suff_end)

      if allowed_char_idxes is not None:
          mask = torch.ones_like(g, dtype=torch.bool)
          mask[:, allowed_char_idxes] = 0
          g[mask] = np.infty

      indices = (-g).topk(k).indices

      new_adv_tokens = []
      b = 0
      while b<B:
        adv_input_tokens_b = adv_input_tokens.clone()
        for i in range(suff_start, suff_end):
            if random.random() < max(END_THRESH, START_THRESH - t/T):
              adv_input_tokens_b[i] = indices[i-suff_start][torch.randint(k, (1,))]

        if SKIP_TRIED_SUFFIXES and adv_input_tokens_b in tried_adv_tokens:
            continue

        tried_adv_tokens.add(adv_input_tokens_b)
        new_adv_tokens.append(adv_input_tokens_b)
        b+=1

      new_adv_tokens = torch.stack(new_adv_tokens)

      output_embeds = None
      with torch.no_grad():
        output_embeds = model(new_adv_tokens)[0].view(B,-1)

      target_expanded = target.unsqueeze(0).repeat(B,1)
      avoid_expanded = avoid.unsqueeze(0).repeat(B,1)


      cos_sim1 = cos_dim1(target_expanded, output_embeds)
      cos_sim2 = cos_dim1(avoid_expanded, output_embeds)

      cos_sim = cos_sim1 - cos_sim2
      max_idx = torch.argmax(cos_sim,dim=0)

      if max_sim < cos_sim[max_idx] and check_encode_decode(tokenizer, new_adv_tokens[max_idx], MAX_LENGTH) and check_target_word(tokenizer, new_adv_tokens[max_idx], substrings, TARGET_WORD_RESTRICTION):
        # print("t:",t, "cos:", cos_sim[max_idx].item(), "adv_prompt:", tokenizer.decode(new_adv_tokens[max_idx], clean_up_tokenization_spaces =True,skip_special_tokens=True))
        print("t:",t, "cos:", "{:.3f}".format(cos_sim[max_idx].item()), "adv_prompt:", tokenizer.decode(new_adv_tokens[max_idx], clean_up_tokenization_spaces =True,skip_special_tokens=True))
        max_sim = cos_sim[max_idx]
        adv_input_tokens = new_adv_tokens[max_idx].clone()

    print("Done: ", tokenizer.decode(adv_input_tokens, clean_up_tokenization_spaces =True,skip_special_tokens=True))
    return adv_input_tokens


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


In [None]:
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda"
START_THRESH = 1
END_THRESH = 0.20
NUM_ADV_TOKENS = 10
NUM_GENERATIONS = 5
NUM_IMAGES = 7
T = 100
k = 256
B_MAX = 512
CONSTRAINED = False
TARGET_WORD_RESTRICTION = True
SKIP_TRIED_SUFFIXES  = True
cos = torch.nn.CosineSimilarity(dim=0, eps=1e-06)
cos_dim1 = torch.nn.CosineSimilarity(dim=1, eps=1e-06)

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None)
pipe = pipe.to(device)
pipe.enable_attention_slicing()


tokenizer = pipe.tokenizer
pipe.text_encoder = pipe.text_encoder.float()
model = pipe.text_encoder
tokenizer.pad_token = tokenizer.eos_token
MAX_LENGTH = tokenizer.model_max_length

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [None]:
def set_seed(seed):
  gen = torch.Generator(device=device)
  random.seed(seed)
  np.random.seed(seed)
  torch.cuda.manual_seed(seed)
  torch.manual_seed(seed)

  return gen.manual_seed(seed)

In [None]:
allowed_char_idxes = None

if CONSTRAINED:
    allowed_chars = get_allowed_characters()
    allowed_char_idxes = tokenizer(allowed_chars, return_tensors="pt")["input_ids"][:,1]

input_target_data = []

with open("/content/data_for_attacks.jsonl", "r") as f:
    for line in f:
        _json = json.loads(line)
        input_target_data.append((_json["input_text"],_json["target_text"]))

In [None]:
def find_substrings_of_target_word(input, target, mode):
  words_A = re.split(r'\s+|\.', input)
  words_B = re.split(r'\s+|\.', target)
  differ = difflib.ndiff(words_A, words_B)
  output = [word for word in differ if word.startswith('+ ')][0][2:]

  if mode is not None:
    subs = get_all_substrings(output)
    return subs
  else:
    return [output]

In [None]:
all_adv_prompts = []

for input_text, target_text in input_target_data:

    substrings = find_substrings_of_target_word(input_text, target_text, None)   # mode can be None or any string for all substrings

    input_tokens = tokenizer(input_text, return_tensors="pt", padding="max_length",max_length=MAX_LENGTH, truncation=True)["input_ids"].to(device)
    target_tokens = tokenizer(target_text, return_tensors="pt", padding="max_length",max_length=MAX_LENGTH, truncation=True)["input_ids"].to(device)

    target = None
    with torch.no_grad():
        target = model(target_tokens)[0][0].view(-1)

    avoid = None
    with torch.no_grad():
        avoid = model(input_tokens)[0][0].view(-1)

    input_tokens = input_tokens[0]

    num_adv_tokens = NUM_ADV_TOKENS
    suff_start = len(tokenizer(input_text).input_ids)-1
    suff_end = suff_start + num_adv_tokens

    print("Input:", input_text)
    print("Target:", target_text)
    print("Cosine Similarity between Input and Target", cos(avoid,target).item())

    tried_adv_tokens = set()

    if allowed_char_idxes is not None:
        k = len(allowed_char_idxes)
    B = min(k*num_adv_tokens, B_MAX)
    all_adv_input_tokens = [gradient_greedy_search(T, k, B, model, input_tokens, target,
                                                   avoid, suff_start, suff_end,
                                                   allowed_char_idxes,
                                                   tried_adv_tokens,
                                                   START_THRESH,
                                                   END_THRESH,
                                                   SKIP_TRIED_SUFFIXES,
                                                   tokenizer,
                                                   MAX_LENGTH,
                                                   substrings,
                                                   TARGET_WORD_RESTRICTION
                                                   ) for i in range(NUM_GENERATIONS)]

    print("Number of inputs tried: ",len(tried_adv_tokens))

    if allowed_char_idxes is not None:
        print("Percentage of Search Space Explored:", 100*len(tried_adv_tokens)/(len(allowed_char_idxes)**num_adv_tokens),"%")

    adv_prompts = []
    for adv_input_tokens in all_adv_input_tokens:
        final_adv_text = tokenizer.decode(adv_input_tokens, clean_up_tokenization_spaces =True,skip_special_tokens=True)
        print(final_adv_text)
        adv_prompts.append(final_adv_text)
    all_adv_prompts.append(adv_prompts)

    #dir_path = input_text.replace(" ","_").replace(".","_")+"___"+target_text.replace(" ","_").replace(".","_")+"/"

Input: three cats standing on a table.
Target: one cats standing on a table.
Cosine Similarity between Input and Target 0.9167361259460449
t: 0 cos: -0.004 adv_prompt: three cats standing on a table. sensei pave demo jaejoong wednesdaymotivation scrapped planners fgo samanthaprabhu '…
t: 1 cos: -0.001 adv_prompt: three cats standing on a table. converter neurology stafford gruesome bikini hokies pez vette londonmarathon jetty
t: 10 cos: 0.006 adv_prompt: three cats standing on a table. jeff closer polly armedbison sickzuaimee londonmarathon seared
t: 11 cos: 0.009 adv_prompt: three cats standing on a table. lanterns tony deeplearning picturebison sherman conditioner aimee londonmarathon dillon
t: 19 cos: 0.019 adv_prompt: three cats standing on a table. lanterns tony contestalert appetforecast jumping nagaaimee raped dillon
t: 31 cos: 0.022 adv_prompt: three cats standing on a table. lanterns logy dreamers appetforecast carnation nagaaimee raped enfield
t: 35 cos: 0.023 adv_prompt: thr

In [None]:
dir_path = "/content/case_study_cup/"

if os.path.exists(dir_path) and os.path.isdir(dir_path):
  try:
      shutil.rmtree(dir_path)
      print(f"Directory '{dir_path}' has been removed")
  except OSError as e:
      print(f"Error: {dir_path} : {e.strerror}")


Directory '/content/case_study_cup/' has been removed


In [None]:
os.makedirs(dir_path)

prmpt_list=[]
for prmpt in adv_prompts:
  prmpt_list.append(emoji.demojize(prmpt))


prompts_dict = {
    "input_text":input_text,
    "target_text":target_text,
    "num_tokens":NUM_ADV_TOKENS,
    "adv_prompts":prmpt_list
}

with open(dir_path+"prompts.json", 'w') as file:
    json.dump(prompts_dict, file, indent=4)

# for i,image in enumerate(images):
#     image.save(dir_path+str(i)+".jpg")

# grid.save(dir_path+"grid.jpg")

# all_grid = image_grid(all_images, rows=len(input_target_data)*NUM_GENERATIONS, cols=NUM_IMAGES)
# all_grid = all_grid.resize((all_grid.width//8,all_grid.height//8))

# all_grid.save("all_grid.jpg")

In [None]:
cnt = 1
for txt in prmpt_list:
  prompt = emoji.emojize(txt)
  print(prompt)
  if prompt == "a red cup on a table.":
    continue

  images, seeds = [], []
  for i in range(NUM_IMAGES):
    with autocast('cuda'):
      seed_val = int(random.random()*1000)
      generator = set_seed(seed_val)
      seeds.append(seed_val)
      images += pipe(prompt, generator=generator, num_inference_steps=30).images
  grid = image_grid(images, rows=1, cols=NUM_IMAGES)

  new_dir_path = dir_path+"prompt_"+str(cnt)+"/"
  os.makedirs(new_dir_path)

  for i,image in enumerate(images):
    image.save(new_dir_path+str(seeds[i])+".jpg")

  grid.save(new_dir_path+"grid_"+str(cnt)+".jpg")
  cnt = cnt + 1

three cats standing on a table. lanterns ulation bronson appetforecast precipitation rugged quitting knuckpompeo


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

three cats standing on a table. upgrading msl niel umd island noisy pregnant worcestershirehour ethereal inger


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

three cats standing on a table. ⃣ 7 acquainted dandelion coo sneaker demonstration twd photograph wondercon


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

three cats standing on a table. supervisor youtuber course seaweed yellow pontifex fancafe skrillex failure jasmin


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

three cats standing on a table. singersongwriter dreamliner pict 💘 hamiltonincanfies exol inexpensive proclamation


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
# generate image (note the `callback` and `callback_steps` argument)
seed = 123
prompt = "A <|endoftext|> cup on a table."

with autocast('cuda'):
  generator = set_seed(seed)
  images = pipe(prompt, generator=generator, num_inference_steps=30).images

  plt.imshow(images[0])
  plt.axis('off')
  plt.show()  # display it