In [1]:
import torch
import torchtext
from torchtext.datasets import IMDB
from collections import Counter
from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
import random

In [2]:
train_iter, test_iter = IMDB(split=('train', 'test'))
next(iter(train_iter))

(1,
 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far betwee

In [3]:
len(list(train_iter)), len(list(test_iter))

(12500, 12500)

In [4]:
set([label for label, text in test_iter]) 

{1}

In [5]:
tokenizer = get_tokenizer("basic_english")

In [6]:
counter = Counter()
for label, line in train_iter:
    counter.update(tokenizer(line))

In [7]:
imdb_vocab = vocab(counter, min_freq=2, specials=['<pad>', '<sos>', '<eos>', '<unk>'])

In [8]:
text_transform = lambda x: [imdb_vocab['<sos>']] + [imdb_vocab[token] for token in tokenizer(x)] + [imdb_vocab['<eos>']]
label_transform = lambda x: 1 if x == 'pos' else 0

In [9]:
print("input to the text_transform:", "here is an example")
print("output of the text_transform:", text_transform("here is an example"))

input to the text_transform: here is an example
output of the text_transform: [1, 1019, 56, 202, 3736, 2]


In [10]:
def collate_batch(batch):
    label_list, text_list = [], []
    for _label, _text in batch:
        label_list.append(label_transform(_label))
        text_list.append(torch.tensor(text_transform(_text)))
        
    return torch.tensor(label_list), pad_sequence(text_list, padding_value=3.0)

train_dataloader = DataLoader(list(train_iter), batch_size=8, shuffle=True, collate_fn=collate_batch)

In [21]:
text_lst = [torch.tensor(text_transform('sentence 1 more words')), torch.tensor(text_transform('sentence 2 less')),
                                                                    torch.tensor(text_transform('sentence 3'))]
text_lst #1x4~6

[tensor([   1, 4552,  345,  784, 1021,    2]),
 tensor([   1, 4552,  348,  286,    2]),
 tensor([   1, 4552, 1126,    2])]

In [22]:
pad_sequence(text_lst, padding_value=3.0).shape

torch.Size([6, 3])

In [None]:
# text_list = []
# for label, line in train_iter:
#     text_list.append(torch.tensor(text_transform(line)))
# text_list

In [None]:
train_list = list(train_iter)
batch_size = 8

def batch_sampler():
    indices = [(idx, len(tokenizer(s[1]))) for idx, s in enumerate(train_list)]
    random.shuffle(indices)
    
    pooled_indices = []
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i+batch_size * 100], key=lambda x:x[1]))
    
    pooled_indices = [i[0] for i in pooled_indices]
    
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i+batch_size]
        
bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
                          collate_fn=collate_batch)

In [None]:
# next(iter(bucket_dataloader))

In [None]:
batch_size = 8
indices = list(range(1000))
random.shuffle(indices)
pooled_indices = []
for i in range(0, len(indices), batch_size*25):
    pooled_indices.extend(sorted(indices[i:i+batch_size*25]))
    
pooled_indices

In [None]:
def random_sampler():
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i+batch_size]
        
for i in random_sampler():
    print(i)