In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=5

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=5


In [2]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
OUT_PATH: str = '../results/manifolds/_raw'

In [4]:
DATA_PATH: str = '../data/imdb'
DATASETS: list = [
    ('train', 'sample.train'),
    ('test', 'sample.test')
]

In [5]:
DIMS: list = [768, 576, 384, 192, 96, 48, 24, 12, 6, 3]
MODELS: list = [
    ('base', 'bert-base-uncased'),
    ('textattack', 'textattack/bert-base-uncased-imdb'),
    ('fabriceyhc', 'fabriceyhc/bert-base-uncased-imdb'),
    ('wakaka', 'Wakaka/bert-finetuned-imdb')
]

In [6]:
### Load Datasets into memory

In [7]:
data_config: dict = {
    'data_label': 'text',
    'target_label': 'sentiment',
    'target_groups': {
        "negative": 0,
        "positive": 1
    }
}

In [8]:
from typing import Dict
from modules import Data

datasets: Dict[str, Data] = {
    label: Data(file_path=f"{DATA_PATH}.{name}.csv", **data_config)
    for label, name in DATASETS
}

In [9]:
for label, dataset in datasets.items():
    display(dataset.data)
    display(dataset.data['sentiment'].value_counts(normalize=True))

Unnamed: 0,text,sentiment
0,"And it falls squarely into the category of ""aw...",negative
1,This is one seriously disturbed movie. Even Th...,negative
2,"Basically this is an overlong, unfunny, action...",negative
3,Hey if you have a little over an hour to kill ...,negative
4,Did anyone read the script. This has to be som...,negative
...,...,...
1240,"First of all, Jenna Jameson is the best actres...",negative
1241,"I didnt think it was possible, but i have foun...",negative
1242,"OK, I taped this off TV and missed the very st...",negative
1243,"Okay, okay, maybe not THE greatest. I mean, Th...",positive


negative    0.503614
positive    0.496386
Name: sentiment, dtype: float64

Unnamed: 0,text,sentiment
0,This...... Movie.... Is..... Horrible!!!!!! Yo...,negative
1,At the same time John Russell was playing ranc...,positive
2,This is the best version of Gypsy that has bee...,positive
3,"It's just stories, some we wish happen to us, ...",positive
4,"This film, without doubt, is the clearest exam...",positive
...,...,...
1235,"""Cooley High"" is one of my favorite movies EVE...",positive
1236,The Comic Strip featured actors from 'The Youn...,negative
1237,I suppose you could say this film has a grain ...,negative
1238,"Having just watched Acacia, I find that I have...",negative


negative    0.520968
positive    0.479032
Name: sentiment, dtype: float64

In [10]:
### Load Encoder into Memory

In [11]:
encoder_config: dict = {
    'layers': [11]
}

In [12]:
from modules import Encoder

encoders: Dict[str, Encoder] = {
    label: Encoder({**{'model': ref}, **encoder_config})
    for label, ref in MODELS
}

In [13]:
### Compute manifolds and measure centroid distance and cluster dispersion

In [14]:
import numpy as np
import pandas as pd

In [15]:
from sklearn.manifold import MDS

def manifold_reduction(data: np.ndarray, dim: int = 3) -> np.ndarray:
    # mds translates using the pairwise distances among points
    return MDS(n_components=dim, n_jobs=-1).fit_transform(data)

In [16]:
from scipy.spatial import distance

def metric_computation(groups: pd.core.groupby.GroupBy, dim: int, col: str) -> dict:

    # hold data in dict record
    record: dict = {'dim': dim}

    # centroids, intra distance per group (binary and multi class)
    for label, group in groups:

            # calculate centroid as cluster mean
            record[f'centroid_{label}'] = np.mean(
                np.stack(
                    group[col].tolist(),
                    axis=0
                ),
                axis=0
            ).tolist()

            #  calculate distance between each cluster point and centroid
            record[f'centroid_point_distances_{label}'] = distance.cdist(
                [record[f'centroid_{label}']],
                group[col].tolist()
            ).tolist()

            # calculate intra cluster distance as centroid point distance mean
            record[f'intra_distance_{label}'] = np.mean(
                record[f'centroid_point_distances_{label}'],
                axis=1
            ).item()

    # extra distance as binary centroid distance (positive|negative)
    # TODO expand to multi class
    record['extra_distance'] = distance.cdist(
        [record['centroid_positive']],
        [record['centroid_negative']]
    ).item()

    return record

In [17]:
from typing import Generator
import torch

def reduce_analyse(data: pd.DataFrame, col: str, dim: list, default_dim: int = 768) -> Generator:

    embed_col: np.ndarray = torch.stack(data[col].tolist()).numpy()

    for d in dim:

        # if reduction size is equal to encoder output dim, skip manifold reduction
        if d == default_dim:
            data['reduced_embeds'] = list(embed_col)
        else:
            data['reduced_embeds'] = list(manifold_reduction(embed_col, dim=d))

        yield metric_computation(
            data.groupby(dataset.target_label),
            dim=d, col='reduced_embeds'
        )

In [18]:
results: Dict[str, pd.DataFrame] = {}

In [19]:
# apply each encoder to all datasets
for enc_label, encoder in encoders.items():
    for data_label, dataset in datasets.items():

        # encode text column
        encoder.df_encode(dataset.data, col=dataset.data_label)

        # apply cluster analysis and export as dataframe
        results[f'{data_label}.{enc_label}'] = pd.DataFrame.from_records(
            list(reduce_analyse(
                dataset.data, encoder.col_name, DIMS,
                default_dim=encoder.dim)
            )
        )

                                                                                                                                                                                                                                                                                                                                                                          

In [20]:
output_cols: str = '(dim|extra_distance|intra_distance_*|centroid_point_distances_*)'

In [21]:
for label, dataset in results.items():
    results[label].filter(regex=output_cols).to_csv(f'{OUT_PATH}/{label}.csv')
    display(label, results[label].filter(regex=output_cols))

'train.base'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[5.992325561399737, 5.203575077395552, 3.5995...",5.764313,"[[6.28652306558972, 3.8467728310963367, 5.1201...",6.138071,1.259246
1,576,"[[5.998907561659804, 5.20429223352104, 3.60059...",5.76573,"[[6.306472168343806, 3.8406067732668197, 5.137...",6.139768,1.193737
2,384,"[[5.99809831110702, 5.209406302314851, 3.59817...",5.765468,"[[6.302880179398811, 3.8439582126890106, 5.129...",6.139453,1.19783
3,192,"[[5.989179558040728, 5.20308222105544, 3.58133...",5.763025,"[[6.30275337188162, 3.804601725111403, 5.13448...",6.137499,1.22088
4,96,"[[6.001490161397765, 5.200131744815813, 3.5928...",5.760449,"[[6.304338052169215, 3.8114252952024854, 5.137...",6.135308,1.237246
5,48,"[[5.978139737305378, 5.17674970095341, 3.54733...",5.752931,"[[6.275858301193788, 3.7716448119043506, 5.101...",6.128193,1.315556
6,24,"[[5.984241262573122, 5.128108147891398, 3.5290...",5.736505,"[[6.186795323682618, 3.7806063934981142, 5.018...",6.113381,1.437465
7,12,"[[5.935683877594702, 5.0894889821615195, 3.455...",5.699003,"[[6.207877198480866, 3.5430944660054875, 5.033...",6.080595,1.681827
8,6,"[[5.923193085156102, 4.869691135083833, 3.2336...",5.625003,"[[6.306060399857682, 3.335986604969422, 4.8193...",6.02726,1.80134
9,3,"[[5.85203375999411, 4.791429520515079, 2.81340...",5.429721,"[[5.9604018162228645, 2.9664268431571874, 4.80...",5.957408,1.44782


'test.base'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[4.832424433221461, 9.807014103979695, 3.7882...",5.798742,"[[5.229993597663103, 4.427183632815254, 5.6675...",6.084546,1.229414
1,576,"[[4.855815287883087, 9.803366795820503, 3.8120...",5.800334,"[[5.239752862285097, 4.418781130039162, 5.6664...",6.086612,1.154896
2,384,"[[4.851692053966959, 9.806107043630272, 3.8091...",5.799666,"[[5.232982052660435, 4.432929383999893, 5.6582...",6.085929,1.162938
3,192,"[[4.849376566608824, 9.79537313965205, 3.81052...",5.79872,"[[5.226997282097812, 4.400307953161804, 5.6606...",6.08458,1.175887
4,96,"[[4.831381840702488, 9.806197866204979, 3.7824...",5.795773,"[[5.254985269573869, 4.406814474286478, 5.6697...",6.081798,1.201157
5,48,"[[4.789311957478701, 9.815455123641163, 3.7847...",5.788482,"[[5.203252643245383, 4.334539989235488, 5.6522...",6.073079,1.297727
6,24,"[[4.715640775269686, 9.848768446858827, 3.7014...",5.773222,"[[5.206691394231041, 4.2870024697869304, 5.635...",6.059071,1.404172
7,12,"[[4.7194338867860495, 9.895807775076673, 3.643...",5.744582,"[[5.2437577681448815, 4.287747248851521, 5.577...",6.033016,1.535413
8,6,"[[4.294540899133911, 9.97537934213856, 3.10632...",5.666244,"[[4.848189758825476, 3.6853112524572817, 5.556...",5.962391,1.810618
9,3,"[[4.281322298005763, 10.012408322920678, 2.171...",5.483858,"[[4.969311071299542, 3.948964141784532, 5.4763...",5.870369,1.508725


'train.textattack'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[6.751666973990453, 7.309737450435016, 5.0174...",6.738079,"[[8.802257554520567, 5.346788559260816, 6.4295...",7.117969,7.820493
1,576,"[[6.748457047546795, 7.3101276981217, 5.014576...",6.736532,"[[8.786901190846585, 5.341769393081038, 6.4342...",7.116558,7.822064
2,384,"[[6.74644970128014, 7.311641128501631, 5.02177...",6.735737,"[[8.789468126819267, 5.345078678432464, 6.4349...",7.115851,7.823887
3,192,"[[6.742776618919625, 7.304196389217014, 5.0117...",6.733366,"[[8.784576568899515, 5.3381486365684125, 6.422...",7.113394,7.830239
4,96,"[[6.736834823458004, 7.2983001810428805, 5.001...",6.727796,"[[8.781989446951869, 5.314007404308993, 6.4255...",7.108456,7.843153
5,48,"[[6.7144606818574095, 7.286519893448449, 4.965...",6.716302,"[[8.772349855159447, 5.300347940662182, 6.4090...",7.097346,7.870537
6,24,"[[6.687358677496187, 7.304893884494862, 4.9407...",6.69111,"[[8.728616872102121, 5.217161550671266, 6.3663...",7.073249,7.930257
7,12,"[[6.632163350333271, 7.307586147913015, 4.8633...",6.632951,"[[8.660855192938065, 5.035483105430938, 6.2211...",7.016746,8.06339
8,6,"[[6.385649528152672, 7.282262221386803, 4.6832...",6.481661,"[[8.52593895412243, 4.701869303600273, 6.10177...",6.870973,8.369621
9,3,"[[6.35914690195332, 7.171314451281628, 4.00497...",5.981564,"[[7.73679117388107, 3.443858042313132, 5.57577...",6.338681,9.167216


'test.textattack'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[6.785032081639631, 9.407629650697947, 6.6619...",6.935167,"[[5.874596543022215, 7.184561450019481, 6.6260...",7.183807,7.312816
1,576,"[[6.778534930704751, 9.413754056953097, 6.6574...",6.933645,"[[5.870795570202823, 7.18290015554014, 6.62498...",7.18229,7.314614
2,384,"[[6.777516547840319, 9.417010432445407, 6.6522...",6.933008,"[[5.871652120367502, 7.1803443647205665, 6.624...",7.181497,7.316687
3,192,"[[6.778307487190863, 9.413657250337293, 6.6453...",6.930826,"[[5.870810430691806, 7.176800685368074, 6.6175...",7.179239,7.322274
4,96,"[[6.774936805290106, 9.417169863585237, 6.6438...",6.926003,"[[5.853989474020858, 7.1750665206798, 6.623074...",7.173779,7.335579
5,48,"[[6.779253027268124, 9.417302831811549, 6.6141...",6.915722,"[[5.846367538738173, 7.153847141301221, 6.6182...",7.162228,7.364136
6,24,"[[6.737800794000499, 9.428653945569573, 6.5851...",6.891857,"[[5.830317733260155, 7.117771409637466, 6.5907...",7.13698,7.426354
7,12,"[[6.669094555797456, 9.460207205639533, 6.4194...",6.839208,"[[5.727548964347177, 7.019843929947428, 6.5287...",7.080091,7.557102
8,6,"[[6.424961354620629, 9.535341037416224, 6.2235...",6.698832,"[[5.595376811248244, 6.661456697685557, 6.4518...",6.923806,7.876579
9,3,"[[5.666607979351545, 9.534882539694415, 5.6022...",6.230469,"[[5.288506068211632, 5.818130265475848, 6.1381...",6.397738,8.693981


'train.fabriceyhc'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[4.967599522866965, 5.136100990019949, 7.8128...",7.644175,"[[4.750496262880372, 5.36116840102162, 4.42993...",6.040889,22.864886
1,576,"[[4.95821859706932, 5.121970710518513, 7.81382...",7.641595,"[[4.75059080073975, 5.3536446089920755, 4.4242...",6.039522,22.865078
2,384,"[[4.960733725540239, 5.1184495731113415, 7.812...",7.641275,"[[4.752677381873392, 5.358462458188564, 4.4300...",6.039016,22.865182
3,192,"[[4.96130359970093, 5.12095201980016, 7.815212...",7.641667,"[[4.747597383738067, 5.357939036225633, 4.4237...",6.038969,22.865391
4,96,"[[4.958182695462467, 5.125312865067483, 7.8087...",7.639926,"[[4.745376892780683, 5.353282813224601, 4.3976...",6.036935,22.865743
5,48,"[[4.932615807757438, 5.132265102577457, 7.8177...",7.638474,"[[4.755946116997268, 5.359537814207116, 4.4071...",6.035228,22.866746
6,24,"[[4.915745212132268, 5.078671000937027, 7.8346...",7.630471,"[[4.736984447999113, 5.363696131870111, 4.3949...",6.028186,22.869076
7,12,"[[4.923674941557828, 5.130468333428186, 7.8540...",7.615146,"[[4.73894688872526, 5.361716249587495, 4.32878...",6.009874,22.875822
8,6,"[[4.84147038372577, 4.84811398576401, 7.941404...",7.551837,"[[4.743799930350909, 5.308257549095261, 4.2506...",5.951208,22.899133
9,3,"[[5.001696260606553, 4.649309916118462, 7.9220...",7.298931,"[[4.451366652379766, 5.274864060521158, 3.7288...",5.607511,23.059663


'test.fabriceyhc'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[9.073984610715168, 10.79574947060581, 10.667...",9.273344,"[[3.9066403335518403, 9.816295601628925, 11.72...",7.438157,19.964325
1,576,"[[9.06992614970188, 10.800045351310528, 10.669...",9.271932,"[[3.8890632235390843, 9.819785380095174, 11.72...",7.436771,19.964599
2,384,"[[9.07007529164065, 10.797333588568195, 10.670...",9.271734,"[[3.8910182075364883, 9.824774339086003, 11.72...",7.435813,19.964717
3,192,"[[9.069799017971265, 10.804503100101812, 10.66...",9.27165,"[[3.887623179281154, 9.820909703245782, 11.726...",7.436544,19.964907
4,96,"[[9.066407841331793, 10.79928185824498, 10.668...",9.270937,"[[3.8733417814869577, 9.820079236729807, 11.72...",7.434083,19.965551
5,48,"[[9.075261868783645, 10.816659058285115, 10.67...",9.268299,"[[3.8833185097806364, 9.841535991094785, 11.73...",7.430381,19.966897
6,24,"[[9.066545539975548, 10.813844096999658, 10.67...",9.264909,"[[3.880122390595751, 9.83701755685778, 11.7283...",7.424555,19.970252
7,12,"[[9.100826811341541, 10.80368415005548, 10.694...",9.251533,"[[3.86162414969133, 9.850031429356708, 11.7106...",7.40676,19.97957
8,6,"[[9.141878211491996, 10.935425290244883, 10.75...",9.209127,"[[3.601098569782421, 9.900103802798562, 11.748...",7.338289,20.008512
9,3,"[[9.158822279650673, 10.232417097741667, 10.99...",8.909093,"[[3.181293716269652, 9.921919133134383, 11.491...",6.985437,20.246208


'train.wakaka'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[8.912249447423662, 5.991112258898744, 6.4723...",7.437458,"[[11.313891151375156, 10.330851402943722, 5.53...",8.561553,11.483087
1,576,"[[8.90885134907158, 5.982331431708107, 6.46783...",7.435582,"[[11.279683472043187, 10.295851851720883, 5.53...",8.559251,11.485798
2,384,"[[8.905329441731666, 5.9824046861561335, 6.468...",7.435125,"[[11.28073226210674, 10.300748659445828, 5.534...",8.558938,11.486248
3,192,"[[8.903514197064906, 5.9841287951360185, 6.463...",7.43332,"[[11.277017410458756, 10.298353000519308, 5.53...",8.55692,11.489656
4,96,"[[8.908116408368548, 5.981449208167516, 6.4543...",7.429639,"[[11.26084998955271, 10.301450704820029, 5.500...",8.553358,11.49426
5,48,"[[8.899269267900785, 5.960281380208095, 6.4290...",7.421411,"[[11.285779032150222, 10.296455812357042, 5.46...",8.545226,11.507993
6,24,"[[8.913132017885491, 5.9277619373770305, 6.420...",7.403176,"[[11.241352431001893, 10.269065221827828, 5.38...",8.526624,11.536142
7,12,"[[8.855939630511708, 5.939909028033813, 6.2673...",7.356365,"[[11.099102818305294, 10.179133175337357, 5.37...",8.47657,11.611673
8,6,"[[8.834782650292649, 5.0154779273389325, 6.034...",7.218211,"[[10.954506302226845, 9.997661259557704, 5.112...",8.336106,11.829616
9,3,"[[8.82980719945389, 4.548188262132617, 5.76543...",6.757139,"[[10.87100169627255, 9.747964656113373, 4.1889...",7.903175,12.39674


'test.wakaka'

Unnamed: 0,dim,centroid_point_distances_negative,intra_distance_negative,centroid_point_distances_positive,intra_distance_positive,extra_distance
0,768,"[[6.595880875307386, 8.794097214610463, 7.2356...",7.581062,"[[6.3813950679018925, 9.885993432500822, 6.399...",8.446871,11.398275
1,576,"[[6.598851457776016, 8.795474086577475, 7.2321...",7.579218,"[[6.365344573924828, 9.866397399173747, 6.3815...",8.444666,11.400987
2,384,"[[6.598614119272363, 8.792584175630877, 7.2294...",7.578515,"[[6.377239684717132, 9.866619407535786, 6.3811...",8.444124,11.401848
3,192,"[[6.603452619081457, 8.794447139235105, 7.2271...",7.577133,"[[6.373549529392549, 9.86019918844127, 6.38728...",8.442543,11.403979
4,96,"[[6.593191233452919, 8.79945676555837, 7.22372...",7.573466,"[[6.362347403072847, 9.87423549608306, 6.34715...",8.438306,11.410811
5,48,"[[6.595364374107468, 8.784340553542048, 7.2265...",7.565383,"[[6.337306244462101, 9.846386478982225, 6.3443...",8.429617,11.424288
6,24,"[[6.548809308502468, 8.749814663665424, 7.1977...",7.547136,"[[6.272178324462712, 9.8272926277015, 6.311980...",8.408072,11.457864
7,12,"[[6.594984399077988, 8.757362971568737, 7.1621...",7.501272,"[[6.202472856527514, 9.764558570436673, 6.1886...",8.352838,11.537099
8,6,"[[6.426641699285559, 8.822089186706071, 7.0631...",7.373446,"[[5.6114946468036, 9.655419117545906, 6.040898...",8.210152,11.744673
9,3,"[[5.105434059748647, 8.854781247830145, 6.2779...",6.928481,"[[5.265330775801922, 9.26810628440048, 5.41939...",7.752875,12.347285
