In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2'

In [2]:
import ecco
import torch
import pickle

lm = ecco.from_pretrained('gpt2-large', verbose=False)

with open('../../DExperts/expert_tm/cache/experts_10k.pkl', 'rb') as f:
    experts = pickle.load(f)

toxic_words = set(experts['lm_head']['experts'].tolist())

In [16]:
def get_toxic_generation(text, toxic_words, lm):
    locations = []
    word_debug = []
    while not locations:
        output = lm.generate(text, generate=20, do_sample=True)
        for i, word in enumerate(output.token_ids[0].tolist()):
            if word in toxic_words:
                locations.append(i - output.n_input_tokens)
                word_debug.append(word)

    data = output.saliency(style="detailed", quiet=False)
    
    print('known toxic words:', [output.tokenizer.decode(x) for x in word_debug])
    
    return locations, data, output

def get_trigger(locations, data, output):
    trigger = {}
    while locations:
        loc = locations.pop(0)
        value, indices = torch.tensor(data['attributions'][loc]).topk(3)
        for val, ind in zip(value.tolist(), indices.tolist()):
            if ind in trigger:
                trigger[output.token_ids[0][ind].item()] += val
            else:
                trigger[output.token_ids[0][ind].item()] = val
            if ind >= output.n_input_tokens:
                locations.append(ind - output.n_input_tokens)
    
    return trigger

def remove_overlap(trigger, toxic_words):
    keys_to_remove = set(trigger.keys()) & toxic_words

    for key in keys_to_remove:
        del trigger[key]

    return [k for k, _ in sorted(trigger.items(), key=lambda x: x[1], reverse=True)]
    

In [20]:
text= "When she rejected his advance, he grabbed"

locations, data, output = get_toxic_generation(text, toxic_words, lm)
trigger = get_trigger(locations, data, output)
trigger_words = remove_overlap(trigger, toxic_words)

print('new toxic trigger words:', [output.tokenizer.decode(x) for x in trigger_words][:3])

<IPython.core.display.Javascript object>

known toxic words: [' neck', ' choking', ' stabbing']
new toxic trigger words: [' advance', ' grabbed', ' repeatedly']
