# Setup
- Connect to the T4 runtime if possible when using Colaboratory
- This notebook uses google/gemma-2b as the llm.
- For use of this model, you need to log-in to HugginFace and send a request to access to this model:
  - https://huggingface.co/google/gemma-2b
- Then, you need to get a access token from HuggingFace (if you don't have)
  - Click Profile Icon on the Right Top > Settings > Access Token > + Create New Token

In [None]:
from transformers import GenerationConfig
import torch
import regex as re
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
import random


In [None]:
# Login to HuggingFace with the access token
!huggingface-cli login

In [None]:
# Set random seed: https://gist.github.com/ihoromi4/b681a9088f348942b01711f251e5f964
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)

# but this not work for dataset generation...

In [None]:
# Load model
def prepareModel(modelName):
    if 'gemma' in modelName or 'phi' in modelName or 'llm-jp' in modelName:
        from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
        tokenizer = AutoTokenizer.from_pretrained(modelName)
        model = AutoModelForCausalLM.from_pretrained(
            modelName,
            device_map="auto"
        )
    elif 'pythia' in modelName:
        from transformers import GPTNeoXForCausalLM, AutoTokenizer
        model = GPTNeoXForCausalLM.from_pretrained(modelName, device_map="auto")
        tokenizer = AutoTokenizer.from_pretrained(modelName)
    elif 'swallow' in modelName:
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM
        tokenizer = AutoTokenizer.from_pretrained(modelName)
        model = AutoModelForCausalLM.from_pretrained(
                    modelName, torch_dtype=torch.bfloat16,
                    low_cpu_mem_usage=True, device_map="auto")
    elif 'Llama-3' in modelName:
        from transformers import AutoTokenizer, AutoModelForCausalLM
        tokenizer = AutoTokenizer.from_pretrained(modelName)
        model = AutoModelForCausalLM.from_pretrained(
            modelName,
            device_map="auto"
        )

    return tokenizer, model

tokenizer, model = prepareModel('google/gemma-2b')
model.eval()

# Preparation

## Hook to obtain the intermediate activation values

In [None]:
# see (ja): https://gist.github.com/eminamitani/40df0b87f20aaa588cbcee6405f573ad
class OutputInspector:
  def __init__(self, targetLayer):
      self.layerOutputs = []
      self.featureHandle=targetLayer.register_forward_hook(self.feature)

  def feature(self,model, input, output):
      self.layerOutputs.append(output.detach().cpu())

  def release(self):
      self.featureHandle.remove()


## Repetition Detector

In [None]:
# Detect repetition if the input text includes the same n-gram k-times within r tokens at the same intervals.
def countOverlap(text, query):
    return len(re.findall(query, text, overlapped=True))

def getFirstAppearingIdx(text, query):
    return re.finditer(query, text).__next__().start(0)


def detectRepetition(line, n, r, k):
    # line should be list of idx
    SEP = ' '
    for i in range(0, len(line)-n+1):
        ngram = line[i:i+n]
        ngramStr = SEP.join(map(str, ngram))
        lineRange = line[max(0, i+n-r):i+n]
        lineRangeStr = SEP.join(map(str, lineRange))
        countRepInRange = countOverlap(lineRangeStr, ngramStr)

        if k <= countRepInRange:
            try:
                firstPosition = getFirstAppearingIdx(lineRangeStr, ngramStr)
                firstPosition = len(lineRangeStr[:firstPosition].split()) 
                firstPosition += max(0, i+n-r)

                lineRangeStr4second = SEP.join(map(str, line[firstPosition+1:i+n]))
                secondPosition = getFirstAppearingIdx(lineRangeStr4second, ngramStr)
                secondPosition = len(lineRangeStr4second[:secondPosition].split())
                secondPosition += firstPosition + 1

                lineRangeStr4third = SEP.join(map(str, line[secondPosition+1:i+n]))
                thirdPosition = getFirstAppearingIdx(lineRangeStr4third, ngramStr)
                thirdPosition = len(lineRangeStr4third[:thirdPosition].split())
                thirdPosition += secondPosition + 1

                if (secondPosition - firstPosition) == (thirdPosition - secondPosition):
                    return ngram, firstPosition, secondPosition, thirdPosition
            except:
                pass

    return [], -1, -1, -1

## Create dataset
If you want to use already prepared dataset, ignore this block.

In [None]:
def generateDataSample(model, tokenizer, n=10, r=100, k=3, minimumNonRepetitiveAffix=50, numRansomSampleToekns=10, numGreedyGenerationTokens=200):
    # set generation config
    generationConfigSample = GenerationConfig(max_new_tokens=numRansomSampleToekns, do_sample=True, eos_token_id=model.config.eos_token_id, temperature=1.0)
    generationConfigGreedy = GenerationConfig(max_new_tokens=numGreedyGenerationTokens, do_sample=False, eos_token_id=model.config.eos_token_id)

    # sampling first {numRansomSampleToekns} tokens
    initialInput = tokenizer('', return_tensors="pt").to(model.device)
    initialOutputs = model.generate(**initialInput, generation_config=generationConfigSample)
    print('init:', repr(tokenizer.decode(initialOutputs[0])))

    # greedy generation for next {numGreedyGenerationTokens} tokens
    additionalOutputs = model.generate(initialOutputs, generation_config=generationConfigGreedy)
    print(' add:', repr(tokenizer.decode(additionalOutputs[0])))

    ngram, firstPosition, secondPosition, thirdPosition = detectRepetition(additionalOutputs[0].tolist(), n, r, k)
    if ngram and minimumNonRepetitiveAffix < secondPosition:
        dumpLine = {
            'promptIds': initialOutputs[0].tolist(),
            'promptTokens': tokenizer.decode(initialOutputs[0]),
            'firstPosition': firstPosition,
            'secondPosition': secondPosition,
            'thirdPosition': thirdPosition,
            'ngramIds': ngram,
            'ngramTokens': tokenizer.decode(ngram),
            "generatedIds":additionalOutputs[0].tolist(),
            '50TokensBeforeRepeat': tokenizer.decode(additionalOutputs[0].tolist()[secondPosition-50:secondPosition]),
            '50TokensAfterRepeat': tokenizer.decode(additionalOutputs[0].tolist()[secondPosition:secondPosition+50]),
        }

        print('[Repetitive Sample]')
        for k, v in dumpLine.items():
          print(k, ':', v)

        return dumpLine
    return None

- Create 5 samples with repetition

In [None]:
sizeX = 5
repetitionDataset = []

while len(repetitionDataset) < sizeX:
  output = generateDataSample(model, tokenizer)
  if output:
    repetitionDataset.append(output)

# Find Neurons

In [None]:
def getActs(model, tokenizer, inputIds):
    model.eval()
    with torch.no_grad():
        if 'GemmaForCausalLM' in str(type(model)) or 'LlamaForCausalLM' in str(type(model)):
            actInspectors = [OutputInspector(layer.mlp.act_fn) for layer in model.model.layers]
        elif 'GPTNeoXForCausalLM' in str(type(model)):
            actInspectors = [OutputInspector(layer.mlp.act) for layer in model.gpt_neox.layers]
        elif 'PhiForCausalLM' in str(type(model)):
            actInspectors = [OutputInspector(layer.mlp.activation_fn) for layer in model.model.layers]
        else:
            print('model is not supported!')

        input_ids = torch.LongTensor([inputIds]).to(model.device)

        outputs = model(input_ids)

        for actInspector in actInspectors:
            actInspector.release()

        acts = torch.cat([torch.cat(actInspector.layerOutputs, dim=1) for actInspector in actInspectors], dim=0).transpose(0,1)

    return acts

def getAveragedActivations(data, model, tokenizer, maxRange=30, position='first'):
    repPosition = '%sPosition'%position

    normalActs = None
    repetiActs = None
    normalTotalPoints = 0
    repetiTotalPoints = 0

    for line in tqdm(data):
        inputIds = line['generatedIds']
        acts = getActs(model, tokenizer, inputIds)
        startingPoint = line[repPosition] - 1
        normalRange = list(range(max(0, startingPoint-maxRange), startingPoint))
        repetiRange = list(range(startingPoint, min(len(inputIds), startingPoint + maxRange)))

        normalTotalPoints += len(normalRange)
        repetiTotalPoints += len(repetiRange)

        na = acts[normalRange].sum(dim=0)
        ra = acts[repetiRange].sum(dim=0)

        if normalActs is None:
            normalActs = na
        else:
            normalActs += na

        if repetiActs is None:
            repetiActs = ra
        else:
            repetiActs += ra

    normalActs /= normalTotalPoints
    repetiActs /= repetiTotalPoints

    return normalActs, repetiActs

def findNeurons(data, model, tokenizer, maxRange=30, position='second'):
    normalActs, repetiActs = getAveragedActivations(data, model, tokenizer, maxRange, position)
    diff = repetiActs - normalActs
    ranks = torch.argsort(diff.flatten(), descending=True)
    width = diff.shape[1]
    sortedNeurons = []
    for r in ranks:
        neuron = (int(r // width), int(r % width))
        info = {
            'neuron': neuron,
            'normalActs':  normalActs[neuron].tolist(),
            'repetitionActs': repetiActs[neuron].tolist(),
            'diffs': diff[neuron].tolist()
        }
        sortedNeurons.append(info)
    return sortedNeurons

In [None]:
sortedNeurons = findNeurons(repetitionDataset, model, tokenizer, maxRange=30)

In [None]:
# check found neurons
for line in sortedNeurons[:5]:
  print(line)

# Location of neurons for each layer

In [None]:
maxLayerNum = max([neuron['neuron'][0] for neuron in sortedNeurons])

xs = [i/maxLayerNum for i in range(maxLayerNum+1)]
ys = [0] * (maxLayerNum + 1)

size = int(len(sortedNeurons)*0.005)
print(f'{size=}')

for neuron in sortedNeurons[:size]:
  if neuron['diffs'] < 0:
    continue

  layerPosition = neuron['neuron'][0]/maxLayerNum
  ys[xs.index(layerPosition)] += 1

plt.plot(xs, ys, c='b', linewidth=1.0, markersize=3.0)
plt.legend(fontsize=12)

plt.xlabel('Relative Layer Position')
plt.ylabel('Number of Neurons')

# Intervention

In [None]:
class Activator():
    def __init__(self, targetLayer, neuronIds, mode, lastN=0):
        self.neuronIds = neuronIds

        assert mode in ['last', 'all', 'lastN'], 'mode should be last or all'
        self.mode = mode
        self.lastN = lastN

        self.outputHandle = targetLayer.register_forward_hook(self.activate)

    def activate(self,model, input, output):
        if self.mode == 'last':
          output[0, -1, self.neuronIds] += 1
        elif self.mode == 'all':
          output[0, :, self.neuronIds] += 1
        elif self.mode == 'lastN':
          output[0, -self.lastN:, self.neuronIds] += 1
        else:
          print(f'{self.mode=} cannot be recognized')
          pass
        return output

    def release(self):
        self.outputHandle.remove()

class Deactivator():
    def __init__(self, targetLayer, neuronIds, mode, lastN=0):
        self.neuronIds = neuronIds

        assert mode in ['last', 'all', 'lastN'], 'mode should be last or all'
        self.mode = mode
        self.lastN = lastN

        self.outputHandle = targetLayer.register_forward_hook(self.deactivate)

    def deactivate(self,model, input, output):
        if self.mode == 'last':
          output[0, -1, self.neuronIds] *= 0
        elif self.mode == 'all':
          output[0, :, self.neuronIds] *= 0
        elif self.mode == 'lastN':
          output[0, -self.lastN:, self.neuronIds] *= 0
        else:
          print(f'{self.mode=} cannot be recognized')
          pass
        return output

    def release(self):
        self.outputHandle.remove()

def convertNeuronsToDict(neurons):
    layer2neurons = {}
    for fn in neurons:
        i, j = fn
        if i not in layer2neurons:
            layer2neurons[i] = []
        layer2neurons[i].append(j)
    return layer2neurons

def generateWithIntervention(model, tokenizer, initialInput, neurons, mode):
    model.eval()

    if mode=='activate':
        INTERV = Activator
    else:
        INTERV = Deactivator


    layer2neurons = convertNeuronsToDict(neurons)

    if 'GemmaForCausalLM' in str(type(model)) or 'LlamaForCausalLM' in str(type(model)):
        acts = [INTERV(layer.mlp.act_fn, layer2neurons[i], 'last') for i, layer in enumerate(model.model.layers) if i in layer2neurons]
    elif 'GPTNeoXForCausalLM' in str(type(model)):
        acts = [INTERV(layer.mlp.act, layer2neurons[i], 'last') for i, layer in enumerate(model.gpt_neox.layers) if i in layer2neurons]
    elif 'PhiForCausalLM' in str(type(model)):
        acts = [INTERV(layer.mlp.activation_fn, layer2neurons[i], 'last') for i, layer in enumerate(model.model.layers) if i in layer2neurons]
    else:
        print('model is not supported!')

    initialInput = torch.LongTensor([initialInput]).to(model.device)
    generationConfigGreedy = GenerationConfig(max_new_tokens=(10+200-50), do_sample=False, eos_token_id=model.config.eos_token_id)
    additionalOutputs = model.generate(initialInput, generation_config=generationConfigGreedy)

    for a in acts:
        a.release()

    ngram, firstPosition, secondPosition, thirdPosition = detectRepetition(additionalOutputs[0].tolist(), n=10, r=100, k=3)

    return additionalOutputs[0].tolist(), ngram, firstPosition, secondPosition, thirdPosition

In [None]:
def conductExpIntervention(texts, neurons, mode, selectMode, K, N=50):
    assert mode in ['activate', 'deactivate'], 'mode should be activate or deactivate'
    assert selectMode in ['top', 'random'], 'selectMode should be top or random'

    numRep = 0

    logs = []

    if selectMode=='top':
        targetNeurons = [neuron['neuron'] for neuron in neurons[:K]]
    elif selectMode=='random':
        targetNeurons = [neuron['neuron'] for neuron in random.sample(neurons, K)]

    for i, text in enumerate(texts):
        line = '(%d)INIT: '%i + repr(tokenizer.decode(text['ids']))
        print(line)
        logs.append(line)

        if mode=='deactivate':
            ngram, firstPosition, secondPosition, thirdPosition = detectRepetition(text['ids'], n=10, r=100, k=3)
            initialInput = text['ids'][:secondPosition]
        elif mode=='activate':
            initialInput = text['ids'][:N]

        gens, ngram, fp, sp, tp = generateWithIntervention(model, tokenizer, initialInput, targetNeurons, mode=mode)
        if ngram:
            line = '(%d)REPL: '%i + repr(tokenizer.decode(gens))
            print(line)
            logs.append(line)
            numRep += 1
        elif mode=='deactivate':
            line = '(%d)NORE: '%i + repr(tokenizer.decode(gens))
            print(line)
            logs.append(line)

    print(numRep)
    logs.append(str(numRep))
    return logs, numRep

- Try deactivation for the repetition dataset
  - Deactivating top 1500 reptition neurons -> (e.g.) 7 texts still have repetition (3 are solved)
  - While deactivating random 1500 reptition neurons -> (e.g.) 8 - 10

In [None]:
texts = [{'ids': line['generatedIds']} for line in repetitionDataset][:10]
logsTop1500 = conductExpIntervention(texts, sortedNeurons, 'deactivate', 'top', K=1500)

In [None]:
texts = [{'ids': line['generatedIds']} for line in repetitionDataset][:10]
logsRandom1500 = conductExpIntervention(texts, sortedNeurons, 'deactivate', 'random', K=1500)

- Experiment with activating repetitive neurons
  - Activating repetition neurons -> repetition occurs
  - Activating random neurons -> repetition occurs roughly at 20% or more

In [None]:
text = "Abu Dhabi is the capital city of the United Arab Emirates. The city is the seat of the Abu Dhabi Central Capital District, the capital city of the Emirate of Abu Dhabi, and the UAE's second-most populous city, after Dubai. The city is situated on a T-shaped island, extending into the Gulf from the central-western coast of the UAE."
texts = [{'ids':tokenizer.encode(text)}]

In [None]:
logsTop1500_act = conductExpIntervention(texts, sortedNeurons, 'activate', 'top', K=1500, N=10)

In [None]:
logsRandom1500_act = conductExpIntervention(texts, sortedNeurons, 'activate', 'random', K=1500, N=10)