# RMSNorm 测试

In [26]:
import mindspore
from mindspore import nn, ops

def torch_rms_layernorm(hidden: mindspore.Tensor, weight: mindspore.Tensor, eps: float):
    old_dtype = hidden.dtype
    variance = hidden.to(mindspore.float32).pow(2).mean(axis=-1, keep_dims=True)
    hidden = (hidden * ops.rsqrt(variance + eps)).to(old_dtype)
    return hidden * weight


class MSMiniCPMRMSNorm(nn.Cell):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MiniCPMRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = mindspore.Parameter(ops.ones(hidden_size))
        self.variance_epsilon = eps

    def construct(self, hidden_states):
        return ms_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
    


In [27]:
import numpy as np

hidden_size = 4096
batch_size = 1
seq_len = 10

ms_nms_norm = MSMiniCPMRMSNorm(hidden_size)


input_array = np.random.rand(batch_size, seq_len, hidden_size)
input_array = input_array.astype(np.float32)
ms_input = mindspore.Tensor.from_numpy(input_array)


ms_output = ms_nms_norm(ms_input)
print(ms_output)

[[[0.78732544 1.2833312  1.2922165  ... 0.04089143 0.49289078 0.6821702 ]
  [1.1579348  0.84069234 0.76891506 ... 1.5036066  1.2710027  0.71515465]
  [1.1578155  0.05839312 0.49893308 ... 0.69267124 0.08189241 0.7144707 ]
  ...
  [0.29475603 0.5684639  0.6187994  ... 1.5154487  1.4241893  0.34242994]
  [0.51642406 0.667304   1.2583175  ... 1.3388052  1.2167487  0.94261634]
  [1.1134353  1.1357505  1.1854455  ... 1.2128674  1.0710244  0.49770108]]]


In [28]:
import torch
from torch import nn

def torch_rms_layernorm(hidden: torch.Tensor, weight: torch.Tensor, eps: float):
    old_dtype = hidden.dtype
    variance = hidden.to(torch.float32).pow(2).mean(dim=-1, keepdim=True)
    hidden = (hidden * torch.rsqrt(variance + eps)).to(old_dtype)
    return hidden * weight


class TorchMiniCPMRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        MiniCPMRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        return torch_rms_layernorm(hidden_states, self.weight, self.variance_epsilon)
    
torch_nms_norm = TorchMiniCPMRMSNorm(hidden_size)

In [30]:
torch_input = torch.from_numpy(input_array)
torch_output = torch_nms_norm(torch_input)
print(torch_output)


tensor([[[0.7873, 1.2833, 1.2922,  ..., 0.0409, 0.4929, 0.6822],
         [1.1579, 0.8407, 0.7689,  ..., 1.5036, 1.2710, 0.7152],
         [1.1578, 0.0584, 0.4989,  ..., 0.6927, 0.0819, 0.7145],
         ...,
         [0.2948, 0.5685, 0.6188,  ..., 1.5154, 1.4242, 0.3424],
         [0.5164, 0.6673, 1.2583,  ..., 1.3388, 1.2167, 0.9426],
         [1.1134, 1.1358, 1.1854,  ..., 1.2129, 1.0710, 0.4977]]],
       grad_fn=<MulBackward0>)


In [33]:
print(np.allclose(ms_output.asnumpy(), torch_output.detach().numpy(), atol=1e-3))

True


# RotrayEmbedding

In [46]:
import mindspore
from mindspore import nn, ops

class MSMiniCPMRotaryEmbedding(nn.Cell):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        self.inv_freq = 1.0 / (self.base ** (ops.arange(0, self.dim, 2).to(mindspore.float32) / self.dim))

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
            seq_len=max_position_embeddings, dtype=mindspore.float32
        )

    def _set_cos_sin_cache(self, seq_len, dtype):
        self.max_seq_len_cached = seq_len
        t = ops.arange(end=self.max_seq_len_cached, dtype=self.inv_freq.dtype)
        freqs = ops.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = ops.cat((freqs, freqs), axis=-1)

        self.cos_cached = emb.cos().to(dtype)
        self.sin_cached = emb.sin().to(dtype)


    def construct(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )

In [47]:
num_attn_heads = 32
head_size = hidden_size // num_attn_heads

ms_rotary_emb = MSMiniCPMRotaryEmbedding(dim=hidden_size)

input_array_rope = np.random.rand(batch_size, num_attn_heads, seq_len, head_size).astype(np.float32)
ms_input_rope = mindspore.Tensor.from_numpy(input_array_rope)

ms_output_rope = ms_rotary_emb(ms_input_rope, seq_len=seq_len)
print(ms_output_rope)

(Tensor(shape=[10, 4096], dtype=Float32, value=
[[ 1.00000000e+00,  1.00000000e+00,  1.00000000e+00 ...  1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
 [ 5.40302277e-01,  5.44072688e-01,  5.47815144e-01 ...  1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
 [-4.16146845e-01, -4.07969773e-01, -3.99797082e-01 ...  1.00000000e+00,  1.00000000e+00,  1.00000000e+00],
 ...
 [ 7.53902256e-01,  7.74163365e-01,  7.93574035e-01 ...  9.99999762e-01,  9.99999762e-01,  9.99999762e-01],
 [-1.45500034e-01, -1.09898202e-01, -7.43169263e-02 ...  9.99999642e-01,  9.99999702e-01,  9.99999702e-01],
 [-9.11130250e-01, -8.93748343e-01, -8.74997675e-01 ...  9.99999583e-01,  9.99999583e-01,  9.99999583e-01]]), Tensor(shape=[10, 4096], dtype=Float32, value=
[[ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00 ...  0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
 [ 8.41470957e-01,  8.39038074e-01,  8.36599410e-01 ...  1.01358310e-04,  1.00903497e-04,  1.00450736e-04],
 [ 9.09297407e-01,  9.12995458e-01

In [48]:
import torch
from torch import nn

class TorchMiniCPMRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


In [49]:
torch_rotary_emb = TorchMiniCPMRotaryEmbedding(dim=hidden_size, device=torch.device("cpu"))

input_array_rope = np.random.rand(batch_size, seq_len, hidden_size).astype(np.float32)
torch_input_rope = torch.from_numpy(input_array_rope)

torch_output_rope = torch_rotary_emb(torch_input_rope, seq_len=seq_len)
print(torch_output_rope)

(tensor([[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.5441,  0.5478,  ...,  1.0000,  1.0000,  1.0000],
        [-0.4161, -0.4080, -0.3998,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [ 0.7539,  0.7742,  0.7936,  ...,  1.0000,  1.0000,  1.0000],
        [-0.1455, -0.1099, -0.0743,  ...,  1.0000,  1.0000,  1.0000],
        [-0.9111, -0.8937, -0.8750,  ...,  1.0000,  1.0000,  1.0000]]), tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [8.4147e-01, 8.3904e-01, 8.3660e-01,  ..., 1.0136e-04, 1.0090e-04,
         1.0045e-04],
        [9.0930e-01, 9.1300e-01, 9.1660e-01,  ..., 2.0272e-04, 2.0181e-04,
         2.0090e-04],
        ...,
        [6.5699e-01, 6.3299e-01, 6.0847e-01,  ..., 7.0951e-04, 7.0632e-04,
         7.0316e-04],
        [9.8936e-01, 9.9394e-01, 9.9723e-01,  ..., 8.1087e-04, 8.0723e-04,
         8.0361e-04],
        [4.1212e-01, 4.4857e-01, 4.8413e-01,  ..., 9.1222e-04, 9.0

In [50]:
ms_output_rope_cos, ms_output_rope_sin = ms_output_rope
torch_output_rope_cos, torch_output_rope_sin = torch_output_rope
print(np.allclose(ms_output_rope_cos.asnumpy(), torch_output_rope_cos.detach().numpy(), atol=1e-3))
print(np.allclose(ms_output_rope_sin.asnumpy(), torch_output_rope_sin.detach().numpy(), atol=1e-3))

True
True


# LinearScalingRoPE

In [51]:
class MSMiniCPMLinearScalingRotaryEmbedding(MSMiniCPMRotaryEmbedding):
    """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base)

    def _set_cos_sin_cache(self, seq_len, dtype):
        self.max_seq_len_cached = seq_len
        t = ops.arange(end=self.max_seq_len_cached, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor

        freqs = ops.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = ops.cat((freqs, freqs), axis=-1)
        self.cos_cached = emb.cos().to(dtype)
        self.sin_cached = emb.sin().to(dtype)


In [52]:
ms_rotary_emb_linear_scaling = MSMiniCPMLinearScalingRotaryEmbedding(dim=hidden_size)
ms_rotary_emb_linear_scaling._set_cos_sin_cache(seq_len, mindspore.float32)
ms_output_cos = ms_rotary_emb_linear_scaling.cos_cached
ms_output_sin = ms_rotary_emb_linear_scaling.sin_cached

In [53]:
class TorchMiniCPMLinearScalingRotaryEmbedding(TorchMiniCPMRotaryEmbedding):
    """MiniCPMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        t = t / self.scaling_factor

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

In [55]:
torch_rotary_emb_linear_scaling = TorchMiniCPMLinearScalingRotaryEmbedding(dim=hidden_size, device=torch.device("cpu"))
torch_rotary_emb_linear_scaling._set_cos_sin_cache(seq_len, torch.device("cpu"), torch.float32)
torch_output_cos = torch_rotary_emb_linear_scaling.cos_cached
torch_output_sin = torch_rotary_emb_linear_scaling.sin_cached

In [56]:
print(np.allclose(ms_output_sin.asnumpy(), torch_output_sin.detach().numpy(), atol=1e-3))
print(np.allclose(ms_output_cos.asnumpy(), torch_output_cos.detach().numpy(), atol=1e-3))

True
True


# DynamicNTK Scaling RoPE

In [57]:
class MSMiniCPMDynamicNTKScalingRotaryEmbedding(MSMiniCPMRotaryEmbedding):
    """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base)

    def _set_cos_sin_cache(self, seq_len, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            self.inv_freq = 1.0 / (base ** (ops.arange(0, self.dim, 2).to(mindspore.float32) / self.dim))

        t = ops.arange(end=self.max_seq_len_cached, dtype=self.inv_freq.dtype)

        freqs = ops.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = ops.cat((freqs, freqs), axis=-1)

        self.cos_cached = emb.cos().to(dtype)
        self.sin_cached = emb.sin().to(dtype)

In [72]:
ms_rotary_emb_dynamic_scaling = MSMiniCPMDynamicNTKScalingRotaryEmbedding(hidden_size)
ms_rotary_emb_dynamic_scaling._set_cos_sin_cache(seq_len, mindspore.float32)
ms_output_cos = ms_rotary_emb_dynamic_scaling.cos_cached
ms_output_sin = ms_rotary_emb_dynamic_scaling.sin_cached

print(ms_output_cos.shape)

(10, 4096)


In [60]:
class TorchMiniCPMDynamicNTKScalingRotaryEmbedding(TorchMiniCPMRotaryEmbedding):
    """MiniCPMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        self.scaling_factor = scaling_factor
        super().__init__(dim, max_position_embeddings, base, device)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len

        if seq_len > self.max_position_embeddings:
            base = self.base * (
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
            ) ** (self.dim / (self.dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
            self.register_buffer("inv_freq", inv_freq, persistent=False)

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

In [61]:
torch_rotary_emb_dynamic_scaling = TorchMiniCPMDynamicNTKScalingRotaryEmbedding(dim=hidden_size, device=torch.device("cpu"))
torch_rotary_emb_dynamic_scaling._set_cos_sin_cache(seq_len, torch.device("cpu"), torch.float32)
torch_output_cos = torch_rotary_emb_dynamic_scaling.cos_cached
torch_output_sin = torch_rotary_emb_dynamic_scaling.sin_cached

In [62]:
print(np.allclose(ms_output_sin.asnumpy(), torch_output_sin.detach().numpy(), atol=1e-3))
print(np.allclose(ms_output_cos.asnumpy(), torch_output_cos.detach().numpy(), atol=1e-3))

True
True


# rotate half

In [63]:
import mindspore
from mindspore import ops

def ms_rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return ops.cat((-x2, x1), axis=-1)

In [64]:
input_array = np.random.rand(batch_size, num_attn_heads, seq_len, head_size).astype(np.float32)

ms_input = mindspore.Tensor.from_numpy(input_array)
ms_output = ms_rotate_half(ms_input)

print(ms_output)


[[[[-0.9383789  -0.7959062  -0.7229492  ...  0.75974023  0.57705337
     0.5258102 ]
   [-0.4408688  -0.0808837  -0.4772482  ...  0.64175427  0.0088954
     0.8439258 ]
   [-0.10451046 -0.24723798 -0.3941851  ...  0.4869339   0.5427277
     0.27998877]
   ...
   [-0.9992323  -0.6110815  -0.51049364 ...  0.48206657  0.1275518
     0.65963066]
   [-0.81961334 -0.60193527 -0.16736974 ...  0.6794498   0.08161078
     0.70404893]
   [-0.9148321  -0.9682962  -0.4853782  ...  0.50420827  0.48732084
     0.74103487]]

  [[-0.9561082  -0.4691347  -0.20137118 ...  0.7968146   0.76639074
     0.5343106 ]
   [-0.8239253  -0.8691968  -0.26496986 ...  0.8805403   0.87118614
     0.86909795]
   [-0.90201867 -0.92350847 -0.4557022  ...  0.7991929   0.14225402
     0.89251596]
   ...
   [-0.5357112  -0.12910734 -0.87065876 ...  0.967807    0.48311362
     0.9123667 ]
   [-0.18272442 -0.80684733 -0.6087966  ...  0.69655305  0.7977314
     0.33861685]
   [-0.24963343 -0.2640282  -0.60354143 ...  0.022681

In [67]:
import torch
from torch import nn

def torch_rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


In [68]:
torch_input = torch.from_numpy(input_array)
torch_output = torch_rotate_half(torch_input)

print(torch_output)

tensor([[[[-0.9384, -0.7959, -0.7229,  ...,  0.7597,  0.5771,  0.5258],
          [-0.4409, -0.0809, -0.4772,  ...,  0.6418,  0.0089,  0.8439],
          [-0.1045, -0.2472, -0.3942,  ...,  0.4869,  0.5427,  0.2800],
          ...,
          [-0.9992, -0.6111, -0.5105,  ...,  0.4821,  0.1276,  0.6596],
          [-0.8196, -0.6019, -0.1674,  ...,  0.6794,  0.0816,  0.7040],
          [-0.9148, -0.9683, -0.4854,  ...,  0.5042,  0.4873,  0.7410]],

         [[-0.9561, -0.4691, -0.2014,  ...,  0.7968,  0.7664,  0.5343],
          [-0.8239, -0.8692, -0.2650,  ...,  0.8805,  0.8712,  0.8691],
          [-0.9020, -0.9235, -0.4557,  ...,  0.7992,  0.1423,  0.8925],
          ...,
          [-0.5357, -0.1291, -0.8707,  ...,  0.9678,  0.4831,  0.9124],
          [-0.1827, -0.8068, -0.6088,  ...,  0.6966,  0.7977,  0.3386],
          [-0.2496, -0.2640, -0.6035,  ...,  0.0227,  0.8105,  0.9328]],

         [[-0.8062, -0.8548, -0.9654,  ...,  0.5692,  0.7955,  0.5584],
          [-0.9142, -0.5408, -

In [71]:
print(np.allclose(ms_output.asnumpy(), torch_output.detach().numpy(), atol=1e-3))

True


# apply RoPE

In [73]:
def ms_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`mindspore.Tensor`): The query tensor.
        k (`mindspore.Tensor`): The key tensor.
        cos (`mindspore.Tensor`): The cosine part of the rotary embedding.
        sin (`mindspore.Tensor`): The sine part of the rotary embedding.
        position_ids (`mindspore.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(mindspore.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # q_embed = (q * cos) + (rotate_half(q) * sin)
    # k_embed = (k * cos) + (rotate_half(k) * sin)
    orig_dtype = k.dtype
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim]
    q_fp32 = q.to(dtype=mindspore.float32)
    k_fp32 = k.to(dtype=mindspore.float32)
    q_embed = (q_fp32 * cos) + (ms_rotate_half(q_fp32) * sin)
    k_embed = (k_fp32 * cos) + (ms_rotate_half(k_fp32) * sin)
    return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

In [88]:
input_qk_array = np.random.rand(batch_size, seq_len, hidden_size).astype(np.float32)
ms_q = ms_k = mindspore.Tensor.from_numpy(input_qk_array)
ms_cos = ms_sin = mindspore.ops.ones((batch_size, seq_len, hidden_size), dtype=mindspore.float32)
position_ids = ops.arange(end=batch_size)

ms_output_k, ms_output_q = ms_apply_rotary_pos_emb(ms_q, ms_k, ms_cos, ms_sin, position_ids)

In [89]:
def torch_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    # q_embed = (q * cos) + (rotate_half(q) * sin)
    # k_embed = (k * cos) + (rotate_half(k) * sin)
    orig_dtype = k.dtype
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)  # [bs, 1, seq_len, dim]
    q_fp32 = q.to(dtype=torch.float32, device=q.device)
    k_fp32 = k.to(dtype=torch.float32, device=k.device)
    q_embed = (q_fp32 * cos) + (torch_rotate_half(q_fp32) * sin)
    k_embed = (k_fp32 * cos) + (torch_rotate_half(k_fp32) * sin)
    return q_embed.to(dtype=orig_dtype), k_embed.to(dtype=orig_dtype)

In [90]:
torch_q = torch_k = torch.from_numpy(input_qk_array)
torch_cos = torch_sin = torch.ones((batch_size, seq_len, hidden_size)).float()
position_ids = torch.arange(end=batch_size)

torch_output_k, torch_output_q = torch_apply_rotary_pos_emb(torch_q, torch_k, torch_cos, torch_sin, position_ids)

In [95]:
print(np.allclose(ms_output_k.asnumpy(), torch_output_k.detach().numpy(), atol=1e-3))
print(np.allclose(ms_output_q.asnumpy(), torch_output_q.detach().numpy(), atol=1e-3))

True
True


# CPMMLP

In [118]:
import mindspore
from mindspore import nn, ops

class MSMiniCPMMLP(nn.Cell):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False)
        self.up_proj = nn.Dense(self.hidden_size, self.intermediate_size, has_bias=False)
        self.down_proj = nn.Dense(self.intermediate_size, self.hidden_size, has_bias=False)
        self.act_fn = config.hidden_act

    def construct(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, axis=0)
            up_proj_slices = self.up_proj.weight.split(slice, axis=0)
            down_proj_slices = self.down_proj.weight.split(slice, axis=1)

            gate_proj = ops.cat(
                [ops.dense(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], axis=-1
            )
            up_proj = ops.cat([ops.dense(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], axis=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, axis=2)
            down_proj = [
                ops.dense(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

In [119]:
class MiniCPMConfig:
    r"""
    This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the MiniCPM-7B.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 32000):
            Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`MiniCPMModel`]
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 11008):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 32):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `num_attention_heads`.
        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 2048):
            The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
            MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*):
            Padding token id.
        bos_token_id (`int`, *optional*, defaults to 1):
            Beginning of stream token id.
        eos_token_id (`int`, *optional*, defaults to 2):
            End of stream token id.
        pretraining_tp (`int`, *optional*, defaults to 1):
            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
            issue](https://github.com/pytorch/pytorch/issues/76232).
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
            `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
            these scaling strategies behave:
            https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
            experimental feature, subject to breaking API changes in future versions.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    ```python
    >>> from transformers import MiniCPMModel, MiniCPMConfig

    >>> # Initializing a MiniCPM minicpm-7b style configuration
    >>> configuration = MiniCPMConfig()

    >>> # Initializing a model from the minicpm-7b style configuration
    >>> model = MiniCPMModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "minicpm"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        hidden_act=nn.SiLU(),
        max_position_embeddings=2048,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        pretraining_tp=1,
        tie_word_embeddings=True,
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        scale_emb=1,
        dim_model_base=1,
        scale_depth=1,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.pretraining_tp = pretraining_tp
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        # self._rope_scaling_validation()
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.scale_emb = scale_emb
        self.dim_model_base = dim_model_base
        self.scale_depth = scale_depth

In [120]:
ms_config = MiniCPMConfig()
ms_mlp = MSMiniCPMMLP(ms_config)

input_array = np.random.rand(batch_size, 2048, hidden_size).astype(np.float32)

ms_input_mlp = mindspore.Tensor.from_numpy(input_array)
ms_output_mlp = ms_mlp(ms_input_mlp)

In [121]:
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn

from transformers.activations import ACT2FN

class TorchMiniCPMMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
            )
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
        else:
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

In [122]:
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MiniCPM model configuration"""

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging


logger = logging.get_logger(__name__)

MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}


class MiniCPMConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the MiniCPM-7B.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 32000):
            Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`MiniCPMModel`]
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 11008):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 32):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details checkout [this
            paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
            `num_attention_heads`.
        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to 2048):
            The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
            MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*):
            Padding token id.
        bos_token_id (`int`, *optional*, defaults to 1):
            Beginning of stream token id.
        eos_token_id (`int`, *optional*, defaults to 2):
            End of stream token id.
        pretraining_tp (`int`, *optional*, defaults to 1):
            Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
            document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
            necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
            issue](https://github.com/pytorch/pytorch/issues/76232).
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
            strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
            `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
            `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
            these scaling strategies behave:
            https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
            experimental feature, subject to breaking API changes in future versions.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.

    ```python
    >>> from transformers import MiniCPMModel, MiniCPMConfig

    >>> # Initializing a MiniCPM minicpm-7b style configuration
    >>> configuration = MiniCPMConfig()

    >>> # Initializing a model from the minicpm-7b style configuration
    >>> model = MiniCPMModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```"""

    model_type = "minicpm"
    keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        intermediate_size=11008,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=None,
        hidden_act="silu",
        max_position_embeddings=2048,
        initializer_range=0.02,
        rms_norm_eps=1e-6,
        use_cache=True,
        pad_token_id=None,
        bos_token_id=1,
        eos_token_id=2,
        pretraining_tp=1,
        tie_word_embeddings=True,
        rope_theta=10000.0,
        rope_scaling=None,
        attention_bias=False,
        attention_dropout=0.0,
        scale_emb=1,
        dim_model_base=1,
        scale_depth=1,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.rms_norm_eps = rms_norm_eps
        self.pretraining_tp = pretraining_tp
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.rope_scaling = rope_scaling
        self._rope_scaling_validation()
        self.attention_bias = attention_bias
        self.attention_dropout = attention_dropout
        self.scale_emb = scale_emb
        self.dim_model_base = dim_model_base
        self.scale_depth = scale_depth

        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            tie_word_embeddings=tie_word_embeddings,
            **kwargs,
        )
        try:
            import flash_attn
            self._attn_implementation = "flash_attention_2"
        except:
            pass

    def _rope_scaling_validation(self):
        """
        Validate the `rope_scaling` configuration.
        """
        if self.rope_scaling is None:
            return

        if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
            raise ValueError(
                "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
                f"got {self.rope_scaling}"
            )
        rope_scaling_type = self.rope_scaling.get("type", None)
        rope_scaling_factor = self.rope_scaling.get("factor", None)
        if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
            raise ValueError(
                f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
            )
        if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
            raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")

In [123]:
torch_config = MiniCPMConfig()
torch_mlp = TorchMiniCPMMLP(torch_config)
torch_input_mlp = torch.from_numpy(input_array)

torch_output_mlp = torch_mlp(torch_input_mlp)

In [127]:
print(np.allclose(ms_output_mlp.asnumpy(), torch_output_mlp.detach().numpy(), atol=1e-3))
print(np.allclose(ms_input_mlp.asnumpy(), torch_input_mlp.detach().numpy(), atol=1e-3))

False
True


In [125]:
print(ms_output_mlp)

[[[ 0.00555526  0.03499052  0.00351002 ... -0.03566707 -0.00674556
   -0.00390526]
  [ 0.00502021  0.04264032  0.04439148 ... -0.02678786  0.00661847
   -0.03573214]
  [ 0.00616076  0.01205319  0.04827144 ...  0.00361077 -0.01063555
   -0.01574899]
  ...
  [-0.02905967 -0.02110439  0.02138827 ... -0.00567385  0.00530072
   -0.03021811]
  [-0.02721993  0.00997222 -0.00347654 ... -0.00987197 -0.00353347
   -0.02058237]
  [ 0.00977214  0.02368069 -0.00850708 ... -0.05252867  0.00523038
   -0.03006126]]]


In [126]:
torch_output_mlp

tensor([[[ 0.0011,  0.0185,  0.0542,  ...,  0.0349,  0.0352, -0.0069],
         [-0.0247,  0.0302, -0.0036,  ...,  0.0202,  0.0188, -0.0153],
         [-0.0372, -0.0314, -0.0384,  ...,  0.0185,  0.0270, -0.0268],
         ...,
         [-0.0097, -0.0449, -0.0001,  ..., -0.0157,  0.0637, -0.0172],
         [ 0.0173,  0.0053, -0.0149,  ...,  0.0137,  0.0653, -0.0643],
         [-0.0006,  0.0005, -0.0208,  ...,  0.0389,  0.0164, -0.0161]]],
       grad_fn=<UnsafeViewBackward0>)

In [136]:
# PyTorch
import torch
from torch import nn
import numpy as np

input_array = np.load("./input_array.npy").astype(np.float32)

torch_net = nn.Linear(4096, 11008, bias=False)
x = torch.from_numpy(input_array)
torch_output = torch_net(x)
print(torch_output.detach().numpy())
# (2, 4)

# MindSpore
import mindspore
from mindspore import Tensor, nn
import numpy as np

x = mindspore.Tensor.from_numpy(input_array)
ms_net = nn.Dense(4096, 11008, has_bias=False)
ms_output = ms_net(x)
print(ms_output)
# (2, 4)

[[[ 0.7589559  -0.07527499 -0.35819474 ... -0.49136406  0.22672285
    0.03544176]
  [ 0.4955714  -0.2900765  -0.5850546  ... -0.6383949   0.276924
   -0.37251896]
  [ 0.8016432  -0.04432888 -0.61939305 ... -0.50569415  0.34891298
   -0.03819778]
  ...
  [ 0.76070076 -0.33570215 -0.5777197  ... -0.46006587  0.5309063
   -0.5314831 ]
  [ 0.9872962  -0.05854413 -0.63073397 ... -0.59783506  0.34394875
   -0.35045716]
  [ 0.73312736 -0.2669691  -0.4728214  ... -0.30531365  0.5313557
   -0.46166003]]]
[[[ 0.11332035  0.474159   -0.01539239 ...  0.10245368  0.05044889
   -0.250153  ]
  [ 0.10101779  0.01157233 -0.38048276 ...  0.11030661  0.25223908
   -0.32689023]
  [-0.02514513  0.1507451  -0.37613693 ...  0.39953572 -0.04060515
   -0.40098608]
  ...
  [-0.18186352  0.0609047  -0.32666564 ... -0.05875109  0.44832134
   -0.10809004]
  [ 0.0401681  -0.1633428  -0.5373877  ...  0.05558265  0.23133233
   -0.50210017]
  [-0.24035767  0.30383867 -0.5683611  ... -0.02691719  0.10517808
   -0.3336

In [133]:
print(ms_net.weight.value())

[[-0.00408853 -0.0072077  -0.01317369 ... -0.01395462 -0.0130742
  -0.01503178]
 [-0.01060252 -0.00302445 -0.00668621 ...  0.00487923  0.01343558
   0.00919498]
 [ 0.00601849 -0.01478642 -0.01136461 ... -0.0005769  -0.00788664
   0.01203256]
 ...
 [ 0.00654013  0.00516639 -0.00797721 ...  0.00618667  0.00966447
  -0.01506723]
 [-0.00121664 -0.00063749 -0.00711064 ... -0.01112679 -0.01172142
   0.0137059 ]
 [-0.00770224  0.00940624 -0.01486338 ... -0.00967028 -0.006234
  -0.00682555]]


In [134]:
print(torch_net.weight)

Parameter containing:
tensor([[-0.0004, -0.0023,  0.0156,  ...,  0.0138, -0.0082,  0.0101],
        [ 0.0138, -0.0060,  0.0135,  ..., -0.0085,  0.0012,  0.0061],
        [-0.0013,  0.0037,  0.0005,  ...,  0.0149,  0.0034,  0.0112],
        ...,
        [-0.0037, -0.0051, -0.0134,  ..., -0.0050, -0.0096,  0.0031],
        [-0.0104, -0.0088,  0.0093,  ...,  0.0087, -0.0028, -0.0149],
        [-0.0148, -0.0138, -0.0059,  ..., -0.0099, -0.0034, -0.0148]],
       requires_grad=True)


In [137]:
print(ms_output.shape)
print(torch_output.shape)

(1, 2048, 11008)
torch.Size([1, 2048, 11008])


In [139]:
from torch import nn
torch_net2 = nn.Linear(4096, 11008, bias=False)
print(torch_net2.weight)

Parameter containing:
tensor([[ 0.0102,  0.0066,  0.0002,  ..., -0.0097, -0.0008,  0.0030],
        [ 0.0154,  0.0099, -0.0091,  ...,  0.0092, -0.0035,  0.0094],
        [ 0.0108, -0.0143, -0.0117,  ..., -0.0146, -0.0131, -0.0018],
        ...,
        [ 0.0138,  0.0054, -0.0034,  ...,  0.0134,  0.0105, -0.0056],
        [ 0.0100, -0.0059,  0.0095,  ...,  0.0115, -0.0059, -0.0018],
        [ 0.0024,  0.0020, -0.0050,  ..., -0.0013, -0.0100,  0.0147]],
       requires_grad=True)


In [148]:
import mindspore
from mindspore import nn, ops
import numpy as np

hidden_size = 4096
num_attn_heads = 32
seq_len = 10
batch_size = 1
n_rep = 3
head_size = hidden_size // num_attn_heads

input_array_repeat_kv = np.random.rand(batch_size, num_attn_heads, seq_len, head_size).astype(np.float32)
ms_input_repeat_kv = mindspore.Tensor(input_array_repeat_kv)

def ms_repeat_kv(hidden_states: mindspore.Tensor, n_rep: int) -> mindspore.Tensor:
    """
    This is the equivalent of ops.repeat_interleave(input, repeats=n_rep, axis=1). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].broadcast_to((batch, num_key_value_heads, n_rep, slen, head_dim))
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

ms_output_repeat_kv = ms_repeat_kv(ms_input_repeat_kv, n_rep=3)

In [149]:
import torch
from torch import nn

def torch_repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

torch_input_repeat_kv = torch.from_numpy(input_array_repeat_kv)
torch_output_repeat_kv = torch_repeat_kv(torch_input_repeat_kv, n_rep=3)

In [150]:
print(np.allclose(ms_output_repeat_kv.asnumpy(), torch_output_repeat_kv.detach().numpy(), atol=1e-3))

True


In [152]:
import mindspore
from mindspore import nn

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )
        self.class_name = self.__class__.__name__

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        print(self.class_name)
        return logits