In [3]:
import copy
import numpy as np
import mlx
import mlx.nn as mx
import mlx.core as mx_core
import torch
import torch.nn as nn
import torch.nn.functional as F
import safetensors
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from huggingface_hub import notebook_login
from pprint import pprint
from source.gemma.model import *
from source.gemma.model_mlx import *

# 테스트 모델 Gemma-2b-it 업로드
path = "./model/gemma-1.1-2b-it/"
model_name = "model-{}-of-{}.safetensors"
model1 = safetensors.safe_open(path+model_name.format("00001", "00002"), framework="pt")
model2 = safetensors.safe_open(path+model_name.format("00002", "00002"), framework="pt")

print(model1.keys())
print(model2.keys())

# model1, model2의 tensor들을 답음
tensors = {}
for key in model1.keys():
    tensors[key] = model1.get_tensor(key)
for key in model2.keys():
    tensors[key] = model2.get_tensor(key)
print("전체 키의 수: ", len(tensors))


# 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
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
input_text = "Write me a poem about Machine Learning test."
result = tokenizer(input_text, return_tensors="pt")
result = result["input_ids"]
print(result.shape)

['model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.laye

In [207]:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    Precomputes the frequency cis.
    """
    freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    # 시간 벡터와 주파수 성분을 외저갛여 각 위치와 주파수 인덱스에 대한 2D 행렬 연산
    freqs     = torch.outer(t, freqs).float()
    # freqs 크기 1 텐서에 1 * cos(freqs) + 1 * sin(freqs)를 연산 -> complex64
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis


def MLXprecompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> mx_core.array:
    """
    Precomputes the frequencey cis.
    """
    freqs = 1.0 / (theta ** (mx_core.arange(0, dim, 2)[:(dim // 2)].astype(mx_core.float16) / dim))
    t = mx_core.arange(end)
    freqs = mx_core.outer(t, freqs)
    cos = mx_core.ones_like(freqs) * mx_core.cos(freqs)
    sin = mx_core.ones_like(freqs) * mx_core.sin(freqs)
    freq_cis = mx_core.stack([cos, sin], axis = -1)
    return freq_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
    """
    Applies the rotary embedding to the query and key tensors.
    """
    # [B, C, H, W] -> [B, H, C, W] -> [B, H, C, W/2] & [B, H, C, W / 2]
    # torch.view_as_complex with dim = -1를 통해 [B, H, C, W/2, 0]은 실수, [B, H, C, W/2, 1]은 허수 
    x_ = torch.view_as_complex(torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
    x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) ## 여기까지 결과가 같음

    # 해결해야 할 부분
    x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
    x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2)
    return x_out


def MLXapply_rotary_emb(x: mx_core.array, freqs_cis: mx_core.array) -> mx_core.array:
    x_transpose = x.transpose(0, 2, 1, 3).astype(mx_core.float16)
    x_real = x_transpose[:, :, :, :2]
    x_imag = x_transpose[:, :, :, 2:]
    x_ = mx_core.stack([x_real, x_imag], axis = -1)

    x_out_real = x_[:, :, :, :, 0] * freqs_cis[:, :, 0] - x_[:, :, :, :, 1] * freqs_cis[:, :, 1]
    x_out_imag = x_[:, :, :, :, 0] * freqs_cis[:, :, 1] + x_[:, :, :, :, 1] * freqs_cis[:, :, 0]
    x_out = mx_core.stack([x_out_real, x_out_imag], axis = -1) ## 여기까지 결과가 같음

    # 해결해야 할 부분
    x_out = mx_core.stack([x_out[:, :, :, :, 0], x_out[:, :, :, :, 1]], axis = 2)

    



dim = 4
end = 4
theta = 10000.0
x = torch.randn(1, 1, 4, 4)
freq1 = precompute_freqs_cis(dim, end, theta)
freq2 = MLXprecompute_freqs_cis(dim, end, theta)


outputs = apply_rotary_emb(x, freq1)
MLXapply_rotary_emb(mx_core.array(x.numpy()), freq2)

torch.Size([1, 4, 4, 4, 1])
(1, 4, 2, 4, 2)
