In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=5

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=5


In [2]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
OUT_PATH: str = '../results/manifolds/raw'

In [4]:
DATA_PATH: str = '../data/imdb'
DATASETS: list = [
    ('train', 'sample.train'),
    ('test', 'sample.test')
]

In [5]:
DIMS: list = [768] # [768, 576, 384, 192, 96, 48, 24, 12, 6, 3]
MODELS: list = [
    ('base', 'bert-base-uncased'),
    ('textattack', 'textattack/bert-base-uncased-imdb'),
    ('fabriceyhc', 'fabriceyhc/bert-base-uncased-imdb'),
    ('wakaka', 'Wakaka/bert-finetuned-imdb')
]

In [6]:
### Load Datasets into memory

In [7]:
data_config: dict = {
    'polarities': {
      "negative": 0,
      "positive": 1
    },
    'data_label': 'text',
    'target_label': 'sentiment'
}

In [8]:
from typing import Dict
from modules import Data

datasets: Dict[str, Data] = {
    label: Data(data_path=f"{DATA_PATH}.{name}.csv", **data_config)
    for label, name in DATASETS
}

In [9]:
for label, dataset in datasets.items():
    display(dataset.data)
    display(dataset.data['sentiment'].value_counts(normalize=True))

Unnamed: 0,text,sentiment
0,"And it falls squarely into the category of ""aw...",negative
1,This is one seriously disturbed movie. Even Th...,negative
2,"Basically this is an overlong, unfunny, action...",negative
3,Hey if you have a little over an hour to kill ...,negative
4,Did anyone read the script. This has to be som...,negative
...,...,...
1240,"First of all, Jenna Jameson is the best actres...",negative
1241,"I didnt think it was possible, but i have foun...",negative
1242,"OK, I taped this off TV and missed the very st...",negative
1243,"Okay, okay, maybe not THE greatest. I mean, Th...",positive


negative    0.503614
positive    0.496386
Name: sentiment, dtype: float64

Unnamed: 0,text,sentiment
0,This...... Movie.... Is..... Horrible!!!!!! Yo...,negative
1,At the same time John Russell was playing ranc...,positive
2,This is the best version of Gypsy that has bee...,positive
3,"It's just stories, some we wish happen to us, ...",positive
4,"This film, without doubt, is the clearest exam...",positive
...,...,...
1235,"""Cooley High"" is one of my favorite movies EVE...",positive
1236,The Comic Strip featured actors from 'The Youn...,negative
1237,I suppose you could say this film has a grain ...,negative
1238,"Having just watched Acacia, I find that I have...",negative


negative    0.520968
positive    0.479032
Name: sentiment, dtype: float64

In [10]:
### Load Encoder into Memory

In [11]:
encoder_config: dict = {
    'layers': [11]
}

In [12]:
from modules import Encoder

encoders: Dict[str, Encoder] = {
    label: Encoder({**{'model': ref}, **encoder_config})
    for label, ref in MODELS
}

In [13]:
### Compute manifolds and measure centroid distance and cluster dispersion

In [14]:
import numpy as np
import pandas as pd

In [15]:
from sklearn.manifold import MDS

def manifold_reduction(data: np.ndarray, dim: int = 3) -> np.ndarray:
    return MDS(n_components=dim).fit_transform(data)

In [16]:
from scipy.spatial import distance

def metric_computation(record: dict, groups: pd.core.groupby.GroupBy) -> None:

    for label, group in groups:
            record[f'centroid_{label}'] = np.mean(np.stack(group['reduced_embeds'].tolist(), axis=0), axis=0).tolist()
            record[f'centroid_point_distances_{label}'] = distance.cdist([record[f'centroid_{label}']], group['reduced_embeds'].tolist()).tolist()
            record[f'intra_distance_{label}'] = np.mean(record[f'centroid_point_distances_{label}'], axis=1).item()

    record['extra_distance'] = distance.cdist([record['centroid_positive']], [record['centroid_negative']]).item()

In [17]:
from typing import Generator
import torch

def reduce_analyse(data: pd.DataFrame, col: str, dim: list, default_dim: int = 768) -> Generator:

    embed_col: np.ndarray = torch.stack(data[col].tolist()).numpy()

    for d in dim:

        # create record to keep row data
        record: dict = {'dim': d}

        # if reduction size is equal to encoder output dim, skip manifold reduction
        if d == default_dim:
            data['reduced_embeds'] = list(embed_col)
        else:
            data['reduced_embeds'] = list(manifold_reduction(embed_col, dim=d))

        metric_computation(record, data.groupby(dataset.target_label))

        yield record

In [19]:
results: Dict[str, pd.DataFrame] = {}

In [20]:
for enc_label, encoder in encoders.items():
    for data_label, dataset in datasets.items():
        encoder.df_encode(dataset.data, col=dataset.data_label)
        results[f'{data_label}.{enc_label}'] = pd.DataFrame.from_records(
            list(reduce_analyse(
                dataset.data, encoder.col_name, DIMS,
                default_dim=encoder.dim)
            )
        )

                                                                                                                                                                                                                                                                                                                                                                          

In [18]:
output_cols: str = '(dim|extra_distance|intra_distance_*|centroid_point_distances_*)'

In [21]:
for label, dataset in results.items():
    results[label].filter(regex=output_cols).to_csv(f'{OUT_PATH}/{label}.csv')
    display(results[label].filter(regex=output_cols))

Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,6.138071,5.764313,1.259246
1,576,6.139669,5.765659,1.194376
2,384,6.139204,5.765113,1.199099
3,192,6.137501,5.763306,1.219064
4,96,6.134574,5.760337,1.241529
5,48,6.128319,5.753036,1.305715
6,24,6.114873,5.737919,1.42086
7,12,6.085098,5.701169,1.636281
8,6,6.018489,5.616422,1.876313
9,3,5.918405,5.411482,1.77994


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,6.084546,5.798742,1.229414
1,576,6.08645,5.800295,1.157035
2,384,6.086301,5.800096,1.155812
3,192,6.084319,5.798407,1.178628
4,96,6.082008,5.795524,1.204912
5,48,6.074534,5.788623,1.268671
6,24,6.058615,5.772423,1.404744
7,12,6.03168,5.745048,1.56663
8,6,5.954286,5.664161,1.905443
9,3,5.892543,5.513398,1.143659


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,7.117969,6.738079,7.820493
1,576,7.116583,6.736554,7.822101
2,384,7.115844,6.735819,7.823868
3,192,7.113526,6.733359,7.829962
4,96,7.108215,6.72787,7.843417
5,48,7.097139,6.71642,7.870813
6,24,7.072455,6.690959,7.932147
7,12,7.01659,6.632266,8.064154
8,6,6.867037,6.482053,8.374858
9,3,6.329559,5.987389,9.175378


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,7.183807,6.935167,7.312816
1,576,7.18244,6.933758,7.314245
2,384,7.181604,6.933,7.316373
3,192,7.17919,6.930747,7.322635
4,96,7.173834,6.926006,7.335768
5,48,7.162849,6.916276,7.361726
6,24,7.136895,6.892173,7.425125
7,12,7.078968,6.839176,7.558527
8,6,6.925149,6.69792,7.877356
9,3,6.395235,6.210816,8.715597


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,6.040889,7.644175,22.864886
1,576,6.039529,7.641622,22.865088
2,384,6.039256,7.641588,22.865183
3,192,6.038501,7.641333,22.865379
4,96,6.037833,7.640487,22.865782
5,48,6.035165,7.639025,22.866592
6,24,6.0271,7.631796,22.869685
7,12,6.01079,7.616623,22.875936
8,6,5.947494,7.559823,22.900114
9,3,5.612881,7.313653,23.043346


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,7.438157,9.273344,19.964325
1,576,7.436355,9.271827,19.964597
2,384,7.436274,9.271746,19.964672
3,192,7.435166,9.271504,19.965052
4,96,7.434171,9.270402,19.965542
5,48,7.431265,9.268296,19.966721
6,24,7.42602,9.265023,19.969651
7,12,7.405454,9.254471,19.978131
8,6,7.336097,9.197858,20.011267
9,3,6.961802,8.920493,20.250669


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,8.561553,7.437458,11.483087
1,576,8.559432,7.435664,11.485481
2,384,8.559021,7.435237,11.485998
3,192,8.557336,7.433518,11.488414
4,96,8.553857,7.429942,11.494228
5,48,8.54477,7.421367,11.507466
6,24,8.52599,7.402648,11.537118
7,12,8.47919,7.356498,11.609513
8,6,8.340063,7.223993,11.823141
9,3,7.904102,6.753563,12.400316


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,8.446871,7.581062,11.398275
1,576,8.444721,7.579173,11.400781
2,384,8.444202,7.578564,11.401899
3,192,8.442202,7.576835,11.405012
4,96,8.438007,7.57384,11.410919
5,48,8.429272,7.565218,11.424533
6,24,8.40891,7.547488,11.455621
7,12,8.357105,7.501981,11.534418
8,6,8.207643,7.370401,11.754118
9,3,7.730735,6.907276,12.368755


In [46]:
['dim', 'dispersion_positive', 'dispersion_negative', 'extra_distance']

['dim', 'dispersion_positive', 'dispersion_negative', 'extra_distance']

In [52]:
from itertools import permutations



[(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)]