In [72]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
import pandas as pd
import numpy as np

In [14]:
import sys
sys.path.append("..")
import training.utils as utils
df = utils.read_json("../data/example.json")
df = utils.preprocess(df)

In [248]:
class AITA_Dataset(Dataset):
    def __init__(self, df, post_vocab, comm_vocab):
        """
        :param df: a pandas dataframe
        :param vocabulary: a list or dictionary strings
        """
        self.list_ids = df['post_body']
        self.list_labels = df['comment_body']
        self.post_vocab = post_vocab
        self.comm_vocab = comm_vocab

    def __len__(self):
        return len(self.list_labels)
    
    def __getitem__(self, idx):
        """
        Return a tuple of lists of strings
        """
        X = self.list_ids[idx]
        y = self.list_labels[idx]
        # print(X)
        return self.apply_special_tokens(X, self.post_vocab), self.apply_special_tokens(y, self.comm_vocab)

    def apply_special_tokens(self, sentence, vocab):
        """
        Add begin and end tokens. Also replace words not in 
        vocabulary with unknown token
        :param sentence: a list of strings
        """
        for i, word in enumerate(sentence):
            if word not in vocab:
                sentence[i] = utils.UNKNOWN_TOKEN
        
        return [utils.START_TOKEN] + sentence + [utils.END_TOKEN]

"""
Collator function to be called with dataloader
"""
def collator(batch):

    ids = (x[0] for x in batch)
    labels = (x[1] for x in batch)
    
    id_lengths = [len(sentence) for sentence in ids]
    print(id_lengths)
    
    longest = max(id_lengths)
    batch_size = len(id_lengths)
    padded_ids = np.ones((batch_size, 0)).tolist()
    # copy over the actual sequences
    for i, length in enumerate(id_lengths):
        sequence = batch[i][0]
        #padded_ids[i, 0:length] = sequence[:length] 
        pad = []
        for n in range(length, longest):
            sequence.append(utils.PAD_TOKEN)

        padded_ids[i].append(sequence)
        padded_ids[i] = padded_ids[i][0]
            
        
    label_lengths = [len(sentence) for sentence in labels]
    print(label_lengths)
    
    longest = max(label_lengths)
    batch_size = len(label_lengths)
    padded_labels = np.ones((batch_size, 0)).tolist()
    # copy over the actual sequences
                            
    for i, length in enumerate(label_lengths):
        sequence = batch[i][1]
        pad = []
        for n in range(length, longest):
            sequence.append(utils.PAD_TOKEN)
        
        padded_labels[i].append(sequence)
        padded_labels[i] = padded_labels[i][0]
            
    
    return (padded_ids, padded_labels)
    
"""
Creates a Dataset using post_body and comment_body columns of dataframe
"""
def get_dataloader(df, post_vocab, comm_vocab, batch_size=5):
    ds = AITA_Dataset(df, post_vocab, comm_vocab)
    
    loader = DataLoader(ds, batch_size=batch_size, shuffle=True, collate_fn=collator)
    
    return loader

def sample_dl(dl):
    features, labels = next(iter(dl))
    
    batch = pd.DataFrame([features, labels], ['post_body', 'comment_body'])
    batch = batch.drop(0, 1)
    batch = batch.transpose()
    
    return batch

In [249]:
post_word_to_idx, post_embeddings = utils.get_embeddings("../data/post_embeddings_W2V_30_1000")
comment_word_to_idx, comment_embeddings = utils.get_embeddings("../data/comment_embeddings_W2V_30_1000")

dataset = AITA_Dataset(df, post_word_to_idx, comment_word_to_idx)

dl = get_dataloader(df, post_word_to_idx, comment_word_to_idx, 5)

In [250]:
for i, x in post_word_to_idx.items():
    if (x == 2170):
        print('hi')

hi


In [251]:
post = sample_dl(dl)

[446, 185, 440, 708, 403]
[85, 35, 185, 53, 142]


  batch = batch.drop(0, 1)


In [252]:
post

Unnamed: 0,post_body,comment_body
1,"[<BEG>, aita, <UNK>, <UNK>, want, <UNK>, cook,...","[<BEG>, nta, ., he, is, ., my, wife, does, mos..."
2,"[<BEG>, wibta, <UNK>, told, <UNK>, fiancé, <UN...","[<BEG>, nta, but, i, ’, m, pretty, sure, your,..."
3,"[<BEG>, aita, <UNK>, punch, <UNK>, <UNK>, pedi...","[<BEG>, nta, ., having, a, <UNK>, ,, or, havin..."
4,"[<BEG>, wibta, <UNK>, ask, <UNK>, girlfriend, ...","[<BEG>, yta, it, ’, s, fine, to, <UNK>, your, ..."


In [258]:
len(post.iloc[2][1])

185

In [11]:
dataload.AITA_Dataset(df)

<dataload.AITA_Dataset at 0x1cb83c9d400>

In [4]:
for i in dataloader:
    print(i)

RuntimeError: each element in list of batch should be of equal size