In [33]:
%matplotlib inline

把新闻报道文件放入相应分类名称的文件夹，例如体育类放入sport文件夹。然后用build-sent-datasets.py --sr构建测试数据。最后生成两个文件，train.csv和test.csv。

In [75]:
import torch
import torchtext
import chinese_news_dataset
NGRAMS = 1#汉字每一个token就是一个词
import os
train_dataset, test_dataset = chinese_news_dataset.CN_NEWS(
    root='./data/chinese_news', ngrams=NGRAMS, vocab=None)
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2066lines [00:11, 183.59lines/s]
2066lines [00:11, 177.05lines/s]
1033lines [00:05, 184.54lines/s]


In [76]:
import torch.nn as nn
import torch.nn.functional as F
class TextSentiment(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

In [77]:
VOCAB_SIZE = len(train_dataset.get_vocab())
print(VOCAB_SIZE)
EMBED_DIM = 32
NUN_CLASS = len(train_dataset.get_labels())
print(NUN_CLASS)
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)

67872
3


In [78]:
def generate_batch(batch):
    label = torch.tensor([entry[0] for entry in batch])
    text = [entry[1].long() for entry in batch]
#     print(label, text)
    offsets = [0] + [len(entry) for entry in text]
    # torch.Tensor.cumsum returns the cumulative sum
    # of elements in the dimension dim.
    # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)

    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text = torch.cat(text)
#     print(label,offsets,text)
    return text, offsets, label

In [73]:
from torch.utils.data import DataLoader

def train_func(sub_train_):

    # Train the model
    train_loss = 0
    train_acc = 0
    data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text, offsets, cls) in enumerate(data):
        optimizer.zero_grad()
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        output = model(text, offsets)
        loss = criterion(output, cls)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == cls).sum().item()

    # Adjust the learning rate
    scheduler.step()

    return train_loss / len(sub_train_), train_acc / len(sub_train_)

def test(data_):
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text, offsets, cls in data:
        text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
        with torch.no_grad():
            output = model(text, offsets)
            loss = criterion(output, cls)
            loss += loss.item()
            acc += (output.argmax(1) == cls).sum().item()

    return loss / len(data_), acc / len(data_)

In [74]:
import time
from torch.utils.data.dataset import random_split
N_EPOCHS = 5
min_valid_loss = float('inf')

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = \
    random_split(train_dataset, [train_len, len(train_dataset) - train_len])

for epoch in range(N_EPOCHS):
    try:
        start_time = time.time()
        train_loss, train_acc = train_func(sub_train_)
        valid_loss, valid_acc = test(sub_valid_)

        secs = int(time.time() - start_time)
        mins = secs / 60
        secs = secs % 60

        print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))
        print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')
        print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')
    except Exception as ex:
        print('Epoch: %d' %(epoch + 1), ex)

tensor([    5,     5,    75,    51,    17,   745,   118,    15,    64,    49,
          189,  1779,  1602,     2,    85,    21,     8,  3349,  1975,  1487,
            9, 10292,   454,  5163, 12804,    33, 51567,  2985,  2471, 15938,
          733,    34,     2,   578,    31,     7, 64062, 14218, 19654,  2064,
         5046,  4746,     4,     8, 13027,  1975,     2, 64482,     9,  1797,
         3489,   273,   275,  2083,     3,  2608,  6129,     8,    18, 46411,
          360,  6405,     9,   103,   103, 15938,     2,  2905,    10,  5163,
            3,  9016,     4,     5,     5,  1482,   189,   116,  9242, 34118,
          777,     2, 14218,  2104,  1444,     2, 14218,   183,  6229,    12,
         8120,     6, 14218,  9425,     6, 14218, 11028,    12,   643,  2820,
            6, 67850,  2104,  2268,     2, 15938, 48665,  1483,  3886,    66,
          382,     6,   705,  1975,  1487,   228,    66,   382,   863,     4,
            5,     5,   489,   189,   895,  1887,   454,    12, 

tensor([  5,   5, 979,  ..., 137,   9,   4]) tensor([  5,   5, 979,  ..., 137,   9,   4])
tensor([    5,     5,   836,   930,   818,    25,   219,  2619, 10408,    28,
          463,    40,   276,  6342,  8068,     5,     5,   496,    75,    89,
           17,    89,   118,   101,    64,  7962,    99,   836,   930,   818,
        10440,   594,    52,     2,    49,   744,    40,   252,   182,     2,
          219,     3,  2619, 10408,    28,   463,    40,     3,   276,  6342,
         8068,     4,     5,     5, 10440,     7,   465, 10034,     2,   219,
          328,   751, 50872, 11819,     2,   130,    18,  5502,    32,     8,
          607,  3670,  1155,     9,     3,  4356,     2,  3287,   980,     6,
         3287, 13407,     2, 28183,    40,   524,     2,   687,   227,  1763,
         7252,  6909,     4,   219, 62412,     2,   348,   170,    12,   611,
            3,  2009,  4154, 65385,     4,   219,   451, 14365,     3,   508,
            2, 10472,    11, 54094,   124, 50747,   

         1335,     9,     4])
tensor([  5,   5, 979,  ...,  33, 191,  34]) tensor([  5,   5, 979,  ...,  33, 191,  34])
tensor([    5,     5,   846,   100,    17,    68,   118,    15,    33,    64,
           15, 48163,    34, 12331,     6,  4444,     6,   846,   492,    12,
         3523,  1779,    26, 16153,   606,    68,    31,     7,   846,   358,
        13257,  7078,     2,   218,   620,    23, 11617, 21354,     2,    27,
          492,   454,    12, 60763,     3,  1667,  3986,    23, 22994, 32849,
            4,     5,     5, 16153, 34279, 31140, 25387,  6706,    52,     2,
          279,   752,  2815,  4631,  3986,     2,    35,    38, 16792,  4444,
         1842,   218,   620,     4,     5,     5,     7, 12676,  1384,  2000,
          619,    20,    96,    65,    17,     3,   139,   164,  1688,   620,
        15212,    48,  4216, 13347,  7099,     4,   269,  4444,   298,     2,
         3326,  2685,   515, 28440,     2, 10139,     7,  7666,  3682,  1191,
           19,     2, 

tensor([    5,     5,    85,    21,   100,    17,   434,  1452,   500,    31,
            2, 60884,    18, 21383, 17612,   862,    27,     7,    18,  3309,
          347,     4,   594,     2,   116, 18639,  8819,  2082,     3,  3863,
         4559,   112,  5426,   271,   393,  1286,     4, 12538,     8,   437,
          114,   179,     6,  6707,   114,   294,     6,   110,   114,    36,
            9,     2,   489,  4559,  1373,    27,  2322,   803,   383,   576,
         5426,     8,  1117,     6,   179,     6,   491,    19,     3,   112,
            9,     2,  4832,  3674,    23,    91,    96,    17,   243,    31,
            4,     5,     5,    18,    36,  3863,   809,    90,  7877,    21,
            4,  4581,   295,     3, 17265, 48290,    18, 10243,     2,  7877,
           21,   968, 64756, 67610,     2,   353,  8819,    37,   144,   109,
           36,   951,  1335, 50628,    23,     8,  3863,     9,     2,  3863,
            3, 14832,    11,  3863,  1497,    12,     8, 66230, 

tensor([    5,     5,    75,    51,    17,   657,   118,    15,     8, 51805,
           15, 28385,     9, 60865, 15262,   666,  1222,  4441,     2,   657,
           31,     7,    75, 15196, 30681, 17946,   358,     4,   189,   716,
           10,   273,   275,  2083,     3,  2838,  2943,  3554,  1728,   365,
          127,     2,    42,  9560,   552,  3554,  2491,     6,   365,  2105,
         1081,     2,  5548,  1502,  4312,   643,   281,  1279,     4,     5,
            5,   426, 15262,   666,  1222,  4441,   148,    10,  2643, 17255,
          383,   228,   777,   326,     6,  2643, 17255,   183,    12,  8120,
            6,  2643, 17255, 46843, 36391,    29,  1698,  5586,     4, 17255,
         2150,  5754,  8504, 12317,     6, 17255,   183,    12,  8120, 35643,
        56326,     6, 30681,    33,    75,    34,   643,   205,   382,  1775,
        52249,     6, 15262, 15600,     3, 11127, 56182,    29,  1558,  1804,
            4,     5,     5,   426,  4441,   112,   895,    20, 

          834,    23,    44, 17393,     4,  3376,    15, 62632]) tensor([    5,     5,  3142,  1191,     2,   436,   822,    76,   543,    74,
         2197,     5,     5,  2016,   492,     2,  1489, 46594,     5,     5,
          604,   139,    51,    17,    96,    31,     2,   436,   174, 36169,
            7, 14696,    48,    19,     4,     5,     5,   496,    64,    15,
        65335,    15, 66288,     5,     5,   604,   139,    51,    17,   504,
           31,     2,    85,    21,  3142,  1191,     7,  1346,  8113,  4271,
            4,    18,  8859,   565,   148,    76,   543,    74,  2197,     2,
          266,    11,  4374,  3234,     6,   350,  3427,  3234,     6,   350,
        13414,  3234,    12,  2755,     6,   523, 15627,  2755,     4,   679,
          133,     7,  1772,     3,   252,  3282,    20,  2766,     2,   250,
           16, 13615,   981,   151,     4,     5,     5,     7,   880,   112,
            3,  1553,    19,     2,   436,  1946,    78,   424,     2,    35,

tensor([    5,     5,   215,    89,    17,   197,   118,    15,    33,    64,
           15, 63642,    34, 17700,  1313,     8,  7685, 63762,     9,   215,
          888,  2152,   143, 50431,     8,  4923, 11216,  4134,     9,   888,
         2361,  9159,   826,  1286,  1810,   197,    31,     7,  1498,   358,
            4,     5,     5,   826,    42,     8,  1149,  3195,    38,   830,
            9,    23,   342,   540,     2,  2322,   227,   888,  2155,   785,
         5426,  2361,  9159,     2,  2344,    49,  7056,  5946,  2525,   697,
          342,   963,     2,  1511,  6707,   888,   640,  1854,   294,     2,
         1079,    55,   332,   888,  6254,    20,  2313,     6,  8421, 12716,
            6, 46700,  8458,     3,  9811,     4,     5,     5,   422, 26634,
         5426,   760,     2,    56,    27,  8434,  3238,     6, 11926,     6,
        12573,   192,     6,   472,    50,    29,   880,   732,     2,  1391,
         7612, 15019,   472,  2313,     3,   888, 47541,     2, 

         1496,   210,  1184,    15,    15,  9178]) tensor([    5,     5,     5,     5,  4924,   867,  9261,     8,  2401,   975,
          499,   112,  3146,     9, 23708, 16283,    15,  2551,   867,   242,
         1276,   873,  1147,  1555,   285,     5,     5,   102,    17,   100,
           31,  1329,     2,   513,  4924,   867,     8,  4737,  7983,   438,
          975,   499,   112,  3146,     9,     3,  8383,  1276,  1371,     2,
         1682,  1391,   285,     4,   670,    20,   903,   298,     2,   422,
         4924,   867,     2, 19336, 25020,   867,    16,  2401,    10,  5633,
          112,     4,   268,   479,   298,     2,   497,     8,  4737,  7983,
            9,   207,   801,    75,  7830,     3,  2604,     4,   907,     2,
         1437,    75,   543,  3186,     3,  1388,   766,  5567,    64,     2,
           75,   543,  3186,   926,    60,    78,  7983,   975,   867,     2,
         7615,    28,   975,   867,   160,     3,   258,    16,    11,   557,
         2349

tensor([   5,    5,  178,  ..., 3128,    4,    9]) tensor([   5,    5,  178,  ..., 3128,    4,    9])
tensor([  5,   5, 606,  ..., 601,   9,   4]) tensor([  5,   5, 606,  ..., 601,   9,   4])
tensor([    5,     5,    51,    17,   904,   118,    15,    75,   139,   904,
           31,  1921,     2,  4511,   125,   100,  1095,     2,  3337,   938,
          470, 10470,     4,  2121,   403,   217, 45635, 10295,   548,  1282,
         4485,     4,  1987,  1928,   605,   217,  2706,  1843,     2,   670,
        30733,  6918,  7806,     4,  1372,  6490,     2,  3337,   339,    25,
           76,  2453, 10470,   274,  7387,  5379, 12423,     4,     5,     5,
        10470,   125,   104,   217,   415,   603,     2, 51202,  7306,    19,
            2, 29589,    72, 19055,  1997,    43,  8999,     2,  9253,     3,
         4139, 10295,   548,   719,  4107,  6163,  8276,  4552,     4,   125,
          544,   217,  7555,   934,   916,     2,  3337,  5498,  4048,     2,
         4139, 10295,   548,

         1053, 11079,  5147, 43793,   504,   229,    60,  1103,  3697,     4]) tensor([    5,     5,   496,  1053,    51,    17,   504,   118,    33,    64,
        48253,    34,    64,    49, 12980,  1053,  4612,  1602,     2,   594,
            2,  5464, 11079,   110,  2889,   382,   548,   238,  1759,     3,
          110,  3548,   558,   216,    60,  6953,   175,   110, 19273,   262,
           43,  1175,  4068,   229,     4,     5,     5,   269,  1053,  7461,
        22034,    85, 22035,    95,   731,  3227,  5388,   158,     2, 11079,
          110,   215,  3720,  2231,  1206,     8,  2093,   690,     9,   188,
           79,     2,    20,     8,   197,  3489,   517,   110,  3548,     9,
            8,  1601,  2742,  7560,   394,     9,    29,  8685,     2,   238,
         1759,     3,   110,  3548,   216,    60,  6953,   175,   110, 19273,
          262,     4,     5,     5, 11079,   110,   705,  3720,   863,     3,
          161,  1696, 47883,  2281,    33, 41508,    34,     2,

KeyboardInterrupt: 

In [70]:
print('Checking the results of test dataset...')
test_loss, test_acc = test(test_dataset)
print(f'\tLoss: {test_loss:.4f}(test)\t|\tAcc: {test_acc * 100:.1f}%(test)')

Checking the results of test dataset...
	Loss: 0.0000(test)	|	Acc: 94.6%(test)


In [71]:
import re
from torchtext.data.utils import ngrams_iterator
from chinese_news_dataset import get_tokenizer

ag_news_label = {1 : "ent",
                 2 : "sport",
                 3 : "fortune"}

def predict(text, model, vocab, ngrams):
    tokenizer = get_tokenizer("jieba")
    with torch.no_grad():
        text = torch.tensor([vocab[token]
                            for token in ngrams_iterator(tokenizer(text), ngrams)])
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = '''
在今天76人以104-106惜败于爵士的比赛中，本-西蒙斯在一次对抗中遭遇右肩扭伤，在赛后接受采访时，76人主帅布朗特-布朗也称他的缺阵让球队的防守大打折扣。最新消息指出，西蒙斯被诊断出右肩肩锁关节一级扭伤，他将缺席76人明日与掘金的比赛。并很可能缺席接下来三场比赛。
'''
vocab = train_dataset.get_vocab()
model = model.to("cpu")

print("This is a %s " %ag_news_label[predict(ex_text_str, model, vocab, 2)])

This is a sport 
