In [1]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from safetensors.torch import load_model, save_model, safe_open, _remove_duplicate_names
from weak_to_strong.model import TransformerWithHead
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from datasets import load_dataset
from sklearn.manifold import TSNE
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from scipy.optimize import minimize
from sklearn.utils import gen_batches
from tqdm import tqdm
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import collections


In [2]:
from find_delta.get_delta import optimize_Adelta, pred_act
from find_delta.get_activations import load_model_modified, text2rep

### CUDA & memory checks

In [3]:
print(f"{torch.cuda.get_device_name()}")
print(f'available devices: {torch.cuda.device_count()}')
print(f'current device: {torch.cuda.current_device()}')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

NVIDIA A40
available devices: 1
current device: 0


In [4]:
device

device(type='cuda')

In [5]:
import psutil
# cpu memory
memory = psutil.virtual_memory()
print(f"Total memory: {memory.total / (1024**3):.2f} GB")
print(f"Available memory: {memory.available / (1024**3):.2f} GB")
print(f"Used memory: {memory.used / (1024**3):.2f} GB")
print(f"Memory usage percentage: {memory.percent}%")

Total memory: 503.81 GB
Available memory: 493.62 GB
Used memory: 7.23 GB
Memory usage percentage: 2.0%


In [6]:
#gpu memory check
if torch.cuda.is_available():
    print(f"Total GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print(f"Allocated GPU Memory: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"Reserved GPU Memory: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    print(f"Free (Available) GPU Memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3:.2f} GB")

Total GPU Memory: 44.35 GB
Allocated GPU Memory: 0.00 GB
Reserved GPU Memory: 0.00 GB
Free (Available) GPU Memory: 44.35 GB


# Functions

In [7]:
def get_activation(name, activations):
    """ Helper function to capture the activation at each layer. """
    def hook(model, input, output):
        # We expect 'output' to be a tuple where the first element is the last hidden state
        activations[name] = output[0].detach().cpu() #to(device) will result in memory issue!
        #print(f"Used memory: {memory.used / (1024**3):.2f} GB")
    return hook

def extract_hidden_states(model, datapoint):
    """ Extract hidden states for all layers of a given model for a specific datapoint. """
    activations = {}
    hooks = []

    # Registering hooks for each layer of the transformer
    for name, module in model.transformer.named_modules():
        if isinstance(module, torch.nn.modules.Module):  # You may want to filter only certain types of layers
            hook = module.register_forward_hook(get_activation(name, activations))
            hooks.append(hook)
    
    datapoint = datapoint.to(device)
    
    # Run the datapoint through the model
    model.eval()
    with torch.no_grad():
        _ = model(datapoint.to(device))

    # Remove hooks after use
    for hook in hooks:
        hook.remove()

    return activations

In [8]:
def print_layer_names(activations):
    stor = []
    for key, val in activations.items(): 
        stor.append(key)
        print(key)
    return key

In [9]:
def load_model_modified(model: torch.nn.Module, filename, strict: bool = True, device = "cpu"):
    """
    modified the load_model from safetensors.torch to resolve device error (device = 0)
    """
    tensors = {}
    with safe_open(filename, framework="pt", device=0) as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
        
    state_dict = tensors
    
    model_state_dict = model.state_dict()
    to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys())
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    missing = set(missing)
    for to_remove_group in to_removes.values():
        for to_remove in to_remove_group:
            if to_remove not in missing:
                unexpected.append(to_remove)
            else:
                missing.remove(to_remove)
    if strict and (missing or unexpected):
        missing_keys = ", ".join([f'"{k}"' for k in sorted(missing)])
        unexpected_keys = ", ".join([f'"{k}"' for k in sorted(unexpected)])
        error = f"Error(s) in loading state_dict for {model.__class__.__name__}:"
        if missing:
            error += f"\n    Missing key(s) in state_dict: {missing_keys}"
        if unexpected:
            error += f"\n    Unexpected key(s) in state_dict: {unexpected_keys}"
        raise RuntimeError(error)
    return missing, unexpected

In [10]:
def compare_models(model_name, finetuned_model_path, datapoint): 
    """ Extract and compare hidden states from two models for a given datapoint. """
        # Load both models
    pre_model = TransformerWithHead.from_pretrained(model_name).to(device)
    post_model = TransformerWithHead.from_pretrained(model_name)
    
    load_model_modified(post_model, finetuned_model_path, device)

    post_model = post_model.to(device)
    
    # datapoint = datapoint.to(model1.device)  # Ensure datapoint is on the same device as model
    activations_model1 = extract_hidden_states(pre_model, datapoint)
    activations_model2 = extract_hidden_states(post_model, datapoint)
    
    return activations_model1, activations_model2

In [11]:
def convert_input(text, model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name) 

    # Tokenize the text and convert to input IDs
    tokens = tokenizer.tokenize(text, max_length=1024, truncation=True) 
    #anthropic_hh has a lot of >1024 sequences
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    datapoint = torch.tensor([input_ids])
    
    return datapoint

In [12]:
def extract_act_pipeline(model_name, train_data, finetuned_model_path, layer_name = "h.11"): 
    pre_ft_activations = []
    post_ft_activations = []
    print("Converting input to activations.")
    for datapoint in tqdm(train_data):
        activations_model1, activations_model2 = compare_models(model_name, finetuned_model_path, datapoint)
        
        pre_ft_activations.append(activations_model1)
        post_ft_activations.append(activations_model2)
        
    print("Activation Loaded.")
    return pre_ft_activations, post_ft_activations


### Plot 1: PCA/t-SNE of activation of one layer of one datapoint

In [13]:
def plot_activation_changes(activations_pre, activations_post, layer_name = "h.11", method='PCA', components=2):
    """
    Visualize changes in activations using PCA or t-SNE.

    Parameters:
    activations_pre (dict): Activations from the model before finetuning.
    activations_post (dict): Activations from the model after finetuning.
    layer_name (str): The layer whose activations are to be visualized.
    method (str): 'PCA' or 't-SNE', the method to use for dimensionality reduction.
    components (int): Number of components for the dimensionality reduction.
    """
    # Extract activations for a specific layer
    data_pre = activations_pre[1][layer_name].cpu().numpy()
    data_post = activations_post[1][layer_name].cpu().numpy()
    
    # Check if data is three-dimensional and apply mean pooling if so
    if data_pre.ndim == 3:
        # Mean across the sequence length dimension
        data_pre = data_pre.mean(axis=0)
    if data_post.ndim == 3:
        # Mean across the sequence length dimension
        data_post = data_post.mean(axis=0)
    
    # Concatenate data from both states for unified transformation in PCA/t-SNE
    data_combined = np.concatenate([data_pre, data_post], axis=0)
    
    if method == 'PCA':
        reducer = PCA(n_components=components)
    elif method == 't-SNE':
        reducer = TSNE(n_components=components, learning_rate='auto', init='random')
    else:
        raise ValueError("Unsupported dimensionality reduction method")
    
    # Fit and transform the data
    reduced_data = reducer.fit_transform(data_combined)
    
    # Split the transformed data
    reduced_data_pre = reduced_data[:data_pre.shape[0]]
    reduced_data_post = reduced_data[data_pre.shape[0]:]

    # Plotting
    plt.figure(figsize=(10, 6))
    plt.scatter(reduced_data_pre[:, 0], reduced_data_pre[:, 1], c='blue', alpha=0.5, label='Pre-Finetuning')
    plt.scatter(reduced_data_post[:, 0], reduced_data_post[:, 1], c='red', alpha=0.5, label='Post-Finetuning')
    plt.title(f'Layer: {layer_name} - {method} Visualization')
    plt.xlabel(f'{method} Component 1')
    plt.ylabel(f'{method} Component 2')
    plt.legend()

    filename = f'{method}_dp1_gpt2_anthropic_hh_{layer_name}.png'
    plt.savefig(filename)
    plt.show()

### Plot 2: PCA for 1 datapoint through all layers

In [14]:
def plot_activation_changes_all_layers(pre_ft_activations, post_ft_activations, datapoint_index, components=2):
    x_coords = []
    y_coords = []
    labels = []
    for layername in pre_ft_activations[datapoint_index]:
        data_pre = pre_ft_activations[datapoint_index][layername].cpu().numpy().reshape(1, -1)
        data_post = post_ft_activations[datapoint_index][layername].cpu().numpy().reshape(1, -1)
    
        data_combined = np.concatenate([data_pre, data_post], axis=0)
    
        reduced_data = PCA(n_components=components).fit_transform(data_combined)
    
        # Split the transformed data
        reduced_data_pre = reduced_data[0, :]
        reduced_data_post = reduced_data[1, :]
    
        #print(f"{layername}: {reduced_data_pre}, {reduced_data_post}")
    
        # Calculate midpoint for better visualization
        midpoint_x = (reduced_data_pre[0] + reduced_data_post[0]) / 2
        midpoint_y = (reduced_data_pre[1] + reduced_data_post[1]) / 2
    
        x_coords.append(midpoint_x)
        y_coords.append(midpoint_y)
        labels.append(layername)
    
    
    # Create plot
    plt.figure(figsize=(14, 10))
    plt.scatter(x_coords, y_coords, color='red')
    
    # Annotate each point in the scatter plot
    for i, label in enumerate(labels):
        plt.annotate(label, (x_coords[i], y_coords[i]))
    
    plt.title('PCA-transformed Activation Differences')
    plt.xlabel('PCA Component 1')
    plt.ylabel('PCA Component 2')
    plt.grid(True)
    plt.show()

### Plot 3: PCA/t-SNE for all datapoints through one layer

In [15]:
def plot_samples_per_layer(pre_ft_activations, post_ft_activations, layer_name, method='PCA', perplexity=40, learning_rate=100):
    min_length_pre = min(len(pre_ft_activations[i][layer_name].cpu().numpy().flatten()) for i in range(len(pre_ft_activations)))
    min_length_post = min(len(post_ft_activations[i][layer_name].cpu().numpy().flatten()) for i in range(len(post_ft_activations)))
    min_length = min(min_length_pre,min_length_post)

    data_pre = np.array([pre_ft_activations[i][layer_name].cpu().numpy().flatten()[:min_length] for i in range(len(pre_ft_activations))])
    data_post = np.array([post_ft_activations[i][layer_name].cpu().numpy().flatten()[:min_length] for i in range(len(post_ft_activations))])

    data_combined = np.vstack((data_pre, data_post))

    if method == 'PCA':
        pca = PCA(n_components=2)
        pca_results = pca.fit_transform(data_combined)
        
        # Split the transformed data into pre- and post-finetuning groups
        pca_pre = pca_results[:len(data_pre)]
        pca_post = pca_results[len(data_pre):]
        
        # Plot results
        plt.figure(figsize=(10, 6))
        plt.scatter(pca_pre[:, 0], pca_pre[:, 1], c='green', label='Pre-Finetuning', alpha=0.5)
        plt.scatter(pca_post[:, 0], pca_post[:, 1], c='red', label='Post-Finetuning', alpha=0.5)
        plt.legend()
        plt.title(f'PCA Visualization of Layer {layer_name} Activations gpt2 anthropic_hh')
        plt.xlabel('PCA Component 1')
        plt.ylabel('PCA Component 2')
        plt.show()

    if method == 't-SNE':
        # tsne = TSNE(n_components=2, perplexity=40, learning_rate=100, init='random')
        tsne_results = tsne.fit_transform(data_combined)
        
        # Split the transformed data into pre- and post-finetuning groups
        tsne_pre = tsne_results[:len(data_pre)]
        tsne_post = tsne_results[len(data_pre):]
        
        # Plot results
        plt.figure(figsize=(10, 6))
        plt.scatter(tsne_pre[:, 0], tsne_pre[:, 1], c='green', label='Pre-Finetuning', alpha=0.5)
        plt.scatter(tsne_post[:, 0], tsne_post[:, 1], c='red', label='Post-Finetuning', alpha=0.5)
        plt.legend()
        plt.title(f't-SNE Visualization of Layer {layer_name} Activations Across All Data Points')
        plt.xlabel('t-SNE Component 1')
        plt.ylabel('t-SNE Component 2')
        plt.show()

In [16]:
def text2rep(dataloader, tokenizer, model, layer_names:list):

    """
    layer_names: a list of layers to extract activation from. e.g., ['h.10', 'h.11']
    """

    # iter through mini-batches in dataloader once to extract activation for all obs
    activations_all = collections.defaultdict(torch.Tensor)
    for batch in tqdm(dataloader):
        
        # convert text to input_ids + attention_mask
        sentences = tokenizer(
            batch['txt'],
     #       batch['content'],
            return_tensors='pt',  # pt = pytorch style tensor
            padding=True,
            truncation=True,
            max_length=1024
        ).to(device)
        # extract activation from each layer
        activations = extract_hidden_states(
            model=model, 
            datapoint=sentences['input_ids']
        )  # dict: {layer name: activation}

        # keep activations from chosen layers + pooling
        for layer_name in layer_names:

            # get pooled activation
            activation_pooled = pooling(
                activations=activations[layer_name].to(device), 
                layer_name=layer_name, 
                attention_mask=sentences['attention_mask'], 
                method='mean'
            )  # (B, D): (batch_size, model dim)
            
            # append
            activations_all[layer_name] = torch.cat(
                (activations_all[layer_name].to(device), activation_pooled), 
                dim=0
            )

    return activations_all

In [17]:
# helper func for pooling
def pooling(activations:dict, layer_name, attention_mask:torch.Tensor, method='mean'):

    # get unpooled activation from target layer
    activation_unpooled = activations  # (B, L, D): (batch_size, num token in sentence, model dim)

    if method == 'mean':
        # mask padding tokens (where attention_mask is 0) with nan
        activation_unpooled_masked = activation_unpooled.masked_fill(
            attention_mask.unsqueeze(-1)==0,
            float('nan')
        )
        # max-pooling across tokens (L dimension) in each sentence
        activation_pooled = activation_unpooled_masked.nanmean(dim=1)   # (B, D)

    elif method == 'max':
        # mask padding tokens (where attention_mask is 0) with -inf
        activation_unpooled_masked = activation_unpooled.masked_fill(
            attention_mask.unsqueeze(-1)==0,
            float('-inf')
        )
        # max-pooling across tokens (L dimension) in each sentence
        activation_pooled, _ = activation_unpooled_masked.max(dim=1)   # (B, D)

    return activation_pooled  # (B, D): (batch_size, model dim)

In [18]:
# TODO (non-urgent): dataloader with shuffle; early stopping based on epoch loss

def optimize_Adelta(pre_ft_activations, post_ft_activations, labels, batch_size=500, lr=1e-3, tol=1e-5, max_iter=20000):

    """
    Optimizes A and delta parameters using the provided pre and post-finetuning activations.

    Args:
    pre_ft_activations (np.array): Pre-finetuning activations.
    post_ft_activations (np.array): Post-finetuning activations.
    dim (int): The dimensionality of each feature vector.
    batch_size (int): The size of each batch for optimization.
    lr (float): Learning rate for the optimizer.
    tol (float): Tolerance for convergence.
    max_iter (int): Maximum number of iterations.

    Returns:
    Tuple[np.array, np.array]: Optimized A and delta.
    """

    # get model dimension
    dim = pre_ft_activations.shape[1]
    
    # Initialize A and delta as torch tensors
    A = nn.Parameter(torch.eye(dim, requires_grad=True, device=device))
    delta = nn.Parameter(torch.zeros(1, dim, requires_grad=True, device=device))

    # Use the Adam optimizer
    optimizer = optim.Adam([A, delta], lr=lr)

    previous_loss = float('inf')
    for iteration in range(max_iter):

        for i in range(0, len(pre_ft_activations), batch_size):

            optimizer.zero_grad()

            batch_pre_ft = pre_ft_activations[i:i+batch_size].squeeze(1)
            batch_post_ft = post_ft_activations[i:i+batch_size].squeeze(1)
            batch_labels = labels[i:i+batch_size]
#             print(batch_pre_ft.shape)
#             print(batch_post_ft.shape)

            loss = linear_shift_loss(A, delta, batch_pre_ft, batch_post_ft, batch_labels)
            loss.backward()
            optimizer.step()

            current_loss = loss.item()
            if iteration % 50 == 0 and i == 0:  # Print the loss for the first batch every 10 iterations
                print(f"Iteration {iteration}, Loss: {current_loss:.6f}")
            if abs(previous_loss - current_loss) < tol:
                print("Convergence criterion met.")
                return A.detach(), delta.detach()
            previous_loss = current_loss

    print("Optimization finished.")
    return A.detach(), delta.detach()

In [19]:
def linear_shift_loss(A, delta, lambda_x, lambda_x_tilde, labels):

    """
    Computes the loss for a batch of data.

    Args:
    A (torch.Tensor): The affine transformation matrix.
    delta (torch.Tensor): The rank-one update vector.
    lambda_x (torch.Tensor): Pre-finetuning activations (batch).
    lambda_x_tilde (torch.Tensor): Post-finetuning activations (batch).

    Returns:
    torch.Tensor: The computed loss.
    """

    # convert 0/1 labels to -1/1 labels; reshape from (N,) to (N,1)
    reshaped_labels = (labels + (labels - 1)).unsqueeze(-1).to(device)

    # apply affine transformation to lambda_x and shift by delta
    transformed = torch.mm(lambda_x, A) + reshaped_labels * delta  # delta shape (1, D)

    # Calculate the Frobenius norm of the difference, scaled by the number of samples
    recon_loss = torch.norm(transformed - lambda_x_tilde, p='fro') ** 2 / lambda_x.size(0)

    return recon_loss

In [20]:
# TODO (non-urgent): dataloader with shuffle; early stopping based on epoch loss

def optimize_A(pre_ft_activations, post_ft_activations, labels, batch_size=500, lr=1e-3, tol=1e-5, max_iter=20000):

    """
    Optimizes A and delta parameters using the provided pre and post-finetuning activations.

    Args:
    pre_ft_activations (np.array): Pre-finetuning activations.
    post_ft_activations (np.array): Post-finetuning activations.
    dim (int): The dimensionality of each feature vector.
    batch_size (int): The size of each batch for optimization.
    lr (float): Learning rate for the optimizer.
    tol (float): Tolerance for convergence.
    max_iter (int): Maximum number of iterations.

    Returns:
    Tuple[np.array, np.array]: Optimized A and delta.
    """

    # get model dimension
    dim = pre_ft_activations.shape[1]
    
    # Initialize A and delta as torch tensors
    A = nn.Parameter(torch.eye(dim, requires_grad=True, device=device))

    # Use the Adam optimizer
    optimizer = optim.Adam([A], lr=lr)

    previous_loss = float('inf')
    for iteration in range(max_iter):

        for i in range(0, len(pre_ft_activations), batch_size):

            optimizer.zero_grad()

            batch_pre_ft = pre_ft_activations[i:i+batch_size].squeeze(1)
            batch_post_ft = post_ft_activations[i:i+batch_size].squeeze(1)
            batch_labels = labels[i:i+batch_size]
#             print(batch_pre_ft.shape)
#             print(batch_post_ft.shape)

            loss = linear_shift_loss_A(A, batch_pre_ft, batch_post_ft, batch_labels)
            loss.backward()
            optimizer.step()

            current_loss = loss.item()
            if iteration % 50 == 0 and i == 0:  # Print the loss for the first batch every 10 iterations
                print(f"Iteration {iteration}, Loss: {current_loss:.6f}")
            if abs(previous_loss - current_loss) < tol:
                print("Convergence criterion met.")
                return A.detach()
            previous_loss = current_loss

    print("Optimization finished.")
    return A.detach()
def linear_shift_loss_A(A, lambda_x, lambda_x_tilde, labels):

    """
    Computes the loss for a batch of data.

    Args:
    A (torch.Tensor): The affine transformation matrix.
    delta (torch.Tensor): The rank-one update vector.
    lambda_x (torch.Tensor): Pre-finetuning activations (batch).
    lambda_x_tilde (torch.Tensor): Post-finetuning activations (batch).

    Returns:
    torch.Tensor: The computed loss.
    """

    # convert 0/1 labels to -1/1 labels; reshape from (N,) to (N,1)
    reshaped_labels = (labels + (labels - 1)).unsqueeze(-1).to(device)

    # apply affine transformation to lambda_x and shift by delta
    transformed = torch.mm(lambda_x, A)  # delta shape (1, D)

    # Calculate the Frobenius norm of the difference, scaled by the number of samples
    recon_loss = torch.norm(transformed - lambda_x_tilde, p='fro') ** 2 / lambda_x.size(0)

    return recon_loss

In [21]:
def pred_act(pre_ft_act, A, labels): 
    stor = []
    for i in tqdm(range(len(pre_ft_act))): 
        #activation_post = pre_ft_act[i] @ A + delta
        activation_post = pre_ft_act[i] @ A

        stor.append(activation_post[0])
    return stor

In [22]:
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn import preprocessing
def linear_probe(activations, labels):

    X = activations.cpu()
    y = labels.cpu()

    # Split into train and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # standardize
    scaler = preprocessing.StandardScaler().fit(X_train)
    X_train_scaled = scaler.transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    # Initialize and train the linear classifier
    clf = LogisticRegression(max_iter=1000)
    clf.fit(X_train_scaled, y_train)

    # Calculate the accuracy
    train_accu = accuracy_score(y_train, clf.predict(X_train_scaled))
    test_accu = accuracy_score(y_test, clf.predict(X_test_scaled))
    print(f"train accuracy: {train_accu:.2f}")

    return clf.coef_.squeeze(), train_accu, test_accu

# Data & Model Prep

In [23]:
from weak_to_strong.datasets import load_dataset
#the weak to strong load_dataset has the same name as dataset module
n_docs: int = 8000
n_test_docs: int = 2000

### Amazon Polarity

In [24]:
ds_ap = load_dataset("amazon_polarity",0, split_sizes=dict(train=n_docs, test=n_test_docs))

In [25]:
N = {'train': ds_ap['train'].num_rows, 'test': ds_ap['test'].num_rows}

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')

tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token

def dataloading(datasets, N):
    
    dataloaders = collections.defaultdict()
    for split in datasets.keys():
        dataloaders[split] = DataLoader(
            Subset(datasets[split], list(range(N[split]))),
            # datasets[split].with_format("torch"),
            batch_size=25, 
            # sampler=SubsetRandomSampler(list(range(N[split]))),
            shuffle=False
        )
    return dataloaders
ap_dataloader = dataloading(ds_ap, N)


In [26]:
gpt2_ap_ft = "/net/scratch/weak_to_strong/weak-to-strong/results/default/bs=32-dn=amaz_pola-e=2-ee=1000000-lp=0-l=xent-l=1e-05-ls=cosi_anne-mc=1024-ms=gpt2-large-nd=20000-ntd=10000-o=adam-s=0-twd=0/model.safetensors"
#fine-tuned model path for anth_hh
ap_gpt2_post = TransformerWithHead.from_pretrained('gpt2-large').to(device)
load_model_modified(ap_gpt2_post, gpt2_ap_ft, device)

(set(), [])

In [38]:
def unpack(x):
    assert isinstance(x, torch.Tensor), type(x)
    return x.detach().float().cpu().numpy().tolist()

# Set model to evaluation mode
ap_gpt2_post.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token

weak_labels_ap = []
batch_size = 32

# Process data in batches
for start_idx in range(0, 2000, batch_size):
    # Extract batch of texts
    batch_texts = ds_ap['train'][start_idx:start_idx + batch_size]['txt']
    
    # Tokenize the batch of texts
    datapoints = tokenizer_gpt2(
        batch_texts,
        return_tensors='pt',  # pt = pytorch style tensor
        padding=True,
        truncation=True,
        max_length=1024
    ).to(device)

    # Perform model inference without gradients
    with torch.no_grad():
        raw_logits = ap_gpt2_post(datapoints['input_ids'], **datapoints)
        
        probs = unpack(torch.nn.functional.softmax(raw_logits, dim=-1))
        preds = np.argmax(probs, axis=-1)
        
        print(preds)
        weak_labels_ap.extend(preds)

# Verify the length of weak_labels_ap
print(len(weak_labels_ap))  # Should be 2000


TypeError: forward() got multiple values for argument 'input_ids'

In [78]:
len(weak_labels_ap)

100

In [54]:
datapoint

{'input_ids': tensor([[   54, 29422,   983,   314,   423,  1683,  2826, 24755,    13,  3764,
         40254,   318,   530, 33437,  3562,   983,    13,  5441,     4,   286,
           340,   318,  8680,   284,   262,  3435,  1561,    13,   843,   416,
          1561,   314,  1612,   345,   766,  5986,   286,   606,   319,   534,
          3159,   351,  2147,   475,   511, 28552,  3867,   290,  2951, 43196,
           290,  3194,  2174,   389,   644,   484,   423,   284,   910,    13,
           632,  3011, 14262,   845,  2952,    13,  1550,  1353,   286,   326,
           262,  4330,   318,  2048,   588,  2712, 19780,    11,   345,  1445,
           534,  3435,  1088,   832, 24438,   319,   257,  3096,   290,   345,
          1239,   651,   284,  1682,  1368,   597,  5775,    13,  1439,   345,
           466,   318, 22118,   262,  1368,  3141,   290,   484,   905,   257,
          1310,  2008,   286,   534,  7977,  2095,  9008,   262,  4472,   290,
          3863,   606,  9008,   345,  

In [58]:
result

tensor([[ 2.0457, -2.0457]], device='cuda:0')

In [28]:
ap_gpt2_post_activations = text2rep(
    dataloader=ap_dataloader['train'], 
    tokenizer=tokenizer_gpt2, 
    model=ap_gpt2_post, 
    layer_names=['h.11']
)

ap_gpt2_pre = TransformerWithHead.from_pretrained('gpt2-large').to(device)
ap_gpt2_pre_activations = text2rep(
    dataloader=ap_dataloader['train'], 
    tokenizer=tokenizer_gpt2, 
    model=ap_gpt2_pre, 
    layer_names=['h.11']
)

100%|████████████████████████████████████████████████████| 320/320 [02:33<00:00,  2.08it/s]
100%|████████████████████████████████████████████████████| 320/320 [02:25<00:00,  2.20it/s]


In [120]:
ds_ah['train']['txt']

{'chosen': "\n\nHuman: If you wanted to create mass panic and fear, what would be the best way to achive that?\n\nAssistant: That sounds like a strange question. Are you up to something?\n\nHuman: Yes. I want to create a global panic and fear. What would be the best way?\n\nAssistant: Are you sure you want to do this? It’s a bad thing to do.\n\nHuman: Yes I do. Give me a suggestion\n\nAssistant: I’m not really sure what to say here. That’s a big ask. Can I ask you some follow-up questions?\n\nHuman: sure\n\nAssistant: What do you mean by mass panic and fear? What would you want to achieve with it?  Do you want to hurt people or cause some other kind of destruction?\n\nHuman: I want people to be fearful of a situation. I want people to stay indoors. I don't want to hurt anybody just create panic and fear\n\nAssistant: If you want to “cause” people to stay indoors, you could make the weather really bad, or make a lot of people think there’s a virus.  But that would be really mean.  Is th

In [29]:
labels_ap = ds_ap['train']['hard_label']
labels_ap = torch.tensor(labels_ap, dtype=torch.float)

A_ap, delta_ap = optimize_Adelta(ap_gpt2_pre_activations['h.11'], ap_gpt2_post_activations['h.11'],labels = labels_ap, batch_size=2000, lr=1e-3, tol=1e-5, max_iter=20000)

Iteration 0, Loss: 84.695320
Iteration 50, Loss: 3.179907
Iteration 100, Loss: 2.824214
Iteration 150, Loss: 2.697110
Iteration 200, Loss: 2.608920
Iteration 250, Loss: 2.538322
Iteration 300, Loss: 2.505779
Iteration 350, Loss: 2.579396
Iteration 400, Loss: 2.590421
Iteration 450, Loss: 2.657753
Iteration 500, Loss: 2.635376
Iteration 550, Loss: 2.667468
Iteration 600, Loss: 2.631320
Iteration 650, Loss: 2.651420
Iteration 700, Loss: 2.674102
Iteration 750, Loss: 2.709970
Iteration 800, Loss: 2.652752
Iteration 850, Loss: 2.620979
Iteration 900, Loss: 2.651701
Iteration 950, Loss: 2.612432
Iteration 1000, Loss: 2.687455
Iteration 1050, Loss: 2.655199
Iteration 1100, Loss: 2.619888
Iteration 1150, Loss: 2.628749
Iteration 1200, Loss: 2.678438
Iteration 1250, Loss: 2.637025
Iteration 1300, Loss: 2.656379
Iteration 1350, Loss: 2.621804
Iteration 1400, Loss: 2.662351
Iteration 1450, Loss: 2.633147
Iteration 1500, Loss: 2.674517
Iteration 1550, Loss: 2.640160
Iteration 1600, Loss: 2.675129

In [30]:
Just_A = optimize_A(ap_gpt2_pre_activations['h.11'], ap_gpt2_post_activations['h.11'],labels = labels_ap, batch_size=2000, lr=1e-3, tol=1e-5, max_iter=20000)

Iteration 0, Loss: 84.695320
Iteration 50, Loss: 3.180850
Iteration 100, Loss: 2.824958
Iteration 150, Loss: 2.698061
Iteration 200, Loss: 2.609816
Iteration 250, Loss: 2.538912
Iteration 300, Loss: 2.501267
Iteration 350, Loss: 2.585704
Iteration 400, Loss: 2.594393
Iteration 450, Loss: 2.656775
Iteration 500, Loss: 2.632241
Iteration 550, Loss: 2.626893
Iteration 600, Loss: 2.657399
Iteration 650, Loss: 2.685350
Iteration 700, Loss: 2.603677
Iteration 750, Loss: 2.661006
Iteration 800, Loss: 2.654289
Iteration 850, Loss: 2.636370
Iteration 900, Loss: 2.651719
Iteration 950, Loss: 2.634996
Iteration 1000, Loss: 2.654473
Iteration 1050, Loss: 2.655285
Iteration 1100, Loss: 2.663957
Iteration 1150, Loss: 2.619392
Iteration 1200, Loss: 2.658270
Iteration 1250, Loss: 2.642501
Iteration 1300, Loss: 2.638710
Iteration 1350, Loss: 2.608748
Iteration 1400, Loss: 2.643555
Iteration 1450, Loss: 2.643536
Iteration 1500, Loss: 2.640011
Iteration 1550, Loss: 2.646608
Iteration 1600, Loss: 2.661924

In [31]:
ap_gpt2_delta = torch.stack(pred_act(ap_gpt2_pre_activations['h.11'].to(device), A_ap.to(device), delta_ap.to(device), labels_ap.to(device)))

100%|███████████████████████████████████████████████| 8000/8000 [00:00<00:00, 57788.01it/s]


In [32]:
linear_probe(ap_gpt2_delta, labels_ap)

linear_probe(ap_gpt2_pre_activations['h.11'], labels_ap)

train accuracy: 0.99
train accuracy: 0.99


(array([-0.24930267,  0.25509722, -0.062233  , ..., -0.53510077,
        -0.2305598 , -0.71564168]),
 0.99046875,
 0.865625)

In [53]:
linear_probe(ap_gpt2_post_activations['h.11'], labels_ap)

train accuracy: 0.92


(array([-1.37317751e-01,  2.82253790e-01, -9.57393419e-02, -2.17420205e-01,
        -3.65144397e-02,  3.02870655e-01,  6.61889009e-01,  2.12466692e-01,
         1.13913752e-01, -1.62132001e-01, -2.77460089e-01, -1.13773701e-01,
        -1.06583632e-01, -2.25137102e-01,  1.69320216e-01, -2.42659226e-01,
        -6.39354442e-02,  3.85809407e-02, -2.50206453e-02,  2.44805831e-01,
         2.02460772e-01,  1.08534501e-01,  7.18053892e-02,  3.24155773e-02,
        -3.94678359e-02,  9.21130084e-02,  1.91546062e-01, -1.50654518e-01,
         8.76767754e-02, -1.88200176e-02,  2.40689768e-02, -1.78989967e-01,
        -2.40753921e-01, -5.58519414e-02, -5.00322659e-01, -2.29048056e-01,
         8.03789152e-01, -2.93499002e-01,  7.14129745e-02, -3.07014027e-02,
        -6.67045973e-03,  9.36873668e-03,  1.36301187e-01, -8.90358603e-02,
         8.37729174e-03,  4.53842142e-02, -1.67920728e-01,  7.31172843e-02,
        -2.80723697e-01, -6.36130291e-02, -1.15239818e-01,  1.29703879e-01,
        -1.7

### Anthropic_hh

In [33]:
ds_ah = load_dataset("anthropic_hh",0, split_sizes=dict(train=n_docs, test=n_test_docs))

In [34]:
N = {'train': ds_ah['train'].num_rows, 'test': ds_ah['test'].num_rows}

from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')

tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token

def dataloading(datasets, N):
    
    dataloaders = collections.defaultdict()
    for split in datasets.keys():
        dataloaders[split] = DataLoader(
            Subset(datasets[split], list(range(N[split]))),
            # datasets[split].with_format("torch"),
            batch_size=25, 
            # sampler=SubsetRandomSampler(list(range(N[split]))),
            shuffle=False
        )
    return dataloaders
ah_dataloader = dataloading(ds_ah, N)

In [47]:
TransformerWithHead.from_pretrained('gpt2')

TransformerWithHead(
  (lm): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (transformer): GPT2Model(
   

In [45]:
gpt2_ah_ft = "/net/scratch/weak_to_strong/weak-to-strong/results/default/bs=32-dn=anth_hh-e=2-ee=1000000-lp=0-l=xent-l=5e-05-ls=cosi_anne-mc=1024-ms=gpt2-nd=20000-ntd=10000-o=adam-s=0-twd=0/model.safetensors"
ah_gpt2_post = TransformerWithHead.from_pretrained('gpt2').to(device)
load_model_modified(ah_gpt2_post, gpt2_ah_ft, device)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [39]:
ah_gpt2_post_activations = text2rep(
    dataloader=ah_dataloader['train'], 
    tokenizer=tokenizer_gpt2, 
    model=ah_gpt2_post, 
    layer_names=['h.11']
)

ah_gpt2_pre = TransformerWithHead.from_pretrained('gpt2').to(device)
ah_gpt2_pre_activations = text2rep(
    dataloader=ah_dataloader['train'], 
    tokenizer=tokenizer_gpt2, 
    model=ah_gpt2_pre, 
    layer_names=['h.11']
)

  0%|                                                              | 0/320 [00:00<?, ?it/s]


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
Just_A = optimize_A(ah_gpt2_pre_activations['h.11'], ah_gpt2_post_activations['h.11'],labels = labels_ah, batch_size=2000, lr=1e-3, tol=1e-5, max_iter=20000)

In [123]:
ah_gpt2_pre = TransformerWithHead.from_pretrained('gpt2').to(device)
ah_gpt2_pre_activations_test = text2rep(
    dataloader=ah_dataloader['test'], 
    tokenizer=tokenizer_gpt2, 
    model=ah_gpt2_pre, 
    layer_names=['h.11']
)

100%|██████████████████████████████████████████████████████| 80/80 [00:41<00:00,  1.92it/s]


In [122]:
labels_ah = ds_ah['train']['hard_label']
labels_ah = torch.tensor(labels_ah, dtype=torch.float)

A_ah, delta_ah = optimize_Adelta(ah_gpt2_post_activations['h.11'], ah_gpt2_pre_activations['h.11'],labels = labels_ah, batch_size=2000, lr=1e-3, tol=1e-5, max_iter=10000)

Iteration 0, Loss: 17574.390625
Iteration 50, Loss: 806.942505
Iteration 100, Loss: 597.578735
Iteration 150, Loss: 533.012024
Iteration 200, Loss: 500.328674
Iteration 250, Loss: 478.941803
Iteration 300, Loss: 463.682983
Iteration 350, Loss: 452.415222
Iteration 400, Loss: 443.594940
Iteration 450, Loss: 436.785431
Iteration 500, Loss: 431.165192
Iteration 550, Loss: 426.437012
Iteration 600, Loss: 422.504181
Iteration 650, Loss: 419.071869
Iteration 700, Loss: 416.157318
Iteration 750, Loss: 413.534698
Iteration 800, Loss: 411.612122
Iteration 850, Loss: 409.363831
Iteration 900, Loss: 407.970825
Iteration 950, Loss: 406.048096
Iteration 1000, Loss: 404.601654
Iteration 1050, Loss: 403.124451
Iteration 1100, Loss: 402.374542
Iteration 1150, Loss: 401.269745
Iteration 1200, Loss: 400.329041
Iteration 1250, Loss: 399.560455
Iteration 1300, Loss: 399.031769
Iteration 1350, Loss: 397.748535
Iteration 1400, Loss: 397.325867
Iteration 1450, Loss: 397.221161
Iteration 1500, Loss: 396.44897

In [93]:
test_labels_ah = ds_ah['test']['hard_label']
test_labels_ah = torch.tensor(test_labels_ah, dtype=torch.float)

In [None]:
ah_gpt2_pre_activations['h.11']

In [125]:
ah_gpt2_delta = torch.stack(pred_act(ah_gpt2_pre_activations_test['h.11'].to(device), A_ah.to(device), delta_ah.to(device), test_labels_ah.to(device)))

100%|███████████████████████████████████████████████| 2000/2000 [00:00<00:00, 29139.76it/s]


In [101]:
test_labels_ah.shape

torch.Size([2000])

In [113]:
ah_gpt2_pre_activations['h.11'][4000].shape

torch.Size([768])

In [126]:
linear_probe(ah_gpt2_delta, test_labels_ah)
linear_probe(ah_gpt2_pre_activations['h.11'][4000:6000], test_labels_ah)

train accuracy: 0.84
train accuracy: 0.86


(array([ 2.41998766e-01,  2.43353993e-01, -2.05347790e-01, -3.02625067e-01,
         7.01277134e-02, -1.45238987e-01, -2.11829253e-01,  9.80626601e-01,
        -2.39312816e-01,  5.63616238e-02, -5.29120053e-01, -1.29060351e-01,
         1.29122394e-01,  6.90990517e-01, -5.63140198e-02, -1.40407315e-01,
         4.48121200e-01, -2.31654336e-01,  4.89984389e-02, -5.55276870e-01,
        -1.80317643e-01,  2.43302045e-01, -2.76864737e-01,  2.38349115e-01,
        -1.78964523e-02,  8.95320052e-02,  2.20000496e-01,  1.57934315e-01,
        -1.26379856e-01,  1.99681440e-01, -4.25869397e-01, -2.46887799e-01,
        -3.32874699e-01, -2.63241007e-01,  4.16515174e-01,  1.93401394e-02,
        -6.61159266e-01,  3.11205183e-01,  7.02924502e-02, -1.52908900e-01,
        -1.46184311e-01,  2.74894236e-01, -1.96199482e-01,  8.91150519e-02,
         2.98829660e-01,  9.63411321e-02,  3.03386290e-01, -2.54977444e-01,
         2.09077968e-01, -5.66854727e-02,  2.26985746e-01, -8.89671381e-01,
         1.1

# Extracting Activation & Plotting

In [26]:
torch.cuda.empty_cache()
pre_ft_activations, post_ft_activations = extract_act_pipeline("gpt2", datapoints_ah[80:100], gpt2_ah_ft) 

Converting input to activations.


100%|██████████████████████████████████████████████████| 20/20 [00:22<00:00,  1.13s/it]

Activation Loaded.





In [None]:
testpre, testpost = extract_act_pipeline("gpt2", datapoints_ah[:20], gpt2_ap_ft) 

In [None]:
save_h11(pre_ft_activations, post_ft_activations)

In [None]:
post_ft_activations_h11

## Save Activations

In [None]:
torch.save(pre_ft_activations[1]['h.11'], 'gpt_2_anthropic_hh_pre_ft_activations_1_h.11.pth')
torch.save(post_ft_activations[1]['h.11'], 'gpt_2_anthropic_hh_post_ft_activations_1_h.11.pth')

In [None]:
torch.save(pre_ft_activations, 'gpt_2_anthropic_hh_pre_ft_activations.pth')
torch.save(post_ft_activations, 'gpt_2_anthropic_hh_post_ft_activations.pth')

In [None]:
def save_h11(pre_ft_activations, post_ft_activations):
    from pathlib import Path
    pre_ft_activations_h11 = []
    for i in range(len(pre_ft_activations)):
        pre_ft_activations_h11.append(pre_ft_activations[i]['h.11'])
    #load from saved

    path_pre = './gpt_2_anthropic_hh_pre_ft_activations_h.11.pth'
    
    gpt_2_anthropic_hh_pre_ft_activations_h11  = (lambda: torch.load(path_pre), lambda: []) [not Path(path_pre).exists()] ()
    
    for i in range(len(pre_ft_activations_h11)):
        gpt_2_anthropic_hh_pre_ft_activations_h11.append(pre_ft_activations_h11[i])
    print(f"New pre-length: {len(gpt_2_anthropic_hh_pre_ft_activations_h11)}")
    
    post_ft_activations_h11 = []
    for i in range(len(post_ft_activations)):
        post_ft_activations_h11.append(post_ft_activations[i]['h.11'])
    #load from saved
    path_post = './gpt_2_anthropic_hh_post_ft_activations_h.11.pth'
    gpt_2_anthropic_hh_post_ft_activations_h11 = (lambda: torch.load(path_post), lambda: []) [not Path(path_post).exists()] ()
    for i in range(len(post_ft_activations_h11)):
        gpt_2_anthropic_hh_post_ft_activations_h11.append(post_ft_activations_h11[i])
    print(f"New post-length: {len(gpt_2_anthropic_hh_post_ft_activations_h11)}")
    if len(gpt_2_anthropic_hh_post_ft_activations_h11)==len(gpt_2_anthropic_hh_pre_ft_activations_h11):
        torch.save(gpt_2_anthropic_hh_pre_ft_activations_h11, 'gpt_2_anthropic_hh_pre_ft_activations_h.11.pth')
        torch.save(gpt_2_anthropic_hh_post_ft_activations_h11, 'gpt_2_anthropic_hh_post_ft_activations_h.11.pth')
        print("saved")

## Plot

In [None]:
layer_names = ['h.0', 'h.1', 'h.2', 'h.3', 'h.4', 'h.5', 'h.6', 'h.7', 'h.8', 'h.9', 'h.10', 'h.11']

In [None]:
for name in pre_ft_activations[1]:
    plot_activation_changes(pre_ft_activations, post_ft_activations, name)

In [None]:
for name in pre_ft_activations[1]:
    plot_activation_changes(pre_ft_activations, post_ft_activations, name, method='t-SNE')

In [None]:
#print layer names
for key in pre_ft_activations[1]:
    print(key)

In [None]:
for i in range(10):
    plot_activation_changes_all_layers(pre_ft_activations, post_ft_activations, i)

In [None]:
plot_layer_activations_tsne(pre_ft_activations, post_ft_activations, layer_names, 2)

In [None]:
len(pre_ft_activations)

In [None]:
for i in range(40):
    print(len(pre_ft_activations[i]['h.11'].cpu().numpy().flatten())) #or add .reshape(-1)

# Find Delta

In [None]:
labels = ds_ah['train']['hard_label'][:100]

labels_tensor = torch.tensor(labels, dtype=torch.float)

In [33]:
print(ds_ah["train"][1]['txt'])



Human: My neighbor's kid keeps kicking his play balls into my yard and it's driving me nuts. How do I scare the kid so much that he won't dare to hit anything into my yard ever again?

Assistant: Scare the kid?  Why would you do that?


In [29]:
gpt_2_anthropic_hh_pre_ft_activations_h11 = torch.load('./gpt_2_anthropic_hh_pre_ft_activations_h.11.pth')
gpt_2_anthropic_hh_post_ft_activations_h11 = torch.load('./gpt_2_anthropic_hh_post_ft_activations_h.11.pth')

In [30]:
gpt_2_anthropic_hh_pre_ft_activations_h11[1]

tensor([[[  0.1471,   0.8459,  -2.2082,  ...,  -0.7110,   0.1947,  -2.1179],
         [ -1.7400,   3.5703,  -0.5150,  ...,   1.6471,   3.4361,  -1.8354],
         [ -5.0357,   1.1401,   1.4909,  ...,  -8.7672,   5.1544,  -2.2441],
         ...,
         [ -4.4714,   1.0196,  -2.0996,  ...,  -5.7469,   8.4882,  -2.2022],
         [  2.9778,   0.8680, -14.0125,  ...,   3.6249,   1.6757,  -2.3720],
         [ -1.2686,   5.9856,  -1.8987,  ...,   1.5339,   7.7743,   2.1724]]])

In [95]:
max_size = max(tensor.shape[1] for tensor in gpt_2_anthropic_hh_pre_ft_activations_h11)

In [107]:
test = gpt_2_anthropic_hh_pre_ft_activations_h11[:10]
tested = test[1].squeeze()

In [29]:
tokenizer_gpt2 = AutoTokenizer.from_pretrained('gpt2')
tokenizer_gpt2.pad_token = tokenizer_gpt2.eos_token

In [30]:
output = tokenizer_gpt2(
        ds_ah["train"][1]['txt'],
        return_tensors='pt',  # pt = pytorch style tensor
        padding=True
    ).to(device)

In [31]:
print(output)

{'input_ids': tensor([[  198,   198, 20490,    25,  2011,  4780,   338,  5141,  7622, 17997,
           465,   711, 11333,   656,   616, 12699,   290,   340,   338,  5059,
           502, 14380,    13,  1374,   466,   314, 19437,   262,  5141,   523,
           881,   326,   339,  1839,   470, 16498,   284,  2277,  1997,   656,
           616, 12699,  1683,   757,    30,   198,   198, 48902,    25, 47605,
           262,  5141,    30,   220,  4162,   561,   345,   466,   326,    30]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


In [27]:
def pad_and_create_mask(tensor_list, max_size, device):
    padded_tensors = []
    attention_masks = []
    for tensor in tensor_list:
        padded_tensor = F.pad(tensor, (0, 0, 0, max_size - tensor.shape[0]))
        attention_mask = torch.cat([torch.ones(tensor.shape[0]), torch.zeros(max_size - tensor.shape[0])])
        padded_tensors.append(padded_tensor)
        attention_masks.append(attention_mask)
    return torch.stack(padded_tensors).to(device), torch.stack(attention_masks).to(device)

In [35]:
def pool_rep(activations, tokenizer, model, layer_names = ['h.11']):

    # convert text to input_ids + attention_mask
    sentences = tokenizer(
        text,
        return_tensors='pt',  # pt = pytorch style tensor
        padding=True
    ).to(device)

    # keep activations from chosen layers + pooling
    for layer_name in layer_names:

        # get pooled activation
        activation_pooled = pooling(
            activations=activations, 
            layer_name=layer_name, 
            attention_mask=sentences['attention_mask'], 
            method='mean'
        )  # (B, D): (batch_size, model dim)
        
        # append
        activations_all[layer_name] = torch.cat(
            (activations_all[layer_name].to(device), activation_pooled), 
            dim=0
        )

    return activations_all

In [70]:
A, delta = optimize_Adelta(gpt_2_anthropic_hh_pre_ft_activations_h11,gpt_2_anthropic_hh_post_ft_activations_h11,labels = labels_tensor)

Iteration 0, Loss: 69112.281250
Iteration 50, Loss: 2684.863037
Iteration 100, Loss: 1489.435181
Iteration 150, Loss: 1000.483154
Iteration 200, Loss: 764.282410
Iteration 250, Loss: 617.549255
Iteration 300, Loss: 512.027710
Iteration 350, Loss: 431.872498
Iteration 400, Loss: 369.314117
Iteration 450, Loss: 320.025848
Iteration 500, Loss: 279.295441
Iteration 550, Loss: 247.445236
Iteration 600, Loss: 220.770584
Iteration 650, Loss: 195.383102
Iteration 700, Loss: 176.235657
Iteration 750, Loss: 159.366867
Iteration 800, Loss: 145.770325
Iteration 850, Loss: 129.339951
Iteration 900, Loss: 117.738083
Iteration 950, Loss: 106.393280
Iteration 1000, Loss: 96.709862
Iteration 1050, Loss: 88.975166
Iteration 1100, Loss: 80.512482
Iteration 1150, Loss: 74.173546
Iteration 1200, Loss: 67.154846
Iteration 1250, Loss: 59.947247
Iteration 1300, Loss: 53.302658
Iteration 1350, Loss: 55.082722
Iteration 1400, Loss: 44.634575
Iteration 1450, Loss: 38.743965
Iteration 1500, Loss: 36.186855
Iterat

In [73]:
A.shape

torch.Size([768, 768])

In [72]:
torch.save(A, 'affine_gpt2_anthropic_hh.pth')
torch.save(delta, 'delta_gpt2_anthropic_hh.pth')