In [1]:
import os
import gc
import json
import gzip
import pickle
import numpy as np
import pandas as pd
from bs4 import BeautifulSoup
from urllib.request import urlopen
from scipy.sparse import csr_matrix

## Arguement

In [2]:
dataset_home = '/home/scai/phd/aiz218323/scratch/XML/amazon-review-data/'

amazon_file = f'{dataset_home}/datasets/All_Amazon_Meta.json.gz'
save_dir = f'{dataset_home}/GraphAmazonProducts/results'

## Helper function

In [3]:
def dict_head_random(dictionary, n=10):
    if len(dictionary):
        keys = np.random.choice(list(dictionary.keys()), size=n)
        for k in keys:
            print(f'{k} : {dictionary[k]}')
    else:
        print("EMPTY!!")
        

In [4]:
def read_amazon_dataset(filename, num_prods=3):
    data = []
    num_prods_cnt = 0
    with gzip.open(filename) as file:
        for i, line in enumerate(file):
            if  num_prods is None or num_prods_cnt < num_prods:
                product = json.loads(line.strip())
                if product['also_buy'] or product['also_view'] or product['similar_item']:
                    data.append(product)
                    num_prods_cnt += 1
            else:
                break
            if i > num_prods:
                break
    return data


## View data

Column headers : <br>
**category | 
tech1 |
description |
fit |
title |
also_buy |
image |
tech2 |
brand |
feature |
rank |
also_view |
details |
main_cat |
similar_item |
date |
price |
asin**

Columns under consideration : <br>
**also_buy | also_view | similar_item**

In [91]:
data = read_amazon_dataset(amazon_file, num_prods=10)

In [5]:
data = read_amazon_dataset(amazon_file, num_prods=100_000)

In [92]:
n = 4
print(f"{data[n]['title']} : {data[n]['asin']}\n", data[n]['also_view'])

FQQ Women Sexy Lingerie Lace Dress Sheer Babydoll Underwear G-string Outfit Set : 6342522545
 []


In [93]:
for i in range(len(data)):
    print(data[i]['asin'])

6342506256
6342509379
6342522081
6342502315
6342522545
634252209X
6342522898
6342523002


In [156]:
for i, d in enumerate(data):
    also, cnt = np.unique(d['also_view'], return_counts=True)
    if len(np.where(cnt > 1)[0]):
        print(i)

## Extract similar items

### Exploration

In [10]:
prod_with_similar = []
for d in data:
    if d['similar_item']:
        prod_with_similar.append(d['similar_item'])

In [11]:
len(prod_with_similar)

2564

In [12]:
html_text = prod_with_similar[15]
print(html_text)

 class="a-bordered a-horizontal-stripes  a-spacing-extra-large a-size-base comparison_table">



            
            
            
            
            
            <tr class="comparison_table_image_row">
                <td class="comparison_table_first_col"></td>


                <th class="comparison_image_title_cell" role="columnheader">
                    <div class="a-row a-spacing-top-micro">
                        <center>
                             <img alt="Dogs Sterling Silver Loud Figural Dog WHISTLE Pendant" src="https://images-na.ssl-images-amazon.com/images/I/419XwGqfZIL._SL500_AC_SS350_.jpg" id="comparison_image">
                        </center>
                    </div>
                    <div class="a-row a-spacing-top-small">
                        <div id="comparison_title" class="a-section a-spacing-none">
                            <span aria-hidden="true" class="a-size-base a-color-base a-text-bold">
                                This item
 

In [11]:
soup = BeautifulSoup(html_text, 'html.parser')

In [12]:
header = soup.find('tr')

In [13]:
for th in header.find_all('th'):
    product_html = th.find('span')
    print(product_html)

<span aria-hidden="true" class="a-size-base a-color-base a-text-bold">
                                This item
                            </span>
<span class="a-size-base">.925 Sterling Silver Whistle Charm Pendant</span>
<span class="a-size-base">Sterling Silver Chihuahuas Dog Pendant</span>
<span class="a-size-base">Sterling Silver Chihuahuas Dog Pendant</span>


In [14]:
for th in header.find_all('th'):
    product_html = th.find('span')
    if len(product_html['class']) == 1 and product_html['class'][0] == 'a-size-base':
        print(product_html.get_text())

.925 Sterling Silver Whistle Charm Pendant
Sterling Silver Chihuahuas Dog Pendant
Sterling Silver Chihuahuas Dog Pendant


### Code

In [5]:
def extract_similar_items(html_text):
    soup = BeautifulSoup(html_text, 'html.parser')
    header = soup.find('tr')
    
    similar_products = []
    
    if header:
        for th in header.find_all('th'):
            product_html = th.find('span')
            if len(product_html['class']) == 1 and product_html['class'][0] == 'a-size-base':
                similar_products.append(product_html.get_text())
                
    return similar_products


In [13]:
extract_similar_items(html_text)

['.925 Sterling Silver Whistle Charm Pendant',
 'Sterling Silver Chihuahuas Dog Pendant',
 'Sterling Silver Chihuahuas Dog Pendant']

In [14]:
extract_similar_items('')

[]

## AmazonGraph

In [122]:
class AmazonGraph:

    def __init__(self):
        self.graph = {}
        
    def add_product(self, prod_id, products):
        products, counts = np.unique(products, return_counts=True)
        self.graph[prod_id] = (products.tolist(), counts.tolist())
        
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        filename = f'{save_dir}/amazon_graph{tag}.pickle'
        with open(filename, 'wb') as f:
            pickle.dump(self.graph, f)
            
        del self.graph
        gc.collect()
            
    def load_data(self, save_dir, tag=''):
        filename = f'{save_dir}/amazon_graph{tag}.pickle'
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                self.graph = pickle.load(f)
            return True
        print(f"ERROR:: Unable to load the graph at '{filename}'")
        return False
    
    def replace_duplicates(self, duplicates):
        delete_nodes = []

        for node in self.graph:
            if node in duplicates:
                delete_nodes.append(node)
            
            for i, (edge, count) in enumerate( zip(*self.graph[node]) ):
                if edge in duplicates:
                    self.graph[node][0][i] = duplicates[edge]

        for node in delete_nodes:
            self.graph[duplicates[node]] = self.graph[node]
            del self.graph[node]
            
    def convert_graph(self):
        if len(self.graph) and isinstance(self.graph, dict):
            key = list(self.graph.keys())[0]
            if isinstance(self.graph[key], tuple):
                for doc, (edges, counts) in self.graph.items():
                    self.graph[doc] = {e:c for e, c in zip(edges, counts)}
            elif isinstance(self.graph[key], dict):
                for doc, edge_count in self.graph.items():
                    self.graph[doc] = (list(edge_count.keys()), list(edge_count.values()))
            else:
                raise Exception("Invalid graph format.")
                
    def remove_dead(self, id_to_title):
        delete_nodes = []
        for product_id, (edges, counts) in self.graph.items():
            active_edges = list()
            active_counts = list()
            
            if product_id in id_to_title and id_to_title[product_id]:
                while edges:
                    edge = edges.pop()
                    count = counts.pop()

                    if edge in id_to_title and id_to_title[edge]:
                        active_edges.append(edge)
                        active_counts.append(count)

                if len(active_edges):
                    self.graph[product_id] = (active_edges, active_counts)
                else:
                    delete_nodes.append(product_id)
            else:
                delete_nodes.append(product_id)
        
        for node in delete_nodes:
            del self.graph[node]
        return None
    

### Testing

In [108]:
g = AmazonGraph()

#### add_product

In [114]:
g.add_product(1, [1, 1, 2, 3, 4, 4, 4, 5, 8, 99])
g.add_product(2, [2, 3, 44, 55, 55, 66, 11])
g.add_product(3, [1, 2])

In [115]:
g.graph

{1: ([1, 2, 3, 4, 5, 8, 99], [2, 1, 1, 3, 1, 1, 1]),
 2: ([2, 3, 11, 44, 55, 66], [1, 1, 1, 1, 2, 1]),
 3: ([1, 2], [1, 1])}

In [119]:
id_to_title = {1:'a', 2:'b', 3:'', 5:'d', 99:'e', 11:'f'}

In [120]:
g.remove_dead(id_to_title)

In [121]:
g.graph

{1: ([1, 2, 5, 99], [2, 1, 1, 1]), 2: ([2, 11], [1, 1])}

#### save and load

In [21]:
g.save_data(f'{dataset_home}/trial', tag='t')

In [25]:
new_g = AmazonGraph()
new_g.load_data(f'{dataset_home}/trial', tag='t')

True

In [26]:
new_g.graph

{100: ([1, 2, 3, 4, 5, 8, 99], [2, 1, 1, 3, 1, 1, 1]),
 200: ([2, 3, 11, 44, 55, 66], [1, 1, 1, 1, 2, 1]),
 300: ([133, 144, 199], [2, 1, 1])}

#### convert graph

In [31]:
g.convert_graph()

In [32]:
g.graph

{100: {1: 2, 2: 1, 3: 1, 4: 3, 5: 1, 8: 1, 99: 1},
 200: {2: 1, 3: 1, 11: 1, 44: 1, 55: 2, 66: 1},
 300: {133: 2, 144: 1, 199: 1}}

## AmazonGraphContainer

In [123]:
def parse(path):
    g = gzip.open(path, 'rb')
    for l in g:
        yield json.loads(l)
        

In [124]:
class AmazonGraphContainer:
    
    def __init__(self):
        """
        Amazon graphs.
        """
        self.graphs = {}
        self.graphs['similar'] = AmazonGraph()
        self.graphs['also_buy'] = AmazonGraph()
        self.graphs['also_view'] = AmazonGraph()
        
        """
        Map from prod_id to title.
        """
        self.id_to_title = {}
        
        """
        Store product description.
        """
        self.description = {}
        
    def convert_graph(self, graph_type="all"):
        graph_types = []
        if graph_type == "all":
            graph_types = ['similar', 'also_buy', 'also_view']
        else:
            graph_types.append(graph_type)
        
        for graph_type in graph_types:
            self.graphs[graph_type].convert_graph()
            
    def remove_dead(self, graph_type="all"):
        graph_types = []
        if graph_type == "all":
            graph_types = ['similar', 'also_buy', 'also_view']
        else:
            graph_types.append(graph_type)
        
        for graph_type in graph_types:
            self.graphs[graph_type].remove_dead(self.id_to_title)
            
    def create_graph(self, filename, limit=None):
        
        for i, product in enumerate(parse(filename)):
            
            similar_items = extract_similar_items(product['similar_item'])
            
            if (len(product['also_view']) or len(product['also_buy']) or len(similar_items)) \
            and len(product['title']) and len(product['asin']):
                product_id = product['asin']
                
                self.id_to_title[product_id] = product['title']
                
                if len(product['also_view']):
                    self.graphs['also_view'].add_product(product_id, product['also_view'])
                if len(product['also_buy']):
                    self.graphs['also_buy'].add_product(product_id, product['also_buy'])
                if len(similar_items):
                    self.graphs['similar'].add_product(product_id, similar_items)
                if len(product['description']):
                    self.description[product_id] = product['description']
                
            if limit is not None and i > limit:
                break
                
    def replace_similar_graph_titles(self):
        nodes_to_delete = []
        title_to_id = { product_title:product_id for product_id, product_title in self.id_to_title.items() }
        
        for node, (title_edges, count_edges) in self.graphs['similar'].graph.items():
            id_edges, new_count_edges = [], []
            for i, edge in enumerate(title_edges):
                if edge in title_to_id:
                    id_edges.append(title_to_id[edge])
                    new_count_edges.append(count_edges[i])

            if not len(id_edges):
                nodes_to_delete.append(node)
            else:
                self.graphs['similar'].graph[node] = (id_edges, new_count_edges)

        for node in nodes_to_delete:
            del self.graphs['similar'].graph[node]
            
    def replace_graph_duplicates(self, duplicates, graph_type='all'):
        graph_types = []
        if graph_type == "all":
            graph_types = ['similar', 'also_buy', 'also_view']
        else:
            graph_types.append(graph_type)
        
        for graph_type in graph_types:
            self.graphs[graph_type].replace_duplicates(duplicates)
                
    def save_graph(self, save_dir, tag='', graph_type='similar'):
        if graph_type != 'similar' and graph_type != 'also_buy' and graph_type != 'also_view':
            raise Exception("graph_type should be in ['similar', 'also_buy', 'also_view']")
            
        self.graphs[graph_type].save_data(save_dir, tag=f'_{graph_type}{tag}')
            
    def save_graphs(self, save_dir, tag=''):
        graph_types = ['similar', 'also_view', 'also_buy']
        
        for graph_type in graph_types:
            self.save_graph(save_dir, tag=tag, graph_type=graph_type)
              
    def save_idtotitle(self, save_dir, tag=''):
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        with open(map_file, 'wb') as f:
            pickle.dump(self.id_to_title, f)   
        del self.id_to_title
        gc.collect()
            
    def save_description(self, save_dir, tag=''):
        content_file = f'{save_dir}/description{tag}.pickle'
        with open(content_file, 'wb') as f:
            pickle.dump(self.description, f)
        del self.description
        gc.collect()
        
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        
        self.save_graphs(save_dir, tag)
        self.save_idtotitle(save_dir, tag)
        self.save_description(save_dir, tag)
            
    def load_graph(self, save_dir, tag='', graph_type='similar'):
        if graph_type != 'similar' and graph_type != 'also_buy' and graph_type != 'also_view':
            raise Exception("graph_type should be in ['similar', 'also_buy', 'also_view']")
            
        if not self.graphs[graph_type].load_data(save_dir, tag=f'_{graph_type}{tag}'):
            raise Exception(f"Unable to load '{graph_type} graph'.")
                
    def load_graphs(self, save_dir, tag=''):
        graph_types = ['similar', 'also_view', 'also_buy']
        
        for graph_type in graph_types:
            self.load_graph(save_dir, tag=tag, graph_type=graph_type)
        
    def load_idtotitle(self, save_dir, tag=''):
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        if os.path.exists(map_file):
            with open(map_file, 'rb') as f:
                self.id_to_title = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'id_to_title' from '{map_file}'.")
            
    def load_description(self, save_dir, tag=''):
        content_file = f'{save_dir}/description{tag}.pickle'
        if os.path.exists(content_file):
            with open(content_file, 'rb') as f:
                self.description = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'description' from '{content_file}'.")
            
    def load_data(self, save_dir, tag=''):
        self.load_graphs(save_dir, tag)
        self.load_idtotitle(save_dir, tag)
        self.load_description(save_dir, tag)
        

### Create graph

In [9]:
save_dir = f'{dataset_home}/GraphAmazonProducts/results'
amazon_graphs = AmazonGraphContainer()

#amazon_graphs.create_graph(amazon_file)
#amazon_graphs.save_data(save_dir)

#### Testing

In [111]:
save_dir = f'{dataset_home}/trial'

amazon_graphs = AmazonGraphContainer()
amazon_graphs.create_graph(amazon_file, limit=10)
#amazon_graphs.save_data(save_dir, tag='_test')

In [112]:
len(amazon_graphs.graphs['also_view'].graph), amazon_graphs.graphs['also_view'].graph.keys()

(7,
 dict_keys(['6342506256', '6342509379', '6342522081', '6342502315', '634252209X', '6342522898', '6342523002']))

6342506256
6342509379
6342522081
6342502315

634252209X
6342522898
6342523002

#### Loading

In [96]:
graph = AmazonGraphContainer()
graph.load_data(save_dir, tag='_test')

In [105]:
dict_head_random(graph.graphs['similar'].graph, n=2)

EMPTY!!


In [106]:
dict_head_random(graph.graphs['also_view'].graph, n=2)

6342502315 : (['B008Q0E5GG', 'B00MJB94W8', 'B00MQ2R4E0', 'B00MSGURES', 'B00NSF70KM', 'B00O2XKJRK', 'B00OYXPK40', 'B00PKTJWZU', 'B00UV3KXJE', 'B015MY2YZ2', 'B0169ZHJDK', 'B016OV70UQ', 'B0176MWX9K', 'B017M55DI4', 'B017M5BVXA', 'B017MRPFDA', 'B01ARAERV0', 'B01CIW8NAG', 'B01FQGGF5G', 'B01LWNUSJJ', 'B01LXAY36A', 'B01LYGQ76G', 'B01LYRMI0Q', 'B01M335X8C', 'B01M3N8NJ0', 'B01M8LC79F', 'B01MAZMNMA', 'B01MTOY6FM', 'B01N6IFJCK', 'B07285LHMN', 'B0741DMN8C', 'B07574HH9H', 'B075GY6V85', 'B075TF2SF3', 'B075VKDZW7', 'B075XHS5ZR', 'B0765BQZCM', 'B07716THHJ', 'B0772NH4K4', 'B077XRC6TB', 'B0786124FK', 'B079GBYMH2', 'B079GCD4VP', 'B07BBJRRP4', 'B07BGJHVZK', 'B07BVK2HSK', 'B07CT5SSVL', 'B07D1N4NPS', 'B07DWV353P', 'B07FD9HWPM', 'B07GDHLQN9', 'B07GWZV32X', 'B07H373XYB', 'B07H4LYXJG', 'B07HHJPC61', 'B07JJNQGYJ', 'B07JWFZMCG', 'B07KQYM4MD', 'B07M8FGC86'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [107]:
dict_head_random(graph.graphs['also_buy'].graph, n=2)

6342522545 : (['B00VBVXVPI'], [1])
6342502315 : (['B00BIJC8I4', 'B00G9EMVIA', 'B00KU1NLGO', 'B00MQ2R4E0', 'B00NSF70KM', 'B00NV1VFPO', 'B0169ZHJDK', 'B017M55DI4', 'B017M5BVXA', 'B018YRBB80', 'B019ZAYUB0', 'B01CIW8NAG', 'B01LLOUFRQ', 'B01LYDMB6U', 'B07FD9HWPM', 'B07KX6PPW6'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [113]:
print(amazon_graphs.description == graph.description, 
      amazon_graphs.id_to_title == graph.id_to_title,
      amazon_graphs.graphs['similar'].graph == graph.graphs['similar'].graph,
      amazon_graphs.graphs['also_view'].graph == graph.graphs['also_view'].graph,
      amazon_graphs.graphs['also_buy'].graph == graph.graphs['also_buy'].graph)

True True True True True


#### Visualize

In [115]:
amazon_graphs.load_idtotitle(save_dir, tag='_test')
print(f'Number of nodes : {len(amazon_graphs.id_to_title)}')

Number of nodes : 8


In [116]:
amazon_graphs.load_graphs(save_dir, tag='_test')
len(amazon_graphs.graphs['similar'].graph), len(amazon_graphs.graphs['also_view'].graph), \
len(amazon_graphs.graphs['also_buy'].graph)

(0, 7, 5)

In [31]:
amazon_graphs.load_graph(save_dir, graph_type='also_buy')
amazon_graphs.load_graph(save_dir, graph_type='also_view')
amazon_graphs.load_graph(save_dir, tag='_resolved', graph_type='similar')

In [117]:
dict_head_random(amazon_graphs.graphs['also_view'].graph, n=2)

6342509379 : (['B0018OY5X0', 'B001AQVMDM', 'B002DMJOC8', 'B004EENYW4', 'B005F28HIK', 'B00CAKH5NI', 'B00W5AWTNY', 'B012KC6EO6', 'B012KC7F2Q', 'B012MKMMWO', 'B0156SZQ5O', 'B0157M276K', 'B0179ATZ9A', 'B0179AUC5Q', 'B0179AUNYQ', 'B017Y1GMR2', 'B019KYRQYO', 'B01A3S8MLC', 'B01CG5HYE6', 'B01E377QRU', 'B01FVRKZ4W', 'B01GJNI42C', 'B01H7EMBXK', 'B01N5M3ZMG', 'B06VV1HQ79', 'B06Y26PZ5R', 'B06Y2QZW18', 'B06Y5N9YFF', 'B071PFP967', 'B072XTTTK9', 'B073P7PHCD', 'B073WRLB3Y', 'B074WHY9NZ', 'B0757DW7XJ', 'B075F6WW62', 'B075KQPV28', 'B075ZWMDK4', 'B076B8J2TX', 'B076ZKSN8Q', 'B077GBDQ44', 'B077GQQKRV', 'B077TR2855', 'B0797PSJ4J', 'B079F6M38L', 'B07C3FNYF5', 'B07CBJQTF6', 'B07CHLF5RG', 'B07CKDN5GS', 'B07CMK5VRD', 'B07CQ4P1ZD', 'B07CQCKX1P', 'B07CQJ8S9V', 'B07FS739BL', 'B07GBSJSWP', 'B07H2YZS5K', 'B07H2Z6S9J', 'B07H319WRY', 'B07JGB3X5J'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [118]:
dict_head_random(amazon_graphs.graphs['also_buy'].graph, n=2)

6342522898 : (['B00125SQ44', 'B015W134LS', 'B01AHZSZ9A', 'B01H43Z5GY', 'B01H8BRYXY', 'B01I809NCO', 'B01JRDL1JE', 'B01NBT7VLP', 'B06XKWCGTT', 'B06ZZBQMT4', 'B07219C7LQ', 'B0748C68ZX', 'B07CMD59N4'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
6342502315 : (['B00BIJC8I4', 'B00G9EMVIA', 'B00KU1NLGO', 'B00MQ2R4E0', 'B00NSF70KM', 'B00NV1VFPO', 'B0169ZHJDK', 'B017M55DI4', 'B017M5BVXA', 'B018YRBB80', 'B019ZAYUB0', 'B01CIW8NAG', 'B01LLOUFRQ', 'B01LYDMB6U', 'B07FD9HWPM', 'B07KX6PPW6'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


In [119]:
dict_head_random(amazon_graphs.graphs['similar'].graph, n=2)

EMPTY!!


In [122]:
amazon_graphs.id_to_title

{'6342506256': "Gaok Men's Retro Cotton Multi-Pocket Camo Cargo Shorts",
 '6342509379': "QIBOE Men's Baggy Jeans Denim Sweatpants Loose Pants",
 '6342522081': 'Crazy Explosion-proof Lens Polarized Sunglasses Cycling Glasses Lenses',
 '6342502315': "Crazy Women's Voile Crinkle Scarf Shawl",
 '6342522545': 'FQQ Women Sexy Lingerie Lace Dress Sheer Babydoll Underwear G-string Outfit Set',
 '634252209X': 'Crazy Explosion-proof Lens Polarized Sunglasses Cycling Glasses Lenses',
 '6342522898': "Crazy Women's Sexy Leather Backless Bodycon Clubwear Mini Dress Nightclub",
 '6342523002': "FQQ Women's Sexy Lingerie Babydoll Dress Sleepwear Bodysuit"}

###  Similar products

In [125]:
amazon_graphs.load_data(save_dir)

In [96]:
amazon_graphs.replace_similar_graph_titles()
amazon_graphs.save_graph(save_dir, tag='_resolved', graph_type='similar')

#### Testing

In [157]:
graph = AmazonGraphContainer()

g  = {1:(['s', 'u', 'c'], [1, 2, 1]), 
     2:(['h', 'i', 't'], [3, 1, 1]), 
     3:(['h', 'p', 'r'], [5, 1, 1]),
     4:(['a', 'b', 'h', 'u'], [1, 99, 1, 1]), 
     5:(['f'], [1])}

In [158]:
graph.graphs['similar'].graph = g

In [159]:
graph.id_to_title = {i:c for i, c in enumerate("suhihpabhg")}
graph.id_to_title

{0: 's',
 1: 'u',
 2: 'h',
 3: 'i',
 4: 'h',
 5: 'p',
 6: 'a',
 7: 'b',
 8: 'h',
 9: 'g'}

In [160]:
graph.replace_similar_graph_titles()

In [161]:
graph.graphs['similar'].graph

{1: ([0, 1], [1, 2]),
 2: ([8, 3], [3, 1]),
 3: ([8, 5], [5, 1]),
 4: ([6, 7, 8, 1], [1, 99, 1, 1])}

#### Statistics

In [203]:
num_products, num_absent = 0, 0
absent_similar_products = list()

title_to_id = {ptitle:pid for pid, ptitle in amazon_graphs.id_to_title.items()}

for node, edges in amazon_graphs.graphs['similar'].graph.items():
    num_products += len(edges)
    for edge in edges[0]:
        if edge not in title_to_id:
            num_absent += 1
            absent_similar_products.append(edge)

print(f'Number of products : {num_products}')
print(f'Number of absent   : {num_absent}')

Number of products : 8
Number of absent   : 14


In [206]:
np.random.choice(absent_similar_products, size=4)

array(['Princess Paradise Percy Ride-in Train Costume, Green, Child',
       'Forum Novelties 78762 Festive Happy Birthday Cone Hat Adults Kids Pom Balls Circus Fancy Accessory Party Supplies, One Size',
       'Girls Princess Tiara Cone Hat Headband, Lavender',
       'Disney Pixar Monsters University Mike Boys Deluxe Costume, Large/4-6'],
      dtype='<U123')

#### Visualize

In [126]:
amazon_graphs.load_graph(save_dir, tag='_resolved', graph_type='similar')

dict_head_random(amazon_graphs.graphs['similar'].graph)

B00J4NR412 : (['B00JJ1TOVC', 'B00JJ1SU1W'], [1, 1])
B004I16NYY : (['B002R0P0C4'], [1])
B007KFNMJ4 : (['B000IY80S2', 'B000IY9IUQ', 'B002L6JFL6'], [1, 1, 1])
B000CIXMLY : (['B000CIV4PU', 'B000E35WHO'], [1, 1])
B001PGTQAS : (['B01E6BS8EI', 'B001PGTQIK', 'B019VFFUGI'], [1, 1, 1])
B00INXD0XA : (['B00INXD2TM', 'B00INXD5O4'], [1, 1])
B00VWXHVII : (['B00QXUDIL4'], [1])
B01D34UKV8 : (['B00ZOZ9XRC'], [1])
B00PAV3GLE : (['B00APBDG4G'], [1])
B01B2TXMIO : (['B00OTXYLKE'], [1])


### Read duplicates

In [68]:
def load_duplicates_map(duplicate_file, id_to_title):
    duplicates = {}
    
    with open(duplicates_file) as file:
        for line in file:
            product_ids = line[:-1].split(' ')
            
            representative_id = None
            if len(product_ids):
                for product_id in product_ids:
                    if product_id in id_to_title:
                        if representative_id is None:
                            representative_id = product_id
                        duplicates[product_id] = representative_id
                    
    return duplicates


In [127]:
duplicates_file = f'{dataset_home}/datasets/duplicates.txt'
duplicates = load_duplicates_map(duplicates_file, amazon_graphs.id_to_title)

In [65]:
dict_head_random(duplicates, n=5)

1935937286 : 1935937286
1503940918 : 1503940918
B000025M3E : B000025M3E
B00GWYE40U : B00GWYE40U
B00KX0RJL0 : B00KX0RJL0


### Remove duplicates

In [128]:
amazon_graphs.replace_graph_duplicates(duplicates)
amazon_graphs.remove_dead()
amazon_graphs.convert_graph()

amazon_graphs.save_graphs(save_dir, tag='_dict')

#### Testing

In [40]:
graph = AmazonGraph()

d = {1:100, 4:400}
g = {1:([22, 33, 44], [1, 1, 1]), 
     2:([1, 2, 3], [1, 1, 1]), 
     3:([99, 23, 12], [1, 1, 1]),
     4:([1, 11, 3, 13], [1, 1, 1]), 
     5:([4, 5], [1, 1])}

In [41]:
graph.graph = g
graph.graph

{1: ([22, 33, 44], [1, 1, 1]),
 2: ([1, 2, 3], [1, 1, 1]),
 3: ([99, 23, 12], [1, 1, 1]),
 4: ([1, 11, 3, 13], [1, 1, 1]),
 5: ([4, 5], [1, 1])}

In [42]:
graph.replace_duplicates(d)
graph.graph

{2: ([100, 2, 3], [1, 1, 1]),
 3: ([99, 23, 12], [1, 1, 1]),
 5: ([400, 5], [1, 1]),
 100: ([22, 33, 44], [1, 1, 1]),
 400: ([100, 11, 3, 13], [1, 1, 1])}

#### Visualize

In [102]:
amazon_graphs.load_graphs(save_dir, tag='_dict')

In [103]:
dict_head_random(amazon_graphs.graphs['also_view'].graph, n=2)

0131879057 : {'0132826453': 1}
B002OC311M : {'B001D6A4G0': 1, 'B0001NLE06': 1, 'B002NLXZX8': 1, 'B002NLZYFA': 1, 'B002S60V76': 1, 'B002ZKGQGK': 1, 'B002ZKKX3M': 1, 'B00NVGLE0U': 1, 'B00NVGLX42': 1, 'B00QH7PSV6': 1, 'B01N4CV2VB': 1, 'B071VJPDK3': 1, 'B0731S5NMG': 1, 'B0779Q7LXW': 1, 'B07F195ZSH': 1}


In [106]:
dict_head_random(amazon_graphs.graphs['also_buy'].graph, n=2)

B01BISDLWQ : {'B006P2SZ6S': 1, 'B01950TCEU': 1, 'B01J5RH5Y2': 1}
B004IZN51Y : {'B000ZFPZHI': 1, 'B00861UPUA': 1, 'B008VPU1YG': 1, 'B0091CC1OG': 1, 'B00DYR9FUC': 1, 'B00BJEP7NQ': 1, 'B00PM2EQE2': 1, 'B00UXK6G00': 1, 'B00XK60UMY': 1, 'B015JCJX2Y': 1, 'B01MSVZFAU': 1, 'B01N4ASYLJ': 1, 'B06ZZ3M9Y7': 1, 'B073KWCFTB': 1, 'B073R18713': 1, 'B074QMCX31': 1, 'B075RLMYBQ': 1, 'B076Q7XY2Q': 1, 'B077GDFV4V': 1, 'B078N7M2HF': 1, 'B079L3WBDN': 1, 'B07CS9MZCD': 1, 'B07HDHXZH4': 1, 'B07HR8SSJV': 1, 'B07L5FPMV3': 1}


In [105]:
dict_head_random(amazon_graphs.graphs['similar'].graph, n=3)

B014LGK1P2 : {'B00MNOPS1C': 1, 'B00CM1AAOG': 1}
B00ILUQSZW : {'B00UPSMB0O': 1}
B00UQQWN3A : {'B014KMKV84': 1, 'B014KMKVGG': 1, 'B0009V89NK': 1, 'B007CISMEE': 1}


## Train-test split

In [129]:
def prune_map(mapping, idxs):
    rev_mapping = {value:key for key, value in mapping.items()}
    pruned_mapping = {}
    for i, idx in enumerate(idxs):
        pruned_mapping[rev_mapping[idx]] = i
    return pruned_mapping

def split_count(num_samples, perc=0.7):
    if num_samples == 1:
        return 1 if np.random.rand() > 0.3 else 0
    
    num_train = int( np.ceil(num_samples*perc) )
    if num_train == num_samples and num_samples > 1:
        num_train -= 1
    return num_train


In [130]:
class AmazonSplit:
    
    def __init__(self, graph=None):
        self.graph = None
        self.labels = None
        self.doc_to_rowindex = None
        
        if graph:
            _ = self.to_matrix(graph)
        
        self.train, self.test = None, None
        self.trn_tst_labels = None
        self.train_doc_to_rowindex, self.test_doc_to_rowindex = None, None
        
    def to_matrix(self, graph):
        indptr = [0]
        indices = []
        data = []

        self.doc_to_rowindex = {}

        self.labels = {}
        for i, (doc, edge_count) in enumerate(graph.items()):
            self.doc_to_rowindex[doc] = i
            
            for link, cnt in edge_count.items():
                index = self.labels.setdefault(link, len(self.labels))
                indices.append(index)
                data.append(cnt)
            indptr.append(len(indices))

        self.graph = csr_matrix((data, indices, indptr), dtype=int)
        return self.graph, self.labels, self.doc_to_rowindex
    
    def clean_matrix(self, clean_type=0):
        if clean_type == 0:
            self.graph, self.labels, self.doc_to_rowindex = self.remove_single_labels(self.graph,
                                                                                      self.labels,
                                                                                      self.doc_to_rowindex)
        elif clean_type == 1:
            pruned_rows = self.get_pruned_row(self.graph)
            self.graph, self.doc_to_rowindex = self.prune_graph_rows(self.graph,
                                                                     self.doc_to_rowindex,
                                                                     pruned_rows)
        else:
            pruned_cols = self.get_pruned_cols(self.graph)
            self.graph, self.labels = self.prune_graph_cols(self.graph,
                                                           self.labels,
                                                           pruned_cols)
    
    """
    def get_split_idx(self):
        train_rowidx = []
        test_rowidx = []
        
        num_rows, num_cols = self.graph.shape
        row_available_flag = np.ones(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.sum(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        for lcnt in uni_label_cnt:
            if lcnt > upper_threshold:
                break
                
            label_pos = np.where(label_cnt == lcnt)[0]
            row_numlabels = np.array(self.graph[:, label_pos].sum(axis=1)).reshape(-1)
            
            row_idx = np.where(row_numlabels > 0)[0]
            row_available_flag[row_idx] = False
            num_train = split_count(row_idx.shape[0], perc=0.7)
            row_idx = list(np.random.permutation(row_idx))

            train_rowidx.extend(row_idx[:num_train])
            test_rowidx.extend(row_idx[num_train:])
        
        row_idx = np.where(row_available_flag == True)[0]
        num_train = split_count(row_idx.shape[0], perc=0.7)
        row_idx = list(np.random.permutation(row_idx))
        train_rowidx.extend(row_idx[:num_train])
        test_rowidx.extend(row_idx[num_train:])
        
        return train_rowidx, test_rowidx
        """
    
    def get_split_idx(self, upper_threshold=10, perc=0.7):
        train_rowidx = []
        test_rowidx = []
        
        row_idxs, col_idxs = self.graph.nonzero()
        sort_idx = np.argsort(col_idxs)
        row_idxs = row_idxs[sort_idx]
        col_idxs = col_idxs[sort_idx]
        
        num_rows = self.graph.shape[0]
        row_inserted_flag = np.zeros(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.getnnz(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        """
        print(label_cnt)
        print(uni_label_cnt)
        """
        
        cnt = 0
        for lcnt in uni_label_cnt:
            if cnt == num_rows or lcnt >= upper_threshold:
                break
                
            pos_ptr, col_ptr = 0, 0
            pos_idxs = np.where(label_cnt == lcnt)[0]
            pos_idxs.sort()

            while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and cnt < num_rows:
                if pos_idxs[pos_ptr] != col_idxs[col_ptr]:
                    col_ptr += 1
                else:
                    sample_row_idxs = []
                    while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and \
                    cnt < num_rows and pos_idxs[pos_ptr] == col_idxs[col_ptr]:
                        rn = row_idxs[col_ptr]
                        if not row_inserted_flag[rn]:
                            sample_row_idxs.append(rn)
                            row_inserted_flag[rn] = True
                            cnt += 1
                        col_ptr += 1
                    pos_ptr += 1
                    
                    num_train = split_count(len(sample_row_idxs), perc=0.7)
                    sample_row_idxs = list(np.random.permutation(sample_row_idxs))
                    train_rowidx.extend(sample_row_idxs[:num_train])
                    test_rowidx.extend(sample_row_idxs[num_train:])
                    """
                    print(f'Sample row : {sample_row_idxs}')
                    print(f'train : {train_rowidx}')
                    print(f'test  : {test_rowidx}')
                    """
                    
        sample_row_idxs = np.where(row_inserted_flag == False)[0]
        num_train = split_count(num_rows, perc=0.7)
        num_train -= len(train_rowidx)
        
        sample_row_idxs = list(np.random.permutation(sample_row_idxs))
        train_rowidx.extend(sample_row_idxs[:num_train])
        test_rowidx.extend(sample_row_idxs[num_train:])
        
        return train_rowidx, test_rowidx
    
    def get_split_bylabel(self, upper_threshold=10, perc=0.7):
        """
        splitting data into train-test
        """
        train_idx, test_idx = self.get_split_idx(upper_threshold, perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        self.trn_tst_labels = self.labels
        
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        """
        pruning the columns
        """
        trn_pruned_cols = self.get_pruned_cols(self.train)
        tst_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(trn_pruned_cols, tst_pruned_cols)
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.trn_tst_labels, pruned_cols)
        
        """
        prunning the rows
        """
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex, pruned_rows)
        
    def get_random_split_idx(self, perc=0.7):
        n_docs = self.graph.shape[0]
        n_trn = int(perc * n_docs)
        rand_idx = np.random.permutation(n_docs)
        return rand_idx[:n_trn], rand_idx[n_trn:]
    
    def get_pruned_cols(self, graph, count=0):
        label_cnt = np.array(graph.getnnz(axis=0)).reshape(-1)
        pruned_cols = np.where(label_cnt > count)[0]
        return pruned_cols
    
    def get_pruned_row(self, graph, count=0):
        pruned_rows = np.where( np.array(graph.getnnz(axis=1)).reshape(-1) > count )[0]
        return pruned_rows
    
    def prune_graph_cols(self, graph, labels, pruned_cols):
        graph = graph[:, pruned_cols]
        labels = prune_map(labels, pruned_cols)
        return graph, labels
    
    def prune_graph_rows(self, graph, doc_to_rowindex, pruned_rows):
        graph = graph[pruned_rows, :]
        doc_to_rowindex = prune_map(doc_to_rowindex, pruned_rows)
        return graph, doc_to_rowindex
    
    def remove_single_labels(self, graph, labels, doc_to_rowindex):
        pruned_cols = self.get_pruned_cols(graph, count=1)
        graph, labels = self.prune_graph_cols(graph, labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(graph)
        graph, doc_to_rowindex = self.prune_graph_rows(graph, doc_to_rowindex, pruned_rows)

        return graph, labels, doc_to_rowindex
        
    def get_split_byrandom(self, perc=0.7):
        train_idx, test_idx = self.get_random_split_idx(perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        train_pruned_cols = self.get_pruned_cols(self.train)
        test_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(train_pruned_cols, test_pruned_cols)
        
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, 
                                                                       pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex,
                                                                     pruned_rows)
        
    def save_data(self, save_dir, tag='category'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'wb') as fout:
            train = (self.trn_tst_labels, self.train_doc_to_rowindex, self.train)
            pickle.dump(train, fout)
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'wb') as fout:
            test = (self.trn_tst_labels, self.test_doc_to_rowindex, self.test)
            pickle.dump(test, fout)
            
    def load_data(self, save_dir, tag='category'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'rb') as fout:
            train = pickle.load(fout)
            self.trn_tst_labels, self.train_doc_to_rowindex, self.train = train
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'rb') as fout:
            test = pickle.load(fout)
            _, self.test_doc_to_rowindex, self.test = test
            

In [135]:
amazon_graphs = AmazonGraphContainer()
amazon_graphs.load_graphs(save_dir, tag='_dict')

In [136]:
amazon_graphs.load_idtotitle(save_dir)

In [137]:
id_to_title = amazon_graphs.id_to_title

In [141]:
graph_types = ['also_buy', 'also_view', 'similar']

data_splitter = {}
for graph_type in graph_types:
    
    data_splitter[graph_type] = AmazonSplit(amazon_graphs.graphs[graph_type].graph)

    data_splitter[graph_type].clean_matrix()
    data_splitter[graph_type].get_split_bylabel(upper_threshold=10)
    data_splitter[graph_type].save_data(save_dir, tag=f'{graph_type}')

In [142]:
metrics = []

for graph_type in graph_types:
    m = []
    trn_graph = data_splitter[graph_type].train
    tst_graph = data_splitter[graph_type].test
    m.append(trn_graph.shape[1])
    m.append(trn_graph.shape[0])
    m.append(tst_graph.shape[0])
    
    graph = data_splitter[graph_type].graph
    points_per_label = np.array(graph.sum(axis=0))[0]
    m.append(np.all( points_per_label > 0 ))
    m.append(np.mean(points_per_label))
    
    labels_per_point = np.array(graph.sum(axis=1))[0]
    m.append(np.all( labels_per_point > 0 ))
    m.append(np.mean(labels_per_point))
    
    metrics.append(m)
    

In [143]:
import pandas as pd

In [144]:
pd.DataFrame(metrics,columns=['labels', 'number of train', 'number of test', 'full labels', 'points per label',
                             'full rows', 'labels per point'], index=graph_types)

Unnamed: 0,labels,number of train,number of test,full labels,points per label,full rows,labels per point
also_buy,1591824,1502829,645683,True,16.246278,True,3.0
also_view,1712182,1722517,742851,True,10.785238,True,1.0
similar,388544,840754,361756,True,6.297799,True,1.0


In [80]:
count_empty_id = 0
count_empty_title = 0

for product_id, product_title in amazon_graphs.id_to_title.items():
    if not len(product_id):
        count_empty_id += 1
    if not len(product_title):
        count_empty_title += 1
        break

In [82]:
amazon_graphs.id_to_title[product_id]

''

## XC dataset format

In [131]:
class XCDataset:
    
    def __init__(self):
        self.description = None
        self.id_to_title = {}
        
        self.data_splitter = {}
        
    def load_idtotitle(self, save_dir, tag='', verbose=True):
        """
        Loading id_to_title
        """
        if verbose:
            print("** Loading Amazon product 'id_to_title'.")
            
        graph_data = AmazonGraphContainer()        
        graph_data.load_idtotitle(save_dir, tag=tag)
        
        self.id_to_title = graph_data.id_to_title
        
    def load_description(self, save_dir, tag='', verbose=True):
        """
        Loading description.
        """
        if verbose:
            print("** Loading Amazon 'description'.")
            
        graph_data = AmazonGraphContainer()       
        graph_data.load_description(save_dir, tag=tag)
        
        self.description = graph_data.description
        
    def save_sparse_file(self, matrix, save_file):
        matrix.sort_indices()
        rows, cols = matrix.nonzero()
        data = matrix.data

        with open(save_file, 'w') as fout:
            fout.write(f'{matrix.shape[0]} {matrix.shape[1]}\n')

            row_ctr = -1
            line = ''
            for r, c, d in zip(rows, cols, data):
                if row_ctr == r:
                    line = line+f' {c}:{d}'
                elif row_ctr+1 == r:
                    if line:
                        fout.write(f'{line}\n')
                    line = f'{c}:{d}'
                    row_ctr += 1
                else:
                    raise Exception("Row is missing.")
            if line:
                fout.write(f'{line}\n')
        
    def save_XY_text(self, save_dir, doc_to_rowindex, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        
        idtocontent_file = f'{save_dir}/{tag}_id_to_text.txt'
        self.save_XY_content(idtocontent_file, self.id_to_title, doc_to_rowindex, self.description)

        idtotitle_file = f'{save_dir}/{tag}_id_to_title.txt'
        self.save_XY_title(idtotitle_file, self.id_to_title, doc_to_rowindex)
                    
    def save_XY_content(self, idtocontent_file, id_to_title, doc_to_rowindex, content):
        with open(idtocontent_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                if doc_id in content:
                    fout.write(f'{doc_id}->{content[doc_id]}\n')
                else:
                    fout.write(f'{doc_id}->\n')
                    
    def save_XY_title(self, idtotitle_file, id_to_title, doc_to_rowindex):
        with open(idtotitle_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                if doc_id in id_to_title:
                    fout.write(f'{doc_id}->{id_to_title[doc_id]}\n')
                else:
                    fout.write(f'{doc_id}\n')
                
    def load_classification_data(self, save_dir, tag='_dict', graph_type='also_buy', verbose=True):
        """
        Creating Classification train-test split
        """
        if verbose:
            print(f"** Creating {graph_type} classification train-test.")
        
        graph_data = AmazonGraphContainer()
        graph_data.load_graph(save_dir, tag=tag, graph_type=graph_type)
        
        self.data_splitter[graph_type] = AmazonSplit(graph_data.graphs[graph_type].graph)
        self.data_splitter[graph_type].clean_matrix()
        self.data_splitter[graph_type].get_split_bylabel(upper_threshold=10)
                
    def save_XCClassification_text(self, xc_dir, graph_type='also_buy', verbose=True):
        """
        Saving Classification XY - title and content(text) 
        """
        if verbose:
            print(f"** Saving {graph_type} classification train X-article text")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].train_doc_to_rowindex,
                          tag=f'{graph_type}_classification_train_X')
        if verbose:
            print(f"** Saving {graph_type} classification test X-article text")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].test_doc_to_rowindex,
                          tag=f'{graph_type}_classification_test_X')
        if verbose:
            print(f"** Saving {graph_type} classification Y-label text")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].trn_tst_labels,
                          tag=f'{graph_type}_classification_Y')
                      
    def save_XCClassification_data(self, xc_dir, graph_type='also_buy', verbose=True):
        if verbose:
            print(f"** Saving {graph_type} classification train 'trn_X_Y.txt'.")
        train_file = f'{xc_dir}/{graph_type}_trn_X_Y.txt'
        self.save_sparse_file(self.data_splitter[graph_type].train, train_file)
        
        if verbose:
            print(f"** Saving {graph_type} classification test 'tst_X_Y.txt'.")
        test_file = f'{xc_dir}/{graph_type}_tst_X_Y.txt'
        self.save_sparse_file(self.data_splitter[graph_type].test, test_file)
        
    def save_XCClassification(self, save_dir, xc_dir, tag='_dict', graph_type='also_buy', verbose=True):
        """
        XC Classification
        """
        self.load_classification_data(save_dir, tag=tag, graph_type=graph_type, verbose=verbose)
        self.save_XCClassification_text(xc_dir, graph_type=graph_type, verbose=verbose)
        self.save_XCClassification_data(xc_dir, graph_type=graph_type, verbose=verbose)
        
    def load_graph_data(self, save_dir, tag='', graph_type='also_buy', verbose=True):
        """
        Loading Graph
        """
        if verbose:
            print(f"** Loading {graph_type} graph.")
            
        graph_data = AmazonGraphContainer()
        graph_data.load_graph(save_dir, tag=tag, graph_type=graph_type)
        
        self.data_splitter[graph_type] = AmazonSplit(graph_data.graphs[graph_type].graph)
        self.data_splitter[graph_type].clean_matrix(clean_type=1)
    
    def save_XCGraph_text(self, xc_dir, graph_type='also_buy', verbose=True):
        """
        Saving XC-Graph text
        """
        if verbose:
            print(f"** Saving {graph_type}_graph X-text.")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].doc_to_rowindex,
                          tag=f'{graph_type}_graph_X')
        
        if verbose:
            print(f"** Saving {graph_type}_graph Y-text.")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].labels,
                          tag=f'{graph_type}_graph_Y')
        
    def save_XCGraph_data(self, xc_dir, graph_type='also_buy', verbose=True):
        if verbose:
            print(f"** Saving '{graph_type}_graph_trn_X_Y.txt'")
        graph_file = f'{xc_dir}/{graph_type}_graph_trn_X_Y.txt'
        self.save_sparse_file(self.data_splitter[graph_type].graph, graph_file)
        
    def save_XCGraph(self, save_dir, xc_dir, tag='_dict', graph_type='also_buy', verbose=True):
        """
        XC Graph
        """
        self.load_graph_data(save_dir, tag=tag, graph_type=graph_type, verbose=verbose)
        self.save_XCGraph_text(xc_dir, graph_type=graph_type, verbose=verbose)
        self.save_XCGraph_data(xc_dir, graph_type=graph_type, verbose=verbose)
        
    def create_XCData(self, save_dir, xc_dir, tag='_dict', verbose=True):
        self.load_idtotitle(save_dir, tag='', verbose=verbose)
        self.load_description(save_dir, tag='', verbose=verbose)
        
        graph_types = ['also_buy', 'also_view', 'similar']
        for graph_type in graph_types:
            if verbose:
                print(f'-- Processing {graph_type} graph.')
            self.save_XCClassification(save_dir, xc_dir, tag=tag, graph_type=graph_type, verbose=verbose)
            self.save_XCGraph(save_dir, xc_dir, tag=tag, graph_type=graph_type, verbose=verbose)
            if verbose:
                print()
                

In [132]:
save_dir = f'{dataset_home}/GraphAmazonProducts/results'
xc_dir = f'{dataset_home}/GraphAmazonProducts/XCData'

In [133]:
xc_data = XCDataset()
xc_data.create_XCData(save_dir, xc_dir, tag=f'_dict')

** Loading Amazon product 'id_to_title'.
** Loading Amazon 'description'.
-- Processing also_buy graph.
** Creating also_buy classification train-test.
** Saving also_buy classification train X-article text
** Saving also_buy classification test X-article text
** Saving also_buy classification Y-label text
** Saving also_buy classification train 'trn_X_Y.txt'.
** Saving also_buy classification test 'tst_X_Y.txt'.
** Loading also_buy graph.
** Saving also_buy_graph X-text.
** Saving also_buy_graph Y-text.
** Saving 'also_buy_graph_trn_X_Y.txt'

-- Processing also_view graph.
** Creating also_view classification train-test.
** Saving also_view classification train X-article text
** Saving also_view classification test X-article text
** Saving also_view classification Y-label text
** Saving also_view classification train 'trn_X_Y.txt'.
** Saving also_view classification test 'tst_X_Y.txt'.
** Loading also_view graph.
** Saving also_view_graph X-text.
** Saving also_view_graph Y-text.
** S