In [29]:
import torch
from transformers import AutoModelForSequenceClassification

MAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)

In [30]:
from transformers import AutoTokenizer, DataCollatorWithPadding

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [31]:
import pandas as pd

test_df = pd.read_csv("./data/sougou/test.csv")

In [32]:
test_df.head()

Unnamed: 0,text,label
0,届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...,0
1,商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...,0
2,今天下午，沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多，因此这场比赛实际...,0
3,本报讯中国足协准备好了与特鲁西埃谈判的合同文本，也在北京给他预订好了房间，但特鲁西埃爽约了！...,0
4,网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日，2006年尤伯杯羽毛球赛在日...,0


In [33]:
import numpy as np
import time

s_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():
    row_s_time = time.time()
    true_labels.append(row["label"])
    encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)
    # print(encoded_text)
    logits = model(**encoded_text)
    label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]
    pred_labels.append(label_id)
    print(i, (time.time() - row_s_time)*1000, label_id)

print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])

0 229.3872833251953 0
1 243.35908889770508 0
2 214.42627906799316 0
3 199.4609832763672 0
4 204.45489883422852 0
5 214.42580223083496 0
6 206.4521312713623 0
7 213.43016624450684 0
8 208.44221115112305 0
9 225.39734840393066 0
10 199.4636058807373 0
11 214.46609497070312 0
12 197.432279586792 0
13 204.4544219970703 0
14 193.4809684753418 0
15 196.4743137359619 0
16 197.5078582763672 0
17 195.4798698425293 0
18 205.45053482055664 0
19 284.2395305633545 0
20 354.09021377563477 0
21 369.00925636291504 0
22 380.9833526611328 0
23 361.03320121765137 0
24 378.9849281311035 0
25 327.1212577819824 0
26 367.05589294433594 0
27 379.9445629119873 0
28 422.87611961364746 0
29 391.9501304626465 0
30 368.01648139953613 0
31 383.9724063873291 0
32 421.8719005584717 0
33 394.9441909790039 0
34 356.0502529144287 0
35 414.8907661437988 0
36 413.8920307159424 0
37 354.0537357330322 0
38 341.0918712615967 0
39 344.1135883331299 0
40 305.18531799316406 0
41 333.10723304748535 0
42 363.03091049194336 0
43 3

341 361.0348701477051 3
342 386.9631290435791 3
343 336.1067771911621 3
344 339.0941619873047 3
345 384.9678039550781 3
346 310.17065048217773 3
347 341.08543395996094 3
348 361.0367774963379 3
349 353.0540466308594 3
350 374.00150299072266 3
351 341.0828113555908 3
352 367.0172691345215 3
353 315.1590824127197 3
354 335.10398864746094 3
355 388.98158073425293 3
356 401.9613265991211 3
357 398.8938331604004 3
358 363.02947998046875 3
359 369.016170501709 3
360 365.0217056274414 3
361 400.9289741516113 3
362 365.02671241760254 3
363 345.0808525085449 3
364 466.7515754699707 3
365 399.9297618865967 3
366 360.03971099853516 3
367 391.94750785827637 3
368 387.96496391296387 3
369 428.85398864746094 3
370 384.96971130371094 3
371 388.9598846435547 3
372 349.06864166259766 3
373 330.11674880981445 3
374 354.0534973144531 3
375 356.0514450073242 3
376 372.0054626464844 3
377 374.00007247924805 3
378 349.1096496582031 3
379 342.08130836486816 3
380 359.0400218963623 3
381 391.9527530670166 3
3

In [34]:
true_labels[:10]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [35]:
pred_labels[:10]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [36]:
from sklearn.metrics import classification_report

print(classification_report(true_labels, pred_labels, digits=4))

              precision    recall  f1-score   support

           0     0.9900    1.0000    0.9950        99
           1     0.9691    0.9495    0.9592        99
           2     0.9900    1.0000    0.9950        99
           3     0.9320    0.9697    0.9505        99
           4     0.9895    0.9495    0.9691        99

    accuracy                         0.9737       495
   macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495



In [37]:
# 模型量化
cpu_device = torch.device("cpu")

In [38]:
torch.backends.quantized.engine = 'x86'

In [43]:
# 8-bit 量化
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
).to(cpu_device)
print(quantized_model)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.1, inplace=False)
            )
      

In [40]:
q_s_time = time.time()
q_true_labels, q_pred_labels = [], [] 

for i, row in test_df.iterrows():
    row_s_time = time.time()
    q_true_labels.append(row["label"])
    encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(cpu_device)
    logits = quantized_model(**encoded_text)
    label_id = np.argmax(logits[0].detach().numpy(), axis=1)[0]
    q_pred_labels.append(label_id)
    print(i, (time.time() - row_s_time) * 1000, label_id)
    
print("avg time: ", (time.time() - q_s_time) * 1000 / test_df.shape[0])

0 195.47462463378906 0
1 214.42794799804688 0
2 187.49666213989258 0
3 229.38966751098633 0
4 210.43682098388672 0
5 206.44688606262207 0
6 229.3863296508789 0
7 194.48041915893555 0
8 194.47994232177734 0
9 223.4032154083252 0
10 196.47884368896484 0
11 201.45750045776367 0
12 211.43531799316406 0
13 200.46401023864746 0
14 221.40908241271973 0
15 202.45695114135742 0
16 200.46687126159668 0
17 222.40281105041504 0
18 196.47479057312012 0
19 217.4217700958252 0
20 230.38458824157715 0
21 185.50443649291992 0
22 238.36445808410645 0
23 241.35565757751465 0
24 212.4342918395996 0
25 228.3918857574463 0
26 223.402738571167 0
27 193.4826374053955 0
28 198.4689235687256 0
29 209.43951606750488 0
30 199.46789741516113 0
31 194.47946548461914 0
32 189.5456314086914 0
33 213.42802047729492 0
34 189.53251838684082 0
35 190.44995307922363 0
36 203.45640182495117 0
37 207.44705200195312 0
38 232.37872123718262 0
39 227.3998260498047 0
40 249.33409690856934 0
41 255.31530380249023 0
42 221.409797

337 195.44172286987305 3
338 195.47510147094727 3
339 235.37063598632812 3
340 193.4833526611328 3
341 226.3946533203125 3
342 212.4338150024414 3
343 206.44187927246094 3
344 178.51972579956055 3
345 217.41890907287598 3
346 198.4696388244629 3
347 223.4041690826416 3
348 200.46114921569824 3
349 204.45489883422852 3
350 198.4701156616211 3
351 197.47304916381836 3
352 206.44783973693848 3
353 190.4921531677246 3
354 192.48437881469727 3
355 207.44633674621582 3
356 217.41986274719238 3
357 207.4449062347412 3
358 220.4113006591797 3
359 191.48778915405273 3
360 207.44609832763672 3
361 191.48826599121094 3
362 209.4414234161377 3
363 216.42208099365234 3
364 325.1354694366455 3
365 186.50269508361816 3
366 232.38134384155273 3
367 195.4786777496338 3
368 200.46257972717285 3
369 194.48041915893555 3
370 208.44244956970215 3
371 179.52251434326172 3
372 208.44340324401855 3
373 193.4812068939209 3
374 188.49468231201172 3
375 208.44483375549316 3
376 198.4691619873047 3
377 202.458858

In [41]:
from sklearn.metrics import classification_report

print(classification_report(q_true_labels, q_pred_labels, digits=4))

              precision    recall  f1-score   support

           0     0.9900    1.0000    0.9950        99
           1     0.9688    0.9394    0.9538        99
           2     0.9900    1.0000    0.9950        99
           3     0.9320    0.9697    0.9505        99
           4     0.9896    0.9596    0.9744        99

    accuracy                         0.9737       495
   macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495



In [44]:
torch.backends.quantized.supported_engines

['none', 'onednn', 'x86', 'fbgemm']

In [46]:
import os

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print("Size (MB): ", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

print_size_of_model(model)
print_size_of_model(quantized_model)

Size (MB):  409.155273
Size (MB):  152.627621
