In [1]:
import pandas as pd
import string 

'''
    单词预处理，将单词全部小写，并且去除标点符号
'''
def preprocessing(phrase):
    lower = [phras.lower() for phras in phrase]    # 将字母全部小写
    no_punct = [text.translate(str.maketrans('','',string.punctuation)) for text in lower]   # 去掉标点符号
    sp = [text.split() for text in no_punct]
    res = [' '.join(lis) for lis in sp]

    return res


def Get_Preprocess_Data(path):
    df = pd.read_csv(path,sep='\t')
    print("原始数据标签统计： ")
    print(df['gold_label'].value_counts())
    df = df[['gold_label','sentence1','sentence2']]
    print("")
    print("去除Nan非法制以及非法标签中... ")
    print("")
    df = df.dropna()
    df = df[df.gold_label.isin(['entailment','neutral','contradiction'])]
    print("处理后数据标签统计： ")
    print(df.gold_label.value_counts())


    df['sentence1'] = preprocessing(df['sentence1'])
    df['sentence2'] = preprocessing(df['sentence2'])
    
    return df

In [2]:
train_df = Get_Preprocess_Data('../Datasets/snli_1.0/snli_1.0/snli_1.0_train.txt')
dev_df = Get_Preprocess_Data('../Datasets/snli_1.0/snli_1.0/snli_1.0_dev.txt')
test_df = Get_Preprocess_Data('../Datasets/snli_1.0/snli_1.0/snli_1.0_test.txt')

原始数据标签统计： 
entailment       183416
contradiction    183187
neutral          182764
-                   785
Name: gold_label, dtype: int64

去除Nan非法制以及非法标签中... 

处理后数据标签统计： 
entailment       183414
contradiction    183185
neutral          182762
Name: gold_label, dtype: int64
原始数据标签统计： 
entailment       3329
contradiction    3278
neutral          3235
-                 158
Name: gold_label, dtype: int64

去除Nan非法制以及非法标签中... 

处理后数据标签统计： 
entailment       3329
contradiction    3278
neutral          3235
Name: gold_label, dtype: int64
原始数据标签统计： 
entailment       3368
contradiction    3237
neutral          3219
-                 176
Name: gold_label, dtype: int64

去除Nan非法制以及非法标签中... 

处理后数据标签统计： 
entailment       3368
contradiction    3237
neutral          3219
Name: gold_label, dtype: int64


In [3]:
from importlib import import_module
from utils import build_vocab,build_dataset

model_name = 'ESIM'

x = import_module('models.' + model_name)
config = x.Config()
print('all class number : ',config.num_classes)

vocab,train_dataset = build_dataset(config,train_df,use_word=True,isTest=False)
vocab,dev_dataset = build_dataset(config,dev_df,use_word=True,isTest=False)
vocab,test_dataset = build_dataset(config,test_df,use_word=True,isTest=True)

config.num_vocab = len(vocab)

all class number :  3
Loading exist vocab ... 
Vocab size: 10002
Loading exist vocab ... 
Vocab size: 10002
Loading exist vocab ... 
Vocab size: 10002


In [4]:
from utils import MyDataset
from torch.utils.data import DataLoader,Dataset


train_dataset = MyDataset(train_dataset)
train_loader = DataLoader(train_dataset,config.batch_size,shuffle=True)

dev_dataset = MyDataset(dev_dataset)
dev_loader = DataLoader(dev_dataset,config.batch_size,shuffle=True)

test_dataset = MyDataset(test_dataset)
test_loader = DataLoader(test_dataset,config.batch_size,shuffle=False)

In [5]:
from train import train

model = x.Model(config).to(config.device)

train(config,model,train_loader,dev_loader,test_loader)

Epoch [1/50]


100%|██████████| 8584/8584 [09:51<00:00, 14.52it/s]


train loss : 0.7938 ,train acc:0.750 , dev loss : 0.7931,dev acc : 0.751 ,test acc : 0.336
saving model ...
Epoch [2/50]


100%|██████████| 8584/8584 [09:54<00:00, 14.44it/s]


train loss : 0.7658 ,train acc:0.780 , dev loss : 0.7735,dev acc : 0.771 ,test acc : 0.339
saving model ...
Epoch [3/50]


100%|██████████| 8584/8584 [09:50<00:00, 14.53it/s]


train loss : 0.7431 ,train acc:0.804 , dev loss : 0.7597,dev acc : 0.787 ,test acc : 0.342
saving model ...
Epoch [4/50]


100%|██████████| 8584/8584 [09:48<00:00, 14.58it/s]


train loss : 0.7296 ,train acc:0.818 , dev loss : 0.7536,dev acc : 0.792 ,test acc : 0.337
saving model ...
Epoch [5/50]


100%|██████████| 8584/8584 [09:52<00:00, 14.49it/s]


train loss : 0.7152 ,train acc:0.833 , dev loss : 0.7471,dev acc : 0.799 ,test acc : 0.341
saving model ...
Epoch [6/50]


100%|██████████| 8584/8584 [09:56<00:00, 14.40it/s]


train loss : 0.7062 ,train acc:0.843 , dev loss : 0.7448,dev acc : 0.800 ,test acc : 0.343
saving model ...
Epoch [7/50]


100%|██████████| 8584/8584 [09:50<00:00, 14.53it/s]


train loss : 0.6969 ,train acc:0.852 , dev loss : 0.7413,dev acc : 0.806 ,test acc : 0.338
saving model ...
Epoch [8/50]


100%|██████████| 8584/8584 [09:54<00:00, 14.43it/s]


train loss : 0.6907 ,train acc:0.858 , dev loss : 0.7443,dev acc : 0.802 ,test acc : 0.346
Epoch [9/50]


100%|██████████| 8584/8584 [09:52<00:00, 14.49it/s]


train loss : 0.6828 ,train acc:0.867 , dev loss : 0.7403,dev acc : 0.808 ,test acc : 0.341
saving model ...
Epoch [10/50]


100%|██████████| 8584/8584 [09:56<00:00, 14.40it/s]


train loss : 0.6765 ,train acc:0.873 , dev loss : 0.7441,dev acc : 0.802 ,test acc : 0.342
Epoch [11/50]


100%|██████████| 8584/8584 [09:54<00:00, 14.43it/s]


train loss : 0.6729 ,train acc:0.877 , dev loss : 0.7380,dev acc : 0.810 ,test acc : 0.343
saving model ...
Epoch [12/50]


 63%|██████▎   | 5392/8584 [06:12<03:40, 14.46it/s]


KeyboardInterrupt: 