# 决赛数据预处理

---

## 环境 & 配置
---

In [1]:
import os
import re
from tqdm import tqdm
import traceback
import random
import sys
import pprint
import jieba

sys.path.insert(0, "/home/team55/notespace/zengbin")

from jddc.config import PreConfig
from jddc.utils import write_file, read_file, save_to_pkl, read_from_pkl, create_logger
from jddc.obj import Session, Sentence
from jddc.seg import JiebaSeg, jieba_tokenize

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.926 seconds.
Prefix dict has been built succesfully.


In [2]:
conf = PreConfig()
logger = create_logger(name='pre', log_file=conf.log_file, cmd=conf.cmd_log)

## 1 - 基础处理
---


** 处理内容 **

1. 读入chat.txt，对超过7个字段的数据行进行处理，整理成7个字段（将第7个字段之后的所有字段与第7个字段合并）,结果文件 chat_pred.txt
2. 读入chat_pred.txt，根据session_id划分对话，将每一个对话所有行归集，结果文件 chat_splited.txt
3. 读入chat_splited.txt，解析每一个会话，合并连续q、a，提取订单信息等，结果文件 session_parsed.txt

### a - 合并多余字段

In [None]:
def merge_surplus_col(input_file, output_file):
    """读入chat.txt，对超过7个字段的数据行进行处理，
    整理成7个字段（将第7个字段之后的所有字段与
    第7个字段合并）,结果文件 chat_pred.txt"""
    with open(input_file, 'r', encoding="utf-8-sig") as f:
        lines = f.readlines()
    
    for i, line in tqdm(enumerate(lines), desc="merge_surplus_col", ncols=100):
        cols = line.strip("\r\n").replace("\n", "").split('\t')
        if len(cols) > 7:
            line_pred = cols[:6]
            text = "，".join(cols[6:])
            line_pred.append(text)
            lines[i] = '\t'.join(line_pred)
    
    write_file(output_file, lines, mode='w')
    return lines


In [None]:
lines = merge_surplus_col(conf.file_chat, conf.file_chat_pred)

### b - 将数据集按session拆分

In [None]:
def chat_split_by_session(input_file, output_file):
    """读入chat_pred.txt，根据session_id划分对话，将每一个对话所有行归集，结果文件 chat_splited.txt"""
    lines = read_file(input_file)

    chat_splited = []

    # 初始化 session info
    sess_info = {
        "session_id": lines[0].split('\t')[0],
        "lines": []
    }

    for i, line in tqdm(enumerate(lines), desc="chat_split_by_session", ncols=100):
        try:
            cols = line.split("\t")
            line_cols = {
                "id": cols[0],
                "user": cols[1],
                "waiter_send": cols[2],
                "transfer": cols[3],
                "repeat": cols[4],
                "sku": cols[5],
                "content": cols[6]
            }
            assert len(cols) == 7, "总共有七个字段，当前行有%i个字段" % len(cols)
            if sess_info['session_id'] == line_cols['id']:
                sess_info['lines'].append(line)
            else:
                chat_splited.append(sess_info)
                sess_info = {
                    "session_id": line_cols['id'],
                    "lines": [line]
                }

            # 保存最后一个session
            if i+1 == len(lines):
                chat_splited.append(sess_info)
                
        except Exception as e:
            logger.error('line error: %s' % line)
            logger.exception(e)
    
    chat_splited = [str(x) for x in chat_splited]
    write_file(output_file, chat_splited, mode='w')
    return chat_splited

In [None]:
chat_splited = chat_split_by_session(conf.file_chat_pred, conf.file_chat_splited)

### c - 会话解析

In [None]:
def _parse_session(sess_info):
    """解析单个session，获取order_id等信息

    返回结果：
        {
        "session_id": 会话id,
        "user_id": user_id,
        "order_id": order_id,
        "sku": 商品品类,
        "transfer": 是否转移,
        "repeat": 是否重复,
        "lines": 原始数据行,
        "qas_merged": 合并之后的对话记录
        }
    """
    lines = sess_info['lines']
    user_id = lines[0].split("\t")[1]
    transfer = list(set([line.split("\t")[3] for line in lines
                         if line.split("\t")[3] != '']))
    repeat = list(set([line.split("\t")[4] for line in lines
                       if line.split("\t")[4] != '']))
    sku = list(set([line.split("\t")[5] for line in lines
                    if line.split("\t")[5] != '']))

    # 提取订单号
    contents = "\t".join([line.split("\t")[6] for line in lines])
    pat_oid = re.compile(r'(ORDERID_\d{8})')
    order_id = list(set(pat_oid.findall(contents)))

    # 合并q/a
    qas = [(line.split("\t")[2], line.split("\t")[6]) for line in lines]
    qas_merged = []
    current_sender = qas[0][0]
    content = qas[0][1]
    for i, qa in enumerate(qas[1:]):
        if current_sender == qa[0]:
            content += "\t" + qa[1]
            # 尾行处理 
            # 必须用行标来定位尾行，不能用内容
            if i == len(qas) - 2:
                qa_ = (current_sender, content)
                qas_merged.append(qa_)
        else:
            qa_ = (current_sender, content)
            qas_merged.append(qa_)
            current_sender = qa[0]
            content = qa[1]
            # 尾行处理
            if i == len(qas) - 2:
                qa_ = (current_sender, content)
                qas_merged.append(qa_)

    return {
        "session_id": sess_info['session_id'],
        "user_id": user_id,
        "order_id": order_id,
        "sku": sku,
        "transfer": transfer,
        "repeat": repeat,
        "lines": lines,
        "qas_merged": qas_merged
    }

def chat_session_parse(input_file, output_file):
    """读入chat_splited.txt，解析每一个会话，合并连续q、a，
    提取订单信息等，结果文件 sessions.txt"""
    chat_splited = read_file(input_file)
    chat_splited = [eval(x) for x in chat_splited]
    session_parsed = []
    for sess_info in tqdm(chat_splited, desc='chat_session_parse', ncols=100):
        try:
            sess_parsed = _parse_session(sess_info)
            session_parsed.append(sess_parsed)
        except Exception:
            print(sess_info)
            traceback.print_exc()
    session_parsed = [str(x) for x in session_parsed]
    write_file(output_file, session_parsed, mode='w')
    return session_parsed

In [None]:
session_parsed = chat_session_parse(conf.file_chat_splited, conf.file_session_parsed)

In [3]:
sessions = read_file(conf.file_session_parsed)

In [6]:
conf.pkl_sessions

'/home/team55/notespace/data/temp/all_sessions.pkl'

In [None]:
s = Session(eval(sessions[0]))

In [None]:
s.multi_qa

In [4]:
def create_pkl_sessions(sessions):
    pkl_sessions = conf.pkl_sessions
    sess_objs = []
    for sess in tqdm(sessions, ncols=100, desc="create sess_objs"):
        obj = Session(eval(sess))
        if obj.data_quality:
            sess_objs.append(obj)
    print("save new sessions to %s" % pkl_sessions)
    save_to_pkl(pkl_sessions, data=sess_objs)

In [7]:
create_pkl_sessions(sessions)

create sess_objs: 100%|█████████████████████████████████| 1025140/1025140 [04:19<00:00, 3947.47it/s]


save new sessions to /home/team55/notespace/data/temp/all_sessions.pkl


In [8]:
def load_sessions():
    pkl_sessions = conf.pkl_sessions
    if os.path.exists(pkl_sessions):
        print("load sessions from %s" % pkl_sessions)
        sess_objs = read_from_pkl(pkl_sessions)
    else:
        print("refresh sessions ...")
        sessions = read_file(conf.file_session_parsed)
        sess_objs = []
        for sess in tqdm(sessions, ncols=100, desc="create sess_objs"):
            obj = Session(eval(sess))
            if obj.data_quality:
                sess_objs.append(obj)
        print("save new sessions to %s" % pkl_sessions)
        save_to_pkl(pkl_sessions, data=sess_objs)
    return sess_objs

In [9]:
sessions = load_sessions()

load sessions from /home/team55/notespace/data/temp/all_sessions.pkl


In [13]:
len(sessions)

989930

### d - 查找所有脱敏词
---


In [None]:
texts = []
for sess in tqdm(sessions, ncols=100, desc="find desensitization"):
    qas = [x[1] for x in sess.qas_merged]
    texts.extend(qas)

In [None]:
import re

In [None]:
pat1 = re.compile("(#.*?\[.*?\])")
pat2 = re.compile("(\[.*?\])")

In [None]:
res_pat1 = pat1.findall(" ".join(texts))
res_pat2 = pat2.findall(" ".join(texts))

## 2 - 数据清洗 & 拆分

---

1. 根据qaqaq的长度进行清洗，仅保留长度在(30, 500)范围内的对话

### a - 选取1000个session进行开发测试


In [10]:
s_1000 = random.sample(sessions, 1000)
save_to_pkl(conf.pkl_mqa_1000, s_1000)

In [11]:
s_10000 = random.sample(sessions, 10000)
save_to_pkl(conf.pkl_mqa_10000, s_10000)

In [12]:
s_100000 = random.sample(sessions, 100000)
save_to_pkl(conf.pkl_mqa_10000, s_100000)

## 3 - 构造用于训练词向量的数据集

---

In [None]:
def create_sentences_for_embedding():
    if os.path.exists(conf.file_stopwords):
        jb_seg = JiebaSeg(file_stopwords=conf.file_stopwords)
    
    texts = []
    for sess in tqdm(sessions, ncols=100, desc="create texts for embedding"):
        qas = [x[1].replace('\t', " ") for x in sess.qas_merged]
        texts.extend(qas)

    sentences = []
    for s in tqdm(texts, ncols=100, desc="cut sentence"):
        s_cuted = Sentence(s, seg=jb_seg)
        sentences.append(s_cuted.cuted_sentence)
    sentences = [" ".join(x) for x in sentences]
    
    write_file(file=conf.file_texts_for_embedding, content=sentences, mode='w', encoding='utf-8')
    return sentences

In [None]:
sentences = create_sentences_for_embedding()

## 3 - 单轮训练集构建

---

1. 构造QAQAQ+Q形式的训练集
2. 仅使用Q来进行匹配

## 4 - 多轮训练集构建

----

1. 字段列表：session id, question, answer