### MechInterp

n-dimensional activation space => n linearly independent basis vector => n non-polysemanticity neurons



In [None]:
probs @ W_V @ W_O

In [1]:
corrupted_prompts = [
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the part, James gave the ball to"
]

In [None]:
correct_tokens = model.to_tokens("Mary Tom", prepend_bos=True)
incorrect_tokens = model.to_tokens("John James", prepend_bos=True)

In [None]:
clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompts)

In [None]:
clean_logits, _ = model.run_with_cache(clean_tokens)
corrupted_logits, _ = model.run_with_cache(corrupted_tokens)

In [2]:
def compute_average_logit_difference(logits, correct_tokens, incorrect_tokens):
    final_logits = logits[:, -1, :]
    correct_logits = final_logits[:, correct_tokens]
    incorrect_logits = final_logits[:, incorrect_tokens] 
    return (correct_logits - incorrect_logits).mean()

In [None]:
clean_logit_difference = compute_average_logit_difference(clean_logits, correct_tokens, incorrect_tokens)
corrupted_logit_difference = compute_average_logit_difference(corrupted_logits, correct_tokens, incorrect_tokens)

In [None]:
def compute_ioi_metric(logits):
    patched_logit_difference = compute_average_logit_difference(logits, correct_tokens, incorrect_tokens)
    return (patched_logit_difference - corrupted_logit_difference) / (clean_logit_difference - corrupted_logit_difference)

step 1: start with a diverse batch of data
step 2: record the activations of the target head
step 3: extract the attention pattern between the target query position with all other positions
step 4: take the average attention pattern across batch
step 5: plot

In [None]:
W_pos = model.W_pos
W_Q = model.W_Q[layer_idx, head_idx]
W_K = model.W_K[layer_idx, head_idx]

In [None]:
pos_by_pos_scores = W_pos @ W_Q @ W_K @ W_pos.T

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

In [4]:
def mask_scores(scores):
    masks = torch.triu(torch.ones_like(scores)).bool()
    return torch.where(masks, scores, -1e9)

In [None]:
d_head = model.cfg.d_head

In [None]:
masked = mask_scores(scores / (d_head))

In [None]:
pos_by_pos_pattern = F.softmax(masked, dim=-1)

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
embed = cache["embed"]
pos_embed = cache["pos_embed"]
head_outputs = cache["result", prev_layer_idx]

In [None]:
input_components = torch.cat([
    embed,
    pos_embed,
    head_outputs
])

In [None]:
W_Q = model.W_Q[layer_idx, head_idx]

In [6]:
from einops import einsum

In [None]:
query_components = einsum(
    input_components,
    W_Q,
    ""
)

In [None]:
query_contributions = query_components.pow(2).sum()

In [None]:
tokens = model.to_tokens(text)

In [None]:
logits = model.run_with_hooks(tokens)

In [None]:
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)

In [None]:
target_tokens = tokens[1:]

In [None]:
predicted_log_probs = log_probs[:, target_tokens]

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
layer_idx = 2
head_idx = 0

In [None]:
attn_pattern = cache["pattern", layer_idx][:, head_idx][:, target_query_positions]

In [None]:
average_attention_pattern = F.softmax(attn_pattern).mean(dim=-1)

step 1: extract the feature "this cell contains my piece" from black's perspective in cell F1
step 2: board history before the target move
step 3: feed
step 4: intervene the residual stream
step 5: continue the 

In [None]:
input_embedding > weights > features/neurons > activations

In [None]:
step 1: logit lens across residual stream
step 2: decompose residual stream, localize attention layer
step 3: decompose the attention layer

In [None]:
W_E = model.W_E
W_Q = model.W_Q[1, 4]
W_K = model.W_K[1, 4]
W_O = model.W_O[0, 7]
W_V = model.W_V[0, 7]

In [None]:
Q = W_E @ W_Q
K = W_E @ W_V @ W_O @ W_K

In [7]:
from transformer_lens import FactoredMatrix

In [None]:
full_circuit = FactoredMatrix(Q, K.T )

In [None]:
step 1: init a global distributed group
step 2: initialize parallel groups
step 4: set device
step 5: seed

In [8]:
from copy import deepcopy

In [None]:
class ModelStateHandler:
    def __init__(self, model):
        self.value = model
        self.commit()
    
    def set_value(self, value):
        self.value = value
    
    def commit(self):
        self._saved_state_dict = deepcopy(self.value.state_dict())
    
    def restore(self):
        self.value.load_state_dict(self._saved_state_dict)

In [9]:
import torch.distributed as dist

In [None]:
class TensorParallelGroupInitializer(ProcessGroupInitializer):
    def init_process_group(self):
        self.num_tensor_parallel_groups = self.world_size // self.tensor_parallel_size
        local_rank = None
        local_world_size = None
        process_group = None
        parallel_mode = ParallelMode.TENSOR
        
        for i in range(self.tensor_parallel_size):
            start_rank = i*self.num_tensor_parallel_groups
            end_rank = (i+1)*self.num_tensor_parallel_groups
            
            for j in range(tensor_parallel_size):
                ranks = list(range(
                    start_rank+j,
                    end_rank,
                    tensor_parallel_size
                ))
                process_group = dist.new_group(ranks=ranks)
                
                if self.rank in ranks:
                    local_rank = ranks.index(self.rank)
                    local_world_size = len(ranks)
        
        return {
            "local_rank": local_rank,
            "local_world_size": local_world_size,
            "process_group": process_group,
            "parallel_mode": parallel_mode
        }

In [None]:
step 1: scale the loss using scaling factor
step 2: calculate the gradients with respect to the scaled loss
step 3: unscale the gradients using the scaling factor
step 4: update

In [None]:
commit, restore, set_value, sync

In [None]:
for group in groups:
    dist.barrier()
    dist.destroy_process_group(group=group)

In [None]:
step 1: initialize a global distributed group
step 2: initialize parallel groups
step 

In [None]:
class VocabParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        self.num_embeddings_per_partition = world_size // num_embeddings
        
        self.vocab_start_idx, self.vocab_end_idx = self.get_vocab_range(
            self.num_embeddings_per_partition
        )
        
        self.weight = nn.Parameter(self.num_embeddings_per_partition, embedding_dim)
    
    def get_vocab_range(self, num_embeddings_per_partition):
        rank = torch.distributed.get_rank()
        start_idx = rank*num_embeddings_per_partition
        end_idx = start_idx+num_embeddings_per_partition
        return start_idx + end_idx

    def forward(self, tokens):
        masks = (tokens < self.vocab_start_idx) | (tokens > self.vocab_end_idx)
        tokens[masks] = 0.
        
        embeddings = F.embedding(tokens, self.weight, padding_idx=0)
        mask_idxs = torch.where(masks == True)[1]
        embeddings[mask_idxs] = 0.
        
        dist.all_reduce(embeddings)
        
        return embeddings

In [None]:
process-based, thread-based, vectorization, stream processing

In [None]:
global distributed group, tensor, pipeline, data

In [None]:
api server, scheduler, ectd, control manager

In [None]:
step 1: initialize a global distributed group
step 2: initialize parallel groups
step 3: set device
step 4: set seed

In [None]:
step 1: gather the weights
step 2: set pre backward hook
step 3: do forward
step 4: set post backward hook
step 5: release the irrelevant weights

step 1: determine a list of global ranks in that group
step 2: if the process's global in that group, initialize parallel group
step 3: get its local rank
step 4: 

In [None]:
weights = attn_weights.diagonal(dim1=-2, dim2=-1)

In [None]:
api server, sccheduler, ectd, control manager 

In [None]:
def compute_forward_pass_using_data_parallelism(model, input, device_ids, output_id):
    models = nn.parallel.replicate(model)
    inputs = nn.parallel.scatter(input)
    
    outputs = nn.parallel.parallel_apply(models, inputs)

In [11]:
import torch.distributed.rpc as rpc

In [None]:
rpc.get_worker_info(worker_name).id 

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps):
        super().__init__()
        self.adds = nn.Parameter(torch.zeros(features))
        self.mults = nn.Parameter(torch.ones(features))
        
        self.mean = nn.Parameter(torch.zeros(features))
        self.var = nn.Parameter(torch.ones(features))

    def forward(self, x):
         

In [None]:
weights = attn_weights.diagonal(dim1=-2, dim2=-1, offset=-1)

In [None]:
step 1: global dist group
step 2: initialize parallel groups
step 3: set device
step 4: set seed

In [None]:
training task: training distribution, function
base optimizer: sgd and archtecture