-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
90 lines (70 loc) · 3.06 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Optional
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from ...utils import MODEL_MAP
from ....layers.pooling import Pooler
def get_auto_fc_tc_model(
model_type: str = "bert",
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> PreTrainedModel:
base_model, parent_model, base_model_name = MODEL_MAP[model_type]
class SequenceClassification(parent_model):
"""
基于BERT的文本分类模型
Args:
config: 模型的配置对象
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.num_labels = config.num_labels
self.pooler_type = getattr(config, 'pooler_type', 'cls')
if self.pooler_type != "cls":
self.config.output_hidden_states = True
setattr(self, base_model_name, base_model(config, add_pooling_layer=False))
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.pooling = Pooler(self.pooler_type)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> SequenceClassifierOutput:
outputs = getattr(self, base_model_name)(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=self.config.output_hidden_states or output_hidden_states,
)
pooled_output = self.dropout(self.pooling(outputs, attention_mask))
logits = self.classifier(pooled_output)
loss = self.compute_loss([logits, labels]) if labels is not None else None
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def compute_loss(self, inputs):
logits, labels = inputs[:2]
loss_fct = CrossEntropyLoss()
return loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassification
def get_fc_model_config(id2label, **kwargs):
model_config = {
"num_labels": len(id2label), "pooler_type": "cls", "classifier_dropout": 0.3, "id2label": id2label,
}
model_config.update(kwargs)
return model_config