In [1]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from sae import SaeConfig, SaeLayerRangeTrainer, TrainConfig
from sae.data import chunk_and_tokenize

%load_ext autoreload
%autoreload 2

In [2]:
MODEL = "EleutherAI/pythia-160m"
dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    trust_remote_code=True,
).select(range(16))
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized = chunk_and_tokenize(dataset, tokenizer, max_seq_len=128)

gpt = AutoModelForCausalLM.from_pretrained(
    MODEL,
    device_map={"": "cuda"},
    torch_dtype=torch.bfloat16,
)

cfg = TrainConfig(
    SaeConfig(expansion_factor=16), batch_size=32
)
trainer = SaeLayerRangeTrainer(cfg, tokenized, gpt)

trainer.fit()

num_proc must be <= 16. Reducing num_proc to 16 for dataset of size 16.


Map (num_proc=16):   0%|          | 0/16 [00:00<?, ? examples/s]

Training on modules: [('gpt_neox.layers.0', 'gpt_neox.layers.1', 'gpt_neox.layers.2', 'gpt_neox.layers.3', 'gpt_neox.layers.4', 'gpt_neox.layers.5', 'gpt_neox.layers.6', 'gpt_neox.layers.7', 'gpt_neox.layers.8', 'gpt_neox.layers.9', 'gpt_neox.layers.10', 'gpt_neox.layers.11')]
Learning rate: 6.67e-05
bitsandbytes 8-bit Adam not available, using torch.optim.Adam
Run `pip install bitsandbytes` for less memory usage.
Number of SAE parameters: 2_718_065_664
Number of model parameters: 162_322_944


Training:   0%|          | 0/60 [00:00<?, ?it/s]

Saving checkpoint


In [2]:
def distributed_matmul(A, B, num_devices):
    _, S, HL = A.shape
    L = HL // H  # Assuming H is known
    assert L % num_devices == 0, "L must be divisible by num_devices"

    # Split A and B
    A_split = torch.split(A, HL // num_devices, dim=2)
    B_split = torch.split(B, HL // num_devices, dim=0)
    
    print("asplit", [x.shape for x in A_split])
    print("bsplit", [x.shape for x in B_split])

    # Simulate distribution to devices and perform partial matmuls
    partial_results = []
    for i in range(num_devices):
        # Simulate sending to a different device
        A_part = A_split[i]
        B_part = B_split[i]

        partial_result = torch.matmul(A_part, B_part)
        partial_results.append(partial_result)  # Move back to CPU for summation

    # Sum up the partial results
    final_result = torch.sum(torch.stack(partial_results), dim=0)

    return final_result


# Example usage and verification
B, S, H, L, K = 32, 128, 64, 256, 4
num_devices = 8

# Create random matrices
A = torch.rand(B, S, H * L).cuda()
B = torch.rand(H * L, H * L * K).cuda()

print(A.shape, B.shape)

# Perform distributed matrix multiplication
distributed_result = distributed_matmul(A, B, num_devices)

# Perform naive matrix multiplication
naive_result = torch.matmul(A, B)
print(naive_result.shape)

# Assert that the results are the same (within numerical precision)
assert torch.allclose(
    distributed_result, naive_result, rtol=1e-5, atol=1e-8
), "Results do not match!"

print("Shape of the result:", distributed_result.shape)
print("Distributed and naive results match!")

torch.Size([32, 128, 16384]) torch.Size([16384, 65536])
asplit [torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048]), torch.Size([32, 128, 2048])]
bsplit [torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536]), torch.Size([2048, 65536])]
torch.Size([32, 128, 65536])
Shape of the result: torch.Size([32, 128, 65536])
Distributed and naive results match!
