In [60]:
import os
import spacy
import pandas as pd

import torch
import torch.nn as nn
from torchtext import data

spacy_en = spacy.load('en')

In [2]:
DATAPATH = '/home/ygx/dev/kaggle/scene/data/splits/csv'

In [6]:
trainpath = os.path.join(DATAPATH, 'train.csv')
train = pd.read_csv(trainpath)

In [7]:
train[0:5]

Unnamed: 0,id,text,genre,labels
0,7688,waves the Boy down into the cellar. The Boy d...,drama,3
1,10801,"rge the beefy jerks, who bolt for a glass part...",comedy,2
2,5212,se You must now be warned! the danger to thyse...,thriller,8
3,20313,", but it looks like Frank is crying. JULIA Fra...",drama,3
4,18121,EX is pacing back and forth. LENNY struts back...,action,0


In [3]:
def tokenizer(text): # create a tokenizer function
    return [tok.text for tok in spacy_en.tokenizer(text)]

In [4]:
TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True)
LABEL = data.Field(sequential=False, use_vocab=False)

In [8]:
train, val, test = data.TabularDataset.splits(
    path=DATAPATH, 
    train='train.csv',
    validation='val.csv', 
    test='test.csv', 
    format='csv',
    fields=[
        ('id', None),
        ('text', TEXT),
        ('genre', None),
        ('labels', LABEL)
    ]
)

In [9]:
train.fields

{'id': None,
 'text': <torchtext.data.field.Field at 0x7f9537fb4cf8>,
 'genre': None,
 'labels': <torchtext.data.field.Field at 0x7f9537fb47f0>}

In [10]:
TEXT.build_vocab(train, vectors='glove.6B.100d')

In [42]:
device = 'cpu'

train_iter, val_iter, test_iter = data.BucketIterator.splits(
    (train, val, test), 
    sort_key=lambda x: len(x.Text),
    batch_sizes=(32, 128, 128), 
    device=device
)

In [43]:
train_iter

<torchtext.data.iterator.BucketIterator at 0x7f95185ab6a0>

In [44]:
vocab = TEXT.vocab
embed = nn.Embedding(len(vocab), 200)

In [45]:
embed

Embedding(54294, 200)

In [48]:
batch = next(train_iter.__iter__()); batch


[torchtext.data.batch.Batch of size 32]
	[.text]:[torch.LongTensor of size 248x32]
	[.labels]:[torch.LongTensor of size 32]

In [49]:
batch.__dict__.keys()

dict_keys(['batch_size', 'dataset', 'fields', 'input_fields', 'target_fields', 'text', 'labels'])

In [47]:
TEXT.vocab.freqs.most_common(5)

[('.', 227756), ('the', 137092), (',', 118420), ('a', 59798), ('to', 55231)]

In [78]:
class BatchWrapper:
    """Convenience wrapper for dataloaders."""
    def __init__(self, dataloader, data="text", label="labels"):
        self.dataloader = dataloader
        self.data = data
        self.label = label
    
    def __iter__(self):
        for batch in self.dataloader:
            x = getattr(batch, self.data)
            
            if self.label is not None:
                y = getattr(batch, self.label)
            else:
                y = torch.zeros((1))

            yield (x, y)
    
    def __len__(self):
        return len(self.dataloader)

In [79]:
trainloader = BatchWrapper(train_iter)

In [80]:
next(trainloader.__iter__())

(tensor([[16104,  2407,   688,  ...,   434, 47015,  6255],
         [    8,   647,    59,  ...,     4,     2,   900],
         [ 4160,     6,    22,  ...,    77,   128,    36],
         ...,
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1]]),
 tensor([5, 3, 8, 8, 3, 2, 3, 2, 0, 8, 8, 0, 3, 3, 3, 8, 3, 3, 5, 2, 3, 3, 3, 3,
         8, 8, 3, 0, 0, 3, 2, 2]))

In [81]:
len(trainloader)

424