In [144]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from pykakasi import kakasi
from typing import Tuple, List

In [9]:
df = pd.read_csv('data/train_manual.csv', encoding="SHIFT-JIS")
df.head()

Unnamed: 0,short,original
0,アイコ,アイスコーヒー
1,アイシン,アイシン精機
2,赤プリ,グランドプリンスホテル赤坂
3,アキバ,秋葉原
4,アクエリ,アクエリアス


In [50]:
df['short_hepburn'] = df['short'].apply(
    lambda x: ''.join([item['hepburn'].replace('#', '-') for item in transform.convert(x.replace('ー', '#'))])
)

df['original_hepburn'] = df['original'].apply(
    lambda x: ''.join([item['hepburn'].replace('#', '-') for item in transform.convert(x.replace('ー', '#'))])
)

In [51]:
df

Unnamed: 0,short,original,short_hepburn,original_hepburn
0,アイコ,アイスコーヒー,aiko,aisuko-hi-
1,アイシン,アイシン精機,aishin,aishinseiki
2,赤プリ,グランドプリンスホテル赤坂,akapuri,gurandopurinsuhoteruakasaka
3,アキバ,秋葉原,akiba,akihabara
4,アクエリ,アクエリアス,akueri,akueriasu
...,...,...,...,...
392,ワラキン,笑いの金メダル,warakin,warainokinmedaru
393,ワンナイ,ワンナイト,wannai,wannaito
394,ワーネバ,ワールドネバーランド,wa-neba,wa-rudoneba-rando
395,ワンピ,ワンピース,wanpi,wanpi-su


In [55]:
list(df['original_hepburn'].iloc[1])

['a', 'i', 's', 'h', 'i', 'n', 's', 'e', 'i', 'k', 'i']

In [61]:
alpha2num = {chr(c): i+1 for i, c in enumerate(range(ord('a'),ord('z')+1))}
num2alpha = {i+1: chr(c) for i, c in enumerate(range(ord('a'),ord('z')+1))}

In [62]:
num2alpha

{1: 'a',
 2: 'b',
 3: 'c',
 4: 'd',
 5: 'e',
 6: 'f',
 7: 'g',
 8: 'h',
 9: 'i',
 10: 'j',
 11: 'k',
 12: 'l',
 13: 'm',
 14: 'n',
 15: 'o',
 16: 'p',
 17: 'q',
 18: 'r',
 19: 's',
 20: 't',
 21: 'u',
 22: 'v',
 23: 'w',
 24: 'x',
 25: 'y',
 26: 'z'}

In [60]:
alpha_dict

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

In [63]:
originals = df['short'].tolist()

In [173]:
class ShortenDataset(Dataset):
    def __init__(self, 
                 originals: list, 
                 shorts: list,
                 max_len: int
                ):
        self.pad_id = 0
        self.macron_id = 27
        self.start_id = 28
        self.end_id = 29
        self.max_len = max_len
        transform = kakasi()
        
        self.originals = list(map(self.padding ,
                                  [list(map(self.transform, 
                                            list(''.join([item['hepburn'].replace('#', '-') 
                                                          for item in transform.convert(text.replace('ー', '#'))]))
                                           ))
                               for text in originals]))
        
        self.shorts = list(map(self.shift_right, 
                               [list(map(self.transform, 
                                         list(''.join([item['hepburn'].replace('#', '-') 
                                                       for item in transform.convert(text.replace('ー', '#'))]))
                                        )) 
                                for text in shorts]))
        
    def __len__(self):
        return len(self.originals)
    
    def __getitem__(self, 
                    index: int
                   ) -> Tuple[torch.Tensor]:
        return torch.tensor(self.originals[index]), torch.tensor(self.shorts[index])
    
    def transform(self, alphabet: List[str]) -> List[str]:
        if alphabet.isalpha():
            return ord(alphabet) - ord('a') + 1
        elif alphabet == '-':
            return self.macron_id
        
    def padding(self, alphabet_list: List[str]) -> List[str]:
        return alphabet_list + [self.pad_id for _ in range(self.max_len-len(alphabet_list))]
    
    def shift_right(self, alphabet_list: List[str]) -> List[str]:
        return [self.start_id] + alphabet_list + [self.end_id] + [self.pad_id for _ in range(self.max_len-len(alphabet_list))]

In [174]:
dataset = ShortenDataset(df['original'].tolist(), df['short'].tolist(), 20)
dataset[0]

(tensor([ 1,  9, 19, 21, 11, 15, 27,  8,  9, 27,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]),
 tensor([28,  1,  9, 11, 15, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0]))

In [175]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [177]:
for x, t in dataloader:
    print(x)
    print(t[:,:-1])
    print(t[:,1:])
    break

tensor([[ 4, 15, 20,  1, 14,  2,  1, 11, 25,  1, 14, 19,  5, 18, 21,  0,  0,  0,
          0,  0],
        [ 1, 11,  1, 26, 21, 11,  9, 14,  3,  8,  1,  3,  8,  1,  0,  0,  0,  0,
          0,  0],
        [ 2,  1, 14,  2, 21, 27,  2, 21, 18,  5, 27,  4, 15,  0,  0,  0,  0,  0,
          0,  0],
        [ 1, 18,  9, 16, 21, 18, 15, 10,  9,  5, 11, 21, 20, 15,  0,  0,  0,  0,
          0,  0]])
tensor([[28,  4, 15, 20,  1, 11, 25,  1, 14, 29,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0],
        [28,  3,  8,  1,  3,  8,  1, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0],
        [28,  2,  1, 14,  2, 21, 18,  5, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0],
        [28,  1, 18,  9, 16, 21, 18, 15, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0]])
tensor([[ 4, 15, 20,  1, 11, 25,  1, 14, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0],
        [ 3,  8,  1,  3,  8,  1, 29,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
      