You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, Callback, text_segmentate, ListDataset
import torch.nn as nn
import torch
import torch.optim as optim
import random, os, numpy as np
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
class MyDataset(ListDataset): @staticmethod
def load_data(filenames):
"""加载数据,并尽量划分为不超过maxlen的句子
"""
D = []
seps, strips = u'\n。!?!?;;,, ', u';;,, '
for filename in filenames:
with open(filename, encoding='utf-8') as f:
for l in f:
text, label = l.strip().split('\t')
for t in text_segmentate(text, maxlen - 2, seps, strips):
D.append((t, int(label)))
return D
def evaluate(data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true).argmax(axis=1)
total += len(y_true)
right += (y_true == y_pred).sum().item()
return right / total
class Evaluator(Callback):
"""评估与保存
"""
def init(self):
self.best_val_acc = 0.
比如这个代码case:
你能提供预测新的文本的inference代码吗?
#! -- coding:utf-8 --
情感分类任务, 加载bert权重
valid_acc: 94.72, test_acc: 94.11
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, Callback, text_segmentate, ListDataset
import torch.nn as nn
import torch
import torch.optim as optim
import random, os, numpy as np
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
maxlen = 256
batch_size = 16
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
writer = SummaryWriter(log_dir='./summary') # prepare summary writer
固定seed
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filenames):
"""加载数据,并尽量划分为不超过maxlen的句子
"""
D = []
seps, strips = u'\n。!?!?;;,, ', u';;,, '
for filename in filenames:
with open(filename, encoding='utf-8') as f:
for l in f:
text, label = l.strip().split('\t')
for t in text_segmentate(text, maxlen - 2, seps, strips):
D.append((t, int(label)))
return D
def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
for text, label in batch:
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_labels.append([label])
加载数据集
train_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.train.data']), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.valid.data']), batch_size=batch_size, collate_fn=collate_fn)
test_dataloader = DataLoader(MyDataset(['E:/Github/bert4torch/examples/datasets/sentiment/sentiment.test.data']), batch_size=batch_size, collate_fn=collate_fn)
定义bert上的模型结构
class Model(BaseModel):
def init(self) -> None:
super().init()
self.bert, self.config = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool=True, return_model_config=True)
self.dropout = nn.Dropout(0.1)
self.dense = nn.Linear(self.config['hidden_size'], 2)
model = Model().to(device)
定义使用的loss和optimizer,这里支持自定义
model.compile(
loss=nn.CrossEntropyLoss(),
optimizer=optim.Adam(model.parameters(), lr=2e-5), # 用足够小的学习率
metrics=['accuracy']
)
定义评价函数
def evaluate(data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true).argmax(axis=1)
total += len(y_true)
right += (y_true == y_pred).sum().item()
return right / total
class Evaluator(Callback):
"""评估与保存
"""
def init(self):
self.best_val_acc = 0.
if name == 'main':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=10, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
The text was updated successfully, but these errors were encountered: