In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=5

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=5


In [2]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
OUT_PATH: str = '../results/manifolds'

In [4]:
DATA_PATH: str = '../data/imdb'
DATASETS: list = [
    ('train', 'sample.train'),
    ('test', 'sample.test')
]

In [5]:
DIMS: list = [768, 576, 384, 192, 96, 48, 24, 12, 6, 3]
MODELS: list = [
    ('base', 'bert-base-uncased'),
    ('textattack', 'textattack/bert-base-uncased-imdb'),
    ('fabriceyhc', 'fabriceyhc/bert-base-uncased-imdb'),
    ('wakaka', 'Wakaka/bert-finetuned-imdb')
]

In [6]:
### Load Datasets into memory

In [7]:
data_config: dict = {
    'polarities': {
      "negative": 0,
      "positive": 1
    },
    'data_label': 'text',
    'target_label': 'sentiment'
}

In [8]:
from typing import Dict
from modules import Data

datasets: Dict[str, Data] = {
    label: Data(data_path=f"{DATA_PATH}.{name}.csv", **data_config)
    for label, name in DATASETS
}

In [9]:
for label, dataset in datasets.items():
    display(dataset.data)
    display(dataset.data['sentiment'].value_counts(normalize=True))

Unnamed: 0,text,sentiment
0,Aside from the horrendous acting and the ridic...,negative
1,Can such an ambient production have failed its...,positive
2,Oh-so-familiar comedy story about low-key nice...,negative
3,"In the beginning of this film, one of the comm...",positive
4,"Yes, some plots are a bit hard to follow, and ...",positive
...,...,...
1240,"I have watched 3 episodes of Caveman, and I ha...",negative
1241,"It's a very good movie, not only for the fans ...",positive
1242,"Modern, original, romantic story.Very good act...",positive
1243,The Straight Story is a multilevel exploration...,positive


positive    0.519679
negative    0.480321
Name: sentiment, dtype: float64

Unnamed: 0,text,sentiment
0,"Out of the top 24 lesbian films in my library,...",positive
1,I went to see this movie with a lady freind of...,positive
2,I have to say that I really liked UNDER SIEGE ...,negative
3,I paid attention and enjoyed the very rich exp...,positive
4,"A bad Quentin Tarantino rip off, at least I ho...",negative
...,...,...
1235,A group of obnoxious teens go to a former fune...,positive
1236,This is one of the movies having made signific...,positive
1237,I approached this movie with the understanding...,negative
1238,I watched fantabulosa! because over the last f...,positive


positive    0.504839
negative    0.495161
Name: sentiment, dtype: float64

In [10]:
### Load Encoder into Memory

In [11]:
encoder_config: dict = {
    'layers': [11]
}

In [12]:
from modules import Encoder

encoders: Dict[str, Encoder] = {
    label: Encoder({**{'model': ref}, **encoder_config})
    for label, ref in MODELS
}

In [13]:
### Compute manifolds and measure centroid distance and cluster dispersion

In [14]:
import numpy as np
import pandas as pd

In [15]:
from sklearn.manifold import MDS

def manifold_reduction(data: np.ndarray, dim: int = 3) -> np.ndarray:
    return MDS(n_components=dim).fit_transform(data)

In [16]:
from scipy.spatial import distance

def metric_computation(record: dict, groups: pd.core.groupby.GroupBy) -> None:

    # compute centroid means and dispersion for each cluster
    for label, group in groups:
            record[f'centroid_{label}'] = np.mean(np.stack(group['reduced_embeds'].tolist(), axis=0), axis=0).tolist()
            record[f'dispersion_{label}'] = np.sum(distance.pdist(group['reduced_embeds'].tolist()))

    record['distance'] = distance.cdist([record['centroid_positive']], [record['centroid_negative']]).item(0)

In [17]:
from typing import Generator
import torch

def reduce_analyse(data: pd.DataFrame, col: str, dim: list, default_dim: int = 768) -> Generator:

    embed_col: np.ndarray = torch.stack(data[col].tolist()).numpy()

    for d in dim:

        # create record to keep row data
        record: dict = {'dim': d}

        # if reduction size is equal to encoder output dim, skip manifold reduction
        if d == default_dim:
            data['reduced_embeds'] = list(embed_col)
        else:
            data['reduced_embeds'] = list(manifold_reduction(embed_col, dim=d))

        metric_computation(record, data.groupby(dataset.target_label))

        yield record

In [18]:
output_cols: list = ['dim', 'dispersion_positive', 'dispersion_negative', 'distance']

In [19]:
results: Dict[str, pd.DataFrame] = {}

In [20]:
for enc_label, encoder in encoders.items():
    for data_label, dataset in datasets.items():
        encoder.df_encode(dataset.data, col=dataset.data_label)
        results[f'{data_label}.{enc_label}'] = pd.DataFrame.from_records(
            list(reduce_analyse(
                dataset.data, encoder.col_name, DIMS,
                default_dim=encoder.dim)
            )
        )

                                                                                                                                                                                                                                                                              

In [21]:
for label, dataset in results.items():
    results[label][output_cols].to_csv(f'{OUT_PATH}/analysis.{label}.csv')
    display(results[label][output_cols])

Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,1846608.0,1474544.0,1.191455
1,576,1848007.0,1475781.0,1.104581
2,384,1847754.0,1475565.0,1.115865
3,192,1847403.0,1475146.0,1.117858
4,96,1845973.0,1473852.0,1.167782
5,48,1843396.0,1471407.0,1.218857
6,24,1838040.0,1465957.0,1.328104
7,12,1826143.0,1454249.0,1.517147
8,6,1812401.0,1435127.0,1.248237
9,3,1768678.0,1375959.0,1.019506


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,1682362.0,1571783.0,1.254312
1,576,1683568.0,1572977.0,1.182686
2,384,1683447.0,1572856.0,1.180682
3,192,1682942.0,1572307.0,1.195615
4,96,1681746.0,1571151.0,1.225906
5,48,1679057.0,1568448.0,1.286972
6,24,1674267.0,1563556.0,1.354284
7,12,1661568.0,1550616.0,1.593289
8,6,1634855.0,1520840.0,1.905295
9,3,1603535.0,1468301.0,1.333726


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,2136783.0,1716567.0,7.983203
1,576,2137147.0,1716806.0,7.985439
2,384,2136869.0,1716547.0,7.986923
3,192,2135925.0,1715649.0,7.993095
4,96,2133923.0,1713791.0,8.005571
5,48,2129540.0,1709696.0,8.032312
6,24,2120059.0,1700843.0,8.090835
7,12,2098074.0,1680481.0,8.221065
8,6,2042790.0,1630799.0,8.522049
9,3,1881023.0,1487288.0,9.3137


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,1997442.0,1865335.0,7.293892
1,576,1997749.0,1865593.0,7.296332
2,384,1997455.0,1865311.0,7.298676
3,192,1996618.0,1864474.0,7.304216
4,96,1994685.0,1862614.0,7.317522
5,48,1990453.0,1858556.0,7.346356
6,24,1981440.0,1849680.0,7.40848
7,12,1960951.0,1829719.0,7.54179
8,6,1907166.0,1777770.0,7.878604
9,3,1757716.0,1640257.0,8.683845


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,1841314.0,1778661.0,23.344327
1,576,1841794.0,1779128.0,23.344435
2,384,1841676.0,1779056.0,23.344594
3,192,1841426.0,1778840.0,23.344719
4,96,1840792.0,1778376.0,23.345179
5,48,1839467.0,1777263.0,23.346219
6,24,1836317.0,1774793.0,23.34869
7,12,1828286.0,1768427.0,23.355068
8,6,1805838.0,1749138.0,23.376483
9,3,1707622.0,1662055.0,23.559772


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,2011053.0,2339915.0,19.961293
1,576,2011486.0,2340341.0,19.96168
2,384,2011368.0,2340235.0,19.961977
3,192,2011151.0,2340037.0,19.961979
4,96,2010568.0,2339567.0,19.962522
5,48,2009287.0,2338370.0,19.9638
6,24,2006147.0,2335904.0,19.966276
7,12,1998519.0,2329213.0,19.974889
8,6,1974660.0,2308466.0,20.006246
9,3,1879749.0,2233494.0,20.204602


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,2522510.0,1890411.0,11.64789
1,576,2522973.0,1890731.0,11.650852
2,384,2522782.0,1890579.0,11.651119
3,192,2522103.0,1889882.0,11.653348
4,96,2520431.0,1888266.0,11.660106
5,48,2516740.0,1884904.0,11.674473
6,24,2509036.0,1877356.0,11.702982
7,12,2489279.0,1859478.0,11.784167
8,6,2438433.0,1812340.0,11.991117
9,3,2300045.0,1676438.0,12.581335


Unnamed: 0,dim,dispersion_positive,dispersion_negative,distance
0,768,2327735.0,2037541.0,11.71839
1,576,2328194.0,2037943.0,11.720677
2,384,2327959.0,2037697.0,11.721993
3,192,2327250.0,2036991.0,11.724361
4,96,2325808.0,2035462.0,11.72977
5,48,2322370.0,2032033.0,11.741664
6,24,2314735.0,2024343.0,11.77179
7,12,2295898.0,2005609.0,11.851487
8,6,2244949.0,1957892.0,12.057583
9,3,2108218.0,1825098.0,12.633208
