In [1]:
from ast import literal_eval
from configparser import ConfigParser
from io import BytesIO

import click
import comet_ml
import torch
from easydict import EasyDict

import pandas as pd

In [3]:
import sys
sys.path.append('./src')

In [4]:
from src.models.protoconv.lit_module import ProtoConvLitModule
from src.models.protoconv.train import get_dataset

In [5]:
experiment = '51f3d18dcf9a4b99b3a7b73f6989382b' 
weights_path = 'fold_0_epoch=16-val_loss_0=0.4415-val_acc_0=0.8268.ckpt'
fold = 0

In [6]:
config = ConfigParser()
config.read('config.ini')

comet_config = EasyDict(config['cometml'])
comet_api = comet_ml.api.API(api_key=comet_config.apikey)
experiment = comet_api.get(project_name=comet_config.projectname, workspace=comet_config.workspace,
                           experiment=experiment)

In [7]:
dataset = experiment.get_parameters_summary('data_set')['valueCurrent']
kfold_split_id = list(
    filter(lambda x: x['fileName'] == 'kfold_split_indices.csv', experiment.get_asset_list())
)[0]['assetId']
kfold_split_binary = experiment.get_asset(kfold_split_id, return_type="binary")
kfold_split = pd.read_csv(BytesIO(kfold_split_binary)).iloc[fold]

train_index = literal_eval(kfold_split['train_indices'])
val_index = literal_eval(kfold_split['val_indices'])

df_dataset = pd.read_csv(f'data/{dataset}/data.csv')
train_df, valid_df = df_dataset.iloc[train_index], df_dataset.iloc[val_index]

TEXT, LABEL, train_loader, val_loader = get_dataset(train_df, valid_df, batch_size=1, cache=None)
model = ProtoConvLitModule.load_from_checkpoint('checkpoints/'+weights_path, vocab_size=len(TEXT.vocab),
                                                embedding_dim=TEXT.vocab.vectors.shape[1], lr=1, fold_id=fold)

  torch.nn.init.xavier_uniform(self.prototypes.data)


In [82]:
import torch.nn.functional as F

model.eval()
model.cuda()

best_prototype_0 = None
best_distance = float('inf')

with torch.no_grad():
    for X, batch_id in train_loader:
        temp = model.embedding(X).permute((0, 2, 1))
        temp = model.conv1(temp)
        temp = model.prototype(temp) 
        prototypes_distances = temp.squeeze(0) # [Prototype, len(X)-4]

        temp = -F.max_pool1d(-temp, temp.size(2))
        best_distances = temp.view(temp.size(0), -1).squeeze(0) # [Prototype]
        
        if best_distances.cpu().numpy()[5] < best_distance:
            best_distance = best_distances.cpu().numpy()[5]
            best_prototype_0 = prototypes_distances.cpu().numpy()[5], X.cpu().numpy()


In [86]:
words = [TEXT.vocab.itos[j] for j in list(best_prototype_0[1].ravel())]
weights = [0,0] + list(best_prototype_0[0].ravel()) + [0, 0]

In [87]:
print(words)

['<START>', 'poor', 'agatha,', 'poor', 'reader....:', "i've", 'read', 'every', 'book', 'in', 'the', 'series', 'as', 'well', 'as', 'the', 'h.', 'macbeth', 'series.i', "haven't", 'finished', 'this', 'title', 'yet', 'but', 'agree', 'with', 'other', 'reviewers', 'regarding', 'the', 'lack', 'of', 'proofreading.', 'i', 'was', 'even', 'stymied', 'as', 'to', 'why', 'this', 'book', 'would', 'be', 'listed', 'under', '"women-detectives-england-norfolk', 'and', 'not', 'glous.', 'or', 'warwickshire."a.r.', 'and', 'the', 'love', 'from', 'hell', 'is', 'a', 'bit', 'darker', 'than', 'the', 'others', 'in', 'the', 'a.r.', 'series', 'and', 'so', 'far', '(up', 'to', 'p.', '172)', 'there', 'is', 'no', 'humor', 'at', 'all.', '<END>']


In [88]:
print(weights)

[0, 0, 4.052055, 6.3237424, 7.799147, 9.631209, 10.185423, 12.775558, 14.446464, 17.705328, 18.057262, 19.036509, 19.217224, 17.733301, 15.123251, 13.684399, 13.591164, 13.136839, 13.027698, 13.425535, 11.515947, 13.288067, 13.144882, 13.19454, 13.414672, 14.644038, 11.726691, 15.134076, 15.421644, 15.696607, 13.410811, 15.256836, 11.329752, 11.387961, 11.367109, 12.03377, 11.766523, 11.745758, 15.492865, 16.92241, 16.75719, 17.890406, 18.69532, 17.317614, 16.990002, 17.912249, 16.078377, 18.082115, 17.087307, 14.021075, 13.244951, 13.173656, 16.043446, 15.143722, 15.354596, 15.773807, 15.002738, 13.982317, 13.299905, 13.254038, 14.868161, 16.319359, 17.26325, 19.297539, 19.809128, 18.067993, 17.617771, 16.886292, 16.407635, 17.309208, 18.796083, 17.15499, 16.392714, 16.959522, 14.851638, 11.850729, 9.816828, 9.59249, 9.489878, 13.347328, 16.010143, 0, 0]


In [89]:
len(words), len(weights)

(83, 83)

In [90]:
import html
from IPython.core.display import display, HTML

def html_escape(text):
    return html.escape(text)

max_alpha = max(weights)

highlighted_text = []
for word, weight in zip(words, weights):
    highlighted_text.append('<span style="background-color:rgba(135,206,250,' + str(weight / max_alpha) + ');">' + html_escape(word) + '</span>')

        
highlighted_text = ' '.join(highlighted_text)

In [91]:
display(HTML(highlighted_text))