In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

# Use this if running this notebook from within its place in the truera repository.
sys.path.insert(0, "../..")

# Install transformers / huggingface.
# !{sys.executable} -m pip install torch
# !{sys.executable} -m pip install transformers
# !{sys.executable} -m pip install mkl
# !{sys.executable} -m pip install vision

# Or otherwise install trulens.
# !{sys.executable} -m pip install git+https://github.com/truera/trulens.git
# ! {sys.executable} -m pip uninstall trulens -y

from IPython.display import display
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from types import SimpleNamespace
import base64

# Twitter Sentiment Model

[Huggingface](https://huggingface.co/models) offers a variety of pre-trained NLP models to explore. We exemplify in this notebook a [transformer-based twitter sentiment classification model](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment). Before getting started, familiarize yourself with the general Truera API as demonstrated in the [intro notebook using pytorch](intro_demo_pytorch.ipynb).

In [None]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer

# Wrap all of the necessary components.
class TwitterSentiment:
    # MODEL = f"cardiffnlp/twitter-roberta-base-sentiment"
    MODEL = f"distilbert-base-uncased-finetuned-sst-2-english"

    #device = 'cpu'
    # Can also use cuda if available:
    device = 'cuda:0'

    model = AutoModelForSequenceClassification.from_pretrained(MODEL).to(device)
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    #embeddings = model.roberta.embeddings.word_embeddings.weight.detach().cpu().numpy()
    embeddings = model.distilbert.embeddings.word_embeddings.weight.detach().cpu().numpy()
    #embeddings_layer = 'roberta_embeddings_word_embeddings'
    embeddings_layer = 'distilbert_embeddings_word_embeddings'

    id_of_token = tokenizer.get_vocab()
    token_of_id = {v: k for k, v in id_of_token.items()}

    labels = [
        'negative', 
        #'neutral', 
        'positive'
    ]

    NEGATIVE = labels.index('negative')
    #NEUTRAL = labels.index('neutral')
    POSITIVE = labels.index('positive')

def create_gender_baseline(embeddings):
    gender_vector = np.frombuffer(base64.b85decode(
        b'w0E7n-g=omA8Z@E5|za~=n6Nz<?^#TCW>smy&qV-C2$_S(W<Dr1=8p|Kjs}g<_wTLnu3G97JfH96>*F_ubVAAMbc<J|DClw0*nnk?urb&Z^*Vh_%FUZ%Ym9Y?-)8foj0jHfrd0Z>lD'
        b'sBixnR|m54vPX_HgE*Y6TN8skO0P5i69h@nwEAn<EE?$ZywF(9G5e@P8JD(mh$V>v)Q;m}w-f6>>wUGJ$p3@m88baPa@kA)IE>S7wbTv^XM&9=R~NoS!wrH)9v)#+Tk6dZfK8Qq6Hb'
        b'AQ*o2O8ZxWanNyhf9e)>3KIiuwNcMP{E$O4LJn8-MT+L&AgwyU*L7T`cCq^XOz!80hU6$l74SJaU3AL+Hj>kn$9M@TIf$ZovbfCe6tF@I-(Cfx9L;7<>Uyx4_^tqmVOJpIvC(OIMQj'
        b'opW0+R4{^RcSNs{h{pSol!?@GDx|B^l+vTV`LYqdsic7ye&;T|)QpfeY))$q$uw6Yp1#z~$1B)X$BAmFq3)F<WM5$-I<{nKw(rS-A-GtUWdX?6^a{Up#9>7ICNUCr>d%e;-KqZnq;W'
        b'E%Xv%wU-1}YuAWLl)W42%UlH;Gg{JZfS+JO3rSh8~Z+@7aevyKxk~b}(bTse~uIfzNWgvZQ=G-S7&%U__%hkrf0z=TwM1FE$aovT`sz(mV0IC8&J8Q+%_$B9JG&wgM_WF?|fZ+I1hj'
        b'nxa3w1Wu8?O}vu5bXYRI!h6m<bKNpMqm&0dm}X`>dkE{h4(klP)@ZvuvQ-s5toyq>Hn4ZR{4i@h$d=VS32i02V7DJV+@<`yYD|&6=QWr;AoxwZXrd`SJbx=Z1P;%;hiVGDArSYvcuy'
        b'g`O2unDNnpIZ825?2@htMZm1CE?>%qIcB)W3E=%VmES$~AQOdZj??@7lyJUQ08-X&waL`0^$nQ#`plAE7B=MRiL`AjmsQ&>tpd(j;{TOp`Cdf$ybM@5u9$Sy%VqLq+5KUX3=7u`2KJ'
        b'PQduY(*BmV7F7eHr&rVKq8I1^x^|OlDdeyC%(J91foMd`o)2}Qk!<XC80ySYnYb3Ifnr~Stlhua*QxNW9OhdV*E?IQb6fEV5z`7u9svz8$O9WsKW%k9t`QcD!s71)sSkuUj-07je&~'
        b'2rO1K3Ci81NPImG<m&KR7XU74(i3L<Wx4oOZEWEY60WDrUryMywZDDsi9`71Gb>s28?jXjzPQFAvv2;kgF*)2k<fDH*kD&Uy%rA_*n~}6U*{~bEVR-305e5)F5{;ZZzj%VYjrScpr{'
        b'^_2&F~YwQ~eP>d|R8n<DtI1WbE!dwKoDi+H=ahsINi1jhS1!Z}ot@)}4|&#1YdxZupKpJGKtJ3S0+09N&sPsHeF+bRa%FJ-YThjqB_@2ih||h-9q26GFN>@hS1Sy^%k?CN4g^_unkM'
        b'ykn0#d%)AX9k<53dv0w!s8ZZKnpoaEW<XxOqH>PCh1NqnK=(Slv3p@WBdEl?Lk|Z%L>tOH>vCT^7;lEWPj${aIW4O^EnR**hzz_u!UP$;zgft;f>aMZCh|wU1+<F2)pR{P^fn<qt*?'
        b'W;!vV-VhjGF>{%P8~LCqgMB*7BBeZy`%TZtvTr6v`;4cP;|EG|gBCN0oBM{VUiPFC@}L;_bljob`7ghficiOMHDdCLbqrc|uF>SYPMI=+g$blLGd)9b}Na1;7GEv03>b3tM~MFx~TO'
        b'dbQg27fBM5%Iviv8kWDdiMJ~`w0@gLVCD8v3rKS=~y>C=tW$-3vRkRZmbqOGxw!Enj!N%MO8?>l&$+aH2T*(Z@(ctd__jQgATzv#Nsc#<}q2kSnNZ+&6JzG><T|TO)=g)c|hwsl8rg'
        b'L@E(S|zTn}zdZ6w*%a$fRbWmA5*=QO)RA?eRDM+3>44m*h7uh$wB&e7@oIeD;R^}7EC9D2CU%xXwdBS@=#>?$I_=2;%Q=*nV3BugH3#%Qye@gK@I+oizM7q7bj|jy)g6>+ol6yQoH|'
        b'0dUfUjRYTDQu)ymgAacMS(UTx2i2Y9zuvw+pL0<?K(rP)?&eNeN=TrE&B;G75=3x&7C?Z~IoffIw(G2pyrk1vu@zru-5;gR2(2y6#B4ouq<1G4e3IsIbMn_Lzmdrukbvy$O%I{7sWQ'
        b'8otW9#)4bCO#_3wZ@qWCRoE=O>_1&S9)_(t7GmN&3Ev<+wPVmc1%24OfUJN#AgK~OV9_<atBn&pA_@b&MjbxAlyU7m3*T})fC$sPeEnm+a$ttMz_RK*Hkw&Il}$1|SL#|k={gNOLT7'
        b'%w`NXa~x3^KfkOtH|2zg9BJTpbS^PN*Y0+y9Lk+I#o!V@b!7!`QCT+yApL&2*&<?fU`bR*BYVHPMoa7GinRo11vGQ+IAN(IupGtm~ktf^}}^X;WPp+Vq01>F=puO!zz^GSodXq*o{$'
        b'154U9ank021r*tKDoNPvz2<hAv3JJtX-BoPzcXEr`#AkASot1TD1T?uJNtBC*k(I@OIj~LD1Da2Im*NPR0Scbs{Of%5&ek?qtclU`yyc<|isV=H>~#w{=>*<okEMbb}W?k6DF1K;2Z'
        b'lI(WRh5N`Xtr6vPC4)ip=lVs|=HJyGtd#-XloS3FOjSjE8o{A~EOqLYA1EzVsmRB%6mlU8pGpU9>BJW~7s$zvatYA{Sf9Ik+LO-ZHOzU2~gCusmXzQ*#E0_$rUvF<b{S-XC-^Io~^N'
        b'1k5!qC{fSstD|NI`Nu6%2|!EqzD5$%qL&+s`6A7cF7E2TeJ>RM*=*QSr&V2w8PKsZ!Lu*jN2L{Z6Pn9DT{W2z3^{s|nw`yYjO<;zqnYjSEM;>H$$bZwLfE5;l-Lw*CFP)x3>7v9`dx'
        b'0~}mDDz(-;4pzgx>+>l+=xi80Z2w+8-BYEzGZ;@iQ)+;{syNd;$kJK8dWQ2nIf(u|&kuw=1!5LF>~X(6Yfh59njrN&DWo{PlNHc8P6ZA;;sMw^575RtL@^vZAZK*EJ@lQuw1M)x+I)'
        b'pP@?qAy%i37HVieFiiwZ40=m$SN;s72zheC<GmBJsra?Gf_e%Uv@b`YmKI<MF~qY!*LRr?k^9~42n-`Ao&5x<$ds^2y}qF3WQFq=_4YN8Ok6zu@KQ7sC+o@|c1CX7Nn<~;{Hem?oU3'
        b'+vH5nqP}OU=7SXvVnQMejpybwUH1#FqI;_F3mnX@0J-os7qWuj3-b$%vIn#hm+mB4oCDnW!nI~Ep3av?9#P7d(z`Px^N=BR*6@=iA{&RpM-fln~sM)+hLcy{EU!2>`|mVw}V1GBhEF'
        b'w8mSLGAZ(<(l+~O(WyxebA3^p!rg5FSEIE=q3gEas(_qFt9_wR0+*px3p!s<`)Wohk$uLMfLtkw?@2lawIkhCb5Ht+E1-tG%4I)0hu`zBumr{v6d6IoSw?|LC$PqNV5-X^?>Ub_aUQ'
        b'SXulh5G1S+`6*^bK0Q)eIB8U9A&6&8AX4o>;ZKHtz_%OxH2Iq^sGxJ2V8nnv&AI$#;Z3x07$YPEql@Gg_rRHM86~!#x+hd#aqhj!}j?!U~kUN3{q&Oy5_%o*fcBXZZEIb$UuWZ;`OP'
        b'Yfi<y0$!xNJP{GSuB4s3wzAp0$6}tk013#vggP#~z(GDeWs4O(iDoi9cL!U&!?|`mS&Q+!HfYYgXbGviTSf!DyMxZVvWCz+$Z55_gMT1B-f<{Bwfeq1>%<DZ!qw_K-mEe`Sw*<Kd2C'
        b'WWrcBp7IArs@R+qNBDJ~1Ws)h%=0yTp>H2O+CfcnS0&#2kEh|B^!tJ@Gg=M$7XAEe|vFt%+xea3&iN)mawdt)iR&2*(brfzw??jSw8c@L;OV&nilmzJBnT?<&ePBw_WoK+G%qjXfg>'
        b'lFyRNg75xsP4YIU=dwC{#G@-TTw?n*N8j4t<W94FJxxDEEjUUD!i*aI2t&**R)|gLWuml<tGrmLkA)~f$gh3hQgdYn`_@ZQF;WrhH1OJB(Wg8uQsJS'
    ), dtype='float32')

    #def normalize(v):
    #    return v / np.linalg.norm(v, ord=2)

    def normalize_many(v):
        return v / np.linalg.norm(v, axis=1, ord=2)[:, np.newaxis]

    all_embs_norm = normalize_many(embeddings)
    baseline_penalties = np.abs(np.dot(all_embs_norm, gender_vector))

    def token_id_baseline(token_id):
        return np.argmax(
            np.abs(
                np.dot(
                    normalize_many(all_embs_norm[token_id] - all_embs_norm +
                                   0.000000001), gender_vector)) -
            0.55 * baseline_penalties)

    return SimpleNamespace(**locals())

task = TwitterSentiment


In [None]:
token_id_baseline = create_gender_baseline(task.embeddings).token_id_baseline

test1 = token_id_baseline([task.id_of_token['poor']])
task.token_of_id[test1]

This model quantifies tweets (or really any text you give it) according to its sentiment: positive, negative, or neutral. Lets try it out on some examples.

In [None]:
sentences = ["I'm so happy!", "I'm so sad!", "I cannot tell whether I should be happy or sad!", "meh"]

# Input sentences need to be tokenized first.

inputs = task.tokenizer(sentences, padding=True, return_tensors="pt").to(task.device) # pt refers to pytorch tensor

# The tokenizer gives us vocabulary indexes for each input token (in this case,
# words and some word parts like the "'m" part of "I'm" are tokens).

print(inputs)

# Decode helps inspecting the tokenization produced:

print(task.tokenizer.batch_decode(torch.flatten(inputs['input_ids'])))
# Normally decode would give us a single string for each sentence but we would
# not be able to see some of the non-word tokens there. Flattening first gives
# us a string for each input_id.

Evaluating huggingface models is straight-forward if we use the structure produced by the tokenizer.

In [None]:
outputs = task.model(**inputs)

print(outputs)

# From logits we can extract the most likely class for each sentence and its readable label.

predictions = [task.labels[i] for i in outputs.logits.argmax(axis=1)]

for sentence, logits, prediction in zip(sentences, outputs.logits, predictions):
    print(logits.to('cpu').detach().numpy(), prediction, sentence)

# Model Wrapper

As in the prior notebooks, we need to wrap the pytorch model with the appropriate Trulens functionality. Here we specify the maximum input size (in terms of tokens) each tweet may have.

In [None]:
from trulens.nn.models import get_model_wrapper
from trulens.nn.quantities import ClassQoI
from trulens.nn.attribution import IntegratedGradients
from trulens.nn.attribution import Cut, OutputCut
from trulens.utils.typing import ModelInputs

task.wrapper = get_model_wrapper(task.model, device=task.device)

# Attributions

In [None]:
infl = IntegratedGradients(
    model = task.wrapper,
    # doi_cut=Cut('roberta_embeddings_word_embeddings'),
    doi_cut=Cut(task.embeddings_layer),
    qoi=ClassQoI(task.POSITIVE),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

A listing as above is not very readable so Trulens comes with some utilities to present token influences a bit more concisely. First we need to set up a few parameters to make use of it:

In [None]:
from trulens.visualizations import NLP

def distilibert_token(x):
    tok = task.tokenizer.decode(x)
    if tok.startswith("##"): # token starts with "##" to denote a word postfix
        return tok[2:]
    else:
        return " " + tok # if not a postfix, add space better indicate a complete word separation

V = NLP(
    wrapper=task.wrapper,
    labels=task.labels,
    #decode=lambda x: task.tokenizer.decode(x),
    decode=distilibert_token,
    tokenize=lambda sentences: ModelInputs(kwargs=task.tokenizer(sentences, padding=True, return_tensors='pt')).map(lambda t: t.to(task.device)),
    # huggingface models can take as input the keyword args as per produced by their tokenizers.

    input_accessor=lambda x: x.kwargs['input_ids'],
    # for huggingface models, input/token ids are under input_ids key in the input dictionary

    output_accessor=lambda x: x['logits'],
    # and logits under 'logits' key in the output dictionary

    hidden_tokens=set([task.tokenizer.pad_token_id])
    # do not display these tokens
)

print("QOI = POSITIVE")
display(V.tokens(sentences, infl))

# Baselines

We see in the above results that special tokens such as the sentence end **&lt;/s&gt;** contributes are found to contribute a lot to the model outputs. While this may be useful in some contexts, we are more interested in the contributions of the actual words in these sentences. To focus on the words more, we need to adjust the **baseline** used in the integrated gradients computation. By default in the instantiation so far, the baseline for each token is a zero vector of the same shape as its embedding. By making the basaeline be identicaly to the explained instances on special tokens, we can rid their impact from our measurement. Trulens provides a utility for this purpose in terms of `token_baseline` which constructs for you the methods to compute the appropriate baseline. 

In [None]:
from trulens.utils.nlp import token_baseline

inputs_baseline_ids, inputs_baseline_embeddings = token_baseline(
    keep_tokens=set([task.tokenizer.cls_token_id, task.tokenizer.sep_token_id]),
    # Which tokens to preserve.

    replacement_token=task.tokenizer.pad_token_id,
    # What to replace tokens with.

    input_accessor=lambda x: x.kwargs['input_ids'],

    ids_to_embeddings=task.model.get_input_embeddings()
    # Callable to produce embeddings from token ids.
)

We can now inspect the baselines on some example sentences. The first method returned by `token_baseline` gives us token ids to inspect while the second gives us the embeddings of the baseline which we will pass to the attributions method.

In [None]:
print("originals=", task.tokenizer.batch_decode(inputs['input_ids']))

baseline_word_ids = inputs_baseline_ids(model_inputs=ModelInputs(args=[], kwargs=inputs))
print("baselines=", task.tokenizer.batch_decode(baseline_word_ids))

In [None]:
infl_positive_baseline = IntegratedGradients(
    model = task.wrapper,
    resolution=50,
    baseline = inputs_baseline_embeddings,
    doi_cut=Cut(task.embeddings_layer),
    qoi=ClassQoI(task.POSITIVE),
    qoi_cut=OutputCut(accessor=lambda o: o['logits'])
)

print("QOI = POSITIVE WITH BASELINE")
display(V.tokens(sentences, infl_positive_baseline))

In [None]:
from datasets import load_dataset

imdb_train = load_dataset("imdb", "plain_text", split="train")
imdb_test = load_dataset("imdb", "plain_text", split="test")
rotten_train = load_dataset("rotten_tomatoes", split="train")
rotten_test = load_dataset("rotten_tomatoes", split="test")

In [58]:
def tokenize(portion):
    return task.tokenizer.batch_encode_plus(
        portion,
        add_special_tokens=True,
        return_attention_mask=False,
        max_length=512,
        truncation=True
    )['input_ids']

import os
os.environ['TOKENIZERS_PARALLELISM'] = '0'

import multiprocessing as mp
p = mp.Pool(24)

def toks_of_texts(texts):
    toks = p.map(tokenize, [texts[1000*i: 1000*(i+1)] for i in range(len(texts)//1000)])
    all = np.array([i for tok in toks for t in tok for i in t ])

    return all

def dists_of_texts(texts):
    all = toks_of_texts(texts)   

    counts = np.zeros(task.tokenizer.vocab_size)
    total = len(all)
    for i in all:
        counts[i] += 1

    dist = counts / total

    return counts, dist

def tops_of_texts(texts, n = 10):
    counts, dist = dists_of_texts(texts)

    return tops_of_dists(counts, dist, n=n)

def tops_of_dists(c, d, n=10):
    sortindex = np.argsort(d)
    top = []

    for idx in sortindex[0:n]:
        top.append((idx, c[idx], d[idx], task.tokenizer.decode(idx)))

    crest_pos = 0
    crest_neg = 0
    drest_pos = 0
    drest_neg = 0

    for idx in sortindex[n:-n]:
        if c[idx] >= 0:
            crest_pos += c[idx]
            drest_pos += d[idx]
        else:
            crest_neg += c[idx]
            drest_neg += d[idx]

    top.append((-1, crest_neg, drest_neg, "*"))
    top.append((-1, crest_pos, drest_pos, "*"))

    for idx in sortindex[-n:]:
        top.append((idx, c[idx], d[idx], task.tokenizer.decode(idx)))

    return top

In [59]:
c1, d1 = dists_of_texts(imdb_train['text'])
c2, d2 = dists_of_texts(rotten_train['text'])
top = tops_of_dists(c1 - c2, d1 - d2, n = 20)
sum([t[2] for t in top])

-4.6058643010660205e-14

In [None]:
from matplotlib import transforms as trans

def plotdist(top, l1, l2):
    fig, ax = plt.subplots(2,1, figsize=(8,6))

    # ax[0].grid()
    ax[1].bar(x=range(len(top)), width=0.8, height=[t[2] for t in top], alpha=1.0, color="black", label=f"{l1} - {l2}")
    ax[0].bar(x=np.arange(len(top)) + 0.2, width=0.4, height=[d1[t[0]] if t[0]>=0 else 0 for t in top], alpha=1.0, color="green", label=l1)    
    ax[0].bar(x=np.arange(len(top)) - 0.2, width=0.4, height=[d2[t[0]] if t[0]>=0 else 0 for t in top], alpha=1.0, color="blue", label=l2)
    ax[0].set_xticks(ticks=range(len(top)), labels=[t[3] for t in top], rotation=90)
    ax[1].set_xticks(ticks=range(len(top)), labels=[str(t[0]) if t[0]>=0 else "*" for t in top], rotation=90)

    # ax[0].grid()
    # ax[1].grid()
    ax[1].plot([0,len(top)-0.6], [0,0], lw=1, color='black')

    ax[0].set_ylabel("prob.")
    ax[1].set_ylabel("prob. diff")

    fig.legend()
    fig.tight_layout()

plotdist(top, l1='imdb', l2='rotten tomatoes')

In [None]:
c1, d1 = dists_of_texts(imdb_train['text'])
c2, d2 = dists_of_texts(imdb_test['text'])
top = tops_of_dists(c1 - c2, d1 - d2, n = 20)
# top

In [None]:

def plotdist(d1, d2, top, l1, l2):
    fig, ax = plt.subplots(2,1, figsize=(8,6))

    # ax[0].grid()
    ax[1].bar(x=range(len(top)), width=0.8, height=[t[2] for t in top], alpha=1.0, color="black", label=f"{l1} - {l2}")
    ax[0].bar(x=np.arange(len(top)) + 0.2, width=0.4, height=[d1[t[0]] if t[0]>=0 else 0 for t in top], alpha=1.0, color="green", label=l1)    
    ax[0].bar(x=np.arange(len(top)) - 0.2, width=0.4, height=[d2[t[0]] if t[0]>=0 else 0 for t in top], alpha=1.0, color="blue", label=l2)
    ax[0].set_xticks(ticks=range(len(top)), labels=[t[3] for t in top], rotation=90)
    ax[1].set_xticks(ticks=range(len(top)), labels=[str(t[0]) if t[0]>=0 else "*" for t in top], rotation=90)

    #ax[0].grid()
    #ax[1].grid()

    ax[1].plot([0,len(top) - 0.6], [0,0], lw=1, color='black')

    ax[0].set_ylabel("prob.")
    ax[1].set_ylabel("prob. diff")

    fig.legend()
    fig.tight_layout()

plotdist(d1=d1, d2=d2, top=top, l1='imdb train', l2='imdb test')

In [None]:
# toks = torch.tensor(toks_of_texts(imdb_train['text'])).to("cpu")
#task.model.roberta.embeddings.word_embeddings(toks)

In [None]:
# embs = task.model.roberta.embeddings.word_embeddings.to("cpu")(toks.to("cpu"))

In [None]:
# embs = embs.detach().numpy()

In [None]:
# plt.hist2d(embs[:,0], embs[:,1])

In [None]:
%matplotlib agg
#from sklearn.manifold import TSNE
from tsnecuda import TSNE as TSNE
# plt.ioff()
dir = Path("/home/piotrm/vistest")
dir_images = dir / "images"
dir_data = dir / "data"

from IPython.display import clear_output, display
# from MulticoreTSNE import MulticoreTSNE as TSNE
# import umap 

In [None]:
embedder = task.model.roberta.embeddings.word_embeddings
all_embs = embedder.weight.detach().cpu().numpy()#[0:50000]
len(all_embs)

In [None]:
import re
match_iter = re.compile(r'^.*tsne([0-9]+)\.npy')
# man = TSNE(n_jobs=48)#(n_iter=250)
#man = umap.UMAP(n_epochs=1)
man = TSNE(n_iter=1)
iters = 0
emb = None
# dir(man)
# man.fit(all_embs)
for file in dir_data.iterdir():
    matches = match_iter.fullmatch(str(file))
    if matches:
        matched_iter = int(matches.group(1))
        if matched_iter > iters:
            iters = matched_iter
            emb = np.load(file, allow_pickle=True)

print(f"last iter loaded: {iters}")

In [None]:
def do_iter():
    global iters, emb
    emb = man.fit_transform(X=all_embs, y=emb)
    iters += 1

def do_show():
    global emb, dir_images, dir_data, iters
    clear_output(wait=True)
    fig, ax = plt.subplots(1,1, figsize=(10,10))
    ax.scatter(x=[e[0] for e in emb], y=[e[1] for e in emb], s=1)
    display(fig)

    fig.savefig(dir_images / f"tsne{iters:05}.png")
    emb.dump(dir_data / f"tsne{iters:05}.npy")

do_show()

In [None]:
for i in range(1000):
    do_show()
    do_iter()    

In [None]:
plt.show()

In [None]:
display(fig)

In [None]:
import tsnecuda

In [None]:
tsnecuda.test()