In [2]:
import torch
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import random

from sklearn.metrics import precision_recall_fscore_support, classification_report, confusion_matrix
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.config import Config
from src.dataloaders.dataloader import ChineseBertDataset
from src.models.original_model_bert import MultiSentFeatClassifier, BertClassifier


def load_model():
    print("Loading model")
    model_saved_dict = torch.load(model_save_path)
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    model_saved_dict = {k: v for k, v in model_saved_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(model_saved_dict)
    model.load_state_dict(model_dict)


def train():
    model.init_weights()
    loss_function = nn.CrossEntropyLoss()
    params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(params, lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=0)
    # 总loss值
    for epoch in range(config.epoch):
        total_loss = 0
        model.train()
        train_data_tqdm = tqdm(train_dataloader)
        for data in train_data_tqdm:
            model.zero_grad()
            # optimizer.zero_grad() 和 model.zero_grad()等效
            # 获取batch信息
            token_ids, masks, _, out, para, positional_embedding = data
            _output, _, _, _ = model(token_ids, masks, para, positional_embedding)
            # 计算交叉熵
            out = out.squeeze(0)
            loss = loss_function(_output, out)
            # 累加loss值
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            train_data_tqdm.set_description(f'Epoch {epoch}')
            train_data_tqdm.set_postfix(loss=loss.item())
            del token_ids, masks, out
            if config.device == 'cuda':
                torch.cuda.empty_cache()
        torch.save(model.state_dict(), model_save_path)


def evaluate():
    global label_list, sentence_feature_list, discourse_feature_list
    y_true, y_pred = [], []
    model.eval()
    data_tqdm = tqdm(train_dataloader, desc=r"Test")
    index = 0
    load_model()
    with torch.no_grad():
        for data in data_tqdm:
            token_ids, masks, _, out, para, positional_embedding = data
            out = out.squeeze(0)
            _output, _, pre_pred, disc_encoding = model(token_ids, masks, para, positional_embedding)
            sentence_feature_list += pre_pred
            discourse_feature_list += [disc_encoding[0]]
            _output = _output.squeeze()
            _, predict = torch.max(_output, 1)
            if torch.cuda.is_available():
                predict = predict.cpu()
                out = out.cpu()
            y_pred += list(predict.numpy())
            temp_true = list(out.numpy())
            y_true += temp_true
            index += 1
            if index >= 100:
                break
    label_list = y_true
    macro_scores = precision_recall_fscore_support(y_true, y_pred, average='macro')
    micro_scores = precision_recall_fscore_support(y_true, y_pred, average='micro')
    print("Classification Report \n", classification_report(y_true, y_pred, digits=4))
    print("MACRO: ", macro_scores)
    print("MICRO: ", micro_scores)
    print("\nConfusion Matrix \n", confusion_matrix(y_true, y_pred))


config = Config()
config.epoch = 40
model_save_path = r'test_model.ckpt'
model = MultiSentFeatClassifier(config).to(config.device)
# model = BertClassifier(config).to(config.device)
print("Loading trainset")
train_dataset = ChineseBertDataset(config, r'data/train/')
train_dataloader = DataLoader(train_dataset, shuffle=False)
# eval_dataset = ChineseBertDataset(config, r'data/validation')
# eval_dataloader = DataLoader(eval_dataset, shuffle=False)
# test_dataset = ChineseBertDataset(config, r'data/test')
# test_dataloader = DataLoader(test_dataset, shuffle=False)

# train()

sentence_feature_list = []
# discourse_feature_list = []
label_list = []

evaluate()

ModuleNotFoundError: No module named 'src.dataloaders'

In [None]:
# sentence_features = torch.cat([i[None, :] for i in sentence_feature_list]).to('cpu')
# tsne = TSNE(n_components=2, init='pca', verbose=1)
# embedding = tsne.fit_transform(sentence_features)
sentence_features = torch.cat([i[None, :] for i in sentence_feature_list]).to('cpu')
# discourse_features = torch.cat([i[None, :] for i in discourse_feature_list]).to('cpu')
# feature_list = torch.cat([sentence_features, discourse_features])
tsne = TSNE(n_components=3, init='pca', verbose=1)
embedding = tsne.fit_transform(sentence_features)

In [None]:
with open(r'embedding.txt', 'w') as f:
    for i in embedding:
        f.write(' '.join([str(x) for x in i]))
        f.write('\n')
# label_list += [9] * discourse_features.shape[0]
with open(r'labels.txt', 'w') as f:
    for i in label_list:
        f.write(str(i))
        f.write('\n')

In [None]:
out_map = {'NA': 0, 'M1': 1, 'M2': 2, 'C1': 3, 'C2': 4,
           'HI': 5,
           'AN': 6, 'EV': 7, 'EX': 8}
m = dict(zip(out_map.values(), out_map.keys()))
plt.figure(figsize=(15, 11))
# ax = plt.gca(projection='3d')
ax = plt.subplot()
n = 180
# n = embedding.shape[0]
ax.scatter(embedding[:n, 0], embedding[:n, 1],
           s=200,
           linewidth=2,
           c=[(label_list[i] + 1) / 10 for i in range(n)],
           cmap='Dark2')
for i in range(0, n, 2):
    x = embedding[i][0]
    y = embedding[i][1]
    ax.text(x, y, m[label_list[i]], fontsize=15)
plt.show()