In [0]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm

import warnings
warnings.filterwarnings("ignore")

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# path = 'data/'
path = '/content/drive/My Drive/M5_Competition/data/'

sales_train = pd.read_csv(path+'sales_train_validation.csv')
sales = sales_train.iloc[:, :6]

# Graph construction

In [13]:
import torch
import time
import os
import pickle
from scipy import sparse

class Graph:
    def __init__(self, categories=['state_id', 'store_id', 'cat_id', 'dept_id'], flag_load=False, load_path='graph_data/'):
        self.cats = categories
        self.flag_load = flag_load
        self.load_path = load_path

    def __call__(self, df):
        graph = self.create_graph(df)
        return graph
        
    def create_graph(self, df):
        if self.flag_load:
            path = self.load_path + 'graph_dict.pickle'
            with open(path, 'rb') as handle:
                graph_dct, self.group_num, self.group_card = pickle.load(handle)
        else:
            df_ind = df.reset_index()
            groups_lst = df_ind.groupby(by=self.cats)['index'].apply(list).values
            graph_dct = {}
            
            self.group_card = [] # number of products i each group
            self.group_num = len(groups_lst)
            
            for gr_num, group in enumerate(tqdm(groups_lst)):
                
                for i, number in enumerate(group):
                    temp = np.delete(group, i, axis=0)
                    graph_dct[number] = temp
                    
                self.group_card.append(i)
            
            if not os.path.exists(self.load_path):
                os.makedirs(self.load_path)

            path = self.load_path + 'graph_dict.pickle'
            with open(path, 'wb') as handle:
                pickle.dump((graph_dct, self.group_num, self.group_card), handle)

        return graph_dct
    
    def one_hot(self, ind, card):
        card = card * (card + 1) # A_n^2 combinations of 2 edges
        one_hot_vector = np.concatenate((np.zeros(ind), [1], np.zeros(self.group_num-ind-1)), axis=0)
        sample = np.tile(list(one_hot_vector), (card, 1))

        return list(sample)
    
    def torch_format(self, df):
        graph = self.create_graph(df)
        if self.flag_load:
            path = self.load_path + 'attributes.pickle'
            with open(path, 'rb') as handle:
                self.edge_index, self.edge_attr = pickle.load(handle)
                self.edge_attr = self.edge_attr.todense()
        else:
            edge_index = [[key, value] for key in graph.keys() for value in graph[key]]
            # important: no need to reverse edge_index due to graph construction

            edge_index = torch.tensor(edge_index, dtype=torch.long)
            self.edge_index=edge_index.t().contiguous()
            
            edge_attr = np.array([]).reshape(0, self.group_num)
            for ind, card in enumerate(tqdm(self.group_card)):
                edge_attr = np.vstack((edge_attr, self.one_hot(ind, card)))
                
            self.edge_attr =  torch.tensor(edge_attr)

            path = self.load_path + 'attributes.pickle'
            with open(path, 'wb') as handle:
                pickle.dump((self.edge_index, sparse.csr_matrix(self.edge_attr)), handle)

        print(np.shape(self.edge_index))
        print(np.shape(self.edge_attr))

start = time.time()
graph_class = Graph(flag_load=True)
graph_class.torch_format(sales)
print(time.time()-start)

torch.Size([2, 16228460])
(16228460, 70)
3.3800888061523438
