https://androidkt.com/create-dataloader-with-collate_fn-for-variable-length-input-in-pytorch/

실습코드

In [19]:
import os
import sys
import pandas as pd
import numpy as np 
import torch
import random
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.nn.utils.rnn import pad_sequence

In [11]:
reviews=['No man is an island','Entire of itself',
'Every man is a piece of the continent','part of the main',
'If a clod be washed away by the sea','Europe is the less',
'As well as if a promontory were','As well as if a manor of thy friend',
'Or of thine own were','Any man’s death diminishes me',
'Because I am involved in mankind',
'And therefore never send to know for whom the bell tolls',
'It tolls for thee']


labels=[random.randint(0, 1) for i in range(13)]

dataset=list(zip(reviews,labels))

tokenizer = get_tokenizer('basic_english')  

def yield_tokens(data_iter):
  for text,label in data_iter:
    yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(iter(dataset)), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

text_pipeline = lambda x: vocab(tokenizer(x))

In [13]:
dataset

[('No man is an island', 1),
 ('Entire of itself', 0),
 ('Every man is a piece of the continent', 0),
 ('part of the main', 0),
 ('If a clod be washed away by the sea', 1),
 ('Europe is the less', 1),
 ('As well as if a promontory were', 1),
 ('As well as if a manor of thy friend', 0),
 ('Or of thine own were', 0),
 ('Any man’s death diminishes me', 1),
 ('Because I am involved in mankind', 1),
 ('And therefore never send to know for whom the bell tolls', 1),
 ('It tolls for thee', 0)]

In [20]:
def collate_batch(batch):
  
  label_list, text_list, = [], []
  
  for (_text,_label) in batch:
    label_list.append(_label)
    processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    text_list.append(processed_text)
  
  label_list = torch.tensor(label_list, dtype=torch.int64)
  
  text_list = pad_sequence(text_list, batch_first=True, padding_value=0)
  
  return text_list.to(device),label_list.to(device),


In [25]:
collate_batch(dataset)[0]

tensor([[43,  8,  6, 13, 32,  0,  0,  0,  0,  0,  0],
        [25,  1, 34,  0,  0,  0,  0,  0,  0,  0,  0],
        [27,  8,  6,  3, 47,  1,  2, 22,  0,  0,  0],
        [46,  1,  2, 37,  0,  0,  0,  0,  0,  0,  0],
        [ 5,  3, 21, 17, 56, 16, 20,  2, 49,  0,  0],
        [26,  6,  2, 36,  0,  0,  0,  0,  0,  0,  0],
        [ 4, 10,  4,  5,  3, 48, 11,  0,  0,  0,  0],
        [ 4, 10,  4,  5,  3, 39,  1, 54, 28,  0,  0],
        [44,  1, 53, 45, 11,  0,  0,  0,  0,  0,  0],
        [15, 40, 23, 24, 41,  0,  0,  0,  0,  0,  0],
        [18, 29, 12, 31, 30, 38,  0,  0,  0,  0,  0],
        [14, 52, 42, 50, 55, 35,  7, 57,  2, 19,  9],
        [33,  9,  7, 51,  0,  0,  0,  0,  0,  0,  0]])

In [21]:
dataloader = DataLoader(
    dataset, 
    batch_size=2, 
    collate_fn=collate_batch,
    shuffle=True)

for x,y in dataloader:
  print(x,"Targets",y,"\n")

tensor([[44,  1, 53, 45, 11,  0,  0,  0],
        [27,  8,  6,  3, 47,  1,  2, 22]]) Targets tensor([0, 0]) 

tensor([[ 5,  3, 21, 17, 56, 16, 20,  2, 49],
        [15, 40, 23, 24, 41,  0,  0,  0,  0]]) Targets tensor([1, 1]) 

tensor([[46,  1,  2, 37,  0,  0,  0],
        [ 4, 10,  4,  5,  3, 48, 11]]) Targets tensor([0, 1]) 

tensor([[18, 29, 12, 31, 30, 38,  0,  0,  0,  0,  0],
        [14, 52, 42, 50, 55, 35,  7, 57,  2, 19,  9]]) Targets tensor([1, 1]) 

tensor([[33,  9,  7, 51],
        [26,  6,  2, 36]]) Targets tensor([0, 1]) 

tensor([[ 4, 10,  4,  5,  3, 39,  1, 54, 28],
        [25,  1, 34,  0,  0,  0,  0,  0,  0]]) Targets tensor([0, 0]) 

tensor([[43,  8,  6, 13, 32]]) Targets tensor([1]) 

