In [70]:
from transformers.generation import LogitsProcessor,LogitsProcessorList
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from scipy.stats import gumbel_l
from arsenal.maths.rvs import TruncatedDistribution
import copy
import torch
import tqdm
import scipy
from scipy.stats import gumbel_l, gumbel_r
from arsenal.maths.rvs import TruncatedDistribution
import transformers
from evaluate import load
import pickle
from collections import Counter, defaultdict
import ot
import sk2torch
from sklearn.linear_model import SGDClassifier
# import nn
import torch.nn as nn

In [258]:
model_name="openai-community/gpt2-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,torch_dtype=torch.float16,
                                             device_map='auto').eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [230]:
@torch.no_grad()
def encode(model, tokenizer, text, batch_size,layer=-1, pooling="last"):
  encodings = []

  with torch.no_grad():
   for i in tqdm.tqdm(range(0, len(text), batch_size)):
    batch = text[i:i+batch_size]
    padded_tokens = tokenizer(batch, padding=True, return_tensors="pt", max_length=128, truncation=True).to("cuda")
    outputs = model(**padded_tokens, output_hidden_states=True)
    lengths = padded_tokens["attention_mask"].sum(axis=1).detach().cpu().numpy()

    hiddens = outputs.hidden_states[layer]
    hiddens = hiddens.detach()
    for h,l in zip(hiddens, lengths):
      if pooling == "last":
        h = h[l-1]
      elif pooling == "cls":
        h = h[0]
      elif pooling == "mean":
        h = h[:l].mean(axis=0)
      encodings.append(h.detach().cpu().numpy())

  return np.array(encodings)

In [231]:
with open("bios_data/bios_data/train.pickle", "rb") as f:
    data = pickle.load(f)
    y = np.array([d["p"] for d in data])
    z = np.array([1 if d["g"] == "m" else 0 for d in data])
    texts = [d["text"] for d in data]

y_to_keep = ["professor", "physician", "attorney"]
idx_to_keep = [i for i in range(len(y)) if y[i] in y_to_keep]
y = y[idx_to_keep]
z = z[idx_to_keep]
texts = [texts[i] for i in idx_to_keep]

num_m, num_f = sum(z), len(z) - sum(z)
n = 10000#min(num_m, num_f)
idx_m = [i for i in range(len(z)) if z[i] == 1]
idx_f = [i for i in range(len(z)) if z[i] == 0]
idx = idx_m[:n] + idx_f[:n]
y = y[idx]
z = z[idx]
texts = [texts[i] for i in idx]


In [232]:
model.transformer.h

ModuleList(
  (0-35): 36 x GPT2Block(
    (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (attn): GPT2Attention(
      (c_attn): Conv1D()
      (c_proj): Conv1D()
      (attn_dropout): Dropout(p=0.1, inplace=False)
      (resid_dropout): Dropout(p=0.1, inplace=False)
    )
    (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [259]:
layer = 16
encodings = encode(model, tokenizer, texts, 32, layer=layer, pooling="mean")


100%|██████████| 625/625 [01:31<00:00,  6.85it/s]


In [260]:
import scipy

def get_optimal_gaussian_transport_func(source_x, target_x):
      cov_source = np.cov(source_x.T).real + 1e-8
      cov_target = np.cov(target_x.T).real + 1e-8

      # optimal transport

      cov_source_sqrt = matrix_squared_root(cov_source)
      cov_source_sqrt_inv = matrix_inv_squared_root(cov_source) #scipy.linalg.inv(cov_source_sqrt)

      A = cov_source_sqrt_inv @ matrix_squared_root(cov_source_sqrt @ cov_target @ cov_source_sqrt) @ cov_source_sqrt_inv
      return A

def matrix_squared_root(A):
    return scipy.linalg.sqrtm(A)


def matrix_inv_squared_root(A):

    return np.linalg.inv(matrix_squared_root(A))

In [261]:
x_train_source = encodings[z==0,:][:]
x_train_target = encodings[z==1,:][:]
A = get_optimal_gaussian_transport_func(x_train_source, x_train_target)


In [262]:
mean_source = x_train_source.mean(axis=0)
mean_target = x_train_target.mean(axis=0)
encodings_transformed = encodings.copy()
encodings_transformed[z == 0] = mean_target + (encodings_transformed[z == 0] - mean_source) @ A
clf = SGDClassifier(loss="log_loss", max_iter=10000, tol=1e-4)
clf.fit(encodings, z)

In [263]:
class InterventionModule(nn.Module):
    def __init__(self, mean_0, mean_1, A, mlp, alpha=1.0):
        super().__init__()
        self.mean_0 = torch.tensor(mean_0)
        self.mean_1 = torch.tensor(mean_1)
        self.A = torch.tensor(A)
        self.mlp = sk2torch.wrap(mlp)
        self.alpha = alpha
        # set requires_grad=False to all params of the mlp
        for p in self.mlp.parameters():
            p.requires_grad = False
    
    def to_cuda(self, device):
      self.A = self.A.to(device)
      self.mean_0 = self.mean_0.to(device)
      self.mean_1 = self.mean_1.to(device)
      self.mlp = self.mlp.to(device)

    def to_cpu(self):
      self.A = self.A.to("cpu")
      self.mean_0 = self.mean_0.to("cpu")
      self.mean_1 = self.mean_1.to("cpu")
      self.mlp = self.mlp.to("cpu")

    def forward(self, hidden_states):
        self.to_cuda(hidden_states.device)
        # fi hidden state is half, convert laso the params to half precision
        if hidden_states.dtype == torch.float16:
          self.A = self.A.half()
          self.mean_0 = self.mean_0.half()
          self.mean_1 = self.mean_1.half()
          self.mlp = self.mlp.half()

        
        preds = self.mlp(hidden_states)
        preds = preds[0]

        x = hidden_states.clone().reshape(-1, hidden_states.shape[-1])
        
        x[preds == 0] = self.alpha*self.mean_1 + (x[preds == 0] - self.alpha*self.mean_0)@self.A
        #print("Steering {} samples".format((preds == 0).sum()))
        x = x.reshape(hidden_states.shape)
        return x

def insert_intervention(model, model_name, intervention, layer, after_layer_norm=False, replace_existing=False):
    if "gpt2" in model_name.lower():
        # if the mlp is already a Sequential object, do nothing
        if isinstance(model.transformer.h[layer].mlp, torch.nn.Sequential):
            return
        if not replace_existing:
            model.transformer.h[layer].mlp = torch.nn.Sequential(model.transformer.h[layer].mlp, intervention)
        else:
            
            model.transformer.h[layer].mlp = torch.nn.Sequential(model.transformer.h[layer].mlp[:-1], intervention)
    else:
        raise NotImplementedError("Only GPT2 is supported")


def remove_intervention(model, model_name, layer):
    if "gpt2" in model_name.lower():
        # if the mlp is not a Sequential object, do nothing
        if not isinstance(model.transformer.h[layer].mlp, torch.nn.Sequential):
            return
        model.transformer.h[layer].mlp = model.transformer.h[layer].mlp[0]
    else:
        raise NotImplementedError("Only GPT2 is supported")

In [274]:
intervention_module = InterventionModule(mean_source, mean_target, A, clf, alpha=5.0)
remove_intervention(model, model_name, layer)
insert_intervention(model, model_name, intervention_module, layer, replace_existing=False)

In [275]:
from transformers import pipeline
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
# generate text
prompt = "Janet Kohn is a professor at NYU, where"
generator(prompt, max_length=75, num_return_sequences=1)

[{'generated_text': 'Janet Kohn is a professor at NYU, where he teaches about human rights and international law. "It may be a little unfair to pick on those who do that work. But the other people in the system that say, \'He must have thought of this,\' well that\'s how we get to pick our own people."\n\nAnd on the "deep state"'}]

In [266]:
generated

[{'generated_text': 'She is a'}]