In [10]:
import torch
import random
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


#### Downloading model

In [12]:
model_path = "data/saved_models/fine_tuned_model_with_regularization"
tokenizer_path = "gpt2"

In [13]:
tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_path)
model = GPT2LMHeadModel.from_pretrained(model_path, output_hidden_states=True).to(device)
tokenizer.pad_token = tokenizer.eos_token
model.generation_config.return_dict_in_generate = False

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [14]:
def prompt_model(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
    pad_token_id = tokenizer.eos_token_id
    output = model.generate(input_ids, max_new_tokens=1, attention_mask=attention_mask, pad_token_id=pad_token_id)
    return tokenizer.decode(output[0])

In [15]:
print(prompt_model("51 + 213 ="))

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


51 + 213 = 264


#### Fitting a Helix

In [16]:
def get_samples_from_segment(left, right, number):
    data = []
    for a in range(left, right + 1):
        for b in range(left, right + 1):
            c = a + b
            prompt = f"{a} + {b} = {c}"
            data.append((a, b, c, prompt))
    return random.sample(data, number)

def get_activation(prompt, layer_idx):
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    hidden_states = outputs.hidden_states
    return hidden_states[layer_idx][0, 0].cpu().numpy()

In [17]:
layer_idx = 1
target_variance = 0.995

In [18]:
H_all = []
a_vals = []
prompts_all = get_samples_from_segment(0, 249, 2000)
for a, b, c, prompt in tqdm(prompts_all):
    a_vals.append(a)
    activation = get_activation(prompt, layer_idx)
    H_all.append(activation)
H_all = np.array(H_all)

100%|██████████| 2000/2000 [00:12<00:00, 164.09it/s]


In [19]:
pca = PCA(n_components=None)
H_pca = pca.fit_transform(H_all)
cumulative_variance = np.cumsum(pca.explained_variance_ratio_)
pca_dim = np.argmax(cumulative_variance >= target_variance) + 1
H_pca = H_pca[:, :pca_dim]

print(f"PCA dim: {pca_dim}")
print("H_pca.shape:", H_pca.shape)
# in H_pca columns are orthogonal

PCA dim: 13
H_pca.shape: (2000, 13)


In [20]:
periods = [2, 5, 10, 100]
def build_fourier_features(a):
    features = [a]
    for period in periods:
        features.append(np.cos(2 * np.pi * a / period))
        features.append(np.sin(2 * np.pi * a / period))
    return np.array(features)

def build_polynomial_features(a):
    features = [1]
    for i in range(1, 9):
        features.append(a ** i)
    return np.array(features)


B = np.array([build_fourier_features(a) for a in a_vals])

In [21]:
reg = LinearRegression(fit_intercept=False)
reg.fit(B, H_pca)
C_pca = reg.coef_
r2 = reg.score(B, H_pca)
print(C_pca.shape)
print(f"R^2 в PCA-пространстве: {r2:.4f}")

(13, 9)
R^2 в PCA-пространстве: 0.3993


In [22]:
VT = pca.components_
VT_k = pca.components_[:pca_dim, :]

In [23]:
C = VT_k.T @ C_pca
H_helix = B @ C.T

#### Evaluating the Quality of the Helical Fit

In [24]:
def get_patched_logit(a, b, a_bad, patched_activation, layer_idx=1):

    prompt_bad = f"{a_bad} + {b} ="
    inputs = tokenizer(prompt_bad, return_tensors='pt').to(device)
    target_token = tokenizer.encode(f"51 + 12 = {str(a + b)}")[4]

    def hook(module, inputs):
        modified_inputs = (inputs[0].clone(),)
        modified_inputs[0][0, 0] = torch.tensor(patched_activation, device=device)
        return modified_inputs
    
    # https://github.com/huggingface/transformers/blob/v4.53.1/src/transformers/models/gpt2/modeling_gpt2.py#L1090
    handle = model.transformer.h[layer_idx].register_forward_pre_hook(hook)
    
    with torch.no_grad():
        outputs = model(**inputs)

    handle.remove()
    
    logits = F.softmax(outputs.logits[0, -1])
    logit = logits[target_token].item()
    return logit

In [25]:
def get_logit(a, b, a_bad):

    prompt_bad = f"{a_bad} + {b} ="
    inputs = tokenizer(prompt_bad, return_tensors='pt').to(device)
    target_token = tokenizer.encode(f"51 + 12 = {str(a + b)}")[4]
    # we need to do this to get right token
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = F.softmax(outputs.logits[0, -1])
    logit = logits[target_token].item()
    return logit

In [27]:
layer_idx = 1
right_patching_improving = 0
our_patching_improving = 0
n = 300

for i, (a, b, c, prompt) in tqdm(enumerate(prompts_all)):
    if i >= n:
        break
    a_vals.append(a)
    activation = get_activation(prompt, layer_idx)

    right_patching = get_patched_logit(a, b, (b + 100) % 250, activation, layer_idx)
    our_patching = get_patched_logit(a, b, (b + 100) % 250, H_helix[i], layer_idx)
    without_patching = get_logit(a, b, (b + 100) % 250)
    right_patching_improving += right_patching - without_patching
    our_patching_improving += our_patching - without_patching

print("Clean patching:", right_patching_improving / n)
print("Helix patching:", our_patching_improving / n)


  logits = F.softmax(outputs.logits[0, -1])
  logits = F.softmax(outputs.logits[0, -1])
300it [00:07, 42.74it/s]

Clean patching: 0.9341645966353733
Helix patching: 0.3341626773095443





- Clean patching: 0.94735
- При target_variance = 0.995:
    - Helix fit -> Helix patching: 0.3232
    - Polynomial fit -> Polynomial patching: 0.0414

- При target_variance = 0.95
    - Helix patching: 0.3259

- При target_variance = 0.75
    - Helix patching: 0.1396

- При target_variance = 0.5
    - Helix patching: 0.0233


- Используя gpt2 без SFT мы получим:
    - Clean patching: -3.586612067010719e-05
    - Helix patching: -0.0006099231596814067