# 利用LSTM神經網路完成路透社新聞分類

(1)下載Reuters資料集

In [None]:
from tensorflow.keras.datasets import reuters
# 下載最常見的多少字
num_word = 10000
(train_data,train_label),(test_data,test_label) = \
    reuters.load_data(num_words=num_word)
print("train_data.shape",train_data.shape)
print("train_label.shape",train_label.shape)
print("test_data.shape",test_data.shape)
print("test_label.shape",test_label.shape)

印出第一筆資料內容

In [None]:
print("train_data[0] :",train_data[0])
print("train_label[0] :",train_label[0])

利用單詞的內容先去找出相對應的索引

In [None]:
Index_of_word = reuters.get_word_index()
youIndex = Index_of_word["you"]
print("'you' index = ",youIndex)

(2)訓練前準備：資料預處理

In [None]:
# 資料預處理
from tensorflow.keras.preprocessing import sequence
# 將原始新聞長度裁剪成固定長度
wordMaxNum = 200
train_data_new = sequence.pad_sequences(train_data,maxlen = wordMaxNum)
test_data_new = sequence.pad_sequences(test_data,maxlen = wordMaxNum)
print(train_data_new.shape)
print(test_data_new.shape)

轉換成One-hot 編碼形式

In [None]:
import tensorflow as tf
# 定義類別數目
num_classes = 46
One_hot_Train = tf.one_hot(train_label,depth=num_classes)
One_hot_Test = tf.one_hot(test_label,depth=num_classes)

(3) 模型建立

In [None]:
from tensorflow.keras.layers import LSTM
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()
model.add(layers.Embedding(num_word,
                           output_dim=200,
                           input_length= wordMaxNum))
model.add(LSTM(128,dropout=0.5,return_sequences=True))
model.add(LSTM(128,dropout=0.5))
model.add(layers.Dense(num_classes, activation='softmax'))
print(model.summary())

(4) 模型編譯與訓練

In [None]:
import matplotlib.pyplot as plt
batch_sizes = 32
epochs = 50
model.compile(optimizer='rmsprop',
              loss='categorical_crossentropy',
              metrics=["accuracy"])
hist = model.fit(train_data_new,One_hot_Train,epochs=epochs,
                 batch_size=batch_sizes,verbose=2,
                 validation_split=0.2)

(5) 模型正確率計算：用測試集加以驗證並計算其正確率

In [None]:
loss,accuracy = model.evaluate(test_data_new,One_hot_Test)
print("測試集的正確率 = ",accuracy)