In [92]:
from typing import Iterator, List, Dict, Optional, cast
import torch
import torch.optim as optim
from torch.nn import MSELoss
from torch.nn import functional as F
from torch.nn import ModuleList

import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, LabelField, SequenceLabelField, ArrayField, MetadataField, ListField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.batch import Batch
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding, TokenEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_embedder import PretrainedTransformerEmbedder
from allennlp.modules.token_embedders.pretrained_transformer_mismatched_embedder import PretrainedTransformerMismatchedEmbedder
# from allennlp.modules.seq2seq_encoders.multi_head_self_attention import MultiHeadSelfAttention
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.modules.attention import Attention
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention
from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention
from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention

from allennlp.modules.conditional_random_field import allowed_transitions, ConditionalRandomField

from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits, \
    get_device_of, masked_softmax, weighted_sum, \
    get_mask_from_sequence_lengths, get_lengths_from_binary_sequence_mask, tensors_equal

from allennlp.training.metrics import BooleanAccuracy, CategoricalAccuracy, MeanAbsoluteError, Average
from allennlp.data.samplers import BucketBatchSampler
from allennlp.data.dataloader import DataLoader, PyTorchDataLoader
from allennlp.training.trainer import GradientDescentTrainer
# from allennlp.predictors import Predictor, Seq2SeqPredictor, SimpleSeq2SeqPredictor, SentenceTaggerPredictor
from allennlp.predictors import Predictor, SentenceTaggerPredictor
from allennlp.nn.activations import Activation
from allennlp.common.tqdm import Tqdm
from allennlp.common.params import Params
from allennlp.common.util import JsonDict, sanitize
from allennlp.common.util import START_SYMBOL, END_SYMBOL

from allennlp_models.generation.predictors import Seq2SeqPredictor
from allennlp_models.generation.models.simple_seq2seq import SimpleSeq2Seq
from allennlp_models.generation.modules.seq_decoders.seq_decoder import SeqDecoder
from allennlp_models.generation.modules.decoder_nets.decoder_net import DecoderNet


# from spacy.tokenizer import Tokenizer as SpacyTokenizer
# from spacy.lang.en import English
# nlp = English()
# Create a blank Tokenizer with just the English vocab
# tokenizer = Tokenizer(nlp.vocab)

from tqdm.notebook import tqdm

from pyAudioAnalysis import audioBasicIO
from pyAudioAnalysis import ShortTermFeatures

import os, sys
import itertools
import json
from collections import defaultdict
from inspect import signature
import warnings
import pickle
from copy import copy, deepcopy
from overrides import overrides
import importlib
import string
import re
import matplotlib.pyplot as plt
import editdistance
from copy import copy, deepcopy
import random

from transformers import BertPreTrainedModel, BertModel, BertConfig, BertTokenizer

from utils.spider import process_sql, evaluation
from utils.schema_gnn.spider_utils import Table, TableColumn, read_dataset_schema
from utils.misc_utils import Load_CMU_Dict, WordPronDist, WordPronSimilarity, ConstructWordSimMatrix

import dataset_readers
import models
import predictors

from dataset_readers.reader_utils import extractAudioFeatures, extractAudioFeatures_NoPooling, \
    extractRawAudios, extractAudioFeatures_NoPooling_Wav2vec, \
    dbToTokens, dbToTokensWithColumnIndexes, dbToTokensWithAddCells, \
    read_DB, Get_align_tags, load_DB_content, collect_DB_toks_dict, text_cell_to_toks
# from modules.encoder import SpeakQLEncoder, SpeakQLEncoderV1, SpeakQLEncoder_Gated_Fusion
from modules.encoder import SpeakQLEncoder, SpeakQLEncoderV1
# from models.reranker import SpiderASRRerankerV0, SpiderASRRerankerV1, SpiderASRRerankerV2, SpiderASRReranker_Siamese
# from predictors.reranker_predictor import SpiderASRRerankerPredictor, SpiderASRRerankerPredictor_Siamese

from dataset_readers import SpiderASRRerankerReaderV2_Siamese_Combined
from dataset_readers import SpiderASRRewriterReader_Seq2seq_Combined, \
    SpiderASRRewriterReader_Tagger_Combined, SpiderASRRewriterReader_ILM_Combined
from models import SpiderASRRewriter_Tagger_Combined, SpiderASRRewriter_ILM_Combined, \
    SpiderASRRewriter_Seq2seq_Combined, \
    SpiderASRRewriter_Tagger_Combined_new, SpiderASRRewriter_ILM_Combined_new
from predictors import SpiderASRRewriterPredictor_Tagger, SpiderASRRewriterPredictor_ILM, SpiderASRRewriterPredictor_Seq2seq

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import pairwise

import spacy

from nltk.stem.porter import PorterStemmer

import pandas as pd

torch.manual_seed(1)

<torch._C.Generator at 0x112615770>

In [2]:
nlp = spacy.load('en_core_web_sm')

In [3]:
# importlib.reload(dataset_readers)
# from dataset_readers import SpiderASRRewriterReader_ILM_Combined

### Loading dataset_reader and model

In [62]:
# don't need to change config version here, if dataset reader is unchanged 
full_config = Params.from_file('train_configs/rewriter_2.29.0.0i.jsonnet')

In [63]:
dsreader_config = deepcopy(full_config['dataset_reader'])
dsreader_config['databases_dir'] = "/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/database"
dsreader_config['dataset_dir'] = "/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my"
dsreader_config['tables_json_fname'] = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/tables.json'
dsreader_config['tabert_model_path'] = '/Users/mac/Desktop/syt/Deep-Learning/Repos/TaBERT/pretrained-models/tabert_base_k1/model.bin'
dsreader_config['pronun_dict_path'] = "/Users/mac/Desktop/syt/Deep-Learning/Dataset/CMUdict/cmudict-0.7b.txt"
dsreader_config['db_tok2phs_dict_path'] = "/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my/db/db_tok2phs.json"
dsreader_config['aux_probes'] = {
    'utter_mention_schema': True,
    'schema_dir_mentioned': True,
    'schema_indir_mentioned': True,
}
dsreader_config.as_dict()


{'databases_dir': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/database',
 'dataset_dir': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my',
 'db_cells_in_bracket': True,
 'db_tok2phs_dict_path': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my/db/db_tok2phs.json',
 'include_align_tags': True,
 'include_gold_rewrite_seq': True,
 'max_sequence_len': 300,
 'ph_token_indexers': {'phonemes': {'namespace': 'phonemes',
   'type': 'single_id'}},
 'pronun_dict_path': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/CMUdict/cmudict-0.7b.txt',
 'specify_full_path': False,
 'src_token_indexers': {'bert': {'model_name': 'facebook/bart-base',
   'type': 'pretrained_transformer_mismatched'},
  'char': {'min_padding_length': 5,
   'namespace': 'token_characters',
   'type': 'characters'}},
 'tabert_model_path': '/Users/mac/Desktop/syt/Deep-Learning/Repos/TaBERT/pretrained-models/tabert_base_k1/model.bin',
 'tables_json_fname': '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/ta

In [64]:
## (empirical memory concern) use ~1000 samples, will give ~5000 datapoints for probing test for R1 
dsreader_config['cands_limit'] = 2

In [65]:
test_dataset_reader = DatasetReader.from_params(deepcopy(dsreader_config))

Loading CMU Dict done: 133854 entries, 125074 words, 113745 prons
Loading db_tok2phs Dict done: 102186 entries, 102186 words, 91693 prons
Joint word2pron size = 202111
[('hellenizing', {('HH', 'EH', 'L', 'AH', 'N', 'AY', 'Z', 'IH', 'NG')}), ('kinsella', {('K', 'IY', 'N', 'S', 'EH', 'L', 'AH')}), ('matsch', {('M', 'AE', 'CH')}), ('odonnel', {('OW', 'D', 'AA', 'N', 'AH', 'L')}), ('quads', {('K', 'W', 'AA', 'D', 'Z')}), ("sears'", {('S', 'IH', 'R', 'Z')})]


In [66]:
test_dataset = test_dataset_reader.read('test')
len(test_dataset)

HBox(children=(IntProgress(value=1, bar_style='info', description='reading instances', max=1, style=ProgressSt…

Loading literals failed: wta_1::players
['first_name', 'last_name', 'hand', 'country_code']
Could not decode to UTF-8 column 'last_name' with text 'Treyes Albarrac��N'
Question OOV: 65, [('doch', ['D', 'AA', 'K']), ('youll', ['Y', 'AW', 'L']), ('republik', ['R', 'IH', 'P', 'AH', 'B', 'L', 'IH', 'K']), ('everdeen', ['EH', 'V', 'ER', 'D', 'IY', 'N']), ('cdo', ['S', 'IY', 'D', 'IY', 'OW']), ('teoh', ['T', 'IH', 'OW']), ('roomba', ['R', 'UW', 'M', 'B', 'AX']), ('citis', ['S', 'AY', 'DX', 'IH', 'S']), ('zahren', ['Z', 'AA', 'R', 'AX', 'N']), ('lexx', ['L', 'EH', 'K', 'S'])]





1082

In [67]:
# limit train set size to fit into memory; 7000 samples, ~35000 datapoints for R1 
dsreader_config['cands_limit'] = 1
train_dataset_reader = DatasetReader.from_params(deepcopy(dsreader_config))

Loading CMU Dict done: 133854 entries, 125074 words, 113745 prons
Loading db_tok2phs Dict done: 102186 entries, 102186 words, 91693 prons
Joint word2pron size = 202111
[('hellenizing', {('HH', 'EH', 'L', 'AH', 'N', 'AY', 'Z', 'IH', 'NG')}), ('kinsella', {('K', 'IY', 'N', 'S', 'EH', 'L', 'AH')}), ('matsch', {('M', 'AE', 'CH')}), ('odonnel', {('OW', 'D', 'AA', 'N', 'AH', 'L')}), ('quads', {('K', 'W', 'AA', 'D', 'Z')}), ("sears'", {('S', 'IH', 'R', 'Z')})]


In [68]:
# Try to set random seed to fix this sampling 
random.seed(127)

In [69]:
train_dataset = train_dataset_reader.read('train')
len(train_dataset)

HBox(children=(IntProgress(value=1, bar_style='info', description='reading instances', max=1, style=ProgressSt…

Loading literals failed: wta_1::players
['first_name', 'last_name', 'hand', 'country_code']
Could not decode to UTF-8 column 'last_name' with text 'Treyes Albarrac��N'
Question OOV: 587, [('sendin', ['S', 'EH', 'N', 'D', 'IH', 'N']), ('fea', ['F', 'IY']), ('conserves', ['K', 'AX', 'N', 'S', 'ER', 'V', 'Z']), ('nancie', ['N', 'AE', 'N', 'S', 'IY']), ('jins', ['JH', 'IH', 'N', 'Z']), ('alis', ['AA', 'L', 'IY', 'Z']), ('roselyn', ['R', 'OW', 'Z', 'L', 'IH', 'N']), ('louiss', ['L', 'UW', 'IH', 'S']), ('statuses', ['S', 'T', 'DX', 'AX', 'S', 'Z']), ('illy', ['IH', 'L', 'IY'])]





7000

In [12]:
# original_dev_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/dev.json'
# with open(original_dev_path, 'r') as f:
#     original_dev_dataset = json.load(f)
# len(original_dev_dataset)

In [13]:
full_train_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my/train/train_rewriter+phonemes.json'
with open(full_train_path, 'r') as f:
    full_train_dataset = json.load(f)
len(full_train_dataset)

7000

In [14]:
full_dev_path = '/Users/mac/Desktop/syt/Deep-Learning/Dataset/spider/my/dev/dev_rewriter(full)+phonemes.json'
with open(full_dev_path, 'r') as f:
    full_dev_dataset = json.load(f)
len(full_dev_dataset)

1034

In [15]:
full_dev_dataset[0][0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql', 'span_ranges', 'original_id', 'ratsql_pred_sql', 'gold_question', 'gold_question_toks', 'ratsql_pred_exact', 'ratsql_pred_score', 'question_toks_edit_distance', 'alignment_span_pairs', 'alignment_text_pairs', 'rewriter_tags', 'rewriter_edits', 'token_phonemes', 'token_phoneme_spans'])

In [104]:
_test_instance = test_dataset[100]
_test_instance.fields

{'sentence': <allennlp.data.fields.text_field.TextField at 0x152908410>,
 'text_mask': <allennlp.data.fields.array_field.ArrayField at 0x152908460>,
 'schema_mask': <allennlp.data.fields.array_field.ArrayField at 0x1529084b0>,
 'schema_column_ids': <allennlp.data.fields.array_field.ArrayField at 0x152908550>,
 'audio_feats': <allennlp.data.fields.list_field.ListField at 0x1528f3990>,
 'audio_mask': <allennlp.data.fields.array_field.ArrayField at 0x152908500>,
 'phoneme_multilabels': <allennlp.data.fields.list_field.ListField at 0x1528eff50>,
 'phoneme_labels': <allennlp.data.fields.list_field.ListField at 0x152903e10>,
 'phoneme_label_mask': <allennlp.data.fields.list_field.ListField at 0x152903e90>,
 'utter_mention_schema_labels': <allennlp.data.fields.array_field.ArrayField at 0x152916c30>,
 'schema_dir_mentioned_labels': <allennlp.data.fields.array_field.ArrayField at 0x152916c80>,
 'schema_indir_mentioned_labels': <allennlp.data.fields.array_field.ArrayField at 0x152916cd0>,
 'alig

In [105]:
list(_test_instance.fields['metadata'].keys())

['original_id',
 'text_len',
 'schema_len',
 'concat_len',
 'text_tokens',
 'schema_tokens',
 'concat_tokens',
 'source_tokens',
 'target_tokens',
 'rewrite_seq_len']

In [106]:
_test_instance.fields['target_token_ids'].array

array([41,  8, 42, 43])

In [110]:
_test_instance.fields['metadata']['target_tokens']

[airlines, [ANS]]

In [None]:
_test_instance.fields['metadata']['source_tokens']

In [111]:
len(_test_instance.fields['metadata']['source_tokens']), _test_instance.fields['source_token_ids'].array.shape

(68, (68,))

In [75]:
MODEL_VER = "2.29.0.0i"

# tagger_ILM_model = Model.from_archive('runs/2.0.1/model.tar.gz')
ILM_model = Model.from_archive(f'runs/{MODEL_VER}/model.tar.gz')

self._start_index: 3, @start@
self._end_index: 4, @end@
self._pad_index: 0, @@PADDING@@


In [76]:
# Construct predictor 
# Just using train_reader; shouldn't have problem since the dataset_reader here is not really used in this code 
predictor = SpiderASRRewriterPredictor_ILM(model=ILM_model,
                                           dataset_reader=train_dataset_reader)
predictor.set_save_intermediate(True)

In [77]:
# test code 
predictor_output = predictor.predict_instance(_test_instance)

In [78]:
predictor_output.keys()

dict_keys(['question', 'original_id', 'rewriter_tags', 'align_tags', 'rewrite_seq_prediction', 'rewrite_seq_prediction_cands', 'rewrite_seq_NLL', 'rewrite_seq_prediction_intermediates'])

In [79]:
predictor_output['rewrite_seq_prediction_intermediates']['encoder'].keys()

dict_keys(['phoneme_attention_map_0', 'phoneme_attention_out_0', 'phoneme_attention_out_with_residual_0', 'encoder_representation_0'])

In [80]:
predictor_output['rewrite_seq_prediction_intermediates']['rewriter_main'].keys()
# "tabert_embedding" actually means token embedding, not necessarily tabert 

dict_keys(['word_embeddings', 'audio_feats_encoded', 'phoneme_tag_embeddings', 'encoder_seq_representation_with_tag'])

In [81]:
np.array(predictor_output['rewrite_seq_prediction_intermediates']['decoder']['state_decoder_hidden:step_0']).shape


(1, 256)

In [None]:
# tokens_repr_encoded = predictor_output['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag']
# np.array(tokens_repr_encoded).shape

In [None]:
# audio_feats_encoded = predictor_output['rewrite_seq_prediction_intermediates']['rewriter_main']['audio_feats_encoded']
# np.array(audio_feats_encoded).shape


In [None]:
# _batch = Batch([_test_instance]).as_tensor_dict()
# _text_mask = _batch['text_mask']
# _schema_mask = _batch['schema_mask']
# _text_mask, _schema_mask

#### Full prediction (can't do)
- taking too much memory

In [24]:
# all_pred_outputs = []
# for _inst in tqdm(test_dataset):
#     predictor_output = predictor.predict_instance(_inst)
#     all_pred_outputs.append(predictor_output)

# print(len(all_pred_outputs))
# print(list(all_pred_outputs[0].keys()))

### Functions

In [30]:
def eval_results(clf, X, y, res_dict=None):
    p = clf.predict_proba(X)
    preds = [np.argmax(_p) for _p in p]

    corr = 0
    corr_pos = 0
    all_pred_pos = 0
    all_true_pos = 0
    for _pred, _y in zip(preds, y):
        corr += (_pred == _y)
        corr_pos += (_pred > 0 and _y > 0 and _pred == _y)
        all_pred_pos += (_pred > 0)
        all_true_pos += (_y > 0)

    acc = corr / len(y)
    prec = corr_pos / (all_pred_pos + 1e-9)
    rec = corr_pos / (all_true_pos + 1e-9)
    f1 = 2 * prec * rec / (prec + rec + 1e-9)

    pos_prop = sum(y) / len(y)

    print(f'Positive prop: {pos_prop:.4f}')
    print(f'Correct positives: {corr_pos}')
    print(f'All pred positives: {all_pred_pos}')
    print(f'All true positives: {all_true_pos}')
    print(f'Accuracy: {acc:.4f}')
    print(f'Precision: {prec:.4f}')
    print(f'Recall: {rec:.4f}')
    print(f'F1: {f1:.4f}')
    
    if res_dict is not None:
        assert isinstance(res_dict, dict)
        res_dict['pos_prop'] = pos_prop
        res_dict['corr_pos'] = corr_pos
        res_dict['all_pred_pos'] = all_pred_pos
        res_dict['all_true_pos'] = all_true_pos
        res_dict['acc'] = acc
        res_dict['prec'] = prec
        res_dict['rec'] = rec
        res_dict['f1'] = f1

In [31]:
def test_error_analysis_utter(clf, X, y, err_output_path=None):
    p = clf.predict_proba(X)
    preds = [np.argmax(_p) for _p in p]

    incorr_indices = []
    for i, (_p, _x, _y) in enumerate(zip(p, X, y)):
        if _y != np.argmax(_p):
            incorr_indices.append(i)
    print(len(incorr_indices))

    _ptr = 0
    err_output = ''
    for o_id, _insts in test_o_id2instances.items():
    #     _cands = full_dev_dataset[o_id]
        _sample_incorr = False
        for _inst in _insts:
            _incorr = False
            _marked_toks = []
            for _tok in _inst['metadata']['text_tokens']:
                if _ptr in incorr_indices:
                    _marked_toks.append(f'<{_tok}:{p[_ptr][1]:.2f}>')
                    _incorr = True
                    _sample_incorr = True
                else:
                    _marked_toks.append(_tok)

                _ptr += 1

            if _incorr:
                err_output += ' '.join(_marked_toks) + '\n'

        if _sample_incorr:
            err_output += ' '.join(_insts[0]['metadata']['schema_tokens']) + '\n\n'

    if err_output_path is None:
        print(err_output)
    else:
        with open(err_output_path, 'w') as f:
            f.write(err_output)
    

In [32]:
def test_error_analysis_schema(clf, X, y, err_output_path=None):
    p = clf.predict_proba(X)
    preds = [np.argmax(_p) for _p in p]

    incorr_indices = []
    for i, (_p, _x, _y) in enumerate(zip(p, X, y)):
        if _y != np.argmax(_p):
            incorr_indices.append(i)
    print(len(incorr_indices))

    _ptr = 0
    err_output = ''
    for o_id, _insts in test_o_id2instances.items():
    #     _cands = full_dev_dataset[o_id]
        _skip = False
        for _inst in _insts:
            _incorr = False
            _marked_toks = []
            for _tok in _inst['metadata']['schema_tokens']:
                if _skip:
                    pass
                elif _ptr in incorr_indices:
                    _marked_toks.append(f'<{_tok}:{p[_ptr][1]:.2f}>')
                    _incorr = True
                else:
                    _marked_toks.append(_tok)

                _ptr += 1

            if _incorr:
                _skip = True   # when find 1 wrong cand, skip others 
                err_output += f'({o_id}) ' + \
                      ' '.join(_inst['metadata']['text_tokens']) + ' || ' + \
                      ' '.join(_marked_toks) + '\n'

    if err_output_path is None:
        print(err_output)
    else:
        with open(err_output_path, 'w') as f:
            f.write(err_output)

### Gather o_id2instances

In [33]:
# Gather instances by o_id 
train_o_id2instances = defaultdict(list)
test_o_id2instances = defaultdict(list)
# o_id2preds = defaultdict(list)

for _inst in train_dataset:
    o_id = _inst['metadata']['original_id']
    train_o_id2instances[o_id].append(_inst)
for _inst in test_dataset:
    o_id = _inst['metadata']['original_id']
    test_o_id2instances[o_id].append(_inst)

In [34]:
len(train_o_id2instances.keys()), len(test_o_id2instances.keys())

(20, 20)

### Task A - token correctness (utterence-only)
- A.1: getting labels by gold rewriter tags, KEEP = 0, DEL = 1, EDIT = 2
<!-- - A.2: getting labels by checking the alignment, if the token exists in counterpart then it’s correct, otherwise incorrect (doesn't seem so necessary) -->
- Results (2.12.3.3): 
    - pos% = 0.3533
    - Train F1 = 1.0000
    - Test F1 = 0.9996
- Results (2.18.2.2): 
    - pos% = 0.3480
    - Train F1 = 1.0000
    - Test F1 = 0.9969

In [30]:
list(_test_instance['metadata'].keys())

['original_id',
 'text_len',
 'schema_len',
 'text_tokens',
 'schema_tokens',
 'rewrite_seq_len']

#### Gather data

In [31]:
# for o_id, _insts in test_o_id2instances.items():
#     _cands = full_dev_dataset[o_id]
#     _preds = o_id2preds[o_id]
#     assert len(_insts) == len(_cands) == len(_preds)
#     for _inst, _cand, _pred in zip(_insts, _cands, _preds):
#         assert ' '.join(_inst['metadata']['text_tokens']) == ' '.join(_cand['question_toks']) == _pred['question']

In [32]:
# Gather data 
train_samples_A = []
test_samples_A = []

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
#     _preds = o_id2preds[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _utter_encodings = _token_encodings[:_text_len]
        
        # _labels = [1 if _t == 'O-KEEP' else 0 for _t in _cand['rewriter_tags']]
        _labels = []
        for _t in _cand['rewriter_tags']:
            if _t == 'O-KEEP':
                _labels.append(0)
            elif _t.endswith('DEL'):
                _labels.append(1)
            elif _t.endswith('EDIT'):
                _labels.append(2)
            else:
                raise ValueError(_t)
        
        train_samples_A.extend(list(zip(_utter_encodings, _labels)))


for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
#     _preds = o_id2preds[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _utter_encodings = _token_encodings[:_text_len]
        
        # _labels = [1 if _t == 'O-KEEP' else 0 for _t in _cand['rewriter_tags']]
        _labels = []
        for _t in _cand['rewriter_tags']:
            if _t == 'O-KEEP':
                _labels.append(0)
            elif _t.endswith('DEL'):
                _labels.append(1)
            elif _t.endswith('EDIT'):
                _labels.append(2)
            else:
                raise ValueError(_t)
        
        test_samples_A.extend(list(zip(_utter_encodings, _labels)))

    

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [33]:
len(train_samples_A), len(test_samples_A)

(85292, 41776)

#### Train probes

In [34]:
clf = LogisticRegression(tol=0.0001, C=100.0)

train_X = [s[0] for s in train_samples_A]
train_y = [s[1] for s in train_samples_A]
test_X = [s[0] for s in test_samples_A]
test_y = [s[1] for s in test_samples_A]

In [35]:
clf.fit(train_X, train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression(C=100.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)

In [36]:
eval_results(clf, train_X, train_y)

Positive prop: 0.3480
Correct positives: 15972
All pred positives: 15972
All true positives: 15972
Accuracy: 1.0000
Precision: 1.0000
Recall: 1.0000
F1: 1.0000


In [37]:
eval_results(clf, test_X, test_y)

Positive prop: 0.3500
Correct positives: 7814
All pred positives: 7881
All true positives: 7886
Accuracy: 0.9983
Precision: 0.9915
Recall: 0.9909
F1: 0.9912


### Task B - token mentioning schema (utterence-only)
- getting labels by lemma match with schema toks
- Result (2.12.3.3):
    - pos% = 0.2636
    - F1 = 0.8105
- Result (2.18.2.2):
    - pos% = 0.2573
    - Train F1 = 0.8410
    - F1 = 0.7899
- Result (2.23.0.1) (added head):
    - pos% = 0.2573
    - Train F1 = 0.9792
    - F1 = 0.9223

In [38]:
# Gather data (X for schema toks (can use A), y for task-B)

train_labels_B = []
test_labels_B = []

stemmer = PorterStemmer()

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _metadata = _inst['metadata']
#         _text_len = _metadata['text_len']
#         _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
#         _utter_encodings = _token_encodings[:_text_len]
        
        _labels = []
        
        _utter_tokens_stem = [stemmer.stem(_t) for _t in _metadata['text_tokens']]
        _schema_tokens_stem = [stemmer.stem(_t) for _t in _metadata['schema_tokens'] if _t not in string.punctuation]
        for _ut in _utter_tokens_stem:
            if _ut in _schema_tokens_stem:
                _labels.append(1)
            else:
                _labels.append(0)
        
#         samples2.extend(list(zip(_utter_encodings, _labels)))
        train_labels_B.extend(_labels)

for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _metadata = _inst['metadata']
#         _text_len = _metadata['text_len']
#         _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
#         _utter_encodings = _token_encodings[:_text_len]
        
        _labels = []
        
        _utter_tokens_stem = [stemmer.stem(_t) for _t in _metadata['text_tokens']]
        _schema_tokens_stem = [stemmer.stem(_t) for _t in _metadata['schema_tokens'] if _t not in string.punctuation]
        for _ut in _utter_tokens_stem:
            if _ut in _schema_tokens_stem:
                _labels.append(1)
            else:
                _labels.append(0)
        
#         samples2.extend(list(zip(_utter_encodings, _labels)))
        test_labels_B.extend(_labels)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [39]:
len(train_labels_B), len(test_labels_B)

(85292, 41776)

In [40]:
# assert [s[0] for s in samples_B] == X

In [41]:
clf = LogisticRegression(tol=0.0001, C=100.0)

train_X = [s[0] for s in train_samples_A]
train_y = train_labels_B
test_X = [s[0] for s in test_samples_A]
test_y = test_labels_B

In [42]:
clf.fit(train_X, train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression(C=100.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)

In [43]:
eval_results(clf, train_X, train_y)

Positive prop: 0.2573
Correct positives: 21420
All pred positives: 21810
All true positives: 21942
Accuracy: 0.9893
Precision: 0.9821
Recall: 0.9762
F1: 0.9792


In [44]:
eval_results(clf, test_X, test_y)

Positive prop: 0.2656
Correct positives: 9892
All pred positives: 10357
All true positives: 11094
Accuracy: 0.9601
Precision: 0.9551
Recall: 0.8917
F1: 0.9223


In [45]:
# Error analysis
err_output_path = f"/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/archive_results/B-token_mentioning_schema-errors-{MODEL_VER}.txt"

test_error_analysis_utter(clf, test_X, test_y, err_output_path=err_output_path)

1667


### Task C - schema directly mentioned (schema-only)
- check lemma match with tokens
- Result (2.12.3.3):
    - pos% = 0.1158
    - F1 = 0.2677
- Result (2.18.2.2):
    - pos% = 0.1113
    - Train F1 = 0.3463
    - F1 = 0.2369
- Result (2.23.1.4) (added head):
    - pos% = 0.1113
    - Train F1 = 0.8145
    - F1 = 0.7462

In [85]:
# Gather data (X for schema toks, y for task-C)

train_samples_C = []
test_samples_C = []

stemmer = PorterStemmer()

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _schema_len = _metadata['schema_len']
        _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _schema_encodings = _token_encodings[_text_len + 1 : _text_len + 1 + _schema_len]
        
        _labels = []
        
        _utter_tokens_stem = [stemmer.stem(_t) for _t in _metadata['text_tokens']]
        _schema_tokens_stem = [stemmer.stem(_t) for _t in _metadata['schema_tokens']]
        for _ut in _schema_tokens_stem:
            if _ut in string.punctuation:
                _labels.append(0)
            elif _ut in _utter_tokens_stem:
                _labels.append(1)
            else:
                _labels.append(0)
        
        train_samples_C.extend(list(zip(_schema_encodings, _labels)))

for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _schema_len = _metadata['schema_len']
        _token_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _schema_encodings = _token_encodings[_text_len + 1 : _text_len + 1 + _schema_len]
        
        _labels = []
        
        _utter_tokens_stem = [stemmer.stem(_t) for _t in _metadata['text_tokens']]
        _schema_tokens_stem = [stemmer.stem(_t) for _t in _metadata['schema_tokens']]
        for _ut in _schema_tokens_stem:
            if _ut in string.punctuation:
                _labels.append(0)
            elif _ut in _utter_tokens_stem:
                _labels.append(1)
            else:
                _labels.append(0)
        
        test_samples_C.extend(list(zip(_schema_encodings, _labels)))
        

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [86]:
len(train_samples_C), len(test_samples_C)

(593194, 187337)

In [74]:
clf = LogisticRegression(tol=0.0001, C=100.0)

train_X = [s[0] for s in train_samples_C]
train_y = [s[1] for s in train_samples_C]
test_X = [s[0] for s in test_samples_C]
test_y = [s[1] for s in test_samples_C]

In [75]:
clf.fit(train_X, train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression(C=100.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)

In [76]:
print('-- Train --')
eval_results(clf, train_X, train_y)
print('-- Test --')
eval_results(clf, test_X, test_y)

-- Train --
Positive prop: 0.1113
Correct positives: 51256
All pred positives: 59818
All true positives: 66041
Accuracy: 0.9606
Precision: 0.8569
Recall: 0.7761
F1: 0.8145
-- Test --
Positive prop: 0.1342
Correct positives: 19635
All pred positives: 27490
All true positives: 25138
Accuracy: 0.9287
Precision: 0.7143
Recall: 0.7811
F1: 0.7462


In [77]:
# Error analysis
err_output_path = f"/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/archive_results/C-schema_directly_mentioned-errors-{MODEL_VER}.txt"

test_error_analysis_schema(clf, test_X, test_y, err_output_path=err_output_path)

13358


### Task D - schema implicitly mentioned (schema-only)
- check gold SQL
- Result (2.12.3.3):
    - pos% = 0.0740
    - F1 = 0.0321
- Result (2.18.2.2):
    - pos% = 0.0713
    - Train F1 = 0.0848
    - F1 = 0.0916
- Result (2.23.2.2) (added head):
    - pos% = 0.0713
    - Train F1 = 0.5722
    - F1 = 0.4628

In [87]:
# Gather data (X for schema toks (can use C), y for task-D)

# samples_D = []
train_labels_D = []
test_labels_D = []

stemmer = PorterStemmer()

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']

        schema_id2names = defaultdict(str)
        _tmp_ids = []
        _tmp_toks = []
        for i, _tok in enumerate(_metadata['schema_tokens']):
            if _tok in ',.:':
                # end of name 
                _name = '_'.join(_tmp_toks)
                for _idx in _tmp_ids:
                    schema_id2names[_idx] = _name
                
                _tmp_ids = []
                _tmp_toks = []
            else:
                _tmp_ids.append(i)
                _tmp_toks.append(_tok)
                
        # assert _tmp_ids == _tmp_toks == []    # Might not be true due to truncating 
        
        sql_schema_names = []
        for i, q_tok in enumerate(_cand['query_toks']):
            _toks = re.split(r'^[Tt]\d+\.', q_tok)
            if len(_toks) > 1:
                assert len(_toks) == 2 and _toks[0] == '', q_tok
            for _tok in _toks:
                if _tok.isupper() or _tok in string.punctuation:
                    continue
                assert ' ' not in _tok, q_tok
                sql_schema_names.append(_tok.lower()) # can have '_' in name 
        
        _labels = []
        for i in range(len(_metadata['schema_tokens'])):
            if schema_id2names[i] in sql_schema_names:
                _labels.append(1)
            else:
                _labels.append(0)
        
#         samples_D.extend(list(zip(_schema_encodings, _labels)))
        train_labels_D.extend(_labels)

for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']

        schema_id2names = defaultdict(str)
        _tmp_ids = []
        _tmp_toks = []
        for i, _tok in enumerate(_metadata['schema_tokens']):
            if _tok in ',.:':
                # end of name 
                _name = '_'.join(_tmp_toks)
                for _idx in _tmp_ids:
                    schema_id2names[_idx] = _name
                
                _tmp_ids = []
                _tmp_toks = []
            else:
                _tmp_ids.append(i)
                _tmp_toks.append(_tok)

        # assert _tmp_ids == _tmp_toks == []    # Might not be true due to truncating 
        
        sql_schema_names = []
        for i, q_tok in enumerate(_cand['query_toks']):
            _toks = re.split(r'^[Tt]\d+\.', q_tok)
            if len(_toks) > 1:
                assert len(_toks) == 2 and _toks[0] == '', q_tok
            for _tok in _toks:
                if _tok.isupper() or _tok in string.punctuation:
                    continue
                assert ' ' not in _tok, q_tok
                sql_schema_names.append(_tok.lower()) # can have '_' in name 
        
        _labels = []
        for i in range(len(_metadata['schema_tokens'])):
            if schema_id2names[i] in sql_schema_names:
                _labels.append(1)
            else:
                _labels.append(0)
        
#         samples_D.extend(list(zip(_schema_encodings, _labels)))
        test_labels_D.extend(_labels)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [88]:
# _cand['query_toks']

In [89]:
len(train_labels_D), len(test_labels_D)

(593194, 187337)

In [90]:
clf = LogisticRegression(tol=0.0001, C=100.0)

train_X = [s[0] for s in train_samples_C]
train_y = train_labels_D
test_X = [s[0] for s in test_samples_C]
test_y = test_labels_D

In [91]:
clf.fit(train_X, train_y)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression


LogisticRegression(C=100.0, class_weight=None, dual=False, fit_intercept=True,
                   intercept_scaling=1, l1_ratio=None, max_iter=100,
                   multi_class='auto', n_jobs=None, penalty='l2',
                   random_state=None, solver='lbfgs', tol=0.0001, verbose=0,
                   warm_start=False)

In [92]:
print('-- Train --')
eval_results(clf, train_X, train_y)
print('-- Test --')
eval_results(clf, test_X, test_y)

-- Train --
Positive prop: 0.0713
Correct positives: 19621
All pred positives: 26274
All true positives: 42311
Accuracy: 0.9505
Precision: 0.7468
Recall: 0.4637
F1: 0.5722
-- Test --
Positive prop: 0.0888
Correct positives: 7344
All pred positives: 15109
All true positives: 16630
Accuracy: 0.9090
Precision: 0.4861
Recall: 0.4416
F1: 0.4628


In [93]:
# Error analysis
err_output_path = f"/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/archive_results/D-schema_indirectly_mentioned-errors-{MODEL_VER}.txt"

test_error_analysis_schema(clf, test_X, test_y, err_output_path=err_output_path)

17051


In [60]:
# ## Old

# corr_pos = 0
# all_pred_pos = 0
# all_true_pos = 0
# for _pred, _y in zip(preds, y):
#     corr_pos += _pred * _y
#     all_pred_pos += _pred
#     all_true_pos += _y

# prec = corr_pos / all_pred_pos
# rec = corr_pos / all_true_pos
# f1 = 2 * prec * rec / (prec + rec + 1e-9)

# print(f'Correct positives: {corr_pos}')
# print(f'All pred positives: {all_pred_pos}')
# print(f'All true positives: {all_true_pos}')
# print(f'Precision: {prec:.4f}')
# print(f'Recall: {rec:.4f}')
# print(f'F1: {f1:.4f}')

### Task E - phoneme existance (multilabel) (utterence-only)
- Result (2.12.3.1):
    - Pos% (averaged over phonemes with pos_prop > 0.05, same below): 0.1011
    - Train F1: 0.0193
    - Test F1: 0.0221
    - For most phonemes, F1 = 0 (even with pos_prop > 0.1, such as N, S, AH0)
- Result (2.18.2.2):
    - Pos%: 0.0985
    - Train F1: 0.2282
    - Test F1: 0.2113
    - Samples are different because dataset was reloaded; not a big concern, because sample size is large, and test set is identical. If having time, can rerun 2.12.3.1 with the given fixed seed

In [61]:
# PHONEME_VOCAB_NAMESPACE = "phonemes"
phonemes_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/SpeakQL/Allennlp_models/runs/2.18.2.0i/vocabulary/phonemes.txt"
with open(phonemes_path, 'r') as f:
    phonemes = f.read().strip().split('\n')
len(phonemes), phonemes[:10]

(71, ['@@UNKNOWN@@', 'sil', 'T', 'D', 'N', 'AH0', 'S', 'R', 'K', 'M'])

In [62]:
# Gather data (X audio for utter toks, y for task-E per phoneme)

train_feats_E = []
test_feats_E = []
train_labels_E_per_ph = dict([(_ph, []) for _ph in phonemes])
test_labels_E_per_ph = dict([(_ph, []) for _ph in phonemes])

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _audio_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['audio_feats_encoded'][0]
#         _audio_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _utter_audio_encodings = _audio_encodings[:_text_len]
        
        train_feats_E.extend(_utter_audio_encodings)
        
        assert len(_cand['token_phonemes']) == len(_utter_audio_encodings) == _text_len, \
            (_cand['token_phonemes'], len(_utter_audio_encodings), _text_len)
        for _tok_phs in _cand['token_phonemes']:
            _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
            
            if _tok_phs is not None:
                for _ph in _tok_phs:
                    if _ph not in phonemes:
                        _ph_exists['@@UNKNOWN@@'] = True
                    else:
                        _ph_exists[_ph] = True
            # if _tok_phs is None, empty set as default 
            for _ph in phonemes:
                train_labels_E_per_ph[_ph].append(int(_ph_exists[_ph]))
    
for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _text_len = _metadata['text_len']
        _audio_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['audio_feats_encoded'][0]
#         _audio_encodings = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
        _utter_audio_encodings = _audio_encodings[:_text_len]
        
        test_feats_E.extend(_utter_audio_encodings)
        
        assert len(_cand['token_phonemes']) == len(_utter_audio_encodings) == _text_len, \
            (_cand['token_phonemes'], len(_utter_audio_encodings), _text_len)
        for _tok_phs in _cand['token_phonemes']:
            _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
            
            if _tok_phs is not None:
                for _ph in _tok_phs:
                    if _ph not in phonemes:
                        _ph_exists['@@UNKNOWN@@'] = True
                    else:
                        _ph_exists[_ph] = True
            # if _tok_phs is None, empty set as default 
            for _ph in phonemes:
                test_labels_E_per_ph[_ph].append(int(_ph_exists[_ph]))


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [63]:
len(train_feats_E), len(test_feats_E)

(85292, 41776)

In [None]:
train_res_dicts = dict()  # Dict[str:ph, Dict:results]
test_res_dicts = dict()

for _ph in tqdm(phonemes):
    if _ph in {'@@UNKNOWN@@', 'sil', '[NONE]', ''}:
        # not an actual phoneme 
        continue
    
    print(f'Phoneme: {_ph}')
    clf = LogisticRegression(tol=0.0001, C=100.0)

    train_X = train_feats_E
    train_y = train_labels_E_per_ph[_ph]
    test_X = test_feats_E
    test_y = test_labels_E_per_ph[_ph]
    
    if sum(train_y) == 0 or sum(test_y) == 0:
        # phoneme not exist in dataset 
        print(f'{_ph} not in dataset')
        print()
        continue
    
    clf.fit(train_X, train_y)
    
    _train_d = dict()
    _test_d = dict()
    eval_results(clf, train_X, train_y, res_dict=_train_d)
    eval_results(clf, test_X, test_y, res_dict=_test_d)
    train_res_dicts[_ph] = _train_d
    test_res_dicts[_ph] = _test_d
    print()

In [65]:
print(np.mean([train_res_dicts[_ph]['pos_prop'] for _ph in train_res_dicts if train_res_dicts[_ph]['pos_prop'] > 0.05]))
print(np.mean([train_res_dicts[_ph]['f1'] for _ph in train_res_dicts if train_res_dicts[_ph]['pos_prop'] > 0.05]))
print(np.mean([test_res_dicts[_ph]['f1'] for _ph in train_res_dicts if train_res_dicts[_ph]['pos_prop'] > 0.05]))

0.09845840507567281
0.2281778533661335
0.21130471085969618


In [66]:
_included_phs = [_ph for _ph in train_res_dicts if train_res_dicts[_ph]['pos_prop'] > 0.05]
[(_ph, train_res_dicts[_ph]['pos_prop'], test_res_dicts[_ph]['f1']) for _ph in _included_phs]

[('T', 0.20149603714299114, 0.37116764472342895),
 ('D', 0.11371523706795479, 0.14226489184585123),
 ('N', 0.1979435351498382, 0.24902939514159889),
 ('AH0', 0.14976785630539793, 0.07930525952786863),
 ('S', 0.14304975847676218, 0.34475045388698267),
 ('R', 0.12463068048586033, 0.3836438265586174),
 ('K', 0.08960981100220419, 0.2250584978536747),
 ('M', 0.12391549031562163, 0.16122531215621105),
 ('L', 0.10229564320217606, 0.444889110780385),
 ('P', 0.057062796041832765, 0.05973025034868963),
 ('EY1', 0.06892791821038316, 0.1813530411299053),
 ('IH0', 0.09059466303990996, 0.0),
 ('ER0', 0.0821179008582282, 0.1391527597244173),
 ('Z', 0.11481733339586363, 0.30796938271590174),
 ('AE1', 0.06946724194531727, 0.0),
 ('IH1', 0.0559138019978427, 0.14818355622460053),
 ('AH1', 0.06053322703184355, 0.37074100509132196),
 ('IY0', 0.05875111382075693, 0.20148588374510407),
 ('EH1', 0.05414341321577639, 0.0),
 ('F', 0.05348684519063922, 0.2003159555782784),
 ('W', 0.055386202691928905, 0.42713270

In [None]:
[sum(_f) for _f in train_feats_E]

In [None]:
# Debug checking
o_id

In [None]:
full_train_dataset[6712][0]['token_phonemes']

In [None]:
train_labels_E_per_ph['D']

In [None]:
len(train_feats_E[0])

In [None]:
clf.predict_log_proba([train_feats_E[9]])

In [None]:
clf = LogisticRegression(tol=0.0001, C=100.0, class_weight='balanced')

train_X = train_feats_E
train_y = train_labels_E_per_ph['T']
test_X = test_feats_E
test_y = test_labels_E_per_ph['T']

clf.fit(train_X, train_y)

In [None]:
eval_results(clf, train_X, train_y)

### Task R1 - decoder awareness of corresponding input text phonemes

In [82]:
# PHONEME_VOCAB_NAMESPACE = "phonemes"
phonemes_path = "/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/SpeakQL/Allennlp_models/runs/2.29.0.0i/vocabulary/phonemes.txt"
with open(phonemes_path, 'r') as f:
    phonemes = f.read().strip().split('\n')
len(phonemes), phonemes[:10]

(50, ['@@UNKNOWN@@', 'T', 'AH', 'D', 'N', 'IH', 'S', '[NONE]', 'R', 'K'])

In [83]:
# _decoder_intermediates[f'state:step_0'].keys()

In [84]:
## Unifying train & test loading as a function, to avoid copying back and forth 
## TODO: add 'phonemes' as argument 
def Load_probing_data_R1(predictor, ds_instances, ds_original, dataset_reader):
    '''
    example:
    Load_probing_data_R1(predictor, train_dataset, full_train_dataset, train_dataset_reader)
    Load_probing_data_R1(predictor, test_dataset, full_dev_dataset, test_dataset_reader)
    '''
    
    feats_R1 = []
    labels_R1_per_ph = dict([(_ph, []) for _ph in phonemes])
    # debug checking 
    tokens_R1 = []
    
    # Gather instances by o_id 
    o_id2instances = defaultdict(list)
    for _inst in ds_instances:
        o_id = _inst['metadata']['original_id']
        o_id2instances[o_id].append(_inst)
        
    for o_id, _insts in tqdm(o_id2instances.items(), total=len(o_id2instances)):
        _cands = ds_original[o_id]

        for _inst, _cand in zip(_insts, _cands):
            _pred = predictor.predict_instance(_inst)

            _metadata = _inst['metadata']
            _target_toks = [str(t) for t in _metadata['target_tokens']] + [END_SYMBOL]
            _decoder_intermediates = _pred['rewrite_seq_prediction_intermediates']['decoder']
            # len(_decoder_hidden_states) == len(_target_toks) from metadata; including the last step predicting @end@ 
            # [0] to get rid of batch dim 
            _decoder_hidden_states = [_decoder_intermediates[f'state_decoder_hidden:step_{i}'][0] for i in range(len(_target_toks))]

            feats_R1.extend(_decoder_hidden_states)
            tokens_R1.extend(_target_toks)

            # List[List[int]], the ids of each span in target sequence 
            _target_spans = []
            _curr_span = []
            for i, t in enumerate(_target_toks):
                _curr_span.append(i)
                if t in {'[ANS]', END_SYMBOL}:
                    _target_spans.append(_curr_span)
                    _curr_span = []

            assert len(_curr_span) == 0, _target_toks
            ## _target_spans include trailing END_SYMBOL as a span, so 1 more 
            assert len(_target_spans) == len(_cand['rewriter_edits']) + 1, f"{_target_toks} || {_cand['rewriter_edits']}"

            for e_i, edit_d in enumerate(_cand['rewriter_edits']):
                q_span = ' '.join([_cand['question_toks'][i].lower() for i in edit_d['src_span']])
                q_span_phs = [ph for q_tok in text_cell_to_toks(q_span) for ph in dataset_reader._token_to_phonemes(q_tok)]

                _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
                for _ph in q_span_phs:
                    if _ph not in phonemes:
                        _ph_exists['@@UNKNOWN@@'] = True
                    else:
                        _ph_exists[_ph] = True

                for _ph in phonemes:
                    # for each target step in the target span correspond to this src span, add these ph labels 
                    labels_R1_per_ph[_ph].extend([int(_ph_exists[_ph])] * len(_target_spans[e_i]))
            # for the last step END_SYMBOL span
            for _ph in phonemes:
                labels_R1_per_ph[_ph].extend([0])

    return {
        "feats_R1": feats_R1,
        "labels_R1_per_ph": labels_R1_per_ph,
        "tokens_R1": tokens_R1,
    }


In [85]:
train_probing_data_R1 = Load_probing_data_R1(predictor, train_dataset, full_train_dataset, train_dataset_reader)
# train_probing_data_R1.keys()

HBox(children=(IntProgress(value=0, max=7000), HTML(value='')))




In [86]:
test_probing_data_R1 = Load_probing_data_R1(predictor, test_dataset, full_dev_dataset, test_dataset_reader)

HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [87]:
train_feats_R1 = train_probing_data_R1['feats_R1']
train_tokens_R1 = train_probing_data_R1['tokens_R1']
train_labels_R1_per_ph = train_probing_data_R1['labels_R1_per_ph']
test_feats_R1 = test_probing_data_R1['feats_R1']
test_tokens_R1 = test_probing_data_R1['tokens_R1']
test_labels_R1_per_ph = test_probing_data_R1['labels_R1_per_ph']

In [88]:
np.array(train_feats_R1).shape, len(train_tokens_R1), len(train_labels_R1_per_ph['S']), \
np.array(test_feats_R1).shape, len(test_tokens_R1), len(test_labels_R1_per_ph['S'])


((27629, 256), 27629, 27629, (4469, 256), 4469, 4469)

In [None]:
## Logistic Regression probes 

train_res_dicts = dict()  # Dict[str:ph, Dict:results]
test_res_dicts = dict()

for _ph in tqdm(phonemes):
    if _ph in {'@@UNKNOWN@@', 'sil', '[NONE]', ''}:
        # not an actual phoneme 
        continue
    
    print(f'Phoneme: {_ph}')
    clf = LogisticRegression(tol=0.0001, C=100.0)

    train_X = train_feats_R1
    train_y = train_labels_R1_per_ph[_ph]
    test_X = test_feats_R1
    test_y = test_labels_R1_per_ph[_ph]
    
    if sum(train_y) == 0 or sum(test_y) == 0:
        # phoneme not exist in dataset 
        print(f'{_ph} not in dataset')
        print()
        continue
    
    clf.fit(train_X, train_y)
    
    _train_d = dict()
    _test_d = dict()
    eval_results(clf, train_X, train_y, res_dict=_train_d)
    eval_results(clf, test_X, test_y, res_dict=_test_d)
    train_res_dicts[_ph] = _train_d
    test_res_dicts[_ph] = _test_d
    print()

In [97]:
## Save probing results 

analysis_dir = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/SpeakQL/Allennlp_models/analysis_results'
train_save_path = os.path.join(analysis_dir, 'probing_R1_results_train.csv')
test_save_path = os.path.join(analysis_dir, 'probing_R1_results_test.csv')
pd.DataFrame.from_dict(train_res_dicts, orient='index').to_csv(train_save_path)
pd.DataFrame.from_dict(test_res_dicts, orient='index').to_csv(test_save_path)

In [100]:
# del train_probing_data_R1
# del test_probing_data_R1
# del train_feats_R1
# del train_tokens_R1
# del train_labels_R1_per_ph
# del test_feats_R1
# del test_tokens_R1
# del test_labels_R1_per_ph

## Using del seems useless...

#### Task R1a - compare of R1: probing the phoneme tag embeddings
- Should have ~100% accuracy!

In [121]:
## Unifying train & test loading as a function, to avoid copying back and forth 
## TODO: add 'phonemes' as argument 
## R1_compare_a: use phoneme embeddings a token to probe its phonemes on token level 
## tokens still from edit spans 
def Load_probing_data_R1_compare_a(predictor, ds_instances, ds_original, dataset_reader):
    feats_R1 = []
    labels_R1_per_ph = dict([(_ph, []) for _ph in phonemes])
    # debug checking 
    tokens_R1 = []
    
    # Gather instances by o_id 
    o_id2instances = defaultdict(list)
    for _inst in ds_instances:
        o_id = _inst['metadata']['original_id']
        o_id2instances[o_id].append(_inst)
        
    for o_id, _insts in tqdm(o_id2instances.items(), total=len(o_id2instances)):
        _cands = ds_original[o_id]

        for _inst, _cand in zip(_insts, _cands):
            _pred = predictor.predict_instance(_inst)

            _metadata = _inst['metadata']
            _source_toks = [str(t) for t in _metadata['source_tokens']]  # concat tokens, but text in front 
            _intermediates = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]

            for e_i, edit_d in enumerate(_cand['rewriter_edits']):
                for i in edit_d['src_span']:
                    feats_R1.append(_intermediates[i])
                    tokens_R1.append(_source_toks[i])

                    _tok = _source_toks[i]
                    _tok_phs = [ph for _t in text_cell_to_toks(_tok) for ph in dataset_reader._token_to_phonemes(_t)]

                    _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
                    for _ph in _tok_phs:
                        if _ph not in phonemes:
                            _ph_exists['@@UNKNOWN@@'] = True
                        else:
                            _ph_exists[_ph] = True

                    for _ph in phonemes:
                        # for each target step in the target span correspond to this src span, add these ph labels 
                        labels_R1_per_ph[_ph].append(int(_ph_exists[_ph]))

    return {
        "feats_R1a": feats_R1,
        "labels_R1a_per_ph": labels_R1_per_ph,
        "tokens_R1a": tokens_R1,
    }


In [122]:
train_probing_data_R1a = Load_probing_data_R1_compare_a(predictor, train_dataset, full_train_dataset, train_dataset_reader)


HBox(children=(IntProgress(value=0, max=7000), HTML(value='')))




In [123]:
test_probing_data_R1a = Load_probing_data_R1_compare_a(predictor, test_dataset, full_dev_dataset, test_dataset_reader)


HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [125]:
train_feats_R1a = train_probing_data_R1a['feats_R1a']
train_tokens_R1a = train_probing_data_R1a['tokens_R1a']
train_labels_R1a_per_ph = train_probing_data_R1a['labels_R1a_per_ph']
test_feats_R1a = test_probing_data_R1a['feats_R1a']
test_tokens_R1a = test_probing_data_R1a['tokens_R1a']
test_labels_R1a_per_ph = test_probing_data_R1a['labels_R1a_per_ph']

In [126]:
np.array(train_feats_R1a).shape, len(train_tokens_R1a), len(train_labels_R1a_per_ph['S']), \
np.array(test_feats_R1a).shape, len(test_tokens_R1a), len(test_labels_R1a_per_ph['S'])


((10661, 256), 10661, 10661, (1744, 256), 1744, 1744)

In [None]:
## Logistic Regression probes 

train_res_dicts = dict()  # Dict[str:ph, Dict:results]
test_res_dicts = dict()

for _ph in tqdm(phonemes):
    if _ph in {'@@UNKNOWN@@', 'sil', '[NONE]', ''}:
        # not an actual phoneme 
        continue
    
    print(f'Phoneme: {_ph}')
    clf = LogisticRegression(tol=0.0001, C=100.0)

    train_X = train_feats_R1a
    train_y = train_labels_R1a_per_ph[_ph]
    test_X = test_feats_R1a
    test_y = test_labels_R1a_per_ph[_ph]
    
    if sum(train_y) == 0 or sum(test_y) == 0:
        # phoneme not exist in dataset 
        print(f'{_ph} not in dataset')
        print()
        continue
    
    clf.fit(train_X, train_y)
    
    _train_d = dict()
    _test_d = dict()
    eval_results(clf, train_X, train_y, res_dict=_train_d)
    eval_results(clf, test_X, test_y, res_dict=_test_d)
    train_res_dicts[_ph] = _train_d
    test_res_dicts[_ph] = _test_d
    print()

In [128]:
## Save probing results 

analysis_dir = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/SpeakQL/Allennlp_models/analysis_results'
train_save_path = os.path.join(analysis_dir, 'probing_R1a_results_train.csv')
test_save_path = os.path.join(analysis_dir, 'probing_R1a_results_test.csv')
pd.DataFrame.from_dict(train_res_dicts, orient='index').to_csv(train_save_path)
pd.DataFrame.from_dict(test_res_dicts, orient='index').to_csv(test_save_path)

#### Task R1b - compare of R1: probing the encoder representations

In [130]:
## Unifying train & test loading as a function, to avoid copying back and forth 
## TODO: add 'phonemes' as argument 
## R1_compare_b: use encoder representation of a token to probe its edit span phonemes 
def Load_probing_data_R1_compare_b(predictor, ds_instances, ds_original, dataset_reader):
    
    feats_R1 = []
    labels_R1_per_ph = dict([(_ph, []) for _ph in phonemes])
    # debug checking 
    tokens_R1 = []
    
    # Gather instances by o_id 
    o_id2instances = defaultdict(list)
    for _inst in ds_instances:
        o_id = _inst['metadata']['original_id']
        o_id2instances[o_id].append(_inst)
        
    for o_id, _insts in tqdm(o_id2instances.items(), total=len(o_id2instances)):
        _cands = ds_original[o_id]

        for _inst, _cand in zip(_insts, _cands):
            _pred = predictor.predict_instance(_inst)

            _metadata = _inst['metadata']
            _source_toks = [str(t) for t in _metadata['source_tokens']] + [END_SYMBOL]
            _intermediates = _pred['rewrite_seq_prediction_intermediates']['rewriter_main']['encoder_seq_representation_with_tag'][0]
#             feats_R1.extend(_intermediates)
#             tokens_R1.extend(_source_toks)

            for e_i, edit_d in enumerate(_cand['rewriter_edits']):
                q_span = ' '.join([_cand['question_toks'][i].lower() for i in edit_d['src_span']])
                q_span_phs = [ph for q_tok in text_cell_to_toks(q_span) for ph in dataset_reader._token_to_phonemes(q_tok)]

                feats_R1.extend([_intermediates[i] for i in edit_d['src_span']])
                tokens_R1.extend([_source_toks[i] for i in edit_d['src_span']])
                
                _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
                for _ph in q_span_phs:
                    if _ph not in phonemes:
                        _ph_exists['@@UNKNOWN@@'] = True
                    else:
                        _ph_exists[_ph] = True

                for _ph in phonemes:
                    # for each target step in the target span correspond to this src span, add these ph labels 
                    labels_R1_per_ph[_ph].extend([int(_ph_exists[_ph])] * len(edit_d['src_span']))

    return {
        "feats_R1b": feats_R1,
        "labels_R1b_per_ph": labels_R1_per_ph,
        "tokens_R1b": tokens_R1,
    }


In [131]:
train_probing_data_R1b = Load_probing_data_R1_compare_b(predictor, train_dataset, full_train_dataset, train_dataset_reader)


HBox(children=(IntProgress(value=0, max=7000), HTML(value='')))




In [132]:
test_probing_data_R1b = Load_probing_data_R1_compare_b(predictor, test_dataset, full_dev_dataset, test_dataset_reader)


HBox(children=(IntProgress(value=0, max=547), HTML(value='')))




In [134]:
train_feats_R1b = train_probing_data_R1b['feats_R1b']
train_tokens_R1b = train_probing_data_R1b['tokens_R1b']
train_labels_R1b_per_ph = train_probing_data_R1b['labels_R1b_per_ph']
test_feats_R1b = test_probing_data_R1b['feats_R1b']
test_tokens_R1b = test_probing_data_R1b['tokens_R1b']
test_labels_R1b_per_ph = test_probing_data_R1b['labels_R1b_per_ph']

In [135]:
np.array(train_feats_R1b).shape, len(train_tokens_R1b), len(train_labels_R1b_per_ph['S']), \
np.array(test_feats_R1b).shape, len(test_tokens_R1b), len(test_labels_R1b_per_ph['S'])


((10661, 256), 10661, 10661, (1744, 256), 1744, 1744)

In [None]:
## Logistic Regression probes 

train_res_dicts = dict()  # Dict[str:ph, Dict:results]
test_res_dicts = dict()

for _ph in tqdm(phonemes):
    if _ph in {'@@UNKNOWN@@', 'sil', '[NONE]', ''}:
        # not an actual phoneme 
        continue
    
    print(f'Phoneme: {_ph}')
    clf = LogisticRegression(tol=0.0001, C=100.0)

    train_X = train_feats_R1b
    train_y = train_labels_R1b_per_ph[_ph]
    test_X = test_feats_R1b
    test_y = test_labels_R1b_per_ph[_ph]
    
    if sum(train_y) == 0 or sum(test_y) == 0:
        # phoneme not exist in dataset 
        print(f'{_ph} not in dataset')
        print()
        continue
    
    clf.fit(train_X, train_y)
    
    _train_d = dict()
    _test_d = dict()
    eval_results(clf, train_X, train_y, res_dict=_train_d)
    eval_results(clf, test_X, test_y, res_dict=_test_d)
    train_res_dicts[_ph] = _train_d
    test_res_dicts[_ph] = _test_d
    print()

In [137]:
## Save probing results 

analysis_dir = '/Users/mac/Desktop/syt/Deep-Learning/Projects-M/SpeakQL/SpeakQL/Allennlp_models/analysis_results'
train_save_path = os.path.join(analysis_dir, 'probing_R1b_results_train.csv')
test_save_path = os.path.join(analysis_dir, 'probing_R1b_results_test.csv')
pd.DataFrame.from_dict(train_res_dicts, orient='index').to_csv(train_save_path)
pd.DataFrame.from_dict(test_res_dicts, orient='index').to_csv(test_save_path)

### Temp

In [None]:
_sen = 'What is the id of students of grade 5 .'
[_t.lemma_ for _t in nlp(_sen)]

In [None]:
stemmer = PorterStemmer()
[stemmer.stem(_t) for _t in _sen.split(' ')]

#### Legacy code

In [38]:
# Gather data (X decoder hidden states, y corresponding source span text)

train_feats_R1 = []
train_labels_R1_per_ph = dict([(_ph, []) for _ph in phonemes])

# debug checking 
train_tokens_R1 = []

for o_id, _insts in tqdm(train_o_id2instances.items(), total=len(train_o_id2instances)):
    _cands = full_train_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _target_toks = [str(t) for t in _metadata['target_tokens']] + [END_SYMBOL]
        _decoder_intermediates = _pred['rewrite_seq_prediction_intermediates']['decoder']
        # len(_decoder_hidden_states) == len(_target_toks) from metadata; including the last step predicting @end@ 
        _decoder_hidden_states = [_decoder_intermediates[f'state_decoder_hidden:step_{i}'] for i in range(len(_target_toks))]
        
        train_feats_R1.extend(_decoder_hidden_states)
        train_tokens_R1.extend(_target_toks)
        
        # List[List[int]], the ids of each span in target sequence 
        _target_spans = []
        _curr_span = []
        for i, t in enumerate(_target_toks):
            _curr_span.append(i)
            if t in {'[ANS]', END_SYMBOL}:
                _target_spans.append(_curr_span)
                _curr_span = []
                
        assert len(_curr_span) == 0, _target_toks
        ## _target_spans include trailing END_SYMBOL as a span, so 1 more 
        assert len(_target_spans) == len(_cand['rewriter_edits']) + 1, f"{_target_toks} || {_cand['rewriter_edits']}"
        
        for e_i, edit_d in enumerate(_cand['rewriter_edits']):
            q_span = ' '.join([_cand['question_toks'][i].lower() for i in edit_d['src_span']])
            q_span_phs = [ph for q_tok in text_cell_to_toks(q_span) for ph in train_dataset_reader._token_to_phonemes(q_tok)]
            
            ### TODO: add ph labels correspond to target span 
            _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
            for _ph in q_span_phs:
                if _ph not in phonemes:
                    _ph_exists['@@UNKNOWN@@'] = True
                else:
                    _ph_exists[_ph] = True
            
            for _ph in phonemes:
                # for each target step in the target span correspond to this src span, add these ph labels 
                train_labels_R1_per_ph[_ph].extend([int(_ph_exists[_ph])] * len(_target_spans[e_i]))
        # for the last step END_SYMBOL span
        for _ph in phonemes:
            train_labels_R1_per_ph[_ph].extend([0])

    if len(train_feats_R1) > 10:
        break
            


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [None]:
test_feats_R1 = []
test_labels_R1_per_ph = dict([(_ph, []) for _ph in phonemes])

test_tokens_R1 = []

for o_id, _insts in tqdm(test_o_id2instances.items(), total=len(test_o_id2instances)):
    _cands = full_dev_dataset[o_id]
    
    for _inst, _cand in zip(_insts, _cands):
        _pred = predictor.predict_instance(_inst)
        
        _metadata = _inst['metadata']
        _target_toks = [str(t) for t in _metadata['target_tokens']] + [END_SYMBOL]
        _decoder_intermediates = _pred['rewrite_seq_prediction_intermediates']['decoder']
        # len(_decoder_hidden_states) == len(_target_toks) from metadata; including the last step predicting @end@ 
        _decoder_hidden_states = [_decoder_intermediates[f'state:step_{i}']['decoder_hidden'] for i in range(len(_target_toks))]
        
        test_feats_R1.extend(_decoder_hidden_states)
        test_tokens_R1.extend(_target_toks)
        
        # List[List[int]], the ids of each span in target sequence 
        _target_spans = []
        _curr_span = []
        for i, t in enumerate(_target_toks):
            _curr_span.append(i)
            if t in {'[ANS]', END_SYMBOL}:
                _target_spans.append(_curr_span)
                _curr_span = []
        assert len(_curr_span) == 0, _target_toks
        assert len(_target_spans) == len(_cand['rewriter_edits']) + 1, f"{_target_toks} || {_cand['rewriter_edits']}"
        
        for e_i, edit_d in enumerate(_cand['rewriter_edits']):
            q_span = ' '.join([_cand['question_toks'][i].lower() for i in edit_d['src_span']])
            q_span_phs = [ph for q_tok in text_cell_to_toks(q_span) for ph in test_dataset_reader._token_to_phonemes(q_tok)]
            
            ### TODO: add ph labels correspond to target span 
            _ph_exists = defaultdict(lambda: False) # Dict[str: ph, bool]
            for _ph in q_span_phs:
                if _ph not in phonemes:
                    _ph_exists['@@UNKNOWN@@'] = True
                else:
                    _ph_exists[_ph] = True
            
            for _ph in phonemes:
                # for each target step in the target span correspond to this src span, add these ph labels 
                test_labels_R1_per_ph[_ph].extend([int(_ph_exists[_ph])] * len(_target_spans[e_i]))
        # for the last step END_SYMBOL span
        for _ph in phonemes:
            test_labels_R1_per_ph[_ph].extend([0])

