In [1]:
#dataloader 
import torch.nn as nn
import torch
import torch.nn.functional as F

class TextRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_of_class, weights=None, rnn_type="RNN"):
        super(TextRNN, self).__init__()

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_of_class = num_of_class
        self.embedding_dim = embedding_dim
        self.rnn_type = rnn_type

        if weights is not None:
            self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, _weight=weights)
        else:
            self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

        if rnn_type == "RNN":
            self.rnn = nn.RNN(input_size=embedding_dim, hidden_size=hidden_size, batch_first=True)
            self.hidden2label = nn.Linear(hidden_size, num_of_class)
        elif rnn_type == "LSTM":
            self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, batch_first=True, bidirectional=True)
            self.hidden2label = nn.Linear(hidden_size*2, num_of_class)

    def forward(self, input_sents):
        batch_size, seq_len = input_sents.shape
        embed_out = self.embed(input_sents)

        if self.rnn_type == "RNN":
            h0 = torch.randn(1, batch_size, self.hidden_size)
            _, hn = self.rnn(embed_out, h0)
        elif self.rnn_type == "LSTM":
            h0, c0 = torch.randn(2, batch_size, self.hidden_size), torch.randn(2, batch_size, self.hidden_size)
            output, (hn, _) = self.lstm(embed_out, (h0, c0))

        logits = self.hidden2label(hn).squeeze(0)

        return logits


class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_of_class, embedding_vectors=None, kernel_num=100, kerner_size=[3, 4, 5], dropout=0.5):
        super(TextCNN, self).__init__()
        if embedding_vectors is None:
            self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        else:
            self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, _weight=embedding_vectors)
        self.convs = nn.ModuleList([nn.Conv2d(1, kernel_num, (K, embedding_dim)) for K in kerner_size])
        self.dropout = nn.Dropout(dropout)
        self.feature2label = nn.Linear(3*kernel_num, num_of_class)

    def forward(self, x):
        embed_out = self.embed(x).unsqueeze(1)
        conv_out = [F.relu(conv(embed_out)).squeeze(3) for conv in self.convs]

        pool_out = [F.max_pool1d(block, block.size(2)).squeeze(2) for block in conv_out]

        pool_out = torch.cat(pool_out, 1)

        logits = self.feature2label(pool_out)

        return logits


if __name__ == "__main__":
    model = TextCNN(vocab_size=10, embedding_dim=10, num_of_class=10)
    x = torch.randint(10, (10, 20))
    logits = model.forward(x)



  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [2]:
#models
import pandas as pd
import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


def prepare_data(dataset_path, sent_col_name, label_col_name):
    file_path = os.path.join(dataset_path, "train.tsv")
    data = pd.read_csv(file_path, sep="\t")
    X = data[sent_col_name].values
    y = data[label_col_name].values
    return X, y


class Language:
    def __init__(self):
        self.word2id = {}
        self.id2word = {}

    def fit(self, sent_list):
        vocab = set()
        for sent in sent_list:
            vocab.update(sent.split(" "))
        word_list = ["<pad>", "<unk>"] + list(vocab)
        self.word2id = {word: i for i, word in enumerate(word_list)}
        self.id2word = {i: word for i, word in enumerate(word_list)}

    def transform(self, sent_list, reverse=False):
        sent_list_id = []
        word_mapper = self.word2id if not reverse else self.id2word
        unk = self.word2id["<unk>"] if not reverse else None
        for sent in sent_list:
            sent_id = list(map(lambda x: word_mapper.get(x, unk), sent.split(" ") if not reverse else sent))
            sent_list_id.append(sent_id)
        return sent_list_id


class ClsDataset(Dataset):
    def __init__(self, sents, labels):
        self.sents = sents
        self.labels = labels

    def __getitem__(self, item):
        return self.sents[item], self.labels[item]

    def __len__(self):
        return len(self.sents)


def collate_fn(batch_data):
    batch_data.sort(key=lambda data_pair: len(data_pair[0]), reverse=True)

    sents, labels = zip(*batch_data)
    sents_len = [len(sent) for sent in sents]
    sents = [torch.LongTensor(sent) for sent in sents]
    padded_sents = pad_sequence(sents, batch_first=True, padding_value=0)

    return torch.LongTensor(padded_sents), torch.LongTensor(labels),  torch.FloatTensor(sents_len)


def get_wordvec(word2id, vec_file_path, vec_dim=50):
    print("开始加载词向量")
    word_vectors = torch.nn.init.xavier_uniform_(torch.empty(len(word2id), vec_dim))
    word_vectors[0, :] = 0  # <pad>
    found = 0
    with open(vec_file_path, "r", encoding="utf-8") as f:
        lines = f.readlines()
        for line in lines:
            splited = line.split(" ")
            if splited[0] in word2id:
                found += 1
                word_vectors[word2id[splited[0]]] = torch.tensor(list(map(lambda x: float(x), splited[1:])))
            if found == len(word2id) - 1:  # 允许<unk>找不到
                break
    print("总共 %d个词，其中%d个找到了对应的词向量" % (len(word2id), found))
    return word_vectors.float()


def make_dataloader(dataset_path="dataset", sent_col_name="Phrase", label_col_name="Sentiment", batch_size=32, vec_file_path="glove.6B.50d.txt", debug=False):
    X, y = prepare_data(dataset_path=dataset_path, sent_col_name=sent_col_name, label_col_name=label_col_name)

    if debug:
        X, y = X[:100], y[:100]

    X_language = Language()
    X_language.fit(X)
    X = X_language.transform(X)

    word_vectors = get_wordvec(X_language.word2id, vec_file_path=vec_file_path, vec_dim=50)


    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    cls_train_dataset, cls_val_dataset = ClsDataset(X_train, y_train), ClsDataset(X_val, y_val)
    cls_train_dataloader = DataLoader(cls_train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    cls_val_dataloader = DataLoader(cls_val_dataset, batch_size=batch_size, collate_fn=collate_fn)

    return cls_train_dataloader, cls_val_dataloader, word_vectors, X_language


if __name__ == "__main__":
    cls_train_dataloader, cls_val_dataloader, word_vectors, X_language = make_dataloader(debug=True, batch_size=10)
    for batch in cls_train_dataloader:
        X, y, lens = batch
        print(X.shape, y.shape)
        break


开始加载词向量
总共 54个词，其中46个找到了对应的词向量
torch.Size([10, 6]) torch.Size([10])


In [3]:
#main
from torch import optim
import torch
from models import TextRNN, TextCNN
import numpy as np

if __name__ == "__main__":
    model_names = ["LSTM", "RNN", "CNN"]  
    learning_rate = 0.001
    epoch_num = 500
    num_of_class = 5
   
    train_iter, val_iter, word_vectors, X_lang = make_dataloader(batch_size=100, debug=True)

    for model_name in model_names[-1:]:
        if model_name == "RNN":
            model = TextRNN(vocab_size=len(word_vectors), embedding_dim=50, hidden_size=128, num_of_class=num_of_class, weights=word_vectors)
        elif model_name == "CNN":
            model = TextCNN(vocab_size=len(word_vectors), embedding_dim=50, num_of_class=num_of_class, embedding_vectors=word_vectors)
        elif model_name == "LSTM":
            model = TextRNN(vocab_size=len(word_vectors), embedding_dim=50, hidden_size=128, num_of_class=num_of_class, weights=word_vectors, rnn_type="LSTM")
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        loss_fun = torch.nn.CrossEntropyLoss()

        for epoch in range(epoch_num):
            model.train()
            for i, batch in enumerate(train_iter):
                x, y, lens = batch
                logits = model(x)
                optimizer.zero_grad()
                loss = loss_fun(logits, y)
                loss.backward()
                optimizer.step()

            model.eval()
            train_accs = []
            for i, batch in enumerate(train_iter):
                x, y, lens = batch
                _, y_pre = torch.max(logits, -1)
                acc = torch.mean((torch.tensor(y_pre == y, dtype=torch.float)))
                train_accs.append(acc)
            train_acc = np.array(train_accs).mean()

            val_accs = []
            for i, batch in enumerate(val_iter):
                x, y, lens = batch
                logits = model(x)
                _, y_pre = torch.max(logits, -1)
                acc = torch.mean((torch.tensor(y_pre == y, dtype=torch.float)))
                val_accs.append(acc)
            val_acc = np.array(val_accs).mean()
            print("epoch %d train acc:%.2f, val acc:%.2f" % (epoch, train_acc, val_acc))
            if train_acc >= 0.99:
                break



开始加载词向量
总共 54个词，其中46个找到了对应的词向量
epoch 0 train acc:0.49, val acc:0.70




epoch 1 train acc:0.81, val acc:0.70
epoch 2 train acc:0.82, val acc:0.70
epoch 3 train acc:0.82, val acc:0.70
epoch 4 train acc:0.82, val acc:0.70
epoch 5 train acc:0.82, val acc:0.70
epoch 6 train acc:0.82, val acc:0.70
epoch 7 train acc:0.82, val acc:0.70
epoch 8 train acc:0.82, val acc:0.70
epoch 9 train acc:0.82, val acc:0.70
epoch 10 train acc:0.82, val acc:0.70
epoch 11 train acc:0.82, val acc:0.70
epoch 12 train acc:0.82, val acc:0.70
epoch 13 train acc:0.82, val acc:0.70
epoch 14 train acc:0.80, val acc:0.70
epoch 15 train acc:0.77, val acc:0.75
epoch 16 train acc:0.80, val acc:0.75
epoch 17 train acc:0.82, val acc:0.75
epoch 18 train acc:0.81, val acc:0.75
epoch 19 train acc:0.84, val acc:0.75
epoch 20 train acc:0.82, val acc:0.75
epoch 21 train acc:0.79, val acc:0.75
epoch 22 train acc:0.80, val acc:0.75
epoch 23 train acc:0.81, val acc:0.75
epoch 24 train acc:0.84, val acc:0.75
epoch 25 train acc:0.76, val acc:0.75
epoch 26 train acc:0.82, val acc:0.75
epoch 27 train acc:0.

epoch 217 train acc:0.73, val acc:0.65
epoch 218 train acc:0.82, val acc:0.65
epoch 219 train acc:0.82, val acc:0.65
epoch 220 train acc:0.77, val acc:0.65
epoch 221 train acc:0.82, val acc:0.65
epoch 222 train acc:0.77, val acc:0.65
epoch 223 train acc:0.82, val acc:0.65
epoch 224 train acc:0.81, val acc:0.65
epoch 225 train acc:0.77, val acc:0.65
epoch 226 train acc:0.75, val acc:0.65
epoch 227 train acc:0.73, val acc:0.65
epoch 228 train acc:0.80, val acc:0.65
epoch 229 train acc:0.77, val acc:0.65
epoch 230 train acc:0.76, val acc:0.65
epoch 231 train acc:0.75, val acc:0.65
epoch 232 train acc:0.77, val acc:0.65
epoch 233 train acc:0.75, val acc:0.65
epoch 234 train acc:0.75, val acc:0.65
epoch 235 train acc:0.80, val acc:0.65
epoch 236 train acc:0.77, val acc:0.65
epoch 237 train acc:0.84, val acc:0.65
epoch 238 train acc:0.77, val acc:0.65
epoch 239 train acc:0.77, val acc:0.65
epoch 240 train acc:0.80, val acc:0.65
epoch 241 train acc:0.80, val acc:0.65
epoch 242 train acc:0.77,

epoch 432 train acc:0.76, val acc:0.65
epoch 433 train acc:0.80, val acc:0.65
epoch 434 train acc:0.80, val acc:0.65
epoch 435 train acc:0.80, val acc:0.65
epoch 436 train acc:0.77, val acc:0.65
epoch 437 train acc:0.77, val acc:0.65
epoch 438 train acc:0.86, val acc:0.65
epoch 439 train acc:0.70, val acc:0.65
epoch 440 train acc:0.74, val acc:0.65
epoch 441 train acc:0.80, val acc:0.65
epoch 442 train acc:0.77, val acc:0.65
epoch 443 train acc:0.80, val acc:0.65
epoch 444 train acc:0.80, val acc:0.65
epoch 445 train acc:0.85, val acc:0.65
epoch 446 train acc:0.80, val acc:0.65
epoch 447 train acc:0.74, val acc:0.65
epoch 448 train acc:0.82, val acc:0.65
epoch 449 train acc:0.71, val acc:0.65
epoch 450 train acc:0.77, val acc:0.65
epoch 451 train acc:0.74, val acc:0.65
epoch 452 train acc:0.80, val acc:0.65
epoch 453 train acc:0.75, val acc:0.65
epoch 454 train acc:0.75, val acc:0.65
epoch 455 train acc:0.81, val acc:0.65
epoch 456 train acc:0.80, val acc:0.65
epoch 457 train acc:0.88,