In [11]:
from generation import Llama
import torch
import gc
import os
os.environ['RANK'] = '0'  # Example rank, adjust accordingly
os.environ['WORLD_SIZE'] = '1'  # Example world size, adjust accordingly
os.environ['MASTER_ADDR'] = 'localhost'  # Example master address
os.environ['MASTER_PORT'] = '12355'  # Example master port


In [2]:

llama = Llama.build(ckpt_dir= "../llama3/Meta-Llama-3-8B-Instruct/", tokenizer_path="../llama3/Meta-Llama-3-8B-Instruct/tokenizer.model", max_seq_len= 512, max_batch_size = 4, model_parallel_size=1)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


Loaded in 10.14 seconds


In [12]:
cpu = torch.device("cpu")
cuda = torch.device("cuda")

In [4]:
llama.model.to(cuda)


Transformer(
  (tok_embeddings): VocabParallelEmbedding()
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): ColumnParallelLinear()
        (wk): ColumnParallelLinear()
        (wv): ColumnParallelLinear()
        (wo): RowParallelLinear()
      )
      (feed_forward): FeedForward(
        (w1): ColumnParallelLinear()
        (w2): RowParallelLinear()
        (w3): ColumnParallelLinear()
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): ColumnParallelLinear()
)

In [5]:
model = llama.model

In [6]:
tokenizer = llama.tokenizer

In [7]:



prompt_template = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant for travel tips and recommendations<|eot_id|>
<|start_header_id|>user<|end_header_id|>

{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def get_prompt(prompt: str):
    return torch.tensor(tokenizer.encode(prompt_template.format(user_prompt = prompt), eos= False, bos = False), dtype= torch.long).unsqueeze(0)




prompt_tokens = get_prompt("Write a poem in 1000 words")


In [8]:
print(tokenizer.decode(prompt_tokens[0].tolist()))


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant for travel tips and recommendations<|eot_id|>
<|start_header_id|>user<|end_header_id|>

Write a poem in 1000 words<|eot_id|><|start_header_id|>assistant<|end_header_id|>




In [9]:
out_tokens = llama.generate(prompt_tokens.tolist(), max_gen_len= 256, temperature = 0)

  0%|          | 0/256 [00:00<?, ?it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)


  1%|          | 3/256 [00:00<00:55,  4.55it/s]

1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
3 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

  2%|▏         | 6/256 [00:00<00:27,  9.00it/s]

9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
11 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLi

  5%|▍         | 12/256 [00:01<00:15, 15.60it/s]

5 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
6 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
7 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

  6%|▌         | 15/256 [00:01<00:13, 17.88it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

  8%|▊         | 21/256 [00:01<00:11, 20.94it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 11%|█         | 27/256 [00:01<00:10, 22.57it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 12%|█▏        | 30/256 [00:01<00:09, 22.71it/s]

30 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLi

 14%|█▍        | 36/256 [00:02<00:09, 23.02it/s]

25 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
26 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
27 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 16%|█▋        | 42/256 [00:02<00:09, 23.37it/s]

13 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
14 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
15 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 18%|█▊        | 45/256 [00:02<00:09, 23.35it/s]

6 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
7 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
8 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 20%|█▉        | 51/256 [00:02<00:08, 23.32it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 21%|██        | 54/256 [00:02<00:08, 23.26it/s]

23 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
24 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
25 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 23%|██▎       | 60/256 [00:03<00:08, 23.48it/s]

17 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
19 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 26%|██▌       | 66/256 [00:03<00:09, 20.28it/s]

10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
11 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
12 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 27%|██▋       | 69/256 [00:03<00:08, 21.09it/s]

2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
3 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
4 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 29%|██▉       | 75/256 [00:03<00:08, 22.40it/s]

31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 30%|███       | 78/256 [00:04<00:07, 22.57it/s]

24 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
25 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
26 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 33%|███▎      | 84/256 [00:04<00:07, 23.11it/s]

18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
19 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
20 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 35%|███▌      | 90/256 [00:04<00:07, 23.41it/s]

10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
11 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
12 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 36%|███▋      | 93/256 [00:04<00:06, 23.34it/s]

3 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
4 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
5 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 39%|███▊      | 99/256 [00:04<00:06, 23.57it/s]

31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 40%|███▉      | 102/256 [00:05<00:06, 23.38it/s]

24 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
25 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
26 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 42%|████▏     | 108/256 [00:05<00:06, 23.45it/s]

16 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
17 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 45%|████▍     | 114/256 [00:05<00:06, 23.56it/s]

7 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
8 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 46%|████▌     | 117/256 [00:05<00:05, 23.52it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 48%|████▊     | 123/256 [00:05<00:05, 23.62it/s]

29 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
30 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 49%|████▉     | 126/256 [00:06<00:05, 23.42it/s]

20 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
21 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
22 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 52%|█████▏    | 132/256 [00:06<00:05, 23.41it/s]

12 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
13 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
14 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 53%|█████▎    | 135/256 [00:06<00:05, 23.30it/s]

3 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
4 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
5 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 55%|█████▌    | 141/256 [00:06<00:04, 23.22it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 57%|█████▋    | 147/256 [00:07<00:04, 23.38it/s]

22 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
23 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
24 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 59%|█████▊    | 150/256 [00:07<00:04, 23.31it/s]

14 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
15 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
16 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 61%|██████    | 156/256 [00:07<00:04, 23.49it/s]

8 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 62%|██████▏   | 159/256 [00:07<00:04, 23.42it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 64%|██████▍   | 165/256 [00:07<00:03, 23.27it/s]

27 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
28 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
29 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 67%|██████▋   | 171/256 [00:08<00:03, 23.48it/s]

16 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
17 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 68%|██████▊   | 174/256 [00:08<00:03, 23.27it/s]

8 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 70%|███████   | 180/256 [00:08<00:03, 23.50it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 71%|███████▏  | 183/256 [00:08<00:03, 23.32it/s]

26 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
27 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
28 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 74%|███████▍  | 189/256 [00:08<00:02, 23.30it/s]

17 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
19 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 75%|███████▌  | 192/256 [00:08<00:02, 23.20it/s]

8 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 77%|███████▋  | 198/256 [00:09<00:02, 23.41it/s]

0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 80%|███████▉  | 204/256 [00:09<00:02, 23.56it/s]

28 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
29 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
30 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 81%|████████  | 207/256 [00:09<00:02, 23.34it/s]

20 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
21 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
22 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 83%|████████▎ | 213/256 [00:09<00:01, 23.37it/s]

10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
11 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
12 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 84%|████████▍ | 216/256 [00:09<00:01, 23.14it/s]

1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
2 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
3 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 87%|████████▋ | 222/256 [00:10<00:01, 23.26it/s]

28 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
29 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
30 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 89%|████████▉ | 228/256 [00:10<00:01, 23.44it/s]

17 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
18 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
19 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 90%|█████████ | 231/256 [00:10<00:01, 23.09it/s]

9 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
10 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
11 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLi

 93%|█████████▎| 237/256 [00:10<00:00, 23.13it/s]

31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
1 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLin

 94%|█████████▍| 240/256 [00:11<00:00, 22.64it/s]

20 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
21 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
22 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL

 96%|█████████▌| 246/256 [00:11<00:00, 23.03it/s]

5 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
6 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
7 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLine

 97%|█████████▋| 249/256 [00:11<00:00, 22.87it/s]

30 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
31 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
0 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLi

100%|██████████| 256/256 [00:11<00:00, 21.83it/s]

22 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
23 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelLinear()
    (w3): ColumnParallelLinear()
  )
  (attention_norm): RMSNorm()
  (ffn_norm): RMSNorm()
)
24 TransformerBlock(
  (attention): Attention(
    (wq): ColumnParallelLinear()
    (wk): ColumnParallelLinear()
    (wv): ColumnParallelLinear()
    (wo): RowParallelLinear()
  )
  (feed_forward): FeedForward(
    (w1): ColumnParallelLinear()
    (w2): RowParallelL




In [11]:
print(tokenizer.decode(out_tokens[0][0]))

A world of wonder, a world of might,
Where cultures blend and stories take flight,
A tapestry rich, with threads of gold,
A journey awaits, young and old.

In cities bustling, with sounds and sights,
Where ancient ruins whisper secrets at night,
In mountains towering, where eagles soar,
And oceans vast, where dolphins play once more.

In deserts hot, where camels roam free,
And forests dark, where wolves roam wild and carefree,
In villages quaint, where laughter echoes loud,
And cities modern, where innovation avows.

A world of wonder, a world of might,
Where dreams are made, and stories take flight,
A journey awaits, young and old,
To explore, to discover, to unfold.

In every step, a new tale unfolds,
Of people, places, and stories untold,
Of struggles and triumphs, of love and of strife,
Of cultures blending, in a tapestry of life.

In every breath, a new scent is born,
Of spices and incense, of coffee and corn,
Of fresh-cut grass, of ocean spray,
Of distant lands, where memories s

In [39]:
len(out_tokens[0][0])

428