# src - utils

## Notebook运行提示
- 代码已拆分为多个小单元, 按顺序运行即可在每一步观察输出与中间变量。
- 涉及 `Path(__file__)` 或相对路径的脚本会自动注入 `__file__` 解析逻辑, Notebook 环境下也能引用原项目资源。
- 可在每个单元下追加说明或参数试验记录, 以跟踪核心算法和数据处理步骤。


In [None]:
from .custom_model import CustomModel
import os
import torch
from transformers import AutoConfig

In [None]:


def freeze(module):
    for parameter in module.parameters():
        parameter.requires_grad = False

In [None]:


def get_backbone_config(config):
    if config.model.backbone_config_path == '':
        backbone_config = AutoConfig.from_pretrained(config.model.backbone_type, output_hidden_states=True)

        backbone_config.hidden_dropout = config.model.backbone_hidden_dropout
        backbone_config.hidden_dropout_prob = config.model.backbone_hidden_dropout_prob
        backbone_config.attention_dropout = config.model.backbone_attention_dropout
        backbone_config.attention_probs_dropout_prob = config.model.backbone_attention_probs_dropout_prob

    else:
        backbone_config = torch.load(config.model.backbone_config_path)
    return backbone_config

In [None]:


def update_old_state(state):
    new_state = {}
    for key, value in state['model'].items():
        new_key = key
        if key.startswith('model.'):
            new_key = key.replace('model', 'backbone')
        new_state[new_key] = value

    updated_state = {'model': new_state, 'predictions': state['predictions']}
    return updated_state

In [None]:


def get_model(config, backbone_config_path=None, model_checkpoint_path=None, train=True):
    backbone_config = get_backbone_config(config) if backbone_config_path is None else torch.load(backbone_config_path)

    model = CustomModel(config, backbone_config=backbone_config)

    if model_checkpoint_path is not None:
        state = torch.load(model_checkpoint_path, map_location='cpu')
        if 'model.embeddings.position_ids' in state['model'].keys():
            state = update_old_state(state)
        model.load_state_dict(state['model'])

    if config.model.gradient_checkpointing:
        if model.backbone.supports_gradient_checkpointing:
            model.backbone.gradient_checkpointing_enable()
        else:
            print(f'{config.model.backbone_type} does not support gradient checkpointing')

    if train:
        if config.model.freeze_embeddings:
            freeze(model.backbone.embeddings)
        if config.model.freeze_n_layers > 0:
            freeze(model.backbone.encoder.layer[:config.model.freeze_n_layers])
        if config.model.reinitialize_n_layers > 0:
            for module in model.backbone.encoder.layer[-config.model.reinitialize_n_layers:]:
                model._init_weights(module)

    return model