In [1]:
%cd ../../spell_correction

/Users/tyoyo/lab/spell_correction


In [8]:
from spell_correction.t5_spell_corrector import T5SpellCorrector
import pandas as pd
from spell_correction.utils import sigmoid

In [3]:
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: c["distance"] <= 1,
    add_original_token_to_masked_text=False  # つけたい時は自分でつけるので
)

## 問題点
filter_func, score_funcでmatch_len ==0のものを弾いてしまうと、insert: "つ" みたいな編集ができなくなる

In [16]:
# 今の設定
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: c["distance"] <= 2 and c["match_len"] >= 1,
    score_func=lambda c: c["match_len"] / (c["distance"] + 1) * sigmoid(c["logit"]),
    add_original_token_to_masked_text=True
)
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
pd.DataFrame(candidates[:10])

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,放射,2,1.0,2726,反射,-6.38833,154237
1,放,1,0.0,2737,放,-6.393455,11215
2,放,1,1.0,2394,開放,-6.159548,148405
3,放射,2,1.0,3544,注射,-6.881783,193502
4,放,1,1.0,2572,解放,-6.283083,112781
5,放,1,1.0,3165,放置,-6.654175,122462
6,放,1,1.0,5629,手放,-7.874041,224867
7,放射,2,1.0,10636,照射,-9.404441,191487
8,放,1,1.0,8409,を放,-8.831514,214310
9,放,1,1.0,9051,放大,-9.017117,210626


In [17]:
# match_lenの条件を緩める
# スコア関数を変える
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: c["distance"] <= 1,
    score_func=lambda c: (c["match_len"] + 1) / (c["distance"] + 1) * sigmoid(c["logit"]),
)
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
pd.DataFrame(candidates[:10])

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,1,1,も,2.261275,978
1,,0,1,2,に,1.226257,600
2,,0,1,3,が,0.853528,536
3,,0,1,4,と,0.801322,725
4,,0,1,6,目,0.364599,7326
5,,0,1,7,▁,0.296484,259
6,,0,1,9,で,0.241346,627
7,,0,1,11,触,0.166528,40845
8,,0,1,12,力,-0.0154,3862
9,,0,1,13,2,-0.016775,338


In [18]:
# match_lenの条件を緩める
# スコア関数を変える
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: c["distance"] <= 1,
    score_func=lambda c: (c["match_len"] + 0.1) / (c["distance"] + 1) * sigmoid(c["logit"]),
)
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
pd.DataFrame(candidates[:10])

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,1,1,も,2.261275,978
1,,0,1,2,に,1.226257,600
2,,0,1,3,が,0.853528,536
3,,0,1,4,と,0.801322,725
4,,0,1,6,目,0.364599,7326
5,,0,1,7,▁,0.296484,259
6,,0,1,9,で,0.241346,627
7,,0,1,11,触,0.166528,40845
8,,0,1,12,力,-0.0154,3862
9,,0,1,13,2,-0.016775,338


In [15]:
# match_lenの条件を緩める
# スコア関数を変える
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: c["distance"] <= 1,
    score_func=lambda c: (c["match_len"] + 0.01) / (c["distance"] + 1) * sigmoid(c["logit"]),
)
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
pd.DataFrame(candidates[:30])

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,1.0,1,も,2.261275,978
1,,0,1.0,2,に,1.226257,600
2,,0,1.0,3,が,0.853528,536
3,,0,1.0,4,と,0.801322,725
4,,0,1.0,6,目,0.364599,7326
5,,0,1.0,7,▁,0.296484,259
6,,0,1.0,9,で,0.241346,627
7,,0,1.0,11,触,0.166528,40845
8,,0,1.0,12,力,-0.0154,3862
9,,0,1.0,13,2,-0.016775,338


## ちゃんと探索できるならスコア関数とかfilter関数とかいらないからちょっと探索方法を考える
  * どのcandidateから選んでいくか
    * distance == 1 の中で尤度最大, distance == 0 の中で尤度最大のビーム幅2とか

In [19]:
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: True,
    score_func=lambda c: sigmoid(c["logit"]),
)
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
pd.DataFrame(candidates[:10])

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,2,0,もの,2.492165,12488
1,,0,1,1,も,2.261275,978
2,,0,1,2,に,1.226257,600
3,,0,1,3,が,0.853528,536
4,,0,1,4,と,0.801322,725
5,,0,2,5,だけ,0.672474,13386
6,,0,1,6,目,0.364599,7326
7,,0,1,7,▁,0.296484,259
8,,0,2,8,にも,0.246668,6578
9,,0,1,9,で,0.241346,627


In [20]:
df = pd.DataFrame(candidates)

In [21]:
df[df.distance == 0].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
2737,放,1,0.0,2737,放,-6.393455,11215
19533,放射,2,0.0,19533,放射,-10.803551,166590


In [23]:
df[df.distance == 1].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
1,,0,1.0,1,も,2.261275,978
2,,0,1.0,2,に,1.226257,600
3,,0,1.0,3,が,0.853528,536
4,,0,1.0,4,と,0.801322,725
6,,0,1.0,6,目,0.364599,7326


In [24]:
df[df.distance == 2].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,2.0,0,もの,2.492165,12488
5,,0,2.0,5,だけ,0.672474,13386
8,,0,2.0,8,にも,0.246668,6578
10,,0,2.0,10,左右,0.233014,37279
15,,0,2.0,15,ほど,-0.225622,16754


↑ よさそう。

In [25]:
# 別の例でも試す
spell_corrector = T5SpellCorrector(
    model_name="google/mt5-small",
    filter_func=lambda c: True,
    score_func=lambda c: sigmoid(c["logit"]),
)
masked_text = "彼はファヴローに <extra_id_0>西部劇映画を提供した。"
original_token = 'クラッシック'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

In [26]:
df[df.distance == 0].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
428,ク,1,0.0,428,ク,-4.204565,6662
979,クラ,2,0.0,979,クラ,-5.102506,48440


In [40]:
df[df.distance == 1].head(n=30)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
5,,0,1.0,5,▁,-0.845111,259
7,,0,1.0,7,、,-0.887233,292
13,,0,1.0,13,壮,-1.265505,78590
17,,0,1.0,17,新,-1.374436,2188
20,,0,1.0,20,一,-1.580832,1374
25,,0,1.0,25,好,-1.860867,3586
27,,0,1.0,27,愛,-1.884031,14942
28,,0,1.0,28,傑,-1.904481,194552
29,,0,1.0,29,適,-1.922753,65187
31,,0,1.0,31,「,-1.947448,939


In [28]:
df[df.distance == 2].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
4,,0,2.0,4,最も,-0.709116,65454
8,,0,2.0,8,特別,-0.914652,41666
9,,0,2.0,9,新作,-0.94686,54509
11,,0,2.0,11,続く,-1.21436,119760
12,,0,2.0,12,映画,-1.264738,25875


In [39]:
df.iloc[df.match_len.idxmax()]

match_ref     クラッシック
match_len          6
distance           1
rank             100
token          クラシック
logit       -2.89716
index         119003
Name: 100, dtype: object

In [42]:
df[df.match_len == 6].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
100,クラッシック,6,1.0,100,クラシック,-2.897156,119003
1317,クラッシック,6,4.0,1317,ダイナミック,-5.468344,232973
2233,クラッシック,6,3.0,2233,グラフィック,-6.302787,197624
3063,クラッシック,6,4.0,3063,オリンピック,-6.855439,150232
3649,クラッシック,6,3.0,3649,ベーシック,-7.222574,176686


In [43]:
df[df.match_len == 5].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
18416,クラッシッ,5,10.0,18416,ポケットコイルマットレス,-11.047287,228405
24960,クラッシッ,5,12.0,24960,ポケットコイルマットレス付き,-11.726411,226019
25716,クラッシッ,5,6.0,25716,オンラインショップ,-11.794085,228295
39527,クラッシッ,5,5.0,39527,チャップアップ,-12.754765,217220


In [44]:
df[df.match_len == 4].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
524,クラッシ,4,5.0,524,スタイリッシュ,-4.407331,197005
1000,クラッシ,4,6.0,1000,プロモーション,-5.127721,233516
1144,クラッシ,4,6.0,1144,アニメーション,-5.284406,235395
1401,クラッシ,4,4.0,1401,ファッション,-5.555392,48804
1922,クラッシ,4,8.0,1922,インターナショナル,-6.064025,223486


In [45]:
df[df.match_len == 3].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
207,クラッ,3,2.0,207,ラップ,-3.529986,145702
525,クラッ,3,4.0,525,ターゲット,-4.410338,199802
541,クラッ,3,3.0,541,ステップ,-4.425168,90780
716,クラッ,3,3.0,716,スポット,-4.728745,106090
799,クラッ,3,4.0,799,ヨーロッパ,-4.860272,122950


In [47]:
df[df.match_len == 2].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
88,クラ,2,5.0,88,プライベート,-2.752218,162722
121,クラ,2,2.0,121,ドラマ,-3.079191,53130
316,クラ,2,3.0,316,サプライ,-3.917451,238640
356,クラ,2,1.0,356,グラ,-4.013418,67915
457,クラ,2,2.0,457,プラス,-4.271429,47175


In [48]:
df[df.match_len == 1].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
74,ク,1,4.0,74,クリスマス,-2.650558,81219
169,ク,1,4.0,169,スクリーン,-3.37588,149854
203,ク,1,5.0,203,ニューヨーク,-3.504385,136215
228,ク,1,2.0,228,クール,-3.593795,124023
270,ク,1,4.0,270,アクション,-3.794083,122951


In [50]:
df[df.match_len == 0].head()

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,3.0,0,新たな,0.558038,87989
1,,0,3.0,1,新しい,0.507257,26634
2,,0,3.0,2,優れた,-0.278255,163464
3,,0,5.0,3,素晴らしい,-0.439402,71651
4,,0,2.0,4,最も,-0.709116,65454


In [54]:
from itertools import product

for match_len, distance in product(range(len(original_token) + 1), range(3)):
    print(df.query(f'match_len == {match_len} and distance == {distance}').head(1))
    print()

Empty DataFrame
Columns: [match_ref, match_len, distance, rank, token, logit, index]
Index: []

  match_ref  match_len  distance  rank token     logit  index
5                    0       1.0     5     ▁ -0.845111    259

  match_ref  match_len  distance  rank token     logit  index
4                    0       2.0     4    最も -0.709116  65454

    match_ref  match_len  distance  rank token     logit  index
428         ク          1       0.0   428     ク -4.204565   6662

     match_ref  match_len  distance  rank token     logit  index
2487         ク          1       1.0  2487    クリ -6.496171  68250

    match_ref  match_len  distance  rank token     logit   index
228         ク          1       2.0   228   クール -3.593795  124023

    match_ref  match_len  distance  rank token     logit  index
979        クラ          2       0.0   979    クラ -5.102506  48440

    match_ref  match_len  distance  rank token     logit  index
356        クラ          2       1.0   356    グラ -4.013418  67915

    m

In [62]:
masked_text = "クラゲは4つ <extra_id_0>を持つが触手はもたない。"
original_token = '放射管'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,2.0,0,もの,2.492165,12488
1,,0,1.0,1,も,2.261275,978
2394,放,1,1.0,2394,開放,-6.159548,148405
2726,放射,2,1.0,2726,反射,-6.38833,154237
2737,放,1,0.0,2737,放,-6.393455,11215
19533,放射,2,0.0,19533,放射,-10.803551,166590


In [63]:
masked_text = "彼はファヴローに <extra_id_0>西部劇映画を提供した。"
original_token = 'クラッシック'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
4,,0,2.0,4,最も,-0.709116,65454
5,,0,1.0,5,▁,-0.845111,259
100,クラッシック,6,1.0,100,クラシック,-2.897156,119003
121,クラ,2,2.0,121,ドラマ,-3.079191,53130
207,クラッ,3,2.0,207,ラップ,-3.529986,145702
228,ク,1,2.0,228,クール,-3.593795,124023
356,クラ,2,1.0,356,グラ,-4.013418,67915
428,ク,1,0.0,428,ク,-4.204565,6662
979,クラ,2,0.0,979,クラ,-5.102506,48440
2487,ク,1,1.0,2487,クリ,-6.496171,68250


In [64]:
masked_text = "桶狭間の戦いで <extra_id_0>の鉄砲が織田信長により初めて戦力として使用された。"
original_token = '国友制'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,1.0,0,、,4.866378,292
2,,0,2.0,2,使用,2.831828,7092
16,国,1,1.0,16,帝国,1.671213,143559
520,国,1,0.0,520,国,-0.841577,3802
1320,国,1,2.0,1320,中国人,-1.929748,123968
27175,国友,2,2.0,27175,お友達,-8.456841,197207
27398,国友,2,1.0,27398,朋友,-8.475882,53392


In [65]:
masked_text = "民主党の予備選挙ではヒラリー・クリントンが1位となり、バラク・オバマが僅差の2位 <extra_id_0>。"
original_token = 'だた'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
9,,0,2.0,9,です,0.691815,1252
10,だた,2,1.0,10,だった,0.64612,9372
11,,0,1.0,11,と,0.533252,725
18,だた,2,2.0,18,でした,-0.292879,9083
22,だ,1,0.0,22,だ,-0.579966,5908
238,だ,1,1.0,238,だけ,-4.8915,13386
247,だ,1,2.0,247,だろう,-4.933357,37591


In [66]:
masked_text = "ひきこもりの <extra_id_0>をはじめたきっかけ"
original_token = '更正'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
1,,0,2.0,1,日々,4.200026,95008
10,,0,1.0,10,女,2.971959,6043
1275,更正,2,1.0,1275,矯正,-2.190703,190284
2603,更正,2,2.0,2603,無修正,-3.47817,209103
3005,更,1,0.0,3005,更,-3.754842,9411
4856,更,1,1.0,4856,更新,-4.876542,16449


    "彼はファヴローにクラッシック西部劇映画を提供した。", # 'クラッシック' → 'クラシック'
    "民主党の予備選挙ではヒラリー・クリントンが1位となり、バラク・オバマが僅差の2位だた。",  # 'だた' → 'だった'
    "クラゲには4つ放射管を持つが触手はもたない。",  # '' -> 'の'
    "桶狭間の戦いで国友制の鉄砲が織田信長により初めて戦力として使用された。",  # '制' → '製'
    "ひきこもりの更正をはじめたきっかけ",  # '正' → '生'


In [67]:
masked_text = "ひきこもりの更 <extra_id_0>をはじめたきっかけ"
original_token = '正'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,2.0,0,かし,3.996111,119628
1,,0,1.0,1,年,3.845809,848
83,正,1,0.0,83,正,-2.498208,6929
1120,正,1,1.0,1120,正月,-6.447443,177694
2388,正,1,2.0,2388,無修正,-7.852767,209103


In [70]:
# ↑で生が出ないと思って 「ひきこもりの更生」→「ひきこもりからの更生」にしたらちゃんと予測できた。
# しばらく気づかなかった僕より言語能力が優れている。。。
masked_text = "ひきこもりからの更 <extra_id_0>をはじめたきっかけ"
original_token = '正'

candidates = spell_corrector.calc_next_token_candidates(masked_text, original_token)
df = pd.DataFrame(candidates)

hopeful_candidates = []
for match_len, distance in product(range(len(original_token) + 1), range(3)):
    hopeful_candidates.append(df.query(f'match_len == {match_len} and distance == {distance}').head(1))

pd.concat(hopeful_candidates).sort_values('logit', ascending=False)

Unnamed: 0,match_ref,match_len,distance,rank,token,logit,index
0,,0,1.0,0,生,3.804604,3355
2,,0,2.0,2,かし,1.831364,119628
337,正,1,0.0,337,正,-5.613061,6929
803,正,1,1.0,803,,-7.177146,1
2721,正,1,2.0,2721,無修正,-9.507661,209103
