In [1]:
from tqdm import trange
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import datetime


def gather_features(model, testset):
    gathered_token_classes = []
    gathered_token_features = []
    gathered_token_features_0 = []
    gathered_token_features_1 = []
    gathered_token_features_2 = []
    gathered_token_features_3 = []
    gathered_token_features_4 = []

    for j in trange(testset.batch_count):
        sentence_ids, bert_tokens, masks, word_spans, tagging_matrices, tokenized, cl_masks, token_classes = testset.get_batch(j)
        with torch.no_grad():
            features = model.bert(bert_tokens, masks)
            # print(features['last_hidden_state'].shape)
            features = features['last_hidden_state']

        gathered_token_features.append(features)
        gathered_token_classes.extend(token_classes)

    # gathered_token_features = torch.cat(gathered_token_features, dim=0)
    # print(gathered_token_features)

    for features, token_classes in zip(gathered_token_features, gathered_token_classes):
        # print(features.shape, token_classes)
        L = len(token_classes)
        # print(L)
        features = features[:, :L, :]
        token_classes = np.array(token_classes)
        
        # print(features.shape)
        # print([token_classes == 0])
        token_class_0 = features[:, token_classes == 0, :]
        # print(token_class_0.shape)
        token_class_1 = features[:, token_classes == 1, :]
        token_class_2 = features[:, token_classes == 2, :]
        token_class_3 = features[:, token_classes == 3, :]
        token_class_4 = features[:, token_classes == 4, :]

        gathered_token_features_0.extend(token_class_0)
        gathered_token_features_1.extend(token_class_1)
        gathered_token_features_2.extend(token_class_2)
        gathered_token_features_3.extend(token_class_3)
        gathered_token_features_4.extend(token_class_4)

    gathered_token_features_0 = torch.cat(gathered_token_features_0, dim=0).cpu().numpy()
    gathered_token_features_1 = torch.cat(gathered_token_features_1, dim=0).cpu().numpy()
    gathered_token_features_2 = torch.cat(gathered_token_features_2, dim=0).cpu().numpy()
    gathered_token_features_3 = torch.cat(gathered_token_features_3, dim=0).cpu().numpy()
    gathered_token_features_4 = torch.cat(gathered_token_features_4, dim=0).cpu().numpy()

    return gathered_token_features_0, gathered_token_features_1, gathered_token_features_2, gathered_token_features_3, gathered_token_features_4

In [None]:
from tqdm import trange
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import datetime


def gather_features(model, testset):
    gathered_token_classes = []
    gathered_token_features = []
    gathered_token_features_0 = []
    gathered_token_features_1 = []
    gathered_token_features_2 = []
    gathered_token_features_3 = []
    gathered_token_features_4 = []

    for j in trange(testset.batch_count):
        sentence_ids, bert_tokens, masks, word_spans, tagging_matrices, tokenized, cl_masks, token_classes = testset.get_batch(j)
        with torch.no_grad():
            features = model.bert(bert_tokens, masks)
            # print(features['last_hidden_state'].shape)
            features = features['last_hidden_state']

        gathered_token_features.append(features)
        gathered_token_classes.extend(token_classes)

    # gathered_token_features = torch.cat(gathered_token_features, dim=0)
    # print(gathered_token_features)

    for features, token_classes in zip(gathered_token_features, gathered_token_classes):
        # print(features.shape, token_classes)
        L = len(token_classes)
        # print(L)
        features = features[:, :L, :]
        token_classes = np.array(token_classes)
        
        # print(features.shape)
        # print([token_classes == 0])
        token_class_0 = features[:, token_classes == 0, :]
        # print(token_class_0.shape)
        token_class_1 = features[:, token_classes == 1, :]
        token_class_2 = features[:, token_classes == 2, :]
        token_class_3 = features[:, token_classes == 3, :]
        token_class_4 = features[:, token_classes == 4, :]

        gathered_token_features_0.extend(token_class_0)
        gathered_token_features_1.extend(token_class_1)
        gathered_token_features_2.extend(token_class_2)
        gathered_token_features_3.extend(token_class_3)
        gathered_token_features_4.extend(token_class_4)

    gathered_token_features_0 = torch.cat(gathered_token_features_0, dim=0).cpu().numpy()
    gathered_token_features_1 = torch.cat(gathered_token_features_1, dim=0).cpu().numpy()
    gathered_token_features_2 = torch.cat(gathered_token_features_2, dim=0).cpu().numpy()
    gathered_token_features_3 = torch.cat(gathered_token_features_3, dim=0).cpu().numpy()
    gathered_token_features_4 = torch.cat(gathered_token_features_4, dim=0).cpu().numpy()

    return gathered_token_features_0, gathered_token_features_1, gathered_token_features_2, gathered_token_features_3, gathered_token_features_4



def plot_pca(gathered_token_class_0, gathered_token_class_1, gathered_token_class_2, gathered_token_class_3, gathered_token_class_4, epoch):

    # random sample N points for each class, where N is the number of points in the smallest class
    gather = [gathered_token_class_0, 
            gathered_token_class_1, 
            gathered_token_class_2, 
            gathered_token_class_3, 
            gathered_token_class_4]
    c_s = ['r', 'g', 'b', 'y', 'm']
    labels = ['NULL', 'Aspect', 'Opinion-POS', 'Opinion-NEU', 'Opinion-NEG']

    gather_ = [i for i in gather if i.shape[0] != 0]
    gather_n = [i.shape[0] for i in gather_]
    c_s_ = [c_s[i] for i in range(len(gather)) if gather[i].shape[0] != 0]
    labels_ = [labels[i] for i in range(len(gather)) if gather[i].shape[0] != 0]

    N = 6 * min(gather_n)
    
    pca = PCA(n_components=2)

    # pca.fit(gathered_token_class_0)

    fig = plt.figure()
    ax = fig.add_subplot(111)

    # print(len(gather_), len(c_s_), len(labels_), N)

    for i, c, label in zip(gather_, c_s_, labels_):
        i = i[np.random.choice(i.shape[0], min(N, i.shape[0]), replace=False), :]
        pca.fit(i)
        pca_i = pca.transform(i)
        ax.scatter(pca_i[:,0], pca_i[:,1], c=c, label=label)

    plt.legend()
    plt.savefig(f'./plots_saved/pca_2d_{datetime.datetime.today()}_{epoch}.png')

 
def plot_pca_3d(gathered_token_class_0, gathered_token_class_1, gathered_token_class_2, gathered_token_class_3, gathered_token_class_4, epoch):
    # random sample N points for each class, where N is the number of points in the smallest class
    gather = [gathered_token_class_0, 
            gathered_token_class_1, 
            gathered_token_class_2, 
            gathered_token_class_3, 
            gathered_token_class_4]
    c_s = ['r', 'g', 'b', 'y', 'm']
    labels = ['NULL', 'Aspect', 'Opinion-POS', 'Opinion-NEU', 'Opinion-NEG']

    gather_ = [i for i in gather if i.shape[0] != 0]
    gather_n = [i.shape[0] for i in gather_]
    c_s_ = [c_s[i] for i in range(len(gather)) if gather[i].shape[0] != 0]
    labels_ = [labels[i] for i in range(len(gather)) if gather[i].shape[0] != 0]

    N = 6 * min(gather_n)
    
    pca = PCA(n_components=3)

    # pca.fit(gathered_token_class_0)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # print(len(gather_), len(c_s_), len(labels_), N)

    for i, c, label in zip(gather_, c_s_, labels_):
        i = i[np.random.choice(i.shape[0], min(N, i.shape[0]), replace=False), :]
        pca.fit(i)
        pca_i = pca.transform(i)
        ax.scatter(pca_i[:,0], pca_i[:,1], pca_i[:,2], c=c, label=label)

    plt.legend()
    plt.savefig(f'./plots_saved/pca_3d_{datetime.datetime.today()}_{epoch}.png')