# Notebook for preparing and saving VOC2011 (VOC) graphs

In [None]:
import numpy as np
import torch
import pickle
import time
import os
%matplotlib inline
import matplotlib.pyplot as plt

### Run the generate_vocsuperpixels_raw.ipynb notebook inside current directory


# Convert to DGL format and save with pickle

In [None]:
import os
print(os.getcwd())

In [None]:
import pickle

# %load_ext autoreload
# %autoreload 2

from superpixels import VOCSegDatasetDGL 

# from data.data import LoadData
from torch.utils.data import DataLoader
# from data.superpixels import VOCSegDataset


In [None]:
DATASET_NAME = 'VOC'
graph_format = ['edge_wt_only_coord', 'edge_wt_coord_feat', 'edge_wt_region_boundary']
graph_format = ['edge_wt_region_boundary']
dataset = []
for gf in graph_format:
    start = time.time()
    data = VOCSegDatasetDGL(DATASET_NAME, gf, slic_compactness=30) 
    print('Time (sec):',time.time() - start)
    dataset.append(data)

In [None]:
def plot_histo_graphs(dataset, title):
    # histogram of graph sizes
    graph_sizes = []
    for graph in dataset:
        graph_sizes.append(graph[0].number_of_nodes())
        #graph_sizes.append(graph[0].number_of_edges())
    plt.figure(1)
    plt.hist(graph_sizes, bins=20)
    plt.title(title)
    plt.show()
    graph_sizes = torch.Tensor(graph_sizes)
    print('nb/min/max :',len(graph_sizes),graph_sizes.min().long().item(),graph_sizes.max().long().item())
    
plot_histo_graphs(dataset[0].train,'trainset')
plot_histo_graphs(dataset[0].val,'valset')
plot_histo_graphs(dataset[0].test,'testset')


In [None]:
print(len(dataset[0].train))
print(len(dataset[0].val))
print(len(dataset[0].test))

print(dataset[0].train[0])
print(dataset[0].val[0])
print(dataset[0].test[0])


## Prepare train, test and val pickles for PyG data source

In [None]:
def dump_voc_pyg_source(dataset, graph_format):
    vallist = []
    for data in dataset.val:
        # print(data)
        x = data[0].ndata['feat'] #x
        edge_attr = data[0].edata['feat'] #edge_attr
        edge_index = torch.stack(data[0].edges(), 0) #edge_index
        y = data[1] #y
        vallist.append((x, edge_attr, edge_index, y))

    trainlist = []
    for data in dataset.train:
        # print(data)
        x = data[0].ndata['feat'] #x
        edge_attr = data[0].edata['feat'] #edge_attr
        edge_index = torch.stack(data[0].edges(), 0) #edge_index
        y = data[1] #y
        trainlist.append((x, edge_attr, edge_index, y))

    testlist = []
    for data in dataset.test:
        # print(data)
        x = data[0].ndata['feat'] #x
        edge_attr = data[0].edata['feat'] #edge_attr
        edge_index = torch.stack(data[0].edges(), 0) #edge_index
        y = data[1] #y
        testlist.append((x, edge_attr, edge_index, y))
        
    print(len(trainlist), len(vallist), len(testlist))
    
    pyg_source_dir = './voc_superpixels_'+graph_format
    if not os.path.exists(pyg_source_dir):
        os.makedirs(pyg_source_dir)
    
    start = time.time()
    with open(pyg_source_dir+'/train.pickle','wb') as f:
        pickle.dump(trainlist,f)
    print('Time (sec):',time.time() - start) # 1.84s
    
    start = time.time()
    with open(pyg_source_dir+'/val.pickle','wb') as f:
        pickle.dump(vallist,f)
    print('Time (sec):',time.time() - start) # 0.29s
    
    start = time.time()
    with open(pyg_source_dir+'/test.pickle','wb') as f:
        pickle.dump(testlist,f)
    print('Time (sec):',time.time() - start) # 0.44s

In [None]:
for idx, gf in enumerate(graph_format):
    dump_voc_pyg_source(dataset[idx], gf)

In [None]:
len(dataset[0].val),len(dataset[0].train),len(dataset[0].test)

In [None]:
len(dataset[1].val),len(dataset[1].train),len(dataset[1].test)

In [None]:
len(dataset[2].val),len(dataset[0].train),len(dataset[2].test)

In [None]:
(1428, 8498, 1430)