In [1]:
from ast import literal_eval
from configparser import ConfigParser
from io import BytesIO

import click
import comet_ml
import torch
from easydict import EasyDict

import pandas as pd

In [2]:
import sys
sys.path.append('./src')

In [3]:
from src.models.protoconv.lit_module import ProtoConvLitModule
from src.models.protoconv.train import get_dataset

In [4]:
experiment = '60647a521a164486831a2b186814edae' 
weights_path = 'fold_0_epoch=10-val_loss_0=0.4314-val_acc_0=0.8127.ckpt'
fold = 0

In [5]:
config = ConfigParser()
config.read('config.ini')

comet_config = EasyDict(config['cometml'])
comet_api = comet_ml.api.API(api_key=comet_config.apikey)
experiment = comet_api.get(project_name=comet_config.projectname, workspace=comet_config.workspace,
                           experiment=experiment)

In [6]:
dataset = experiment.get_parameters_summary('data_set')['valueCurrent']
kfold_split_id = list(
    filter(lambda x: x['fileName'] == 'kfold_split_indices.csv', experiment.get_asset_list())
)[0]['assetId']
kfold_split_binary = experiment.get_asset(kfold_split_id, return_type="binary")
kfold_split = pd.read_csv(BytesIO(kfold_split_binary)).iloc[fold]

train_index = literal_eval(kfold_split['train_indices'])
val_index = literal_eval(kfold_split['val_indices'])

df_dataset = pd.read_csv(f'data/{dataset}/data.csv')
train_df, valid_df = df_dataset.iloc[train_index], df_dataset.iloc[val_index]

TEXT, LABEL, train_loader, val_loader = get_dataset(train_df, valid_df, batch_size=1, cache=None)
model = ProtoConvLitModule.load_from_checkpoint('checkpoints/'+weights_path, vocab_size=len(TEXT.vocab),
                                                embedding_dim=TEXT.vocab.vectors.shape[1], lr=1, fold_id=fold)

  torch.nn.init.xavier_uniform(self.prototypes.data)


In [7]:
n_prototypes = model.prototype.prototypes.shape[0]
n_prototypes

16

In [8]:
from collections import namedtuple

class PrototypeExample:    
    def __init__(self, best_distance, distances, X):
        self.best_distance=best_distance
        self.distances=distances
        self.similarity=-best_distance
        self.X=X
    
    def __lt__(self, other):
        return self.similarity < other.similarity # max-heap
    
    def __repr__(self):
        return f'PE(dist:{self.best_distance})'

In [9]:
import torch.nn.functional as F
import heapq

model.eval()
model.cuda()

k=3
heaps = [[] for _ in range(n_prototypes)]

with torch.no_grad():
    for X, batch_id in train_loader:
        temp = model.embedding(X).permute((0, 2, 1))
        temp = model.conv1(temp)
        temp = model.prototype(temp) 
        prototypes_distances = temp.squeeze(0).detach().cpu().numpy() # [Prototype, len(X)-4]
        temp = -F.max_pool1d(-temp, temp.size(2))
        best_distances = temp.view(temp.size(0), -1).squeeze(0).detach().cpu().numpy() # [Prototype]
        
        for i, (best_dist, dists_in_prot) in enumerate(zip(best_distances, prototypes_distances)):
            prototype_repr = PrototypeExample(best_dist, dists_in_prot, X[0].detach().cpu().numpy())
            if len(heaps[i]) < k:
                heapq.heappush(heaps[i], prototype_repr)
            else:
                heapq.heappushpop(heaps[i], prototype_repr)



In [10]:
import html
from IPython.core.display import display, HTML

def html_escape(text):
    return html.escape(text)

In [11]:
import numpy as np

for i in range(0, len(heaps)):
    print('-'*10, i,'-'*10)
    for p in heapq.nlargest(3, heaps[i]):
        print('Best distance:', p.best_distance)
        words = [TEXT.vocab.itos[j] for j in list(p.X)]
        
        max_d = max(p.distances)
        min_d = min(p.distances)
        weights = [max_d] + list(p.distances) +[max_d]  
        
    
        weights = list(1-(np.array(weights) - min_d)/(max_d-min_d))
        highlighted_text = []
        for word, weight in zip(words, weights):
            highlighted_text.append(f'<span style="background-color:rgba(135,206,250,{str(weight)});">' + html_escape(word) + '</span>')


        highlighted_text = ' '.join(highlighted_text)
        display(HTML(highlighted_text))

---------- 0 ----------
Best distance: 0.540432


Best distance: 0.5565988


Best distance: 0.5565988


---------- 1 ----------
Best distance: 1.5243436


Best distance: 1.6273602


Best distance: 1.674057


---------- 2 ----------
Best distance: 1.0624387


Best distance: 1.227362


Best distance: 1.315919


---------- 3 ----------
Best distance: 0.54540205


Best distance: 0.5818137


Best distance: 0.5837115


---------- 4 ----------
Best distance: 0.3842943


Best distance: 0.43312174


Best distance: 0.4380337


---------- 5 ----------
Best distance: 1.627188


Best distance: 1.627188


Best distance: 1.627188


---------- 6 ----------
Best distance: 0.46669734


Best distance: 0.5580617


Best distance: 0.56232154


---------- 7 ----------
Best distance: 0.45444566


Best distance: 0.47361863


Best distance: 0.47713345


---------- 8 ----------
Best distance: 0.6177577


Best distance: 0.6177577


Best distance: 0.6177577


---------- 9 ----------
Best distance: 2.232697


Best distance: 2.232697


Best distance: 2.232697


---------- 10 ----------
Best distance: 0.82497096


Best distance: 0.8676913


Best distance: 0.87733126


---------- 11 ----------
Best distance: 0.58848876


Best distance: 0.62661844


Best distance: 0.64970315


---------- 12 ----------
Best distance: 1.0794783


Best distance: 1.1237422


Best distance: 1.5306046


---------- 13 ----------
Best distance: 1.2296857


Best distance: 1.3217659


Best distance: 1.3929272


---------- 14 ----------
Best distance: 1.1317965


Best distance: 1.2983837


Best distance: 1.3035163


---------- 15 ----------
Best distance: 1.0656776


Best distance: 1.068504


Best distance: 1.1170831
