In [37]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat
import numpy as np

In [292]:
8 * 8 * 2048 * 8000

1048576000

In [2]:
from transformers import GPT2Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

In [38]:
x = x.numpy()

In [46]:
x

array([[  340,   345,   517,  1909,  1165,   783,   502,   994,   640,
          503],
       [  198, 50256,    20,    19,    17,    23,    82,   362,   278,
          352],
       [  198,   327, 15886, 50256,    91,  6494,   250,   376,  4522,
          685],
       [15886,   198,  1294,  2937,    91,   327,   311,   317, 50256,
        31440],
       [ 1802,  2026,   838,  1160,    20,  1542,   642,  4019,  1679,
         4101],
       [  371,   311,  5161,   317,  1195,   347,   337,   360,   367,
          350],
       [  198,   440,   259,   288,   267,   364,   300,   360,    82,
          642],
       [  198, 50256, 15886, 24555, 22219,   250,    17,    16,    91,
           20],
       [   16,    17,    82,    18,    64,   251,    19,    65,    20,
           32],
       [  198, 50256, 15886,     0, 11146,   427,   277,    91,    11,
        25998],
       [  198,   830,  3012,    17, 15886,   317,  1160,  1802,  2026,
          838],
       [  340,  2102,   640,    82,   326, 

In [40]:
word = tokenizer.decode(x[0])

In [41]:
word

' it you more today too now me here time out'

In [35]:
x

tensor([[  340,   345,   517,  1909,  1165,   783,   502,   994,   640,   503],
        [  198, 50256,    20,    19,    17,    23,    82,   362,   278,   352],
        [  198,   327, 15886, 50256,    91,  6494,   250,   376,  4522,   685],
        [15886,   198,  1294,  2937,    91,   327,   311,   317, 50256, 31440],
        [ 1802,  2026,   838,  1160,    20,  1542,   642,  4019,  1679,  4101],
        [  371,   311,  5161,   317,  1195,   347,   337,   360,   367,   350],
        [  198,   440,   259,   288,   267,   364,   300,   360,    82,   642],
        [  198, 50256, 15886, 24555, 22219,   250,    17,    16,    91,    20],
        [   16,    17,    82,    18,    64,   251,    19,    65,    20,    32],
        [  198, 50256, 15886,     0, 11146,   427,   277,    91,    11, 25998],
        [  198,   830,  3012,    17, 15886,   317,  1160,  1802,  2026,   838],
        [  340,  2102,   640,    82,   326,   614,   531,   812,   352,   345],
        [  198,  1729,  2116,  1462,  10

In [67]:
word = tokenizer.decode([318])

In [68]:
word

' is'

In [163]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.x1 = nn.Linear(3, 4, bias=False)
        self.x2 = nn.Linear(3, 4, bias=False)
        self.x1.weight = self.x2.weight

    def params_count(self):
        x1_params = sum(x.numel() for x in self.x1.parameters())
        x2_params = sum(x.numel() for x in self.x2.parameters())
        params = sum(x.numel() for x in self.parameters())
        return x1_params, x2_params, params

In [None]:
def params_count(self, config: Config):
    total_params = sum(x.numel() for x in self.parameters())
    et = config.expert_num_per_token
    not_a_sub_expert_num = config.expert_num * config.sub_expert_num - 1
    not_fa_sub_expert_num = (config.expert_num - et) * config.sub_expert_num

    ffn_experts_coeff = (
        config.main_dim
        * (3 * config.sub_expert_dim + config.sub_expert_router)
        * (config.layer_num - 1)
    )
    ffn_experts_not_a_params = ffn_experts_coeff * not_a_sub_expert_num
    ffn_experts_not_fa_params = ffn_experts_coeff * not_fa_sub_expert_num

    total_a_params = total_params - ffn_experts_not_a_params - config.main_dim * config.expert_num * (config.layer_num - 1)
    total_fa_params = total_params - ffn_experts_not_fa_params
    params_result = {
        "total_params": total_params,
        "total_a_params": total_a_params,
        "total_fa_params": total_fa_params,
    }
    return params_result

In [None]:
p_dense: bool
p_ffn_factor: float
f_dense: bool
f_ffn_factor: float
p_shared_expert: bool
p_shared_expert_dim: int
f_shared_expert: bool
f_shared_expert_dim: int

p_dim_s = config.p_shared_expert_dim
f_dim_s = config.f_shared_expert_dim

if config.p_dense:
    self.p_ffn_h_dim = int(dim * config.p_ffn_factor)
    self.p_dense_up = nn.Linear(dim, self.p_ffn_h_dim * 2, bias=False)
    self.p_dense_down = nn.Linear(self.p_ffn_h_dim, dim, bias=False)
else:
    self.p_ffn_experts = nn.Parameter(torch.randn(3, ple, dim, p_dim))
    if config.p_shared_expert:
        self.p_ffn_up = nn.Linear(dim, p_dim_s * 2, bias=False)
        self.p_ffn_down = nn.Linear(p_dim_s, dim, bias=False)

if config.f_dense:
    self.f_ffn_h_dim = int(dim * config.f_ffn_factor)
    self.f_dense_up = nn.Linear(dim, self.f_ffn_h_dim * 2, bias=False)
    self.f_dense_down = nn.Linear(self.f_ffn_h_dim, dim, bias=False)
else:
    self.f_ffn_experts = nn.Parameter(torch.randn(3, fle, dim, f_dim))
    if config.f_shared_expert:
        self.f_ffn_up = nn.Linear(dim, f_dim_s * 2, bias=False)
        self.f_ffn_down = nn.Linear(f_dim_s, dim, bias=False)

if config.p_dense:
    px1, px2 = torch.split(self.p_dense_up(p_x_ffn), self.p_ffn_h_dim, -1)
    px3 = F.silu(px1) * px2
    py = self.p_dense_down(px3)
else:
    p_scores = p_values.sigmoid()
    p_scores = p_scores / p_scores.sum(-1, keepdim=True)
    p_indices = p_indices.flatten()
    p_scores = p_scores.flatten().unsqueeze(-1) * config.p_routed_scaling_factor
    py = grouped_gemm_func(
        p_x_ffn, self.p_ffn_experts, p_indices, p_scores, config
    )

    if config.p_shared_expert:
        px1, px2 = torch.split(
            self.p_ffn_up(p_x_ffn), config.p_shared_expert_dim, -1
        )
        px3 = F.silu(px1) * px2
        p_y_shared = self.p_ffn_down(px3)
        py = py + p_y_shared

    if config.f_dense:
        fx1, fx2 = torch.split(self.f_dense_up(p_x_ffn), self.f_ffn_h_dim, -1)
        fx3 = F.silu(fx1) * fx2
        fy = self.f_dense_down(fx3)
        py = py + fy

if config.f_dense:
    fx1, fx2 = torch.split(self.f_dense_up(f_x_ffn), self.f_ffn_h_dim, -1)
    fx3 = F.silu(fx1) * fx2
    fy = self.f_dense_down(fx3)
else:
    f_scores = f_values.sigmoid()
    f_scores = f_scores / f_scores.sum(-1, keepdim=True)
    f_indices = f_indices.flatten()
    f_scores = f_scores.flatten().unsqueeze(-1) * config.f_routed_scaling_factor
    fy = grouped_gemm_func(
        f_x_ffn, self.f_ffn_experts, f_indices, f_scores, config
    )

    if config.f_shared_expert:
        fx1, fx2 = torch.split(
            self.f_ffn_up(f_x_ffn), config.f_shared_expert_dim, -1
        )
        fx3 = F.silu(fx1) * fx2
        f_y_shared = self.f_ffn_down(fx3)
        fy = fy + f_y_shared

y = torch.cat([py, fy], dim=0) + x_ffn_input

In [321]:
x = torch.randint(0, 64, (128, 4))

In [322]:
b = torch.bincount(x.flatten())

In [323]:
b = b / b.sum() * 64

In [324]:
b

tensor([1.1, 1.2, 0.8, 0.4, 1.5, 1.0, 0.6, 1.2, 1.1, 0.8, 0.8, 1.0, 0.5, 1.2,
        1.1, 0.8, 1.1, 1.4, 1.1, 1.0, 0.6, 1.0, 1.0, 1.2, 1.0, 0.8, 1.0, 1.2,
        1.1, 1.6, 1.0, 1.0, 0.8, 0.6, 0.4, 1.6, 0.8, 1.1, 1.5, 1.4, 1.5, 1.6,
        0.8, 1.1, 0.8, 0.5, 0.6, 1.2, 0.9, 1.0, 1.1, 1.5, 0.6, 1.5, 0.8, 0.9,
        0.5, 0.6, 0.6, 1.1, 0.4, 1.0, 2.0, 0.9])

In [325]:
x_b = b[x]

In [326]:
w1 = 2 / (x_b.amax(-1, keepdim=True) + x_b.amax(-1, keepdim=True))

In [327]:
w1 = w1.expand(-1, 4)

In [328]:
b1 = torch.bincount(x.flatten(), w1.flatten())

In [329]:
b1 = b1 / b1.sum() * 64

In [330]:
b1

tensor([1.2, 1.4, 0.8, 0.4, 1.4, 1.1, 0.5, 1.3, 1.3, 0.9, 0.7, 1.0, 0.5, 1.2,
        1.2, 0.8, 1.3, 1.2, 1.3, 1.0, 0.6, 1.0, 1.0, 1.3, 1.1, 0.8, 1.1, 1.3,
        1.1, 1.4, 1.0, 1.1, 0.7, 0.7, 0.4, 1.4, 0.9, 1.1, 1.4, 1.3, 1.4, 1.4,
        0.8, 1.2, 0.8, 0.4, 0.6, 1.3, 1.0, 1.1, 1.1, 1.4, 0.7, 1.3, 0.8, 0.9,
        0.5, 0.6, 0.6, 1.2, 0.3, 1.1, 1.4, 0.9])

In [331]:
x_b1 = b1[x]

In [332]:
w2 = 2 / (x_b1.amin(-1, keepdim=True) + x_b1.amin(-1, keepdim=True))

In [333]:
w2 = w2.expand(-1, 4)

In [334]:
b2 = torch.bincount(x.flatten(), w1.flatten() * w2.flatten())

In [335]:
b2 = b2 / b2.sum() * 64

In [336]:
b2

tensor([1.0, 1.1, 0.8, 0.7, 1.5, 0.8, 0.7, 1.2, 1.0, 0.9, 0.9, 1.0, 0.7, 1.2,
        0.9, 1.1, 1.1, 1.3, 1.3, 0.8, 0.7, 0.9, 0.9, 1.1, 1.1, 0.8, 0.8, 1.3,
        0.9, 1.4, 0.9, 1.0, 0.9, 0.7, 0.7, 1.6, 0.9, 1.1, 1.3, 1.3, 1.3, 1.4,
        0.9, 1.1, 0.8, 0.7, 0.7, 1.2, 1.0, 0.9, 1.1, 1.4, 0.8, 1.3, 0.8, 0.8,
        0.7, 0.8, 0.7, 1.1, 0.7, 1.0, 1.6, 1.0])

In [31]:
import torch

def simulate_category_weights(n, m, a, b):

    categories = torch.randint(0, n, (m,))
    weights = torch.rand(m) * (b - a) + a
    accumulated_weights = torch.zeros(n)
    accumulated_weights.scatter_add_(0, categories, weights)
    
    return accumulated_weights

In [129]:
torch.set_printoptions(precision=1, sci_mode=False)

In [189]:
simulate_category_weights(64, 512, 0.1, 0.9)

tensor([2.9, 2.9, 3.3, 3.3, 4.6, 1.1, 4.1, 3.3, 3.6, 1.6, 3.6, 5.7, 8.3, 3.7,
        4.9, 3.6, 4.2, 5.9, 5.2, 3.4, 2.1, 6.8, 3.6, 5.3, 3.2, 3.1, 1.9, 1.2,
        4.4, 3.1, 3.0, 4.0, 4.6, 4.3, 5.0, 3.0, 4.8, 5.5, 2.9, 3.9, 2.2, 1.5,
        4.6, 4.6, 1.0, 2.5, 5.1, 4.6, 3.5, 7.4, 3.5, 3.7, 4.1, 3.8, 1.8, 3.4,
        6.6, 4.9, 7.8, 3.6, 4.1, 2.2, 3.6, 3.3])

In [178]:
simulate_category_weights(64, 50, 0.1, 0.9) - simulate_category_weights(64, 50, 0.1, 0.9)

tensor([    -0.0,      0.0,     -0.2,      0.3,      1.4,     -0.8,      0.5,
             1.4,      0.0,     -0.2,      2.7,     -0.7,      0.0,      0.4,
             0.0,      0.2,     -0.6,      0.0,     -0.5,     -0.8,     -0.5,
             0.1,     -0.4,     -0.4,     -1.4,      0.0,     -0.3,      2.2,
             0.6,      0.4,      0.0,      0.7,      0.4,      0.0,      0.4,
             0.8,     -0.1,     -0.6,     -1.7,      0.2,      0.5,     -0.8,
            -0.7,      0.0,      0.0,      0.3,      0.0,     -0.8,      0.0,
             1.1,      0.1,     -0.2,     -0.4,     -0.9,      0.0,     -1.2,
             0.0,      0.7,      0.0,     -0.5,      0.4,      0.0,     -0.6,
            -1.7])

In [174]:
simulate_category_weights(64, 5120, 0.1, 0.9) / 0.4

tensor([ 93.9, 113.0, 138.4,  84.9, 102.2, 107.9,  95.6,  95.4,  83.8, 104.5,
        108.3,  94.8, 100.6, 113.3,  98.4,  77.6, 119.6,  94.4,  84.9,  97.8,
         93.8,  98.3, 102.4,  97.4,  92.6,  98.5,  95.0, 112.5, 108.0, 111.6,
        109.7, 103.9,  92.1,  87.0, 123.1, 103.0,  83.9, 104.6, 102.1,  91.8,
         96.1, 100.0,  99.6, 125.8,  84.9, 106.3,  93.5, 103.9,  89.3,  90.0,
         65.9, 101.7,  95.0, 101.1,  94.8,  93.7,  98.4, 109.6, 114.4, 121.7,
        100.1, 101.7, 120.4, 103.1])

In [130]:
simulate_category_weights(64, 12800, 0.1, 0.9)

tensor([ 90.8, 105.2, 101.8, 100.3,  98.2, 104.2, 105.7,  93.9, 111.1,  97.2,
         95.8, 101.6,  99.3,  96.7,  93.4,  90.1,  99.4,  99.2, 105.9,  96.5,
         97.6, 104.7, 116.1, 109.1,  98.6, 100.0,  99.9, 100.0, 106.2, 103.3,
        104.1, 100.8, 102.5,  93.1, 108.8,  98.1,  93.8, 101.6,  95.9,  98.0,
         93.0, 105.7,  79.7, 102.5,  98.7, 110.1,  86.9,  85.3,  92.5,  95.4,
         87.6,  90.2, 105.3, 109.4, 113.1,  96.0,  99.8, 134.6,  92.9,  85.5,
        117.5,  86.5,  95.3, 104.0])

In [131]:
simulate_category_weights(400, 800000, 0.1, 0.9) / 10

tensor([104.9, 100.1, 101.5,  98.0,  99.8, 100.0,  99.1, 100.9, 100.8,  97.7,
         98.7,  95.3, 104.5, 101.6,  99.0,  98.6, 101.7, 100.0, 100.0, 105.1,
         98.5, 103.9, 100.5, 101.6, 100.2, 100.3, 100.1,  95.4, 100.8,  97.9,
         93.7, 100.9,  99.0, 100.0,  98.1, 101.0,  98.3, 100.1, 103.3, 101.6,
        101.0,  97.8,  98.3, 103.9,  96.5, 103.7, 104.4,  97.7,  95.9, 102.9,
        102.9,  98.1,  95.8,  98.2, 102.0,  98.0, 100.0, 101.1,  97.9, 102.8,
         98.9,  97.7, 107.0, 100.2, 106.3, 103.0,  99.6,  95.6,  95.5,  96.5,
        102.2,  99.6, 100.9,  92.9,  98.0,  98.9, 100.5, 101.0, 102.5, 102.2,
         98.4, 101.3, 100.1,  95.4,  98.3, 103.0, 100.8,  92.0,  98.2, 104.8,
         99.1, 106.6,  99.1,  98.3, 101.0, 101.7, 103.9, 103.1,  99.0, 104.6,
         99.8,  99.4, 102.1,  96.6, 100.3,  96.9, 101.1, 100.0, 102.3,  98.4,
        100.0,  98.6, 100.2,  98.9,  98.8, 103.0, 100.6, 102.4, 106.5, 100.1,
         98.5,  98.7,  97.9,  99.4,  98.8,  98.3,  98.7,  98.2, 

In [132]:
simulate_category_weights(400, 8000000, 0.1, 0.9) / 100

tensor([ 98.4,  99.1,  99.0,  99.9, 100.4, 100.7,  98.9,  99.0, 100.8,  98.7,
         99.9, 100.0,  99.7,  99.8,  99.8, 100.0, 100.9, 100.1, 100.4,  98.7,
         99.7,  99.4, 101.0,  99.5, 100.8, 100.6,  99.3, 100.6, 100.7,  99.8,
         99.9,  99.4,  99.2,  99.7, 101.8,  98.5, 100.4, 100.1, 100.1, 100.5,
         99.1,  98.6,  99.7,  99.5, 100.0, 100.1,  99.7, 100.0, 100.4, 100.2,
        100.6,  99.3, 100.4, 100.3,  99.4,  98.7, 100.1, 100.5,  99.5,  99.3,
        100.1, 100.4, 100.0,  99.8, 100.8,  99.3, 100.0,  99.9, 101.6, 100.7,
        100.8, 101.5,  98.9,  99.4, 100.3, 100.7,  99.3,  99.9, 100.6,  99.3,
        100.4,  99.0, 100.6,  99.9, 100.5, 100.7, 100.6, 100.5, 101.6,  99.7,
         99.2, 100.5,  98.5, 101.1, 101.1,  99.0,  99.7,  99.7, 100.8,  99.5,
        100.6,  99.0, 100.8,  98.5, 100.5, 101.4, 100.3,  99.0,  99.2,  99.5,
         99.8, 100.0,  99.5,  99.1, 100.3,  98.4, 100.3, 100.1,  98.8,  98.8,
        100.0, 100.0, 101.8, 100.4,  99.5,  99.2,  99.8,  99.8, 