In [47]:
import typing

import cltrier_nlp as nlp

import numpy as np
import pandas as pd

from datasets import load_dataset

from sklearn.decomposition import PCA

import plotly
import plotly.graph_objs as go

import tqdm
import requests

In [46]:
encoder_type: str = "local"
SAMPLE_SIZE: int = 500

In [39]:
dataset = pd.DataFrame(load_dataset("stanfordnlp/imdb").shuffle()["train"][:SAMPLE_SIZE])
dataset

Unnamed: 0,text,label
0,"How viewers react to this new ""adaption"" of Sh...",0
1,"**SPOILERS**This was an ugly movie, and I'm so...",0
2,Following the release of Cube 2: Hypercube (20...,0
3,"I loved this movie, I'll admit it. This has to...",0
4,'Ninteen Eighty-Four' is a film about a futuri...,0
...,...,...
495,In my opinion this movie advances no new thoug...,0
496,Daraar got off to a pretty good start. The fir...,0
497,"I don't like using the word ""awful"" to describ...",0
498,This is another of the many B minus movies tag...,0


In [44]:
def local_encoding(data: pd.Series) -> typing.List:
    return [
        embed.detach().numpy() 
        for embed in 
        nlp.encoder.EncoderPooler()(
            nlp.encoder.EncoderPooler(data.tolist()), 
            form="sent_cls")
    ]

In [45]:
def remote_encoding(data: pd.Series) -> typing.List:
    embeds: typing.Dict[str, np.ndarray] = {}

    for index, value in tqdm.tqdm(data.items(), total=len(dataset)):
    
        if index in embeds.keys():
            continue
            
        try: 
            embed = np.array(requests.post(
                'https://inf.cl.uni-trier.de/embed/',
                json={'prompt': value}
            ).json()["response"])
            
        except Exception as _e:
            display(_e)
            embed = None
        
        embeds[index] = embed

In [40]:
ENCODER: typing.Dict[str, typing.Callable] = {
    "local": local_encoding,
    "remote": remote_encoding,
}


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



In [41]:
dataset["embeds"] = ENCODER[encoder_type](dataset["text"])
dataset.head()

Unnamed: 0,text,label,embeds
0,"How viewers react to this new ""adaption"" of Sh...",0,"[-0.97060066, -1.0717797, -5.1491156, -1.05498..."
1,"**SPOILERS**This was an ugly movie, and I'm so...",0,"[-0.63698316, -1.3964503, -4.581432, -0.651500..."
2,Following the release of Cube 2: Hypercube (20...,0,"[0.26786214, -1.1600063, -4.3904576, -0.463652..."
3,"I loved this movie, I'll admit it. This has to...",0,"[-1.1490418, -0.3227527, -4.8551035, -1.191128..."
4,'Ninteen Eighty-Four' is a film about a futuri...,0,"[0.23239323, -0.74430376, -3.9842298, -1.20341..."


In [42]:
dataset["embeds_n3"] = list(PCA(n_components=3).fit_transform(dataset["embeds"].tolist()))
dataset.head()

Unnamed: 0,text,label,embeds,embeds_n3
0,"How viewers react to this new ""adaption"" of Sh...",0,"[-0.97060066, -1.0717797, -5.1491156, -1.05498...","[0.3445591365699556, 2.5830269435245343, 0.129..."
1,"**SPOILERS**This was an ugly movie, and I'm so...",0,"[-0.63698316, -1.3964503, -4.581432, -0.651500...","[2.20437178883263, -0.9610163790752045, -1.773..."
2,Following the release of Cube 2: Hypercube (20...,0,"[0.26786214, -1.1600063, -4.3904576, -0.463652...","[0.28819386307322586, -2.7382544045099246, -2...."
3,"I loved this movie, I'll admit it. This has to...",0,"[-1.1490418, -0.3227527, -4.8551035, -1.191128...","[1.0786834937328944, 3.5832474025486167, -0.08..."
4,'Ninteen Eighty-Four' is a film about a futuri...,0,"[0.23239323, -0.74430376, -3.9842298, -1.20341...","[3.4755524299036824, -1.1269907083641717, -2.7..."


In [43]:
plotly.offline.iplot(
    go.Figure(
        data=[
                go.Scatter3d(
                **{
                    dim: group["embeds_n3"].str[idx].tolist() 
                    for idx, dim in enumerate(["x", "y", "z"])
                },
                text = group["text"],
                name=label,
                mode='markers',
                marker={
                    'size': 6,
                    'opacity': 0.8,
                }
            )
            for label, group in dataset.groupby("label")
        ], 
        layout=go.Layout(
            margin={'l': 0, 'r': 0, 'b': 0, 't': 0},
            height=820
        )
    )
)