-
Notifications
You must be signed in to change notification settings - Fork 0
/
last_2_first_2_BERTweet_model.py
65 lines (58 loc) · 2.23 KB
/
last_2_first_2_BERTweet_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from transformers import RobertaConfig, RobertaModel, BertPreTrainedModel
from torch.nn import CrossEntropyLoss
from torch import nn
from typing import Tuple
import torch
class BERTweetModelForClassification(BertPreTrainedModel):
base_model_prefix = "roberta"
def __init__(self):
self.num_labels: int = 2
config: RobertaConfig = RobertaConfig.from_pretrained(
"./BERTweet_base_transformers/config.json",
output_hidden_states=True,
)
super().__init__(config)
self.model: RobertaModel = RobertaModel.from_pretrained(
"./BERTweet_base_transformers/model.bin",
config=config
)
self.dense = nn.Linear(in_features=768 * 4,
out_features=768,
)
self.dropout = nn.Dropout(p=0.1)
self.dense_2 = nn.Linear(in_features=768,
out_features=256,
)
self.classifier = nn.Linear(in_features=256,
out_features=self.num_labels,
)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
):
outputs = self.model(
input_ids,
attention_mask=attention_mask,
)
# Take <CLS> token for Native Layer Norm Backward
hidden_states: Tuple[torch.tensor] = outputs[2]
sequence_output: torch.tensor = torch.cat((
hidden_states[-1][:, 0, :],
hidden_states[-2][:, 0, :],
hidden_states[0][:, 0, :],
hidden_states[1][:, 0, :]
), dim=1)
sequence_output = self.dense(sequence_output)
sequence_output = self.dropout(sequence_output)
sequence_output = self.dense_2(sequence_output)
sequence_output = self.dropout(sequence_output)
logits: torch.tensor = self.classifier(sequence_output)
outputs = (logits,)
if labels is not None:
loss_function = CrossEntropyLoss()
loss = loss_function(
logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # loss, logits