In [2]:
import tools
from transformers import (
    BertTokenizer,
    BertModel,
    GPT2Tokenizer,
    GPT2Model,
    LlamaTokenizer,
    LlamaModel,
)
import polars as pl


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
print(f"Bert:{model}")

Bert:BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=Fals

In [4]:
model = GPT2Model.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
print(f"GPT2:{model}")

GPT2:GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)


In [5]:
model = LlamaModel.from_pretrained("meta-llama/Llama-2-7b-hf")

print(model)

Loading checkpoint shards: 100%|██████████| 2/2 [00:30<00:00, 15.10s/it]

LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)





# the distribution methods
1. distributed by batch(lock step of block)
    1. no copy, in single block
    1. need n copy of weights in every step.(broad cast the weight)
1. distributed by batch (distributed block) ( each pe compute a block)
    1. load full weights and attention, generate 1x4096, no communication
1. distributed by projection weight
    1. distributed by rows: (broadcase input(1x4096), no communication, generate N output (1x4096/n))
    1. distributed by cols: (each fetch 1/n input). need reduction(n part of 1x4096) , genrate 1 output(1x4096)
1. distributed by attention tokens
    1. distributed by headers( accecpt distributed KQV by col and distributed input, generate distributed V (1x4096 / n ), no communication)
    1. distributed by tokens ( accept distibuted KQV by token and broadcast input, need reduction, generate 1 output)
1. distributed by ffn
    1. distribted all by rows:(broadcase input(1x4096), no communication, generate N output (1x4096/n)) then reduce and broadcast ( then ) (no communication, generate N output (1x4096/n))
    1. distributed by cols: (accept distributed input, need reduction, generate 1 output), then scatter the output(1/n) then (accept distributed input, need reduction, generate 1 output)
    1. row + col: (broadcase input(1x4096), no communication, generate N output (1x4096/n)) then (accept distributed input, need reduction, generate 1 output)
    1. col + row: (accept distributed input, need reduction, generate 1 output) then (broadcase input(1x4096), no communication, generate N output (1x4096/n))

# motivation:
1. the broadcast is cheap: only one shared bus can achive
2. the scatter or gether is expensive: if each pe need different data, the bus need N iterations to send data
3. the reduction is cheap: only need local bus to do tree-like reduction: O(logn) time and no extra connection
4. so, need to avoid scatter. only broadcast and reduction.
# workflow:
1. accept  broadcast  input -> project distributed by rows -> attention distributed by headers(or sub header) -> ffn col + row,  + generate distributed input
    - no scatter. 1 reduction in col + row, 1 broadcast in col+ row
2. accept distributed input -> project by cols -> distributed by tokens -> row + col -> generate full output
    - no scatter, reduction + prodcast in projection, reduction+broadcast in att, not in FFN, generate distributed output. 

# scalability

# situations
1. multiple batched input: 
    - best fr 2.1 to distribute the tasks into multiple PE
    - or use 1.1 2.1, stack the same step into same PE
2. single task, short sequence.
    - should distribute the weights or headers 
3. single task, long sequence.
    - should distribute the tokens 
4. combined:
    1. multiple batch, short seqence.
    2. multiple batch, long seqence.

# challenges
1. need to scales
2. need to fast broadcast and reduction
3. need to have mode transform

# solution
1. a good connection design and route mechanism.