In [1]:
import numpy as np
import pandas as pd
import jieba
from collections import defaultdict
import torch
from operator import add
from functools import reduce
from collections import Counter
from torch.utils.data import DataLoader
from icecream import ic

In [2]:
from self_embedding import get_embedding

load embedding file: dataset/sgns.weibo.word.bz2 end!
函数注释 {'ham': 42, 'eggs': <class 'int'>, 'return': 'Nothing to see here'}
参数值打印 www spam
<class 'str'> <class 'str'>
<enumerate object at 0x000001702E2D18C0>
5 Spring
6 Summer
7 Fall
8 Winter


In [3]:
def add_with_print(all_corpus):
    add_with_print.i = 0
    def _wrap(a, b):
        add_with_print.i += 1  ##这里不能用++
        print('{}/{}'.format(add_with_print.i, len(all_corpus)),end=' ')
        return a+b
    return _wrap

In [4]:
def get_all_vocabulary(train_file_path, vocab_size):
    CUT, SENTENCE = 'cut', 'sentence'
    
    corpus = pd.read_csv(train_file_path)
    corpus[CUT] = corpus[SENTENCE].apply(lambda s:' '.join(list(jieba.cut(s))))
    sentence_counters = map(Counter, map(lambda s:s.split(),corpus[CUT].values))
    chose_words = reduce(add_with_print(corpus), sentence_counters).most_common(vocab_size)
    
    return [w for w,_ in chose_words]

## 字典的get方法
### 语法
get()方法语法：

dict.get(key, value) 
### 参数
key -- 字典中要查找的键。

value -- 可选，如果指定键的值不存在时，返回该默认值。
### 返回值
返回指定键的值，如果键不在字典中返回默认值 None 或者设置的默认值。

In [5]:
def tokenizer(sentence, vocab: dict):
    UNK = 1
    ids = [vocab.get(word, UNK) for word in jieba.cut(sentence)]
    
    return ids

字典(Dictionary) items() 函数以列表返回可遍历的(键, 值) 元组数组。

In [6]:
def get_train_data(train_file, vocab2ids):
    val_ratio = 0.2
    content = pd.read_csv(train_file)
    num_val = int(len(content) * val_ratio)
    
    LABEL, SENTENCE = 'label', 'sentence'
    
    labels = content[LABEL].values
    content['input_ids'] = content[SENTENCE].apply(lambda s:' '.join([str(_id) for _id in tokenizer(s, vocab2ids)]))
    sentence_ids = np.array([[int(_id) for _id in v.split()] for v in content['input_ids'].values])
    
    ids = np.random.choice(range(len(content)), size = len(content))
    ## shuffle ids
    
    train_ids = ids[num_val:]
    val_ids = ids[:num_val]
    
    X_train,y_train = sentence_ids[train_ids], labels[train_ids]
    X_val, y_val = sentence_ids[val_ids], labels[val_ids]
    
    label2id = {label: i for i, label in enumerate(np.unique(y_train))}
    id2label = {i : label for label, i in label2id.items()}
    y_train = torch.tensor([label2id[y] for  y in y_train], dtype = torch.long)
    y_val = torch.tensor([label2id[y] for y in y_val], dtype = torch.long)
    
    return X_train, y_train, X_val, y_val, label2id, id2label

#  DataLoader
做成的数据集放入Data.DataLoader中，可以生成一个迭代器，从而我们可以方便的进行批处理。

## 介绍一下DataLoader(object)的参数：
    
dataset：Dataset类型，从其中加载数据 

batch_size：int，可选。每个batch加载多少样本 

shuffle：bool，可选。为True时表示在每个epoch对数据进行打乱 

sampler：Sampler，可选。从数据集中采样样本的方法。 

num_workers：int，可选。加载数据时使用多少子进程。默认值为0，表示在主进程中加载数据。 

collate_fn：callable，可选。 

pin_memory：bool，可选 

drop_last：bool，可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃

## zip() 函数
用于将可迭代的对象作为参数，将对象中对应的元素打包成一个个元组，然后返回由这些元组组成的列表。

如果各个迭代器的元素个数不一致，则返回列表长度与最短的对象相同，利用 * 号操作符，可以将元组解压为列表。

### zip举例

In [7]:
for x, y in zip([1,2,3], [11,22,33]):
    print(x,y)

1 11
2 22
3 33


## 特别注意：
1.这里额外定义了一个collate_fn 并用在了build_dataloader中
如果没有用collate_fn限制会出现

RuntimeError: each element in list of batch should be of equal size

没有用collate_fn的解决方案： 在if __name__ == '__main__':中设置batch_size=1

2.在运行if __name__ == '__main__':中出现了

Caught TypeError in DataLoader worker process 0.

解决方案：把build_dataloader 中num_workers = 0 （原本为4）

num_workers = 4的方案暂时没有解决

In [8]:
def collate_fn(batch):
    #  batch是一个列表，其中是一个一个的元组，每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[1], dtype=torch.int32)
    texts = batch[1]
    del batch  
    return labels, texts

In [9]:
def build_dataloader(X_train, y_train, X_val, y_val, batch_size):
    train_dataloader = DataLoader([(x,y) for x, y in zip(X_train, y_train)], batch_size = batch_size, num_workers = 0, shuffle = True, collate_fn= collate_fn)
    val_dataloader = DataLoader([(x,y) for x, y in zip(X_val, y_val)], batch_size = batch_size, num_workers = 0, shuffle = True,collate_fn = collate_fn)
    
    return train_dataloader, val_dataloader

### isinstance() 函数来判断一个对象是否是一个已知的类型

### open (xxx,f)以只读方式打开文件。文件的指针将会放在文件的开头。这是默认模式。

## if __name__ == '__main__':的作用
一个python文件通常有两种使用方法，第一是作为脚本直接执行，第二是 import 到其他的 python 脚本中被调用（模块重用）执行。因此 if __name__ == 'main': 的作用就是控制这两种情况执行代码的过程，在 if __name__ == 'main': 下的代码只有在第一种情况下（即文件作为脚本直接执行）才会被执行，而 import 到其他脚本中是不会被执行的。

具体情况参照：https://blog.csdn.net/heqiang525/article/details/89879056

In [10]:
if __name__ == '__main__':
    #vocab_size = 10000
    #vocabulary = get_all_vocabulary(train_file_path='dataset/train.csv', vocab_size=vocab_size)
    #assert isinstance(vocabulary, list)
    #assert isinstance(vocabulary[0], str)
    #assert len(vocabulary) <= vocab_size
    
    ## with open('dataset/vocabulary.txt','w', encoding='utf-8') as f:
        ## for word in vocabulary:
            ## f.write(word + '\n')
    ## print('write vocabulary is finished!')
    ## 生成vocabulary.txt  生成完后可以不用执行
    f = open('dataset/vocabulary.txt', 'r', encoding='utf-8')
    vocabulary = f.readlines()
    vocabulary = [v.strip() for v in vocabulary]
    
    embedding, token2id, vocab_size = get_embedding(set(vocabulary))
    
    X_train, y_train, X_val, y_val, label2id, id2label = get_train_data('dataset/train.csv', vocab2ids = token2id)
    
    print(X_train, y_train, X_val, y_val, label2id, id2label)
    
    train_loader, val_loader = build_dataloader(X_train, y_train, X_val, y_val, batch_size=128)
    
    for i, (x, y) in enumerate(train_loader):
        ic(x)
        ic(y)
        if i > 10: break

100%|███████████████████████████████████████████████████████████████████████| 195202/195202 [00:02<00:00, 69641.16it/s]
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\90392\AppData\Local\Temp\jieba.cache
Loading model cost 0.540 seconds.
Prefix dict has been built successfully.
  sentence_ids = np.array([[int(_id) for _id in v.split()] for v in content['input_ids'].values])


[list([1, 24, 1, 61, 4, 439, 1439, 24, 1, 300, 946, 4, 1, 5534, 350, 1, 197, 2232, 315])
 list([877, 620, 936, 1397, 8989, 1, 4, 4121, 10, 21, 232, 1, 7])
 list([6519, 10, 4611, 1, 3216, 4, 6897, 1, 4, 341, 85, 6295, 17]) ...
 list([1353, 4301, 6363, 5028, 191, 762, 4, 6538, 208, 760, 281, 1278, 4, 4877, 6363, 1159, 1424, 8])
 list([256, 743, 7513, 4290, 4, 493, 1, 85, 2855, 298, 14, 719, 4, 7513, 4290, 221, 2855, 298, 14, 115, 1109, 92, 17])
 list([3445, 1258, 4, 1, 2852, 5, 164])] tensor([ 6,  4,  1,  ...,  9, 11, 10]) [list([4970, 890, 5066, 1286, 946, 4, 574, 1, 169, 7456, 7, 1499, 81, 667, 3574, 4087, 2582, 6939, 7])
 list([31, 73, 2718, 5, 4906, 1047, 644, 1])
 list([4556, 141, 8647, 8, 1, 4, 4208, 5817, 1703, 3027, 4, 4027, 31, 341, 4983, 17])
 ...
 list([1, 356, 1, 5, 25, 4, 5148, 90, 8, 4375, 2321, 1, 4, 1, 216, 1, 6727, 7])
 list([107, 13, 1, 91, 3996, 17, 1243, 1, 2008, 1436, 6430, 1609, 31, 7])
 list([1, 3176, 5528, 10, 1, 797, 158, 4633, 154, 1861, 4, 1, 2162, 95, 409, 26,

ic| x: tensor([ 8, 14,  6,  8, 10, 10,  0, 11, 13,  4,  9,  8, 13, 11,  6,  4, 13, 14,
                7, 10,  3,  8, 10,  7,  3,  9,  1,  8, 11,  2,  2,  6, 14,  2, 10,  1,
               10,  8,  4, 13,  1,  6,  7, 14,  3, 14,  6, 13,  2,  8, 11,  6,  6, 14,
                9,  6,  7,  4,  4,  6,  4,  2,  4,  6,  9,  2,  9,  4,  2, 11,  6,  3,
                7, 13, 11,  3,  8,  4,  8,  4,  4,  3,  9,  4, 13,  0, 13,  8,  1,  3,
                7,  1,  8,  4,  2, 11, 12,  5,  2, 14,  1,  6, 11,  1, 14, 11,  8,  6,
                4,  1,  5,  4,  2,  6,  8,  7,  1, 11, 10,  2,  0,  8,  3,  4, 11,  1,
                5,  2], dtype=torch.int32)
ic| y: (tensor(8),
        tensor(14),
        tensor(6),
        tensor(8),
        tensor(10),
        tensor(10),
        tensor(0),
        tensor(11),
        tensor(13),
        tensor(4),
        tensor(9),
        tensor(8),
        tensor(13),
        tensor(11),
        tensor(6),
        tensor(4),
        tensor(13),
        tensor(14

        tensor(10),
        tensor(2),
        tensor(2),
        tensor(14),
        tensor(11),
        tensor(6),
        tensor(4),
        tensor(9),
        tensor(11),
        tensor(5),
        tensor(3),
        tensor(13),
        tensor(6),
        tensor(11),
        tensor(3),
        tensor(3),
        tensor(9),
        tensor(2),
        tensor(1),
        tensor(11),
        tensor(2),
        tensor(1),
        tensor(11),
        tensor(2),
        tensor(14),
        tensor(11),
        tensor(10),
        tensor(14),
        tensor(13),
        tensor(8),
        tensor(13),
        tensor(3),
        tensor(13),
        tensor(6),
        tensor(2),
        tensor(11),
        tensor(11),
        tensor(2),
        tensor(13),
        tensor(8),
        tensor(11),
        tensor(6),
        tensor(11),
        tensor(8),
        tensor(6),
        tensor(3),
        tensor(1),
        tensor(7),
        tensor(8),
        tensor(4),
        tensor(1),
        ten