In [1]:
from pathlib import Path

import os
import sys

current_path = Path(os.getcwd())
sys.path.append(str(current_path.parent))

In [2]:
import vectorian

In [40]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import vectorian.alignment

class FineTuneableWidget:
    def __init__(self, fix_to=None):
        self._type = widgets.Dropdown(
            options=[x[0] for x in self._types],
            value=self._default if fix_to is None else fix_to,
            description=self._description,
            disabled=fix_to is not None)

        self._instantiate_fine_tune(self._type.value)
        
        self._type.observe(self.on_changed, names='value')

        self._vbox = widgets.VBox([self._type, self._fine_tune.widget])
        
    def _instantiate_fine_tune(self, name):
        i = [x[0] for x in self._types].index(name)
        self._fine_tune = self._types[i][1]()
        
    def on_changed(self, change):
        self._instantiate_fine_tune(change.new)
        self._vbox.children = [self._type, self._fine_tune.widget]

    @property
    def widget(self):
        return self._vbox
    
class CosineMetricWidget:
    def __init__(self):
        self._vbox = widgets.VBox([])

    @property
    def widget(self):
        return self._vbox

class ImprovedSqrtCosineMetricWidget:
    def __init__(self):
        self._vbox = widgets.VBox([])

    @property
    def widget(self):
        return self._vbox

class PNormWidget:
    def __init__(self):
        self._p = widgets.BoundedFloatText(
            value=2,
            min=1e-4,
            max=10,
            step=0.25,
            description='p:',
            disabled=False)

        self._scale = widgets.BoundedFloatText(
            value=2,
            min=1e-4,
            max=1000,
            step=0.25,
            description='Scale:',
            disabled=False)
        
        self._vbox = widgets.VBox([self._p, self._scale])

    @property
    def widget(self):
        return self._vbox    

    
class VectorMetricWidget(FineTuneableWidget):
    _description = ''
    
    _types = [
        ('Cosine', CosineMetricWidget),
        ('P-Norm', PNormWidget),
        ('Improved Sqrt Cosine', ImprovedSqrtCosineMetricWidget)
    ]
    
    _default = 'Cosine'
    
    
class EmbeddingWidget:
    def __init__(self):
        self._widget = widgets.Dropdown(
            options=[
                'fasttext-en',
                'mixed'],
            value='fasttext-en',
            description='',
            disabled=False,
        )

    @property
    def widget(self):
        return self._widget
        

class TokenMetricWidget:
    def __init__(self):
        self._metric = VectorMetricWidget()
        
        self._embedding = EmbeddingWidget()
   
        self._hbox = widgets.HBox([
            widgets.Label('Token Metric:'),
            self._metric.widget,
            widgets.Label('on'),
            self._embedding.widget])

    @property
    def widget(self):
        return self._hbox
    


class SlidingGapCostWidget:
    def __init__(self, description, construct, max=1.0):
        self._construct = construct

        self._cost = widgets.FloatSlider(
            value=0,
            min=0,
            max=max,
            step=0.01,
            description=description,
            disabled=False)
        
        '''
        self._cost = widgets.BoundedFloatText(
            value=0,
            min=0,
            max=max,
            step=0.1,
            description=description,
            disabled=False)
        '''

        self._plot = widgets.Image(
            value=b'',
            format='png',
            width=300,
            height=400,
        )

        self.update_plot()
        self._cost.observe(self.on_changed, names='value')
        
        self._vbox = widgets.VBox([self._cost, self._plot])
     
    def make(self):
        return self._construct(self._cost.value)
    
    def update_plot(self):
        fig, ax = plt.subplots(1, 1, figsize=(5, 2))
        im_data = self.make().plot_to_image(
            fig, ax, 20, format='png')
        plt.close()
        self._plot.value = im_data

    def on_changed(self, change):
        # cost = change.new
        self.update_plot()
        
    @property
    def widget(self):
        return self._vbox
    
class ConstantGapCostWidget(SlidingGapCostWidget):
    def __init__(self):
        super().__init__('Cost:', vectorian.alignment.ConstantGapCost)

class LinearGapCostWidget(SlidingGapCostWidget):
    def __init__(self):
        super().__init__('Cost:', vectorian.alignment.LinearGapCost)

class ExponentialGapCostWidget(SlidingGapCostWidget):
    def __init__(self):
        super().__init__('Cutoff:', vectorian.alignment.ExponentialGapCost, max=20)
                
class GapCostWidget(FineTuneableWidget):
    _description = 'Gap Type:'
    
    _types = [
        ('Constant', ConstantGapCostWidget),
        ('Linear', LinearGapCostWidget),
        ('Exponential', ExponentialGapCostWidget)
    ]
    
    _default = 'Linear'
    
class AlignmentAlgorithmWidget:
    def __init__(self, parameters):
        self._token_metric = TokenMetricWidget()
        
        if parameters is None:
            self._vbox = widgets.VBox([
                self._token_metric.widget])
        else:
            self._vbox = widgets.VBox([
                self._token_metric.widget, parameters])
        
    @property
    def widget(self):
        return self._vbox
    
class NeedlemanWunschWidget(AlignmentAlgorithmWidget):
    def __init__(self):
        self._gap_cost = GapCostWidget(fix_to="Linear")
        super().__init__(self._gap_cost.widget)

class SmithWatermanWidget(AlignmentAlgorithmWidget):
    def __init__(self):
        self._gap_cost = GapCostWidget(fix_to="Linear")
        super().__init__(self._gap_cost.widget)

class WatermanSmithBeyerWidget(AlignmentAlgorithmWidget):
    def __init__(self):
        self._gap_cost = GapCostWidget()
        super().__init__(self._gap_cost.widget)

class WordMoversDistanceWidget(AlignmentAlgorithmWidget):
    _variants = [
        'wmd/kusner',
        'wmd/vectorian',
        'rwmd/kusner',
        'rwmd/jablonsky',
        'rwmd/vectorian'
    ]
    
    # WordMoversDistance.wmd('kusner')

    def __init__(self):
        self._variant = widgets.Dropdown(
            options=self._variants,
            value="rwmd/vectorian",
            description="Variant:",
            disabled=False)
        
        self._extra_mass_penalty = widgets.FloatText(
            value=-1,
            description='Extra Mass Penalty:',
            disabled=False
        )
        
        super().__init__(widgets.VBox([
            self._variant,
            self._extra_mass_penalty
        ]))

class WordRotatorsDistanceWidget(AlignmentAlgorithmWidget):
    def __init__(self):
        super().__init__(None)
        
    
class AlignmentWidget(FineTuneableWidget):
    _description = 'Alignment:'

    _types = [
        ('Needleman-Wunsch', NeedlemanWunschWidget),
        ('Smith-Waterman', SmithWatermanWidget),
        ('Waterman-Smith-Beyer', WatermanSmithBeyerWidget),
        ('Word Movers Distance', WordMoversDistanceWidget),
        ('Word Rotators Distance', WordRotatorsDistanceWidget)
    ]

    _default = 'Waterman-Smith-Beyer'
    
class TagWeightedAlignmentWidget():
    def __init__(self):
        self._pos_mismatch_penalty = widgets.FloatSlider(
            value=1,
            min=0,
            max=1,
            step=0.1,
            description='POS Mismatch Penalty:',
            disabled=False)
    
        self._tag_weights = widgets.Dropdown(
            options=['Off', 'POST STSS'],
            value='POST STSS',
            description='Tag Weights:',
            disabled=False)

        self._similarity_threshold = widgets.FloatSlider(
            value=0.2,
            min=0,
            max=1,
            step=0.1,
            description='Similarity Threshold:',
            disabled=False)
        
        self._alignment = AlignmentWidget()
        
        self._vbox = widgets.VBox([
            self._pos_mismatch_penalty,
            self._tag_weights,
            self._similarity_threshold,
            self._alignment.widget
        ])

    @property
    def widget(self):
        return self._vbox
        
    
class SentenceEmbeddingWidget:
    # https://docs.google.com/spreadsheets/d/14QplCdTCDwEmTqrn1LH4yrbKvdogK4oQvYO1K1aPR5M/edit#gid=0
    _variants = [
        'stsb-roberta-large',
        'stsb-roberta-base',
        'stsb-bert-large',
        'stsb-distilbert-base'
    ]

    def __init__(self):
        self._widget = widgets.Dropdown(
            options=self._variants,
            value='stsb-distilbert-base',
            description='Model:',
            disabled=False)

    @property
    def widget(self):
        return self._widget
    

class SentenceMetricWidget(FineTuneableWidget):
    _description = 'Sentence Metric:'

    _types = [
        ('Alignment', AlignmentWidget),
        ('Tag-Weighted Alignment', TagWeightedAlignmentWidget),
        ('Sentence Embedding', SentenceEmbeddingWidget)
    ]
    
    _default = 'Alignment'

class PartitionWidget:
    def __init__(self):
        self._level = widgets.Dropdown(
            options=['sentence', 'token'],
            value='sentence',
            description='Partition:',
            disabled=False)
        
        self._window_size = widgets.BoundedIntText(
            value=1,
            min=1,
            max=1000,
            step=1,
            description='Window Size:',
            disabled=False)

        self._window_step = widgets.BoundedIntText(
            value=1,
            min=1,
            max=1000,
            step=1,
            description='Window Step:',
            disabled=False)
        
        self._hbox = widgets.HBox([
            self._level,
            self._window_size,
            self._window_step
        ])
        
    @property
    def widget(self):
        return self._hbox
    
class QueryWidget:
    def __init__(self):
        self._partition = PartitionWidget()
        self._sentence = SentenceMetricWidget()
        
        self._vbox = widgets.VBox([
            self._partition.widget,
            self._sentence.widget])
    
    @property
    def widget(self):
        return self._vbox

    
w = QueryWidget()
w.widget

VBox(children=(HBox(children=(Dropdown(description='Partition:', options=('sentence', 'token'), value='sentencâ€¦