In [15]:
# coding=utf-8
import sys
from os.path import abspath, join, dirname, exists
import pandas as pd
import numpy as np
import re
from tqdm import tqdm
import time
import jieba
import random
import csv
import difflib
import ast
import json
import nest_asyncio
import torch

sys.path.insert(0, join(abspath(dirname('__file__')), '../../../'))

import application as app
from chatgpt import gpt_utils
from extensions import utils

from models.analyze.analyze_tool import used_time
from models.tiktok.video_common import *
from models.nlp.common import text_normalize

In [4]:
dy2 = app.connect_db('dy2')

db_sop = app.get_clickhouse('chsop')
db_sop.connect()

db = app.connect_db('default')

Connect db 10.21.90.164 douyin2
Connect clickhouse 10.21.90.164 sop
Connect db 10.21.90.164 cleaner


In [24]:
categorynm = '运动鞋/运动穿搭'
cat = 3000003
subcid = 0

In [11]:
prompt_head = [
        {"role": "system", "content": '''假设你是一个资深短视频内容分析师，熟悉短视频内容，十分严谨，擅长识别短视频文本的主题，擅长对短视频文本做归类，擅长分析短视频文本的描述主体。'''},
    {"role": "user", "content": '''【需求】
                我会先给你一段文本，请你，
                1、先找出与{category_nm}相关的实体或产品，只要找到，就回答“此文本与{category_nm}相关，提到了与{category_nm}相关的xxx产品”并结束对话，
                2、如果找不到与{category_nm}相关的实体或产品，请你判断整体文本是否跟“{category_nm}”相关，如果不相关请直接回答“此文本与{category_nm}无关”并简要说明理由，如果相关请回答“此文本与{category_nm}相关”并简要说明理由。'''},
     {"role": "user", "content": '''【示例】
                例子1、文本：“我喜欢在他爸手里,,我想我的亲人你说你俩忙一天,你俩吃饭了吗?这身体累垮了怎么办,?等于,赶快吃了吧,,私底下就是底下新品半条士力架,糖分减少58%,到某烟卖两种口味,香脆坚果和谷物在脖上,铃铛和小好吃不甜腻,工作没时间吃饭”，此文本与巧克力相关，因为提到了士力架巧克力的特点。
                例子2、文本：“有白桃茉莉味,紫薯芋泥味,杨枝甘露味,竹香抹茶味,黑巧摩卡味,黑巧珍珍味,外皮糯唧唧的超q弹,白桃茉莉味尝起来很清新,里面是真的有白桃果肉,甜度也刚好,每一口都很享受,黑巧摩卡有一种很细腻的巧克力摩卡风味,搭配着糯唧唧的表皮”，此文本与巧克力无关，“巧克力摩卡风味”与巧克力相关，但整体文本在说的是口味，不是巧克力产品。
                例子3、文本：“嘿嘿嘿,你这样拿冰糖吃呀,你都把我的棒棒糖全部烫起来了,好了好了放回去放回去,可怜的家里还有,气消米乳酸菌的,还有那个巧克力味的蛋糕,这次的气消米不一样了”，此文本与巧克力无关，因为实际上说的是巧克力味的蛋糕。'''},
    {"role": "user", "content": '''【注意】
                1、短视频内容的ocr和asr文本一般较长，而且可能有错别字，句子乱序，因此需要你先去噪。
                2、与“{category_nm}”相关的产品可能只是在内容中简单提到，因此需要你有极强的总结分析能力。
                3、回答时，先说结论（此文本与{category_nm}相关 或 此文本与{category_nm}无关），再简要说明理由。
                【提问】
                下面我会给你文本，用```包起来，请你按需求，参考示例的分析，遵守注意点，给出你的回答。
            ```{txt}```直接给出你的回答，不要拖泥带水，不要产生幻觉。'''}
]

In [12]:
# 确定跑gpt的范围
# digg_count>10000 可以随意指定
if subcid:
    sql = f'''
    select distinct aweme_id
    from douyin2_cleaner.douyin_video_zl_{cat}
    where digg_count>10000
    and aweme_id in (select aweme_id from douyin2_cleaner.douyin_video_sub_cid_zl_{cat} where sub_cid in ({subcid}))
    limit 10
    '''
    rr = dy2.query_all(sql)
    awmids = [str(x[0]) for x in rr]
else:
    sql = f'''
    select distinct aweme_id
    from douyin2_cleaner.douyin_video_zl_{cat}
    where digg_count>10000
    and aweme_id in (select aweme_id from douyin2_cleaner.douyin_video_sub_cid_zl_{cat} where sub_cid<>0)
    limit 10
    '''
    rr = dy2.query_all(sql)
    awmids = [str(x[0]) for x in rr]

In [13]:
awmids

['7257562810513640759',
 '7267245942326250793',
 '7255228628684852519',
 '7265139593673018636',
 '7212930080870386959',
 '7230776140158225725',
 '7200307004185873703',
 '7201072338413063483',
 '7210204920631283001',
 '7242229599843700000']

In [16]:
# 对应范围取数

multi_tasks = [1, 2, 3, 14]
task_names = [type2name[task] for task in multi_tasks]

def get_data(awmids, multi_tasks = [1, 2, 3, 12, 14], prefix='', category='6'):
    multi_data = {}
    # multi_data: [Dict[aweme_id, Dict[task, List[txt, *info]]]]
    ids = ','.join(awmids)
    for task in multi_tasks:
        sql = data_sql[task].format(version=prefix, category=category, ids=ids)
        if task == 11:
            dataa = db_sop.query_all(sql, print_sql=False)
        else:
            dataa = dy2.query_all(sql)
        for aweme_id, *info in dataa:
            if multi_data.get(aweme_id) is None:
                multi_data[aweme_id] = {task: [] for task in multi_tasks}
            multi_data[aweme_id][task].append(info)

    return multi_data

def get_data_source(dy2, cat, prefix=''):
    tblnm = f"douyin_video_zl{prefix}_{cat}"
    sql = f"select data_source from {project_table} where table_name='{tblnm}';"
    rr = dy2.query_all(sql)
    return rr[0][0] if rr[0][0] else None

def get_project_cat(dy2, category, prefix):
    tbl_nm = f"douyin_video_zl{prefix}_{category}"
    sql = f'''
    select brand, prop from douyin2_cleaner.project
    where table_name='{tbl_nm}'
    '''
    rr = dy2.query_all(sql)
    if not rr:
        return category, category
    else:
        brand = rr[0][0] if rr[0][0] else category
        prop = rr[0][1] if rr[0][1] else category
        return brand, prop

brand_cat, prop_cat = get_project_cat(dy2, cat, '')

data_type = get_data_source(dy2, cat)
print('data_type', data_type)
if data_type in (2, 3):
    multi_tasks = [1,2,3]
    data_sql.update(xiaohongshu_data_sql)
elif data_type == 4:
    multi_tasks = [1,2,3]
    data_sql.update(juliang_data_sql)

multi_data = get_data(awmids, multi_tasks = multi_tasks, category=cat)

res = []
for aweme_id, all_data in multi_data.items():
    tmp = ['_'+str(aweme_id), ]
    for task, data_list in all_data.items():
        txt_ = ','.join([text_normalize(txt) for txt, *info in data_list])
        tmp.append(txt_)
    res.append(tmp)

data_type 5


In [17]:
# 过滤规则，质量低的文本不过gpt
def count_chars(s):
    chinese_chars = 0
    english_chars = 0
    if not s:
        raise ValueError("error")

    for char in s:
        if char.isalpha():
            if char.isascii():
                english_chars += 1
            else:
                chinese_chars += 1

    return chinese_chars/len(s)

In [18]:
# 生成对应文本
ress = []
for awm, tt, xf,ocr,whp in tqdm(res):
    if whp and count_chars(whp)>0.5 and len(whp)>20:
        ress.append((awm, whp))
    elif ocr and count_chars(ocr)>0.5 and len(ocr)>20:
        ress.append((awm, ocr))
    elif xf and count_chars(xf)>0.5 and len(xf)>20:
        ress.append((awm, xf))
    if tt and count_chars(tt)>0.5 and len(tt)>20:
        ress.append((awm, tt))

100%|██████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 12539.03it/s]


In [20]:
ress[0]

('_7200307004185873703',
 '生活不易,被理解也是一种幸福。#711便利店#暖心饮系列 #给生活加点甜 #记录真实生活 #粤语 711便利店;暖心饮系列;给生活加点甜;记录真实生活;粤语')

In [None]:
# gpt messages生成
h_tt_mess = dict()
messages = []
for ii in ress:
    for txtt in ii[1:]:
        promptt = str(prompt_head)
        prompthead = ast.literal_eval(promptt)
        prompthead[1]['content'] = prompthead[1]['content'].format(category_nm=categorynm)
        prompthead[3]['content'] = prompthead[3]['content'].format(category_nm=categorynm,txt=txtt)
        mm = str(prompthead)
        h_tt_mess[txtt] = mm
        messages.append(prompthead)

In [None]:
_, ask_result_dict3 = await gpt_utils.get_chat_answers_nint(messages, mode='nintjp', model='gpt3')

In [None]:
# gpt结果映射
resss = []
# 这里和prompt head强相关
ccp = f"此文本与{categorynm}相关"
ccn = f"此文本与{categorynm}无关"

for ii in tqdm(ress):
    awm = ii[0][1:]
    for txtt in ii[1:]:
        if len(txtt)<20:
            continue
        mm = h_tt_mess[txtt]
        response = ask_result_dict3[mm]
        try:
            resp_detail = response["choices"][0]["message"]["content"]
            if ccn in resp_detail:
                gpt = '0'
            elif ccp in resp_detail:
                gpt = '1'
                if not txtt:
                    continue
            resss.append([ii[0], txtt, resp_detail, gpt])
        except:
            pass

In [None]:
# 训练数据生成
train_data_path = f'./data/{cat}/train.txt'
dev_data_path = f'./data/{cat}/dev.txt'
label_path = f'./data/{cat}/label.txt'

resone = []
reszero = []
for ii in resss:
    if ii[-1]=='0':
        reszero.append(ii[1])
    elif ii[-1]=='1':
        resone.append(ii[1])
        
from sklearn.model_selection import train_test_split

train_one, test_one = train_test_split(resone, test_size=0.2, random_state=2023)
train_zero, test_zero = train_test_split(reszero, test_size=0.2, random_state=2023)

# 打开文件，以写入模式写入内容
with open(train_data_path, 'w', encoding='utf-8') as file:
    # 将列表元素连接成字符串，使用制表符分隔
    for tt in train_one:
        # 写入字符串到文件
        file.write(tt+'\t'+f'{categorynm}'+'\n')
    for tt in train_zero:
        # 写入字符串到文件
        file.write(tt+'\t'+f'不是{categorynm}'+'\n')

# 打开文件，以写入模式写入内容
with open(dev_data_path, 'w', encoding='utf-8') as file:
    # 将列表元素连接成字符串，使用制表符分隔
    for tt in test_one:
        # 写入字符串到文件
        file.write(tt+'\t'+f'{categorynm}'+'\n')
    for tt in test_zero:
        # 写入字符串到文件
        file.write(tt+'\t'+f'不是{categorynm}'+'\n')
        
# 打开文件，以写入模式写入内容
with open(label_path, 'w', encoding='utf-8') as file:
    # 将列表元素连接成字符串，使用制表符分隔
    file.write(f'{categorynm}'+'\n')
    file.write(f'不是{categorynm}'+'\n')


In [27]:
# 开始训练模型
output_dir = f'models/{cat}'
model = "ernie-3.0-tiny-medium-v2-zh"

! python train.py \
    --train_path $train_data_path\
    --dev_path $dev_data_path\
    --label_path $label_path\
    --do_train \
    --do_eval \
    --do_export \
    --model_name_or_path $model \
    --output_dir $output_dir \
    --device gpu \
    --num_train_epochs 10 \
    --early_stopping True \
    --early_stopping_patience 5 \
    --learning_rate 3e-5 \
    --max_length 512 \
    --per_device_eval_batch_size 32 \
    --per_device_train_batch_size 32 \
    --metric_for_best_model accuracy \
    --load_best_model_at_end \
    --logging_steps 5 \
    --evaluation_strategy epoch \
    --save_strategy epoch \
    --save_total_limit 1

[32m[2023-12-12 16:22:26,329] [    INFO][0m - The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).[0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m -      Model Configuration Arguments      [0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m - paddle commit id              :3a1b1659a405a044ce806fbe027cc146f1193e6d[0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m - export_model_dir              :None[0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m - model_name_or_path            :ernie-3.0-tiny-medium-v2-zh[0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m - [0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m -       Data Configuration Arguments      [0m
[32m[2023-12-12 16:22:26,330] [    INFO][0m - paddle commit id              :3a1b1659a405a044ce806fbe027cc146f1193e6

loss: 0.69463539, learning_rate: 8.772e-07, global_step: 5, interval_runtime: 2.5993, interval_samples_per_second: 61.55544275448713, interval_steps_per_second: 1.9236075860777229, epoch: 0.0292
loss: 0.6883739, learning_rate: 1.754e-06, global_step: 10, interval_runtime: 1.4299, interval_samples_per_second: 111.89250934412473, interval_steps_per_second: 3.496640917003898, epoch: 0.0585
loss: 0.681604, learning_rate: 2.632e-06, global_step: 15, interval_runtime: 1.4415, interval_samples_per_second: 110.99371405378531, interval_steps_per_second: 3.468553564180791, epoch: 0.0877
loss: 0.66056333, learning_rate: 3.509e-06, global_step: 20, interval_runtime: 1.4153, interval_samples_per_second: 113.0515593579587, interval_steps_per_second: 3.532861229936209, epoch: 0.117
loss: 0.63736095, learning_rate: 4.386e-06, global_step: 25, interval_runtime: 1.4078, interval_samples_per_second: 113.65066300729386, interval_steps_per_second: 3.551583218977933, epoch: 0.1462
loss: 0.62513413, learning