In [1]:
import re
import json
import torch
import numpy as np
from tqdm import tqdm 
import itertools

def load_attr_dict(file):
    # 读取属性字典
    with open(file, 'r') as f:
        attr_dict = {}
        for attr, attrval_list in json.load(f).items():
            attrval_list = list(map(lambda x: x.split('='), attrval_list))
            attr_dict[attr] = list(itertools.chain.from_iterable(attrval_list))
    return attr_dict

# load attribute dict
attr_dict_file = "../../data/original_data/attr_to_attrvals.json"
attr_dict = load_attr_dict(attr_dict_file)

In [4]:
# remove years [fine]
fine_data = 'original_data/train_fine.txt'
new_fine_data = 'processed_data/fine50000.txt'
# new_fine_data = 'processed_data/test.txt'

rets = []
delete_list = ['2017年','2018年','2019年','2020年','2021年','2022年']
i = 0
with open(fine_data, 'r') as f:
    for i, data in enumerate(tqdm(f)):
        data = json.loads(data)
        title = data['title']
        title = title.upper() # 字母统一为大写
        for delete_item in delete_list:
            title = title.replace(delete_item, '')
        data['title'] = title
        
        rets.append(json.dumps(data, ensure_ascii=False)+'\n')
        
        # if i>500:
        #     break
        # i += 1
          
with open(new_fine_data, 'w') as f:
    f.writelines(rets)

501it [00:00, 905.35it/s]


In [8]:
# remove years and get attributes [coarse]
# 裤门襟和鞋带的属性有重合，所以需要额外的判断机制
coarse_data = 'original_data/train_coarse.txt'
new_coarse_data = 'processed_data/coarse100000.txt'
# new_coarse_data = 'processed_data/test.txt'

rets = []
delete_list = ['2017年','2018年','2019年','2020年','2021年','2022年']
querys = attr_dict.keys()
i = 0
with open(coarse_data, 'r') as f:
    for i, data in enumerate(tqdm(f)):
        data = json.loads(data)
        title = data['title']
        title = title.upper() # 字母统一为大写
        for delete_item in delete_list:
            title = title.replace(delete_item, '')
        data['title'] = title
        
        if data['match']['图文'] == 1:
            for query in querys:
                values = attr_dict[query]
                if (query == '裤门襟') and ('裤' not in title):
                    continue
                if (query == '闭合方式') and ('裤' in title):
                    continue
                for value in values:
                    if value in title:
                        data['key_attr'][query] = value
                        data['match'][query] = 1
                    
        rets.append(json.dumps(data, ensure_ascii=False)+'\n')
        
        # if i>500:
        #     break
        # i += 1
        
        
with open(new_coarse_data, 'w') as f:
    f.writelines(rets)

100000it [01:54, 875.45it/s]


In [5]:
# 对测试集处理，提前提取出对应的属性值
test_data = 'original_data/preliminary_testA.txt'
new_test_data = 'processed_data/test4000.txt'

rets = []
delete_list = ['2017年','2018年','2019年','2020年','2021年','2022年']

with open(test_data, 'r') as f:
    for i, data in enumerate(tqdm(f)):
        data = json.loads(data)
        title = data['title']
        for delete_item in delete_list:
            title = title.replace(delete_item, '')
        data['title'] = title
        
        data['key_attr'] = {}
        if len(data['query']) > 1:
            querys = data['query'][1:]
            for query in querys:
                values = attr_dict[query]
                for value in values:
                    if value in title:
                        data['key_attr'][query] = value
        # 重排feature在字典中的顺序
        feature = data['feature']
        del data['feature']
        data['feature'] = feature
        rets.append(json.dumps(data, ensure_ascii=False)+'\n')
        
        # if i>500:
        #     break
        # i += 1
        
print(len(rets))

with open(new_test_data, 'w') as f:
    f.writelines(rets)

4000it [00:04, 903.44it/s]


4000


In [11]:
# 分train val数据
coarse_path = 'processed_data/coarse100000.txt'

coarse_train_path = 'train/coarse89588.txt'
coarse_val_path = 'val/coarse10412.txt'


train_rets = []
val_rets = []

with open(coarse_path, 'r') as f:
    for i, data in enumerate(tqdm(f)):
        data = json.loads(data)
        if data['match']['图文'] == 1:      
            train_rets.append(json.dumps(data, ensure_ascii=False)+'\n')
        else:
            val_rets.append(json.dumps(data, ensure_ascii=False)+'\n')
            
        # if i>500:
        #     break
        # i += 1
        
print(len(train_rets))
print(len(val_rets))

with open(coarse_train_path, 'w') as f:
    f.writelines(train_rets)
with open(coarse_val_path, 'w') as f:
    f.writelines(val_rets)

100000it [01:52, 889.82it/s]


89588
10412


In [2]:
# 分train val数据
fine_path = 'processed_data/fine50000.txt'

fine_train_path = 'train/fine45000.txt'
fine_val_path = 'val/fine5000.txt'


train_rets = []
val_rets = []

with open(fine_path, 'r') as f:
    for i, data in enumerate(tqdm(f)):
        data = json.loads(data)
        if len(train_rets) < 45000:      
            train_rets.append(json.dumps(data, ensure_ascii=False)+'\n')
        else:
            val_rets.append(json.dumps(data, ensure_ascii=False)+'\n')
            
        # if i>500:
        #     break
        # i += 1
        
print(len(train_rets))
print(len(val_rets))

with open(fine_train_path, 'w') as f:
    f.writelines(train_rets)
with open(fine_val_path, 'w') as f:
    f.writelines(val_rets)

50000it [00:55, 901.53it/s]


45000
5000
