### 作業目的: 熟練自定義collate_fn與sampler進行資料讀取

本此作業主要會使用[IMDB](http://ai.stanford.edu/~amaas/data/sentiment/)資料集利用Pytorch的Dataset與DataLoader進行
客製化資料讀取。
下載後的資料有分成train與test，因為這份作業目的在讀取資料，所以我們取用train部分來進行練習。
(請同學先行至IMDB下載資料)

### 載入套件

In [3]:
# Import torch and other required modules
import glob
import torch
import re
import os 
import nltk

import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_sequence
from sklearn.datasets import load_svmlight_file
from nltk.corpus import stopwords


nltk.download('stopwords') #下載stopwords
nltk.download('punkt') #下載word_tokenize需要的corpus

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/xuyouming/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /Users/xuyouming/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

### 探索資料與資料前處理
這份作業我們使用test資料中的pos與neg


In [52]:
# 讀取字典，這份字典為review內所有出現的字詞
###<your code>###
with open(os.path.join('E:\\jupyter\\dataset\\aclimdb\\','imdb.vocab'), encoding = 'utf-8') as e:
    vocab = [word.strip() for word in e.read().split('\n')]

# # 以nltk stopwords移除贅字，過多的贅字無法提供有用的訊息，也可能影響模型的訓練
print(f"vocab length before removing stopwords: {len(vocab)}")
vocab = [i for i in vocab if i not in stopwords.words('english')]
print(f"vocab length after removing stopwords: {len(vocab)}")
# # 將字典轉換成dictionary
vocab = dict( [i[1],i[0]]for i in enumerate(vocab))


vocab length before removing stopwords: 89527
vocab length after removing stopwords: 89356


In [53]:
# 將資料打包成(x, y)配對，其中x為review的檔案路徑，y為正評(1)或負評(0)
# 這裡將x以檔案路徑代表的原因是讓同學練習不一次將資料全讀取進來，若電腦記憶體夠大(所有資料檔案沒有很大)
# 可以將資料全一次讀取，可以減少在訓練時I/O時間，增加訓練速度
path = os.path.join('E:\\jupyter\\dataset\\aclimdb\\train\\')
pos_dir_list = [(path + 'pos\\' + i , 0) for i in os.listdir(path + '/pos')]
neg_dir_list = [(path + 'neg\\' + i , 1) for i in os.listdir(path + '/neg')]
review_pairs = pos_dir_list + neg_dir_list

print(review_pairs[:2])
print(f"Total reviews: {len(review_pairs)}")

[('E:\\jupyter\\dataset\\aclimdb\\train\\pos\\0_9.txt', 0), ('E:\\jupyter\\dataset\\aclimdb\\train\\pos\\10000_8.txt', 0)]
Total reviews: 25000


### 建立Dataset, DataLoader, Sampler與Collate_fn讀取資料
這裡我們會需要兩個helper functions，其中一個是讀取資料與清洗資料的函式(load_review)，另外一個是生成詞向量函式
(generate_vec)，注意這裡我們用來產生詞向量的方法是單純將文字tokenize(為了使產生的文本長度不同，而不使用BoW)

In [59]:
def load_review(review_path):
    with open(review_path, encoding = 'utf-8') as e:
        review = e.read()
    review = re.sub(r'\W', ' ', review)
    review = nltk.word_tokenize(review)
    return review

def generate_vec(review, vocab_dic):
    vec = [vocab_dic[word] for word in review if vocab_dic.get(word)]
    return vec

In [60]:
#建立客製化dataset
class dataset(Dataset):
    '''custom dataset to load reviews and labels
    Parameters
    ----------
    data_pairs: list
        directory of all review-label pairs
    vocab: list
        list of vocabularies
    '''
    def __init__(self, data_pairs ,vocab):
        self.data_pairs = data_pairs
        self.vocab = vocab
        
    def __len__(self):
        return len(self.data_pairs) 
    
    def __getitem__(self, idx):
        x_path ,y = self.data_pairs[idx]
        x_word = load_review(x_path)
        x_vec = generate_vec(x_word, self.vocab)
        return x_vec, y        

#建立客製化collate_fn，將長度不一的文本pad 0 變成相同長度
def collate_fn(batch):
    reviews, labels = zip(*batch)
    lengths = torch.LongTensor([len(i) for i in reviews])
    labels = torch.LongTensor(labels)
    reviews = pad_sequence([
        torch.LongTensor(review) for review in reviews
    ], batch_first=True, padding_value=0)
    return reviews, lengths

In [70]:
# 使用Pytorch的RandomSampler來進行indice讀取並建立dataloader
reviews_dataset = dataset(review_pairs, vocab)
custom_dataloader = DataLoader(reviews_dataset,
                                      batch_size = 3,
                                      sampler = RandomSampler(review_pairs),
                                      collate_fn = collate_fn
)

next(iter(custom_dataloader))

(tensor([[ 1241,   128,    10, 20044, 17112,  3705,   595, 16901,  3533,  4141,
           6090,   362,    61,   789,  8187,   159,   378,    18,   322,   281,
             25,    93,  1568, 11665,   159,  3189,  1007,    79,   211,  8754,
          75154, 75154,    17,  2428,  3856,   272,   160,    31,    31,   180,
           1808,  1557,    20,    49,    49,    28,  1518, 27011,   394,  3323,
          75154, 75154,    46,    26,   437,    61,    21,   101,   930,  1350,
          27377, 27377,    20,   125,    49,    22,  2218,  8731,   916,  1346,
           3522,   472,    45,  2524,  1894,  3448,  5182,    12,   497,    28,
          16284,  1856,  1094,    53,   224,   104,   981,  8837,    31, 75154,
          75154,   240,   112,    98,    28,    24,  6255,  3239,   261,   128,
            114, 15540,  1177,  1403,  8383,   512,   136,   102,    24,  2283,
            491,   443,   263,  3187,   747,    86,    14,  1810,   215,    79,
           8844,   128, 75154, 75154,   