# MAXP 2021初赛数据探索和处理-4

把原始数据的标签转换成数字形式，并完成Train/Validation/Test的分割。这里的划分是用于比赛模型训练和模型选择用的，并不是原始的文件名。

In [1]:
import pandas as pd
import numpy as np
import os
import pickle

import dgl

Using backend: pytorch


In [2]:
# path
base_path = './dataset'
publish_path = ''

nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')

### 读取节点列表

In [3]:
nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})
print(nodes_df.shape)
nodes_df.tail(4)

(5346177, 4)


Unnamed: 0,node_idx,paper_id,Label,Split_ID
5346173,5346173,caed47d55d1e193ecb1fa97a415c13dd,,1
5346174,5346174,c82eb6be79a245392fb626b9a7e1f246,,1
5346175,5346175,926a31f6b378575204aae30b5dfa6dd3,,1
5346176,5346176,bbace2419c3f827158ea4602f3eb35fa,,1


### 转换标签为数字

In [4]:
# 先检查一下标签的分布
label_dist = nodes_df.groupby(by='Label').count()
print(label_dist.shape)
label_dist

(23, 3)


Unnamed: 0_level_0,node_idx,paper_id,Split_ID
Label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A,2670,2670,2670
B,65303,65303,65303
C,111502,111502,111502
D,104005,104005,104005
E,45014,45014,45014
F,32876,32876,32876
G,43452,43452,43452
H,71824,71824,71824
I,23994,23994,23994
J,25241,25241,25241


#### 可以看到一共有23个标签，A类最少，C类最多，基本每类都有几万个。下面从0开始，重够标签


In [5]:
# 按A-W的顺序，从0开始转换
for i, l in enumerate(label_dist.index.to_list()):
    nodes_df.loc[(nodes_df.Label==l), 'label'] = i

nodes_df.label.fillna(-1, inplace=True)
nodes_df.label = nodes_df.label.astype('int')
nodes_df.head(4)

Unnamed: 0,node_idx,paper_id,Label,Split_ID,label
0,0,bfdee5ab86ef5e68da974d48a138c28e,S,0,18
1,1,78f43b8b62f040347fec0be44e5f08bd,,0,-1
2,2,a971601a0286d2701aa5cde46e63a9fd,G,0,6
3,3,ac4b88a72146bae66cedfd1c13e1552d,,0,-1


#### 只保留新的node index、标签和原始的分割标签

In [6]:
nodes = nodes_df[['node_idx', 'label', 'Split_ID', 'paper_id']]
nodes.tail(4)

Unnamed: 0,node_idx,label,Split_ID,paper_id
5346173,5346173,-1,1,caed47d55d1e193ecb1fa97a415c13dd
5346174,5346174,-1,1,c82eb6be79a245392fb626b9a7e1f246
5346175,5346175,-1,1,926a31f6b378575204aae30b5dfa6dd3
5346176,5346176,-1,1,bbace2419c3f827158ea4602f3eb35fa


## 划分Train/Validation/Test

由于只有原始的Train_nodes文件里面包括了标签，所以这里的Train/Validation是对原始的分割。

这里按照9:1的比例划分Train/Validation。Test就是原来的validation_nodes里面的index。

In [7]:
# 获取所有的标签
tr_val_labels_df = nodes[(nodes.Split_ID == 0) & (nodes.label >= 0)]
test_label_df = nodes[nodes.Split_ID == 1]

# 按照0~22每个标签划分train/validation
tr_labels_idx = np.array([0])
val_labels_idx = np.array([0])
split_ratio = 0.9

for label in range(23):
    label_idx = tr_val_labels_df[tr_val_labels_df.label == label].node_idx.to_numpy()
    split_point = int(label_idx.shape[0] * split_ratio)
    
    # 把每个标签的train和validation的index添加到整个列表
    tr_labels_idx = np.append(tr_labels_idx, label_idx[: split_point])
    val_labels_idx = np.append(val_labels_idx, label_idx[split_point: ])

In [8]:
# 获取Train/Validation/Test标签index
tr_labels_idx = tr_labels_idx[1: ]
val_labels_idx = val_labels_idx[1: ]

test_labels_idx = test_label_df.node_idx.to_numpy()
test_paper_id = test_label_df.paper_id.to_numpy()

In [9]:
print(test_labels_idx)
print(test_paper_id)

[3063061 3063062 3063063 ... 5346174 5346175 5346176]
['c39457cc34fa969b03819eaa4f9b7a52' '668b9d0c53e9b6e2c6b1093102f976b3'
 'ca5c7bc1b40c0ef3c3f864aed032ca90' ... 'c82eb6be79a245392fb626b9a7e1f246'
 '926a31f6b378575204aae30b5dfa6dd3' 'bbace2419c3f827158ea4602f3eb35fa']


In [10]:
test_label_df.head()

Unnamed: 0,node_idx,label,Split_ID,paper_id
3063061,3063061,-1,1,c39457cc34fa969b03819eaa4f9b7a52
3063062,3063062,-1,1,668b9d0c53e9b6e2c6b1093102f976b3
3063063,3063063,-1,1,ca5c7bc1b40c0ef3c3f864aed032ca90
3063064,3063064,-1,1,44f810c0c000cda27ce618add55e815f
3063065,3063065,-1,1,3c206335d88637d36d83c2942586be98


In [11]:
test_label_df['paper_id'].iloc[0]

'c39457cc34fa969b03819eaa4f9b7a52'

In [12]:
test_id_dict = {idx: test_paper_id[i] for i, idx in enumerate(test_labels_idx)}
print(test_id_dict[3063061])
print(len(test_id_dict))

c39457cc34fa969b03819eaa4f9b7a52
2283116


In [13]:
# 获取完整的标签列表
labels = nodes.label.to_numpy()

In [14]:
# 保存标签以及Train/Validation/Test的index为二进制格式方便后面建模时的快速读取
label_path = os.path.join(base_path, publish_path, 'labels.pkl')

with open(label_path, 'wb') as f:
    pickle.dump({'tr_label_idx': tr_labels_idx, 
                 'val_label_idx': val_labels_idx, 
                 'test_label_idx': test_labels_idx,
                 'label': labels}, f)

In [15]:
dict_path = os.path.join(base_path, publish_path, 'test_id_dict.pkl')
with open(dict_path, 'wb') as f:
    pickle.dump(test_id_dict, f)