In [7]:
from config import args
import joblib
import numpy as np
from torch_geometric.data import Data, DataLoader
import torch
import random
from tqdm import tqdm

In [2]:
class MyDataLoader(object):

    def __init__(self, dataset, batch_size, mini_batch_size=0):
        self.total = len(dataset)
        self.dataset = dataset
        self.batch_size = batch_size
        self.mini_batch_size = mini_batch_size
        if mini_batch_size == 0:
            self.mini_batch_size = self.batch_size

    def __getitem__(self, item):
        ceil = (item + 1) * self.batch_size
        sub_dataset = self.dataset[ceil - self.batch_size:ceil]
        if ceil >= self.total:
            random.shuffle(self.dataset)
        return DataLoader(sub_dataset, batch_size=self.mini_batch_size)

    def __len__(self):
        if self.total == 0:
            return 0
        return (self.total - 1) // self.batch_size + 1

In [3]:
def split_train_valid_test(data, train_size, valid_part=0.1):
    train_data = data[:train_size]
    test_data = data[train_size:]
    random.shuffle(train_data)
    valid_size = round(valid_part * train_size)
    valid_data = train_data[:valid_size]
    train_data = train_data[valid_size:]
    return train_data, valid_data, test_data

In [4]:
dataset = 'mr'

In [5]:
# param
train_size = args[dataset]["train_size"]

In [34]:
# load data
inputs = np.load(f"temp/{dataset}.inputs.npy")
graphs = np.load(f"temp/{dataset}.graphs.npy")
weights = np.load(f"temp/{dataset}.weights.npy")
targets = np.load(f"temp/{dataset}.targets.npy")
len_inputs = joblib.load(f"temp/{dataset}.len.inputs.pkl")
len_graphs = joblib.load(f"temp/{dataset}.len.graphs.pkl")
word2vec = np.load(f"temp/{dataset}.word2vec.npy")

poses = np.load(f"temp/{dataset}.poses_pad.npy")
pos2vec = np.load(f"temp/{dataset}.pos2vec.npy")

In [51]:
# py graph dtype
data = []
for x, edge_index, edge_attr, y, lx, le, pos in tqdm(list(zip(inputs, graphs, weights, targets, len_inputs, len_graphs, poses))):
    # x就是输入的nodes
    x = torch.tensor(x[:lx], dtype=torch.long)
    # y就是标签
    y = torch.tensor(y, dtype=torch.long)
    
    pos = torch.tensor(pos[:lx], dtype=torch.long)
    
    edge_index = torch.tensor([e[:le] for e in edge_index], dtype=torch.long)
    edge_attr = torch.tensor(edge_attr[:le], dtype=torch.float)
    lens = torch.tensor(lx, dtype=torch.long)
    # x=nodes y=labels
    data.append(Data(x=x, y=y, edge_attr=edge_attr, edge_index=edge_index, length=lens, pos=pos))

100%|██████████| 10662/10662 [00:00<00:00, 14806.23it/s]


In [None]:
# split
    train_data, test_data, valid_data = split_train_valid_test(data, train_size, valid_part=0.1)

    # return loader & word2vec
    return [MyDataLoader(data, batch_size=batch_size, mini_batch_size=mini_batch_size)
            for data in [train_data, test_data, valid_data]], word2vec

In [None]:
def get_data_loader(dataset, batch_size, mini_batch_size):
    # param
    train_size = args[dataset]["train_size"]
    
    # load data
    inputs = np.load(f"temp/{dataset}.inputs.npy")
    graphs = np.load(f"temp/{dataset}.graphs.npy")
    weights = np.load(f"temp/{dataset}.weights.npy")
    targets = np.load(f"temp/{dataset}.targets.npy")
    len_inputs = joblib.load(f"temp/{dataset}.len.inputs.pkl")
    len_graphs = joblib.load(f"temp/{dataset}.len.graphs.pkl")
    word2vec = np.load(f"temp/{dataset}.word2vec.npy")

    poses = np.load(f"temp/{dataset}.poses_pad.npy")
    pos2vec = np.load(f"temp/{dataset}.pos2vec.npy")
    
    # py graph dtype
    data = []
    for x, edge_index, edge_attr, y, lx, le, pos in tqdm(list(zip(inputs, graphs, weights, targets, len_inputs, len_graphs, poses))):
        # x就是输入的nodes
        x = torch.tensor(x[:lx], dtype=torch.long)
        # y就是标签
        y = torch.tensor(y, dtype=torch.long)

        pos = torch.tensor(pos[:lx], dtype=torch.long)

        edge_index = torch.tensor([e[:le] for e in edge_index], dtype=torch.long)
        edge_attr = torch.tensor(edge_attr[:le], dtype=torch.float)
        lens = torch.tensor(lx, dtype=torch.long)
        # x=nodes y=labels
        data.append(Data(x=x, y=y, edge_attr=edge_attr, edge_index=edge_index, length=lens, pos=pos))
    
    # split
    train_data, test_data, valid_data = split_train_valid_test(data, train_size, valid_part=0.1)

    # return loader & word2vec
    return [MyDataLoader(data, batch_size=batch_size, mini_batch_size=mini_batch_size)
            for data in [train_data, test_data, valid_data]], word2vec, pos2vec

##### 数据集样本统计

In [1]:
from collections import Counter

In [2]:
dataset = 'R8'

In [3]:
with open(f"temp/{dataset}.texts.remove.txt", "r") as f:
    texts = f.read().strip().split("\n")

In [4]:
len(texts)

7674

In [5]:
texts[0]

'computer terminal systems completes sale computer terminal systems inc said completed sale shares common stock warrants acquire additional one mln shares n v switzerland dlrs company said warrants exercisable five years purchase price dlrs per share computer terminal said also right buy additional shares increase total holdings pct computer terminal outstanding common stock certain circumstances involving change control company company said conditions occur warrants would exercisable price equal pct common stock market price time exceed dlrs per share computer terminal also said sold rights dot impact technology including future improvements inc houston tex dlrs said would continue exclusive worldwide licensee technology company said moves part reorganization plan would help pay current operation costs ensure product delivery computer terminal makes computer generated forms ticket printers terminals'

In [6]:
Counter([len(t.split()) for t in texts])

Counter({125: 15,
         198: 4,
         31: 94,
         42: 93,
         341: 1,
         63: 59,
         22: 78,
         27: 129,
         58: 94,
         95: 19,
         26: 112,
         21: 60,
         16: 155,
         47: 108,
         78: 19,
         9: 146,
         222: 4,
         117: 25,
         49: 94,
         187: 2,
         57: 78,
         325: 1,
         81: 24,
         65: 60,
         304: 1,
         25: 88,
         35: 109,
         8: 141,
         70: 34,
         151: 10,
         7: 69,
         10: 125,
         75: 32,
         104: 15,
         11: 68,
         48: 84,
         87: 23,
         43: 99,
         30: 104,
         60: 64,
         295: 4,
         46: 94,
         37: 120,
         83: 25,
         71: 33,
         32: 93,
         137: 10,
         80: 22,
         45: 77,
         36: 115,
         38: 96,
         34: 93,
         77: 29,
         410: 4,
         51: 81,
         44: 96,
         55: 84,
         66: 40,
 

In [79]:
targets = np.load(f"temp/{dataset}.targets.npy")

In [80]:
targets

array([19, 19, 19, ..., 10, 10, 10])

In [81]:
Counter(targets)

Counter({19: 216,
         17: 75,
         3: 50,
         22: 1030,
         2: 223,
         16: 354,
         8: 63,
         0: 250,
         11: 63,
         21: 557,
         1: 125,
         13: 342,
         14: 195,
         4: 1175,
         20: 133,
         18: 129,
         5: 233,
         9: 323,
         6: 86,
         12: 410,
         15: 553,
         7: 19,
         10: 796})

In [82]:
with open(f"corpus/{dataset}.labels.txt", "r") as f:
    labels = f.read().strip().split("\n")

In [83]:
Counter(labels)

Counter({'C01': 216,
         'C02': 75,
         'C03': 50,
         'C04': 1030,
         'C05': 223,
         'C06': 354,
         'C07': 63,
         'C08': 250,
         'C09': 63,
         'C10': 557,
         'C11': 125,
         'C12': 342,
         'C13': 195,
         'C14': 1175,
         'C15': 133,
         'C16': 129,
         'C17': 233,
         'C18': 323,
         'C19': 86,
         'C20': 410,
         'C21': 553,
         'C22': 19,
         'C23': 796})