# Amazon Reviewsを用いてレビュー文の評価分類をLSTMとRNNで比較

In [50]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
import re
from collections import Counter

device = "cuda" if torch.cuda.is_available() else "cpu"

# --- データロード ---
dataset = load_dataset("amazon_polarity")
train_texts = dataset["train"]["content"][:50000]  # 軽量化
train_labels = dataset["train"]["label"][:50000]
test_texts  = dataset["test"]["content"][:10000]
test_labels = dataset["test"]["label"][:10000]

# --- トークナイズ（数字も残す） ---
def tokenize(text):
    text = text.lower()
    text = re.sub(r"[^a-z0-9 ]", "", text)  # 数字も残す
    return text.split()

train_tokens = [tokenize(t) for t in train_texts]
test_tokens  = [tokenize(t) for t in test_texts]

# --- 語彙作成（上位 50,000語） ---
counter = Counter()
for tokens in train_tokens:
    counter.update(tokens)

vocab = {"<pad>": 0, "<unk>": 1}
for word, _ in counter.most_common(50000):
    vocab[word] = len(vocab)

def encode(tokens):
    return torch.tensor([vocab.get(t, vocab["<unk>"]) for t in tokens])

train_encoded = [encode(t) for t in train_tokens]
test_encoded  = [encode(t) for t in test_tokens]

# --- DataLoader ---
def collate_fn(batch):
    texts, labels = zip(*batch)
    texts = pad_sequence(texts, batch_first=True)
    labels = torch.tensor(labels)
    return texts, labels

train_loader = DataLoader(list(zip(train_encoded, train_labels)), batch_size=64, shuffle=True, collate_fn=collate_fn)
test_loader  = DataLoader(list(zip(test_encoded, test_labels)), batch_size=64, collate_fn=collate_fn)

# --- モデル定義（LSTM / RNN） ---
class SentimentRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=128, rnn_type="LSTM"):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        if rnn_type == "LSTM":
            self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        else:
            self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 2)
        self.rnn_type = rnn_type

    def forward(self, x):
        x = self.embedding(x)
        if self.rnn_type == "LSTM":
            _, (h, _) = self.rnn(x)
        else:
            _, h = self.rnn(x)
        out = self.fc(h[-1])
        return out

# --- 学習関数 ---
def train_model(model, train_loader, test_loader, epochs=5, lr=1e-3):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"{model.rnn_type} Epoch {epoch+1}, Loss: {total_loss:.3f}")

    # 評価
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    acc = correct / total
    print(f"{model.rnn_type} Test Accuracy: {acc:.4f}")
    return model, acc

# --- ラベル & 単語復元 ---
label_map = {0: "negative", 1: "positive"}
idx_to_word = {idx: word for word, idx in vocab.items()}

def decode_review(token_ids):
    words = [idx_to_word.get(i, "<unk>") for i in token_ids if i != 0]
    return " ".join(words)

def visualize_review(token_ids, y_true, y_pred, confidence, orig_text="", title="Review"):
    print(f"{title}:")
    print("Original:", orig_text)  # 元の文章も表示
    print("Tokenized:", decode_review(token_ids))  # モデルに入ったトークン列
    print(f"Ground Truth : {label_map[y_true]}")
    print(f"Prediction   : {label_map[y_pred]}")
    print(f"Confidence   : {confidence:.3f}")
    print("-" * 80)

# --- モデル作成 & 学習 ---
lstm_model = SentimentRNN(len(vocab), rnn_type="LSTM")
rnn_model  = SentimentRNN(len(vocab), rnn_type="RNN")

lstm_model, lstm_acc = train_model(lstm_model, train_loader, test_loader)
rnn_model, rnn_acc   = train_model(rnn_model, train_loader, test_loader)

print(f"\nComparison -> LSTM Accuracy: {lstm_acc:.4f}, RNN Accuracy: {rnn_acc:.4f}\n")

LSTM Epoch 1, Loss: 541.844
LSTM Epoch 2, Loss: 528.238
LSTM Epoch 3, Loss: 287.088
LSTM Epoch 4, Loss: 183.813
LSTM Epoch 5, Loss: 125.049
LSTM Test Accuracy: 0.8731
RNN Epoch 1, Loss: 542.618
RNN Epoch 2, Loss: 541.661
RNN Epoch 3, Loss: 543.998
RNN Epoch 4, Loss: 546.622
RNN Epoch 5, Loss: 544.783
RNN Test Accuracy: 0.4875

Comparison -> LSTM Accuracy: 0.8731, RNN Accuracy: 0.4875



In [51]:
# --- サンプルレビュー比較 ---
def compare_models_on_sample(idx):
    x = test_encoded[idx].unsqueeze(0).to(device)
    y_true = test_labels[idx]
    orig_text = test_texts[idx]
    with torch.no_grad():
        lstm_pred_logits = lstm_model(x)
        rnn_pred_logits  = rnn_model(x)
        lstm_pred = torch.softmax(lstm_pred_logits, dim=1).argmax(1).item()
        rnn_pred  = torch.softmax(rnn_pred_logits, dim=1).argmax(1).item()
        lstm_conf = torch.softmax(lstm_pred_logits, dim=1)[0, lstm_pred].item()
        rnn_conf  = torch.softmax(rnn_pred_logits, dim=1)[0, rnn_pred].item()
    visualize_review(test_encoded[idx], y_true, lstm_pred, lstm_conf, orig_text, title=f"LSTM/RNN Sample {idx+1}")
    print(f"LSTM Prediction: {label_map[lstm_pred]} (conf: {lstm_conf:.3f})")
    print(f"RNN  Prediction: {label_map[rnn_pred]} (conf: {rnn_conf:.3f})")
    print("-"*80)

for i in range(5):
    compare_models_on_sample(i)

LSTM/RNN Sample 1:
Original: My lovely Pat has one of the GREAT voices of her generation. I have listened to this CD for YEARS and I still LOVE IT. When I'm in a good mood it makes me feel better. A bad mood just evaporates like sugar in the rain. This CD just oozes LIFE. Vocals are jusat STUUNNING and lyrics just kill. One of life's hidden gems. This is a desert isle CD in my book. Why she never made it big is just beyond me. Everytime I play this, no matter black, white, young, old, male, female EVERYBODY says one thing "Who was that singing ?"
Tokenized: <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk

In [52]:
# --- 誤分類レビュー比較 ---
def compare_misclassified(max_samples=5):
    misclassified_count = 0
    for i, (x_tokens, y_true) in enumerate(zip(test_encoded, test_labels)):
        x = x_tokens.unsqueeze(0).to(device)
        orig_text = test_texts[i]
        with torch.no_grad():
            lstm_pred = torch.softmax(lstm_model(x), dim=1).argmax(1).item()
            rnn_pred  = torch.softmax(rnn_model(x), dim=1).argmax(1).item()
            lstm_conf = torch.softmax(lstm_model(x), dim=1)[0, lstm_pred].item()
            rnn_conf  = torch.softmax(rnn_model(x), dim=1)[0, rnn_pred].item()
        if lstm_pred != y_true or rnn_pred != y_true:
            misclassified_count += 1
            visualize_review(x_tokens, y_true, lstm_pred, lstm_conf, orig_text, title=f"Misclassified Review {misclassified_count}")
            print(f"LSTM Prediction: {label_map[lstm_pred]} (conf: {lstm_conf:.3f})")
            print(f"RNN  Prediction: {label_map[rnn_pred]} (conf: {rnn_conf:.3f})")
            print("-"*80)
            if misclassified_count >= max_samples:
                break

compare_misclassified(max_samples=5)

Misclassified Review 1:
Original: Despite the fact that I have only played a small portion of the game, the music I heard (plus the connection to Chrono Trigger which was great as well) led me to purchase the soundtrack, and it remains one of my favorite albums. There is an incredible mix of fun, epic, and emotional songs. Those sad and beautiful tracks I especially like, as there's not too many of those kinds of songs in my other video game soundtracks. I must admit that one of the songs (Life-A Distant Promise) has brought tears to my eyes on many occasions.My one complaint about this soundtrack is that they use guitar fretting effects in many of the songs, which I find distracting. But even if those weren't included I would still consider the collection worth it.
Tokenized: <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <u

In [53]:
import pandas as pd

# --- 詳細比較表作成 ---
def compute_detailed_metrics(model, test_encoded, test_labels):
    """モデルごとの詳細評価指標を計算"""
    correct, total = 0, 0
    tp, tn, fp, fn = 0, 0, 0, 0  # 正例／負例ごとの分類
    for x_tokens, y_true in zip(test_encoded, test_labels):
        x = x_tokens.unsqueeze(0).to(device)
        with torch.no_grad():
            pred = torch.softmax(model(x), dim=1).argmax(1).item()
        total += 1
        if pred == y_true:
            correct += 1
            if y_true == 1:
                tp += 1
            else:
                tn += 1
        else:
            if y_true == 1:
                fn += 1
            else:
                fp += 1
    accuracy = correct / total
    pos_acc = tp / (tp + fn) if (tp + fn) > 0 else 0
    neg_acc = tn / (tn + fp) if (tn + fp) > 0 else 0
    misclassified_rate = 1 - accuracy
    return {
        "Accuracy": accuracy,
        "Misclassified Rate": misclassified_rate,
        "Positive Accuracy": pos_acc,
        "Negative Accuracy": neg_acc
    }

# 各モデルの詳細評価
lstm_metrics = compute_detailed_metrics(lstm_model, test_encoded, test_labels)
rnn_metrics  = compute_detailed_metrics(rnn_model, test_encoded, test_labels)

# DataFrame にまとめる
df_metrics = pd.DataFrame([lstm_metrics, rnn_metrics], index=["LSTM", "RNN"])
print("\nDetailed Comparison Table:")
print(df_metrics)


Detailed Comparison Table:
      Accuracy  Misclassified Rate  Positive Accuracy  Negative Accuracy
LSTM    0.8655              0.1345           0.859707           0.871590
RNN     0.5001              0.4999           0.505171           0.494769
