In [None]:
import math
import warnings
from typing import List, Optional, Tuple, Union
from safetensors import safe_open

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import PreTrainedModel
from transformers import ACT2FN
from transformers import Cache, DynamicCache, StaticCache
from transformers import AttentionMaskConverter

from transformers import (
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from transformers import Phi3Config


if is_flash_attn_2_available():
    from transformers import _flash_attention_forward
from safetensors import safe_open

In [None]:
from model_ref import (
    _prepare_4d_causal_attention_mask_with_cache_position,
    Phi3RMSNorm,
    Phi3RotaryEmbedding,
    Phi3SuScaledRotaryEmbedding,
    Phi3YarnScaledRotaryEmbedding,
    Phi3LongRoPEScaledRotaryEmbedding,
    rotate_half,
    apply_rotary_pos_emb,
    Phi3MLP,
    repeat_kv,
    Phi3Attention,
    Phi3FlashAttention2,
    Phi3SdpaAttention,
    Phi3DecoderLayer,
    NewPhi3Config
)

In [None]:
class Phi3Head(nn.Module):
    def __init__(self, tokenizer, config, head):
        super().__init__()
        self.tokenizer = tokenizer
        self.padding_idx = self.tokenizer.eos_token_id
        self.vocab_size = self.tokenize.vocab_size
        self.config = config
        self.head_length = head
        
        self.embed_token = nn.Embedding(self.vocab_size, self.config.hidden_size, self.padding_idx)
        self.embed_dropout = nn.Dropout(self.config.embd_pdrop)
        
        self.layers = nn.ModuleList(
            [Phi3DecoderLayer(self.config, layer_idx) for layer_idx in range(self.head_length)]
        )
        
        self._attn_implementation = self.config._attn_implementation
        self.gradient_checkpointing = False
    
    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ):
        
        inputs_embeds = self.embed_tokens(input_ids)
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)
            
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        hidden_states = inputs_embeds
        
        next_decoder_cache = None
        
        for decoder_layer in self.layers:
                
            layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    cache_position=cache_position,
                )
            
            hidden_states = layer_outputs[0]

            next_decoder_cache = layer_outputs[1]
            
        return (hidden_states, casual_mask, position_ids, cache_position, next_decoder_cache)

In [None]:
class Phi3Body(nn.Module):
    def __init__(self, config, head, body):
        super().__init__()
        self.config = config
        self.body_length = body - head
        
        self.layers = nn.ModuleList(
            [Phi3DecoderLayer(self.config, layer_idx) for layer_idx in range(self.body_length)]
        )
        
        self._attn_implementation = self.config._attn_implementation
        self.gradient_checkpointing = False
    
    def forward(
        self,
        head_output,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
    ):
        hidden_states = head_output[0]
        next_decoder_cache = head_output[4]
        
        for decoder_layer in self.layers:
                
            layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=head_output[1],
                    position_ids=head_output[2],
                    past_key_value=past_key_values,
                    cache_position=head_output[3],
                )
            
            hidden_states = layer_outputs[0]

            next_decoder_cache = layer_outputs[1]
        return (hidden_states, casual_mask, position_ids, cache_position, next_decoder_cache)

In [None]:
class Phi3Tail(nn.Module):
    def __init__(self, config, body):
        super().__init__()
        self.config = config
        self.vocab_size = self.config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.tail_length = self.config.num_hidden_layers - body
        
        self.layers = nn.ModuleList(
            [Phi3DecoderLayer(self.config, layer_idx) for layer_idx in range(self.tail_length)]
        )
        
        self._attn_implementation = self.config._attn_implementation
        self.gradient_checkpointing = False
    
    def forward(
        self,
        body_output,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
    ):
        
        hidden_states = head_output[0]
        next_decoder_cache = head_output[4]
        
        for decoder_layer in self.layers:
                
            layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=head_output[1],
                    position_ids=head_output[2],
                    past_key_value=past_key_values,
                    cache_position=head_output[3],
                )
            
            hidden_states = layer_outputs[0]

            next_decoder_cache = layer_outputs[1]
                
        hidden_states = self.norm(hidden_states)
        
        logits = self.lm_head(hidden_states)
        logits = logits.float()
        
        return (logits, next_decoder_cache)

In [None]:
class CustomedPhi3ForCausalLM(PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, tokenizer, config, file_path):
        super().__init__()
        self.tokenizer = tokenizer
        self.config = config
        self.head = self.config.head
        self.body = self.config.body
        self.Head_Model = Phi3Head(self.tokenizer, self.config, self.head)
        self.file_path = file_path
    
    def load_weights(self, file_num, partial_model, start, end):
        """
        외장 메모리에서 decoder layer [start,end)까지 가져오기 코드
        여기에 저장하기
        """
        keys = []
        base_file_path_template = '/nas/user/hayoung/model-0000{}-of-00006.safetensors'
        base_key_name = "model.layers."
        included_layers = ['.input_layernorm.weight','.mlp.down_proj.weight', '.mlp.gate_up_proj.weight', 
                           '.post_attention_layernorm.weight','.self_attn.o_proj.weight', 
                           '.self_attn.qkv_proj.weight']

        failed_name = []
        file_path = base_file_path_template.format(file_num)
        
        with safe_open(file_path, framework="pt", device="cuda") as f:
            if start == 0:
                tensor = f.get_tensor('model.embed_tokens.weight')
                partial_model.state_dict()[key].copy_(tensor)
            for i in range(start, end):
                layer_name = base_key_name + str(i)
                for name in included_layers:
                    full_name = layer_name + name
                    try:
                        tensor = f.get_tensor(full_name)
                        partial_model.state_dict()[full_name].copy_(tensor)
                    except:
                        failed_name.append((full_name, file_num))

            if end == 40:
                tensor = f.get_tensor('model.norm.weight')
                partial_model.state_dict()['model.norm.weight'].copy_(tensor)
                tensor = f.get_tensor('lm_head.weight')
                partial_model.state_dict()['lm_head.weight'].copy_(tensor)
        f.close()
        print(failed_name)

        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    )
        
        load_weights(self.Head_Model, 0, self.head)
        head_output = Head_Model(input_ids, attention_mask, position_ids, past_key_values, cache_position)
        
        Body_Model = Phi3Body(self.config, self.head, self,body)
        load_weights(Body_Model, body, tail)
        body_output = Body_Model(head_output, past_key_values)
        del head_output
        
        Tail_Model = Phi3Tail(self.config, self.body)
        load_weights(Tail_Model, self.body, self.config.num_hidden_layers)
        output = Tail_Model(body_output, past_key_values)
        del body_output
        
        hidden_states = outputs[0]
        logits = self.lm_head(hidden_states)
        logits = logits.float()
        
        return (logits, output[1])
        
    # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        **kwargs,
    ):
        
        if past_key_values is not None:
            input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
            input_ids = input_ids[:, cache_position]

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

                # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s  `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
                position_ids = position_ids.clone(memory_format=torch.contiguous_format)

        
        model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

        if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
            
            batch_size, sequence_length = model_inputs["input_ids"].shape
            device = model_inputs["input_ids"].device

            dtype = self.lm_head.weight.dtype
            min_dtype = torch.finfo(dtype).min

            attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
                attention_mask,
                sequence_length=sequence_length,
                target_length=past_key_values.get_max_length(),
                dtype=dtype,
                device=device,
                min_dtype=min_dtype,
                cache_position=cache_position,
                batch_size=batch_size,
            )

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
            }
        )
        return model_inputs