In [1]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

In [2]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
BATCH_SIZE = 4

# 특수 기호(symbol)와 인덱스를 정의합니다
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
# 토큰들이 어휘집(vocab)에 인덱스 순서대로 잘 삽입되어 있는지 확인합니다
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [3]:
token_transform = {}
text_transform = {}
vocab_transform = {}

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
print(token_transform)

{'de': functools.partial(<function _spacy_tokenize at 0x7fda32757e50>, spacy=<spacy.lang.de.German object at 0x7fda2ebb6790>), 'en': functools.partial(<function _spacy_tokenize at 0x7fda32757e50>, spacy=<spacy.lang.en.English object at 0x7fda22f2bb20>)}


In [4]:
def yield_tokens(data_iter, language: str):
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # torchtext의 Vocab(어휘집) 객체 생성
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# ``UNK_IDX`` 를 기본 인덱스로 설정합니다. 이 인덱스는 토큰을 찾지 못하는 경우에 반환됩니다.
# 만약 기본 인덱스를 설정하지 않으면 어휘집(Vocabulary)에서 토큰을 찾지 못하는 경우
# ``RuntimeError`` 가 발생합니다.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

print(vocab_transform)
english_vocabs = vocab_transform['en']
print(english_vocabs.get_itos())

{'de': Vocab(), 'en': Vocab()}


In [5]:
# 순차적인 작업들을 하나로 묶는 헬퍼 함수
def sequential_transforms(*transforms): ## token_transform[ln], vocab_transform[ln], tensor_transform
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# BOS/EOS를 추가하고 입력 순서(sequence) 인덱스에 대한 텐서를 생성하는 함수
def tensor_transform(token_ids):
    return torch.cat((torch.tensor([SOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], # 토큰화(Tokenization)
                                               vocab_transform[ln], # 수치화(Numericalization)
                                               tensor_transform) # BOS/EOS를 추가하고 텐서를 생성
    
print(text_transform)

{'de': <function sequential_transforms.<locals>.func at 0x7fda22d11b80>, 'en': <function sequential_transforms.<locals>.func at 0x7fda22d11c10>}


In [12]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    # print(len(batch))
    for src_sample, tgt_sample in batch:
        src = text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))
        # print("=====SRC=====")
        # print(src_sample)
        # print(src)

        tgt = text_transform[TGT_LANGUAGE](src_sample.rstrip("\n"))
        # print("=====TGT=====")
        # print(tgt_sample)
        # print(tgt)

        ## Unknown = 0
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n"))) ## text_transform["de"](src_sample.rstrip("\n"))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n"))) ## text_transform["en"](src_sample.rstrip("\n"))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX) ## PAD_IDX = 1
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)

    return src_batch, tgt_batch


train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

sample = next(iter(train_dataloader))

sample_src, sample_tgt = sample[0], sample[1]
print(sample_src.shape)
print(sample_src) ## Column wise : sentence

test = sample_src.view(BATCH_SIZE, -1)
print(test.shape)
print(test)

torch.Size([17, 4])
tensor([[   2,    2,    2,    2],
        [  21,   84,    5,    5],
        [  85,   31,   69,   12],
        [ 257,   10,   27,    7],
        [  31,  847,  219,    6],
        [  87, 2208,    7,   47],
        [  22,   15,   15,   41],
        [  94, 8268, 6769,   30],
        [   7,    4,   55,   11],
        [  16,    3,  508,   13],
        [ 112,    1,    4,  543],
        [7910,    1,    3,    9],
        [3209,    1,    1,  698],
        [   4,    1,    1,   15],
        [   3,    1,    1,  248],
        [   1,    1,    1,    4],
        [   1,    1,    1,    3]])
torch.Size([4, 17])
tensor([[   2,    2,    2,    2,   21,   84,    5,    5,   85,   31,   69,   12,
          257,   10,   27,    7,   31],
        [ 847,  219,    6,   87, 2208,    7,   47,   22,   15,   15,   41,   94,
         8268, 6769,   30,    7,    4],
        [  55,   11,   16,    3,  508,   13,  112,    1,    4,  543, 7910,    1,
            3,    9, 3209,    1,    1],
        [ 698,    

In [7]:
for src, tgt in train_dataloader:
    print(src.shape)

torch.Size([17, 4])
torch.Size([18, 4])
torch.Size([20, 4])
torch.Size([16, 4])
torch.Size([18, 4])
torch.Size([14, 4])
torch.Size([21, 4])
torch.Size([15, 4])
torch.Size([22, 4])
torch.Size([25, 4])
torch.Size([15, 4])
torch.Size([17, 4])
torch.Size([24, 4])
torch.Size([23, 4])
torch.Size([27, 4])
torch.Size([19, 4])
torch.Size([23, 4])
torch.Size([23, 4])
torch.Size([21, 4])
torch.Size([14, 4])
torch.Size([18, 4])
torch.Size([19, 4])
torch.Size([17, 4])
torch.Size([21, 4])
torch.Size([22, 4])
torch.Size([16, 4])
torch.Size([17, 4])
torch.Size([18, 4])
torch.Size([14, 4])
torch.Size([16, 4])
torch.Size([24, 4])
torch.Size([18, 4])
torch.Size([22, 4])
torch.Size([26, 4])
torch.Size([29, 4])
torch.Size([28, 4])
torch.Size([15, 4])
torch.Size([15, 4])
torch.Size([16, 4])
torch.Size([23, 4])
torch.Size([20, 4])
torch.Size([14, 4])
torch.Size([26, 4])
torch.Size([17, 4])
torch.Size([20, 4])
torch.Size([26, 4])
torch.Size([20, 4])
torch.Size([13, 4])
torch.Size([23, 4])
torch.Size([18, 4])




torch.Size([16, 4])
torch.Size([18, 4])
torch.Size([15, 4])
torch.Size([16, 4])
torch.Size([18, 4])
torch.Size([21, 4])
torch.Size([17, 4])
torch.Size([18, 4])
torch.Size([22, 4])
torch.Size([21, 4])
torch.Size([18, 4])
torch.Size([19, 4])
torch.Size([25, 4])
torch.Size([16, 4])
torch.Size([16, 4])
torch.Size([15, 4])
torch.Size([28, 4])
torch.Size([18, 4])
torch.Size([15, 4])
torch.Size([25, 4])
torch.Size([15, 4])
torch.Size([19, 4])
torch.Size([15, 4])
torch.Size([21, 4])
torch.Size([27, 4])
torch.Size([17, 4])
torch.Size([20, 4])
torch.Size([19, 4])
torch.Size([27, 4])
torch.Size([19, 4])
torch.Size([22, 4])
torch.Size([16, 4])
torch.Size([19, 4])
torch.Size([32, 4])
torch.Size([17, 4])
torch.Size([17, 4])
torch.Size([24, 4])
torch.Size([16, 4])
torch.Size([17, 4])
torch.Size([18, 4])
torch.Size([18, 4])
torch.Size([19, 4])
torch.Size([16, 4])
torch.Size([14, 4])
torch.Size([17, 4])
torch.Size([19, 4])
torch.Size([14, 4])
torch.Size([23, 4])
torch.Size([21, 4])
torch.Size([18, 4])
