### view로 다른 size 값을 명시하지 않고 -1만 실행했을 경우 np.squeeze()와 같은 결과를 보인다.

In [1]:
import torch
import numpy as np

In [2]:
a = np.array([1, 2, 3, 4, 5], dtype=np.float32)
a

array([1., 2., 3., 4., 5.], dtype=float32)

In [3]:
a = np.expand_dims(a, axis=0)
print(a.shape)

(1, 5)


In [4]:
at = torch.Tensor(a)
at

tensor([[1., 2., 3., 4., 5.]])

In [5]:
at.view(-1)

tensor([1., 2., 3., 4., 5.])

## torchtext

### raw dataset iterator

In [6]:
import torch
from torchtext.datasets import AG_NEWS

train_iter = iter(AG_NEWS(split='train')) ## iterator
sample = next(train_iter)
print(sample) ## (label, sentence)

(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")


### data process pipeline
가공되지 않은 텍스트 문자열(raw dataset)에 대한 기본적인 데이터 처리 빌딩 블록(data processing building block)

tokenization(토큰화). 주어진 문장에서 token이라는 단위로 나누는 작업을 말하며, 토큰의 단위는 보통 의미있는 단어(문법적으로 더 이상 나눌 수 없는 언어요소)를 기준으로 토큰을 만든다.

In [7]:
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer("basic_english")
sample_tokens = tokenizer("You can now install TorchText using pip!")
print(sample_tokens)

['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']


In [10]:
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"]) ## 첫번째 단계는 raw dataset으로부터 말뭉치(Vocab 또는 Corpus)를 만든다.
vocab.set_default_index(vocab["<unk>"]) ## 0 : <unk> unknown
total_vocabs = vocab.get_itos()
print(len(total_vocabs))
print(total_vocabs[:10])

print(vocab(sample_tokens))
print(total_vocabs.index("you"))

95811
['<unk>', '.', 'the', ',', 'to', 'a', 'of', 'in', 'and', 's']
[165, 122, 185, 5015, 0, 590, 0, 764]
165


In [19]:
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    print(len(batch))
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         print(f"original : {_label}, {_text}")
         label = label_pipeline(_label)
         label_list.append(label)

         text = text_pipeline(_text)
         processed_text = torch.tensor(text, dtype=torch.int64)
         text_list.append(processed_text)
         print(f"pre-processed : {label}")
         print(processed_text.shape)
         print(processed_text)

         offsets.append(processed_text.size(0))
    
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)

    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=4, shuffle=False, collate_fn=collate_batch)

for data in dataloader:
    print("=====Dataloader======")
    print(f"Labels : {data[0]}")
    print(f"Texts : {data[1]}")
    print(f"Offsets : {data[2]}")
    break

4
original : 3, Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.
pre-processed : 2
torch.Size([29])
tensor([  431,   425,     1,  1605, 14838,   113,    66,     2,   848,    13,
           27,    14,    27,    15, 50725,     3,   431,   374,    16,     9,
        67507,     6, 52258,     3,    42,  4009,   783,   325,     1])
original : 3, Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.
pre-processed : 2
torch.Size([42])
tensor([15874,  1072,   854,  1310,  4250,    13,    27,    14,    27,    15,
          929,   797,   320, 15874,    98,     3, 27657,    28,     5,  4459,
           11,   564, 52790,     8, 80617,  2125,     7,     2,   525,   241,
            3,    28,  389