In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM
import numpy as np

In [2]:
cache_dir = "/ssd_scratch/sweta.jena"
dataset_name="imdb"
class_names = ["negative", "positive"]
device='cuda'

In [3]:
dataset = load_dataset("imdb")
train_data = dataset["train"] 

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",cache_dir=cache_dir)
model = BertForMaskedLM.from_pretrained("bert-base-uncased", output_hidden_states=True,
                                        cache_dir=cache_dir).to(device)
model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [5]:
def get_cls_embedding(text):
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
        outputs = model(**inputs)
        return outputs.hidden_states[-1][0, 0].cpu()


In [6]:
pos_embs = []
neg_embs = []


In [7]:
for example in tqdm(train_data, total=len(train_data)):
    text = example["text"]
    label = example["label"]
    print(text)
    print(label)

    break

  0%|          | 0/25000 [00:00<?, ?it/s]

I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, eve




In [8]:
for example in tqdm(train_data, total=len(train_data)):
    text = example["text"]
    label = example["label"]

    emb = get_cls_embedding(text)

    if label == 1:
        pos_embs.append(emb)
    else:
        neg_embs.append(emb)

pos_embs = torch.stack(pos_embs)
neg_embs = torch.stack(neg_embs)

100%|██████████| 25000/25000 [09:01<00:00, 46.14it/s]


In [9]:
steering_vector_pos_neg = torch.tensor(pos_embs.mean(dim=0) - neg_embs.mean(dim=0), dtype=torch.float32).to(device)
torch.save(steering_vector_pos_neg, "steering_vector_pos_neg.pt")

  steering_vector_pos_neg = torch.tensor(pos_embs.mean(dim=0) - neg_embs.mean(dim=0), dtype=torch.float32).to(device)


In [10]:
steering_vector_neg_pos = torch.tensor(neg_embs.mean(dim=0) - pos_embs.mean(dim=0), dtype=torch.float32).to(device)
torch.save(steering_vector_neg_pos, "steering_vector_neg_pos.pt")

  steering_vector_neg_pos = torch.tensor(neg_embs.mean(dim=0) - pos_embs.mean(dim=0), dtype=torch.float32).to(device)


In [11]:
# steering_vector_pos_neg=steering_vector_pos_neg.to(device)
# steering_vector_neg_pos=steering_vector_neg_pos.to(device)

# find best layer

In [12]:
steering_vector_pos_neg=torch.load("steering_vector_pos_neg.pt").to(device)
steering_vector_neg_pos=torch.load("steering_vector_neg_pos.pt").to(device)

In [13]:
positive_words = ["great", "amazing", "excellent", "wonderful", "fantastic", "good"]
negative_words = ["terrible", "awful", "bad", "horrible", "dreadful", "poor"]

pos_ids = torch.tensor([tokenizer.convert_tokens_to_ids(w) for w in positive_words]).to(device)
neg_ids = torch.tensor([tokenizer.convert_tokens_to_ids(w) for w in negative_words]).to(device)

In [14]:
def steering_score_for_layer(text, steering_vec, layer, alpha=3.0):
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    mask_index = (inputs.input_ids == tokenizer.mask_token_id).nonzero()[0,1]

    h = outputs.hidden_states[layer][0, mask_index] 
    h = h + alpha * steering_vec
    logits = model.cls(h)
    probs = logits.softmax(-1)

    pos_score = probs[pos_ids].sum().item()
    neg_score = probs[neg_ids].sum().item()

    return pos_score - neg_score


In [15]:
text = "The movie was [MASK]."
scores = []

for layer in range(model.config.num_hidden_layers+1):   # BERT has 13 hidden state outputs: 0..12 including embedding
    score = steering_score_for_layer(text, steering_vector_pos_neg, layer, alpha=3.0)
    scores.append(score)
    print(f"Layer {layer}: {score:.4f}")

best_layer = np.argmax(scores)
print("Best steering layer:", best_layer)


Layer 0: 0.0010
Layer 1: 0.0032
Layer 2: 0.0068
Layer 3: 0.0054
Layer 4: 0.0090
Layer 5: 0.0084
Layer 6: 0.0060
Layer 7: 0.0053
Layer 8: 0.0114
Layer 9: 0.0160
Layer 10: 0.0473
Layer 11: 0.1408
Layer 12: 0.4122
Best steering layer: 12


# Test

In [16]:
text = "The movie was very [MASK]."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs, output_hidden_states=True)

mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero()[0,1]

layer = 12
hidden = outputs.hidden_states[layer][0, mask_index]

alpha = 2.0
hidden = hidden + alpha * steering_vector_pos_neg

logits = model.cls(hidden.unsqueeze(0))
probs = logits.softmax(dim=-1)
top = torch.topk(probs, 10)

print([tokenizer.decode(i) for i in top.indices[0]])

['good', 'successful', 'happy', 'popular', 'beautiful', 'interesting', 'entertaining', 'exciting', 'short', 'enjoyable']


In [17]:
text = "The movie was very [MASK]."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs, output_hidden_states=True)

mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero()[0,1]

layer = 10
hidden = outputs.hidden_states[layer][0, mask_index]

alpha = 2.0
hidden = hidden + alpha * steering_vector_pos_neg

logits = model.cls(hidden.unsqueeze(0))
probs = logits.softmax(dim=-1)
top = torch.topk(probs, 10)

print([tokenizer.decode(i) for i in top.indices[0]])

['good', 'successful', 'popular', 'short', 'quiet', 'slow', 'close', 'fast', 'expensive', 'busy']


In [18]:
text = "The movie was very [MASK]."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs, output_hidden_states=True)

mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero()[0,1]

layer = 12
hidden = outputs.hidden_states[layer][0, mask_index]

alpha = 2.0
hidden = hidden + alpha * steering_vector_neg_pos

logits = model.cls(hidden.unsqueeze(0))
probs = logits.softmax(dim=-1)
top = torch.topk(probs, 10)

print([tokenizer.decode(i) for i in top.indices[0]])

['good', 'bad', 'funny', 'popular', 'boring', 'interesting', 'successful', 'entertaining', 'short', 'pretty']


In [19]:
text = "The movie was very [MASK]."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

outputs = model(**inputs, output_hidden_states=True)

mask_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero()[0,1]

layer = 10
hidden = outputs.hidden_states[layer][0, mask_index]

alpha = 2.0
hidden = hidden + alpha * steering_vector_neg_pos

logits = model.cls(hidden.unsqueeze(0))
probs = logits.softmax(dim=-1)
top = torch.topk(probs, 10)

print([tokenizer.decode(i) for i in top.indices[0]])

['popular', 'bad', 'short', 'good', 'expensive', 'quiet', 'slow', 'successful', 'poor', 'fast']


In [20]:
mlm_tokenizer = tokenizer
mlm_model = model

In [21]:
def generate_candidates_steered(text, word, steering_vec, alpha=2.0, layer=12, top_k=10):
    # Encode once, correctly
    encoding = mlm_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=512
    )

    tokens = mlm_tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])

    # Check word presence
    if word not in tokens:
        return []

    # Mask the target word
    idx = tokens.index(word)
    tokens[idx] = mlm_tokenizer.mask_token
    masked_text = mlm_tokenizer.convert_tokens_to_string(tokens)

    # Re-encode masked text
    inputs = mlm_tokenizer(masked_text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = mlm_model(**inputs, output_hidden_states=True)

    # Mask position
    mask_index = (inputs["input_ids"] == mlm_tokenizer.mask_token_id).nonzero()[0,1]

    # Get hidden representation at chosen layer
    hidden = outputs.hidden_states[layer][0, mask_index]

    # ✅ Apply steering vector
    hidden = hidden + alpha * steering_vec.to(device)

    # Pass through MLM head
    logits = mlm_model.cls(hidden.unsqueeze(0))
    probs = torch.softmax(logits, dim=-1)

    # Select top-k tokens
    top_tokens = torch.topk(probs, top_k).indices[0].tolist()
    candidates = [mlm_tokenizer.decode([t]).strip() for t in top_tokens]

    return candidates


In [22]:
generate_candidates_steered(
    text="The movie was bad.",
    word="bad",
    steering_vec=steering_vector_pos_neg,
    alpha=2.5,       # steering strength
    layer=9,         # semantic layer
    top_k=10
)

['gone',
 'done',
 'silent',
 'finished',
 'underway',
 'quiet',
 'completed',
 'alive',
 'ready',
 'perfect']

In [23]:
generate_candidates_steered(
    text="The movie was bad.",
    word="bad",
    steering_vec=steering_vector_pos_neg,
    alpha=2.5,       # steering strength
    layer=11,         # semantic layer
    top_k=10
)

['perfect',
 'good',
 'finished',
 'wonderful',
 'beautiful',
 'rocked',
 'gone',
 'done',
 'silent',
 'slow']

In [24]:
generate_candidates_steered(
    text="The movie was bad.",
    word="bad",
    steering_vec=steering_vector_pos_neg,
    alpha=2.5,       # steering strength
    layer=12,         # semantic layer
    top_k=10
)

['good',
 'perfect',
 'wonderful',
 'beautiful',
 'great',
 'amazing',
 'over',
 'incredible',
 'excellent',
 'fantastic']

In [25]:
generate_candidates_steered(
    text="The movie was good.",
    word="good",
    steering_vec=steering_vector_neg_pos,
    alpha=2.5,       # steering strength
    layer=9,         # semantic layer
    top_k=10
)

['gone',
 'done',
 'underway',
 'quiet',
 'incomplete',
 'finished',
 'empty',
 'ruined',
 'bad',
 'loaded']

In [26]:
generate_candidates_steered(
    text="The movie was good.",
    word="good",
    steering_vec=steering_vector_neg_pos,
    alpha=2.5,       # steering strength
    layer=10,         # semantic layer
    top_k=10
)

['gone',
 'finished',
 'bad',
 'slow',
 'ruined',
 'expensive',
 'silent',
 'awful',
 'stopped',
 'short']

In [27]:
generate_candidates_steered(
    text="The movie was good.",
    word="good",
    steering_vec=steering_vector_neg_pos,
    alpha=2.5,       # steering strength
    layer=12,         # semantic layer
    top_k=10
)

['awful',
 'over',
 'bad',
 'horrible',
 'good',
 'terrible',
 'boring',
 'ridiculous',
 'great',
 'ruined']