# src - custom_model

## Notebook运行提示
- 代码已拆分为多个小单元, 按顺序运行即可在每一步观察输出与中间变量。
- 涉及 `Path(__file__)` 或相对路径的脚本会自动注入 `__file__` 解析逻辑, Notebook 环境下也能引用原项目资源。
- 可在每个单元下追加说明或参数试验记录, 以跟踪核心算法和数据处理步骤。


In [None]:
import torch
import torch.nn as nn
from .pooling_layers import get_pooling_layer
from transformers import AutoModel, AutoConfig
from torch.utils.checkpoint import checkpoint

In [None]:


class CustomModel(nn.Module):
    def __init__(self, cfg, backbone_config):
        super().__init__()
        self.cfg = cfg
        self.backbone_config = backbone_config

        if self.cfg.model.pretrained_backbone:
            self.backbone = AutoModel.from_pretrained(cfg.model.backbone_type, config=self.backbone_config)
        else:
            self.backbone = AutoModel.from_config(self.backbone_config)

        self.backbone.resize_token_embeddings(len(cfg.tokenizer))
        self.pool = get_pooling_layer(cfg, backbone_config)
        self.fc = nn.Linear(self.pool.output_dim, len(self.cfg.general.target_columns))

        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.backbone_config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.backbone_config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, inputs):
        outputs = self.backbone(**inputs)
        feature = self.pool(inputs, outputs)
        output = self.fc(feature)
        return output