In [1]:
import re
import json
import torch
import numpy as np
from tqdm import tqdm 
import itertools

def load_attr_dict(file):
    # 读取属性字典
    with open(file, 'r', encoding='utf-8') as f:
        attr_dict = {}
        for attr, attrval_list in json.load(f).items():
            attrval_list = list(map(lambda x: x.split('='), attrval_list))
            attr_dict[attr] = list(itertools.chain.from_iterable(attrval_list))
    return attr_dict


def generate_key_attr(path):
    attr_dict = load_attr_dict("./data/attr_to_attrvals.json")
    querys = attr_dict.keys()
    # 得到coarse.txt的属性
    pos_attr_yes = []
    pos_attr_no = []
    neg_attr_yes = []
    neg_attr_no = []
    with open(path, 'r', encoding='utf-8') as f:
        for i, data in enumerate(tqdm(f)):
            data = json.loads(data)
            title = data['title']
            # 去除title中年份和数字
            title = ''.join([ch for ch in title if (not ch.isdigit()) and (ch != '年')])
            data['title'] = title
            
            for query in querys:
                values = attr_dict[query]
                if (query == '裤门襟') and ('裤' not in title):
                    continue
                if (query == '闭合方式') and ('鞋' not in title or '靴' not in title):
                    continue
                for value in values:
                    if query == '衣长' and '中长款' in title:
                        data['key_attr'][query] = '中款'
                        title.replace('中长款', '中款')
                        break
                    if query == '裙长' and '中长裙' in title:
                        data['key_attr'][query] = '中裙'
                        title.replace('中长裙', '中裙')
                        break
                    if value in title:
                        data['key_attr'][query] =  value
            if '厚度常规' in title:
                title = title.replace('厚度常规', '常规厚度')
            if data['match']['图文'] == 1: 
                if data['key_attr']:    
                    pos_attr_yes.append(json.dumps(data, ensure_ascii=False)+'\n')
                else:
                    pos_attr_no.append(json.dumps(data, ensure_ascii=False)+'\n')
            else: 
                if data['key_attr']:    
                    neg_attr_yes.append(json.dumps(data, ensure_ascii=False)+'\n')
                else:
                    neg_attr_no.append(json.dumps(data, ensure_ascii=False)+'\n')
    
    print('pos attr yes {:} | pos attr no {:} | neg attr yes {:} | neg attr no {:} |'.format(len(pos_attr_yes), len(pos_attr_no), len(neg_attr_yes), len(neg_attr_no)))
    print('sum is : ', len(pos_attr_yes) + len(pos_attr_no) + len(neg_attr_yes) + len(neg_attr_no))
    return pos_attr_yes, pos_attr_no, neg_attr_yes, neg_attr_no

coarse_path = './data/train_coarse.txt'

coarse_pos_attr_yes, coarse_pos_attr_no, coarse_neg_attr_yes, coarse_neg_attr_no = generate_key_attr(coarse_path)

100000it [01:45, 947.63it/s]

pos attr yes 88463 | pos attr no 1125 | neg attr yes 10412 | neg attr no 0 |
sum is :  100000





In [2]:
import json
new_dic = {}
# 生成 attr_values值 { key:[[val_1=,], [val_2=, ]]}
with open('./data/attr_to_attrvals.json', 'r', encoding='utf-8') as f:
    attr_key = json.load(f)
    print(attr_key)
    for key, value in attr_key.items():
        tmp = []
        for v in value:
            if '=' in v:
                if key == '裙长' and '中长裙' in v:
                    tmp.append(['中裙'])
                else:
                    tmp.append(v.split('='))
            else:
                if key == '衣长' and '中长款' in v:
                    tmp.append(['中款'])
                else:
                    tmp.append([v])
        new_dic[key] = tmp

dic_rets = [json.dumps(new_dic, ensure_ascii=False)+'\n']
with open('./data/attr_match_78.json', 'w', encoding='utf-8') as f:
    f.writelines(dic_rets)

{'领型': ['高领=半高领=立领', '连帽=可脱卸帽', '翻领=衬衫领=POLO领=方领=娃娃领=荷叶领', '双层领', '西装领', 'U型领', '一字领', '围巾领', '堆堆领', 'V领', '棒球领', '圆领', '斜领', '亨利领'], '袖长': ['短袖=五分袖', '九分袖=长袖', '七分袖', '无袖'], '衣长': ['超短款=短款=常规款', '长款=超长款', '中长款'], '版型': ['修身型=标准型', '宽松型'], '裙长': ['短裙=超短裙', '中裙=中长裙', '长裙'], '穿着方式': ['套头', '开衫'], '类别': ['手提包', '单肩包', '斜挎包', '双肩包'], '裤型': ['O型裤=锥形裤=哈伦裤=灯笼裤', '铅笔裤=直筒裤=小脚裤', '工装裤', '紧身裤', '背带裤', '喇叭裤=微喇裤', '阔腿裤'], '裤长': ['短裤', '五分裤', '七分裤', '九分裤=长裤'], '裤门襟': ['松紧', '拉链', '系带'], '闭合方式': ['松紧带', '拉链', '套筒=套脚=一脚蹬', '系带', '魔术贴', '搭扣'], '鞋帮高度': ['高帮=中帮', '低帮']}


In [3]:
def process_title(path):
    # 处理title中引起歧义的两个词和过滤title
    rets = []
    with open(path, 'r') as f:
        for i, data in enumerate(tqdm(f)):
            data = json.loads(data)
            title = data['title']
            # 去除title中年份和数字
            title = ''.join([ch for ch in title if (not ch.isdigit()) and (ch != '年')])

            key_attr = data['key_attr']
            
            # 属性替换
            for query, attr in key_attr.items():
                # 去掉两个特殊的属性
                if query=='衣长' and attr=='中长款':
                    key_attr[query] = '中款'
                    title = title.replace(attr, '中款')
                if query=='裙长' and attr=='中长裙':
                    key_attr[query] = '中裙'
                    title = title.replace(attr, '中裙')
            # 一个高频词的特殊处理
            if '厚度常规' in title:
                title = title.replace('厚度常规', '常规厚度')
            data['key_attr'] = key_attr
            data['title'] = title
            rets.append(json.dumps(data, ensure_ascii=False)+'\n')
    return rets

train_fine_path = './data/train_fine.txt'

train_fine_data = process_title(train_fine_path)


50000it [00:53, 939.37it/s]


In [4]:
print(len(coarse_pos_attr_yes))
print(len(coarse_neg_attr_yes))
print(len(train_fine_data))

pos_data = coarse_pos_attr_yes + train_fine_data
np.random.shuffle(pos_data)
train_data = pos_data[len(coarse_neg_attr_yes):]
fine_data = pos_data[:len(coarse_neg_attr_yes)] + coarse_neg_attr_yes

print(len(train_data))
print(len(fine_data))

with open('./data/new_pos_train.txt', 'w', encoding='utf-8') as f:
    f.writelines(train_data)

with open('./data/new_finetune_data.txt', 'w', encoding='utf-8') as f:
    f.writelines(fine_data)

88463
10412
50000
128051
20824


In [1]:
# 属性词典
import json
with open('./data/attr_match.json', 'r', encoding='utf-8') as f:
    attr_key = json.load(f)
new_dic = {}
i = 0
for key, values in attr_key.items():
    for single_lis in values:
        for val in single_lis:
            new_dic[key+val] = i
            i += 1
print(len(new_dic.keys()))

with open('./data/attr_dic.json', 'w', encoding='utf-8') as f:
    json.dump(new_dic, f, ensure_ascii=False)


80
