In [1]:
import os
import re
import gc
import json
import pickle
import xml.sax
import requests
import pandas as pd

import bz2
import subprocess
import numpy as np
from IPython import display
from scipy.sparse import csr_matrix
from matplotlib import pyplot as plt
from timeit import default_timer as timer

In [2]:
import tqdm
from functools import partial
from multiprocessing import Pool

In [3]:
import mwparserfromhell

In [4]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Helper functions

In [None]:
def dict_head_random(dictionary, n=10):
    keys = np.random.choice(list(dictionary.keys()), size=n)
    for k in keys:
        print(f'{k} : {dictionary[k]}')
        

## Arguement

In [6]:
project = 'enwiki'
dump_date = "20220420"
dataset_home = '/home/cse/phd/anz198717/scratch/suchith_data/wikipedia/wikipedia-data-science'

In [7]:
prog = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')

partitions = sorted([f'{dataset_home}/datasets/{file}' for file in os.listdir(f'{dataset_home}/datasets') 
                     if prog.match(file)])

print(f'Total number of partitions of the wikipedia dump : {len(partitions)}')

Total number of partitions of the wikipedia dump : 62


## Page Handler

In [73]:
class WikiXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self, matches=None):
        xml.sax.handler.ContentHandler.__init__(self)
        
        """
        PARSING VARIABLES: these variables will be useful
        while parsing the wikipedia dumps.
        """
        #regrex to select wikipedia section
        #TODO: remove this from here.
        self.matches = matches
        
        #basic storage for on-the-fly processing
        self._buffer = None
        self._values = {}
        self._current_tag = None
        
        #flags for handling special cases.
        self._add_page = True
        self._is_pageid = True
        
        #from here improvements can be made
        self.pages = []
        self.category = []
        self._total_edges = 0
        self._article_count = 0
        self._category_count = 0
        
        """
        output information
        """
        #self.article_mat = WikilinkGraph()
        #self.seealso_mat = WikilinkGraph()
        
        self.redirects = {}
        self.id_to_title = {}
        self.page_content = {}
        
        #self.found_cat_redirect = False
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title', 'text', 'ns'):
            self._current_tag = name
            self._buffer = []
        elif name == 'id' and self._is_pageid:
            self._is_pageid = False
            self._current_tag = name
            self._buffer = []    
        elif name == 'redirect':
            article_title = self._values['article_title'].strip()
            self.redirects[article_title] = attrs.getValue('title').strip()
            #if is_category_link(article_title):
                #self.found_cat_redirect = True
                #print(article_title)
            self._add_page = False
                   
    def endElement(self, name):
        if name == self._current_tag:
            self._values[f'article_{name}'] = ' '.join(self._buffer)
            self._current_tag = None
        elif name == 'page':
            article_ns = int(self._values['article_ns'])
            if article_ns != 0 and article_ns != 14:
                self._add_page = False
                
            if self._add_page:
                if article_ns == 0:
                    self._article_count += 1
                    self.pages.append(self._values.copy())
                elif article_ns == 14:
                    self._category_count += 1
                    self.category.append(self._values.copy())
                
            self._add_page = True
            self._is_pageid = True
            

In [74]:
data_path = sorted(partitions)[2]
matches = r'^([Ss]ee[ ]*|[Ss]ee[ ]*([Aa]lso|[Mm]ore|[Aa]ll)|[Ss]ee[ ]*[Aa]lso[ ]*\(.+\))$'

In [75]:
data_path

'/home/cse/phd/anz198717/scratch/suchith_data/wikipedia/wikipedia-data-science/datasets/enwiki-20220420-pages-articles-multistream11.xml-p5399367p6899366.bz2'

In [77]:
start = timer()

handler = WikiXmlHandler(matches)

parser = xml.sax.make_parser()
parser.setContentHandler(handler)

for i, line in enumerate(subprocess.Popen(['bzcat'], 
                         stdin = open(data_path), 
                         stdout = subprocess.PIPE).stdout):
    parser.feed(line)
    
    if handler._category_count > 5:
    #if handler.found_cat_redirect:
        break
        
end = timer()
        
print(f'Searched through {handler._article_count} articles.')
print(f'Processing time is {round(end - start)} secs.')

Searched through 106 articles.
Processing time is 0 secs.


In [78]:
handler._article_count, len(handler.pages), handler._category_count, len(handler.category)

(106, 106, 6, 6)

In [92]:
article_num = 50

article_ns = handler.pages[article_num]['article_ns']
article_id = handler.pages[article_num]['article_id']
article_title = handler.pages[article_num]['article_title']
article_text = handler.pages[article_num]['article_text']

In [109]:
article_num = 5

article_ns = handler.category[article_num]['article_ns']
article_id = handler.category[article_num]['article_id']
article_title = handler.category[article_num]['article_title']
article_text = handler.category[article_num]['article_text']

In [110]:
print(article_ns)
print(article_id)
print(article_title)
print(article_text)

14
5400296
Category:Sensors
{{Cat main|Sensor}} 
 {{Commons cat|Sensors}} 
 
 [[Category:Measuring instruments]] 
 [[Category:Transducers]] 
 {{CatAutoTOC}}


## Graph XML handler

In [38]:
class WikiXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self):
        xml.sax.handler.ContentHandler.__init__(self)
        
        """
        PARSING VARIABLES: these variables will be useful
        while parsing the wikipedia dumps.
        """
        #basic storage for on-the-fly processing
        self._buffer = None
        self._values = {}
        self._current_tag = None
        
        #flags for handling special cases.
        self._add_page = True
        self._is_pageid = True
        
        """
        STORAGE VARIABLES: these variable will be used for
        storing the graph and content of the wikipedia dump.
        """
        self.wikidataset = WikiGraphDataset()
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title', 'text', 'ns'):
            self._current_tag = name
            self._buffer = []
            
        elif name == 'id' and self._is_pageid:
            self._is_pageid = False
            self._current_tag = name
            self._buffer = []
            
        elif name == 'redirect':
            article_title = self._values['article_title'].strip()
            target_title = attrs.getValue('title').strip()
            article_ns = int(self._values['article_ns'])
            if article_ns == 0 or article_ns == 14 :
                self.wikidataset.add_redirect(article_title, target_title, article_ns)
            self._add_page = False
                   
    def endElement(self, name):
        if name == self._current_tag:
            self._values[f'article_{name}'] = ' '.join(self._buffer)
            self._current_tag = None
            
        elif name == 'page':
            article_ns = int(self._values['article_ns'])
            if article_ns != 0 and article_ns != 14:
                self._add_page = False
                
            """
            EXTRACT_DATA : The following code stores the data
            """
            if self._add_page:
                self.wikidataset.add_article(**self._values)
            """
            EXTRACT_DATA
            """
            
            self._add_page = True
            self._is_pageid = True
            

## Filtering links

In [10]:
class FilterWikilinks:
    
    def __init__(self):
        """
        Removed category, cat from subject namespaces.
        """
        self.subject_namespaces = {'user', 'wikipedia', 'wp', 'project', 'file', 'image', 'mediawiki', 
                                   'template', 't', 'help', 'h', 'portal', 'p', 
                                   'draft', 'timedtext', 'module', 'special', 'media'}
        
        self.talk_namespaces = {'talk', 'user talk', 'wikipedia talk', 'wt', 'project talk', 'file talk', 
                                'image talk','mediawiki talk', 'template talk', 'help talk', 'category talk', 
                                'portal talk', 'draft talk', 'timedtext talk', 'module talk'}
        
        self.interwiki_links = {'wiktionary', 'wikt', 'wikinews', 'n', 'wikibooks', 'b', 'wikiquote','q', 
                                'wikisource','s', 'oldwikisource', 'wikispecies', 'species', 
                                'wikiversity', 'v', 'wikivoyage', 'voy', 'wikimedia','foundation', 'wmf', 
                                'commons', 'c', 'metawiki', 'metawikimedia', 'metawikipedia', 'meta' , 'm', 
                                'incubator', 'strategy', 'mediawikiwiki', 'mw', 'mediazilla', 'bugzilla'}
        
        self.language_code = re.compile(r'^[a-z][a-z]$')
        
    def remove_section_tags(self, link):
        hash_parts = link.split('#')
        if len(hash_parts) > 1:
            link = hash_parts[0].strip()  
        return link
    
    def filter_special_tags(self, link):
        colon_parts = link.split(':')
        part_num = 0
        for part in colon_parts:
            part = part.lower()
            if (part == "w") or (part == "en") or (part_num == 0 and not part):
                part_num += 1
            elif (part in self.subject_namespaces) or (part in self.talk_namespaces) \
            or (part in self.interwiki_links) or self.language_code.match(part):
                return ''
            else:
                break
        return ':'.join(colon_parts[part_num:])
    
    def lower_wikilink(self, link):
        cat_link = is_category_link(link)
        if cat_link:
            return 'category:' + cat_link[0].lower() + cat_link[1:] 
        elif link: 
            link = link[0].lower() + link[1:]
        return link
        
    def remove_underscore(self, link):
        return link.replace('_', ' ')
    
    def fix_single_quotes(self, link):
        if '"' in link:
            link_split = link.split('"')
            for i in range(len(link_split)):
                if i == 0:
                    link_split[i] = link_split[i][:-1]
                elif i == len(link_split)-1:
                    link_split[i] = link_split[i][1:]
                else:
                    link_split[i] = link_split[i][1:-1]
            return '"'.join(link_split)
        return link
    
    def process_wikilink(self, link):
        link = self.remove_section_tags(link)
        link = self.filter_special_tags(link)
        link = self.lower_wikilink(link)
        link = self.remove_underscore(link)
        link = self.fix_single_quotes(link)
        return link
    
    def extract_wikilink_from_wikicode(self, wikicode):
        links = list( map(lambda link: (link.title).strip_code().strip(), wikicode.filter_wikilinks()) )
        return links
    
    """
    The main reason for this function is to seperate category and article links.
    'group_1' contains the category links
    'group_2' contains the article links.
    """
    def seggregate_links(self, links, seg_func):
        group_1, group_2 = list(), list()
        for link in links:
            processed_link = seg_func(link)
            if processed_link:
                group_1.append(processed_link)
            else:
                group_2.append(link)
        return group_1, group_2
    

In [11]:
def is_category_link(link):
    link_parts = link.split(":", maxsplit=1)
    link_type = link_parts[0].lower()
    if len(link_parts) == 2 and (link_type == "category" or link_type == "cat"):
        return link_parts[1]
    return ''


### Testing

#### test_1

In [80]:
link_filter = FilterWikilinks()
link_filter.process_wikilink('en:w:category:sdfsf')

'category:sdfsf'

#### test_2

In [55]:
link_filter = FilterWikilinks()

In [56]:
wikicode = mwparserfromhell.parse(article_text, skip_style_tags=True)
wikicode.remove_nodetype(inplace=True)

In [57]:
links = link_filter.extract_wikilink_from_wikicode(wikicode)

wikilinks = list()
for link in links:
    processed_link = link_filter.process_wikilink(link)
    if processed_link:
        wikilinks.append(processed_link)

In [58]:
wikilinks, len(wikilinks)

(['category:measuring instruments', 'category:transducers'], 2)

In [59]:
cat_links, article_links = link_filter.seggregate_links(wikilinks, is_category_link)

In [60]:
cat_links, len(cat_links)

(['measuring instruments', 'transducers'], 2)

In [61]:
article_links, len(article_links)

([], 0)

## Data-structures

In [226]:
class WikilinkGraph:

    def __init__(self):
        self.graph = {}
        
    def add_article(self, article_id, wikilinks):
        links, count = np.unique(wikilinks, return_counts=True)
        self.graph[article_id] = (links.tolist(), count.tolist())
        
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        filename = f'{save_dir}/link_graph{tag}.pickle'
        with open(filename, 'wb') as f:
            pickle.dump(self.graph, f)
            
        del self.graph
        gc.collect()
            
    def load_data(self, save_dir, tag=''):
        filename = f'{save_dir}/link_graph{tag}.pickle'
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                self.graph = pickle.load(f)
            return True
        print(f"ERROR:: Unable to load the graph at '{filename}'")
        return False
    
    def replace_redirects(self, redirects):
        for edges, _ in self.graph.values():
            for i, edge in enumerate(edges):
                if edge:
                    edge = edge[0].lower()+edge[1:]
                    if edge in redirects:
                        edges[i] = redirects[edge]
        return None
    
    def remove_dead(self, title_to_id):
        for article_id, (edges, counts) in self.graph.items():
            active_edges = list()
            active_counts = list()

            while edges:
                edge = edges.pop()
                count = counts.pop()

                if edge:
                    edge = edge[0].lower()+edge[1:]
                    if edge in title_to_id:
                        active_edges.append(title_to_id[edge])
                        active_counts.append(count)
            self.graph[article_id] = (active_edges, active_counts)  
        return None
    

In [82]:
class WikiGraphDataset:
    
    def __init__(self):
        
        self.article_graph = WikilinkGraph()
        self.label_graph = WikilinkGraph()
        self.classification_graph = WikilinkGraph()
        
        self.id_to_title = {}
        self.label_id_to_title = {}
        
        self.wiki_content = {}
        
        self.redirects = {}
        self.label_redirects = {}
        
        self.filter_links = FilterWikilinks()
    
    """
    Add articles to graph dataset.
    """
    def extract_article_info(self, article_text):
        wikicode = mwparserfromhell.parse(article_text, skip_style_tags=True)
        wikicode.remove_nodetype(inplace=True)
        
        """
        extract link
        """
        links = self.filter_links.extract_wikilink_from_wikicode(wikicode)

        wikilinks = list()
        for link in links:
            processed_link = self.filter_links.process_wikilink(link)
            if processed_link:
                wikilinks.append(processed_link)

        category_links, article_links = self.filter_links.seggregate_links(wikilinks, is_category_link)
        
        """
        extract content
        """
        article_content = wikicode.strip_code().strip()

        return category_links, article_links, article_content
    
    def add_article(self, article_title, article_text, article_id, article_ns):
        article_title, article_id, article_ns = article_title.strip(), int(article_id.strip()), int(article_ns.strip())
        
        article_title = self.filter_links.process_wikilink(article_title)
        cat_title = is_category_link(article_title)
        if cat_title:
            article_title = cat_title
        
        if not article_title or article_id in self.id_to_title or article_id in self.label_id_to_title:
            return
        
        category_links, article_links, article_content = self.extract_article_info(article_text)
        
        if len(category_links) or len(article_links):
            if article_ns == 0:
                self.id_to_title[article_id] = article_title
                if len(category_links):
                    self.classification_graph.add_article(article_id, category_links) 
                if len(article_links):
                    self.article_graph.add_article(article_id, article_links)
                if article_content:
                    self.wiki_content[article_id] = article_content
            elif article_ns == 14:
                self.label_id_to_title[article_id] = article_title
                if len(category_links):
                    self.label_graph.add_article(article_id, category_links)
                    
    """
    Add redirect
    """
    def add_redirect(self, article_title, target_title, article_ns):
        article_title = self.filter_links.process_wikilink(article_title)
        target_title = self.filter_links.process_wikilink(target_title)
        
        if article_title and target_title:
            if article_ns == 14:
                target_title = is_category_link(target_title)
                article_title = is_category_link(article_title)
                if target_title and article_title:
                    self.label_redirects[article_title] = target_title
            else:
                self.redirects[article_title] = target_title
            
    """
    Save content
    """
    def save_graph(self, save_dir, tag='', graph_type='all'):
        if graph_type == 'classification' or graph_type == 'all':
            self.classification_graph.save_data(save_dir, tag=f'_classification{tag}')
            
        if graph_type == 'article' or graph_type == 'all':
            self.article_graph.save_data(save_dir, tag=f'_article{tag}')
            
        if graph_type == 'label' or graph_type == 'all':
            self.label_graph.save_data(save_dir, tag=f'_label{tag}')
              
    def save_idtotitle(self, save_dir, tag='', id_type='all'):
        if id_type == "article" or id_type == "all":
            map_file = f'{save_dir}/id_to_title{tag}.pickle'
            with open(map_file, 'wb') as f:
                pickle.dump(self.id_to_title, f)   
            del self.id_to_title
        
        if id_type == "label" or id_type == "all":
            map_file = f'{save_dir}/label_id_to_title{tag}.pickle'
            with open(map_file, 'wb') as f:
                pickle.dump(self.label_id_to_title, f)   
            del self.label_id_to_title
        gc.collect()
            
    def save_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'wb') as f:
            pickle.dump(self.wiki_content, f)
        del self.wiki_content
        gc.collect()
            
    def save_redirects(self, save_dir, tag='', redirect_type='all'):
        if redirect_type == "article" or redirect_type == "all":
            redirect_file = f'{save_dir}/redirects{tag}.pickle'
            with open(redirect_file, 'wb') as f:
                pickle.dump(self.redirects, f)
            del self.redirects
            
        if redirect_type == "label" or redirect_type == "all":
            redirect_file = f'{save_dir}/label_redirects{tag}.pickle'
            with open(redirect_file, 'wb') as f:
                pickle.dump(self.label_redirects, f)
            del self.label_redirects
        gc.collect()
        
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        
        self.save_graph(save_dir, tag)
        self.save_idtotitle(save_dir, tag)
        self.save_redirects(save_dir, tag)
        self.save_wikicontent(save_dir, tag)
            
    def load_graph(self, save_dir, tag='', graph_type='all'):
        if graph_type == 'classification' or graph_type == 'all':
            if not self.classification_graph.load_data(save_dir, tag=f'_classification{tag}'):
                raise Exception("Unable to load 'classification graph'.")
                
        if graph_type == 'article' or graph_type == 'all':
            if not self.article_graph.load_data(save_dir, tag=f'_article{tag}'):
                raise Exception("Unable to load 'article graph'.")
                
        if graph_type == 'label' or graph_type == 'all':
            if not self.label_graph.load_data(save_dir, tag=f'_label{tag}'):
                raise Exception("Unable to load 'label graph'.")
        
    def load_idtotitle(self, save_dir, tag='', id_type='all'):
        if id_type == "article" or id_type == "all":
            map_file = f'{save_dir}/id_to_title{tag}.pickle'
            if os.path.exists(map_file):
                with open(map_file, 'rb') as f:
                    self.id_to_title = pickle.load(f)
            else:
                raise Exception(f"Unable to load 'id_to_title' from '{map_file}'.")
                
        if id_type == "label" or id_type == "all":
            map_file = f'{save_dir}/label_id_to_title{tag}.pickle'
            if os.path.exists(map_file):
                with open(map_file, 'rb') as f:
                    self.label_id_to_title = pickle.load(f)
            else:
                raise Exception(f"Unable to load 'label_id_to_title' from '{map_file}'.")
            
    def load_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        if os.path.exists(content_file):
            with open(content_file, 'rb') as f:
                self.wiki_content = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'wiki_content' from '{content_file}'.")
            
    def load_redirects(self, save_dir, tag='', redirect_type='all'):
        if redirect_type == "article" or redirect_type == "all":
            redirect_file = f'{save_dir}/redirects{tag}.pickle'
            if os.path.exists(redirect_file):
                with open(redirect_file, 'rb') as f:
                    self.redirects = pickle.load(f)
            else:
                raise Exception(f"Unable to load 'redirects' from '{redirect_file}'.")
                
        if redirect_type == "label" or redirect_type == "all":
            redirect_file = f'{save_dir}/label_redirects{tag}.pickle'
            if os.path.exists(redirect_file):
                with open(redirect_file, 'rb') as f:
                    self.label_redirects = pickle.load(f)
            else:
                raise Exception(f"Unable to load 'label redirects' from '{redirect_file}'.")
                
    def load_data(self, save_dir, tag=''):
        self.load_graph(save_dir, tag)
        self.load_idtotitle(save_dir, tag)
        self.load_wikicontent(save_dir, tag)
        self.load_redirects(save_dir, tag)
                

### Testing

#### Adding single article

In [99]:
page_num = 50
handler.pages[page_num]['article_title']

'Walt Disney World Millennium Celebration'

In [111]:
page_num = 5
handler.category[page_num]['article_title']

'Category:Sensors'

In [103]:
wikidataset = WikiGraphDataset()

In [104]:
wikidataset.add_article(**handler.pages[page_num])

In [112]:
wikidataset.add_article(**handler.category[page_num])

In [113]:
wikidataset.article_graph.graph

{5399804: (['eXPO 2000',
   'epcot',
   'epcot attraction and entertainment history',
   'illumiNations: Reflections of Earth',
   "mcDonald's",
   'millennium celebrations',
   'spaceship Earth (Epcot)',
   'super Bowl XXXIV',
   'tapestry of Nations',
   'uNESCO',
   'united Nations',
   'walt Disney World Resort',
   'world Bank'],
  [1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}

In [114]:
wikidataset.classification_graph.graph

{5399804: (['1999 establishments in Florida',
   '2001 disestablishments in Florida',
   'epcot',
   'events at Walt Disney World',
   'former Walt Disney Parks and Resorts attractions',
   'turn of the third millennium',
   'walt Disney Parks and Resorts entertainment'],
  [1, 1, 1, 1, 1, 1, 1])}

In [115]:
wikidataset.label_graph.graph

{5400296: (['measuring instruments', 'transducers'], [1, 1])}

In [116]:
wikidataset.label_id_to_title, wikidataset.id_to_title

({5400296: 'sensors'}, {5399804: 'walt Disney World Millennium Celebration'})

#### Adding multiple articles

In [117]:
wikidataset = WikiGraphDataset()

In [118]:
for i in range(len(handler.pages)):
    wikidataset.add_article(**handler.pages[i])

In [119]:
for i in range(len(handler.category)):
    wikidataset.add_article(**handler.category[i])

In [121]:
wikidataset.label_id_to_title

{5399399: 'restaurants in Ohio',
 5399559: 'congress of South African Trade Unions',
 5400010: 'rAAF radar and surveillance units',
 5400259: 'prison Break episodes',
 5400296: 'sensors'}

In [122]:
dict_head_random(wikidataset.article_graph.graph, n=2)

5399441 : (['a Ticket in Tatts (1934 film)', 'aberdeen, New South Wales', 'ausStage', 'austLit: The Australian Literature Resource', 'australian Dictionary of Biography', 'australian Screen Online', 'barry Humphries', 'brisbane', 'charles Chaplin', 'cinema of the United States', 'cinesound Productions', "collits' Inn", 'duck Soup (1933 film)', 'everyman', 'frontline (Australian TV series)', 'gone to the Dogs (1939 film)', 'harmony Row (1933 film)', 'his Royal Highness (1932 film)', 'ken G. Hall', 'kensington, New South Wales', 'let George Do It (1938 film)', 'logie Award', 'marx Brothers', 'national Film and Sound Archive', 'new Tivoli Theatre, Sydney', 'north Queensland', 'paul Hogan', 'pyrmont, New South Wales', 'roy Rene', 'slapstick', 'stiffy and Mo', 'street performance', 'sydney', 'the Beloved Vagabond (play)', 'the Great Dictator', 'the National Film and Sound Archive', 'the Rats of Tobruk', 'the Rats of Tobruk (1944 film)', 'tivoli circuit', 'vaudeville', 'wharfie', 'wherever S

In [123]:
dict_head_random(wikidataset.classification_graph.graph, n=2)

5399672 : (['1981 births', '20th-century African-American people', '20th-century African-American women', '21st-century African-American musicians', '21st-century African-American women', '21st-century American rappers', '21st-century American women musicians', 'african-American songwriters', 'african-American women rappers', 'american women rappers', 'def Jam Recordings artists', 'living people', 'people from Englewood, New Jersey', 'rappers from New Jersey', 'songwriters from New Jersey'], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
5400271 : (['hoya de Buñol', 'municipalities in the Province of Valencia'], [1, 1])


In [125]:
wikidataset.id_to_title[5399672]

'lady Luck (rapper)'

In [126]:
dict_head_random(wikidataset.label_graph.graph, n=2)

5400259 : (['american television episodes by series', 'prison Break', 'television episodes set in Chicago', 'television episodes set in Los Angeles', 'television episodes set in prisons'], [1, 1, 1, 1, 1])
5400296 : (['measuring instruments', 'transducers'], [1, 1])


#### Running generation script

In [135]:
save_dir = f'{dataset_home}/WikiCategory/results'

In [136]:
wikidataset = WikiGraphDataset()
wikidataset.load_data(save_dir, tag='-enwiki-20220420-27-p69975910p70585441')

#wikidataset.load_data(save_dir, tag='-enwiki-20220420-27-p69975910p70585441')

In [139]:
dict_head_random(wikidataset.article_graph.graph, n=2)

70460953 : (['alburnoides', 'alburnus', 'amazon River', 'brycon', 'bryconops', 'bryconops caudomaculatus', 'bryconops gracilis', 'carl H. Eigenmann', 'chalceus', 'characidae', 'characiformes', 'cyprinidae', 'floodplain', 'iguanodectes', 'iguanodectidae', 'insect', 'insectivore', 'lateral line', 'orinoco', 'piabucus', 'premaxilla', 'rudolf Kner', 'south America', 'species complex', 'standard Length', 'tetragonopterus', 'type species'], [1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1])
70459860 : (['battle of the Cosmin Forest', 'bayezid II', 'black Sea', 'bukovina', 'crimea', 'crimean Khanate', 'danube Delta', 'dniester', 'john I Albert', 'lviv', 'moldavia', 'stephen the Great', 'valia Kuzmyna', 'volga'], [1, 1, 1, 1, 2, 1, 2, 1, 2, 1, 2, 2, 1, 1])


In [141]:
#dict_head_random(wikidataset.classification_graph.graph, n=2)
wikidataset.classification_graph.graph[70460953]

(['characiformes', 'fish described in 1858'], [1, 1])

In [142]:
wikidataset.id_to_title[70460953]

'bryconops alburnoides'

In [143]:
dict_head_random(wikidataset.label_graph.graph, n=2)

70425522 : (['association football', 'association football players', 'sportsmen by sport'], [1, 1, 1])
70022000 : (['international universities', 'universities in Sweden'], [1, 1])


In [144]:
wikidataset.label_id_to_title[70425522]

"men's association football players"

## Create Graph

In [138]:
def create_graph(data_path, save_dir, matches=None, limit=None, save=True, tag_extractor=None, 
                 back=True, verbose=True):
    
    handler = WikiXmlHandler()
    
    parser = xml.sax.make_parser()
    parser.setContentHandler(handler)
    
    for i, line in enumerate(subprocess.Popen(['bzcat'],
                                             stdin = open(data_path),
                                             stdout = subprocess.PIPE).stdout):
        try:
            parser.feed(line)
        except StopIteration:
            break
            
        if limit is not None and len(handler.wikidataset.id_to_title) >= limit:
            if back:
                return handler.wikidataset
            else:
                break
    
    if save:
        file_tag = ''
        if isinstance(tag_extractor, re.Pattern):
            parts_tag = tag_extractor.match(os.path.basename(data_path))
            if parts_tag:
                try:
                    file_tag = f'-{parts_tag.group(1)}-{parts_tag.group(2)}-{parts_tag.group(3)}'
                except:
                    file_tag = ''
        elif isinstance(tag_extractor, str):
            file_tag = tag_extractor
            
        handler.wikidataset.save_data(save_dir, tag=file_tag)
        if verbose:
            print(f"** Completed processing {os.path.basename(data_path)}.", end='\r')

    del handler
    del parser
    gc.collect()
    
    return None


In [139]:
save_dir = f'{dataset_home}/test_data'

tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')

In [140]:
data_path = partitions[0]
create_graph(data_path, save_dir, matches=matches, limit=100, save=True, tag_extractor=tag_extractor, back=False)

** Completed processing enwiki-20220420-pages-articles-multistream1.xml-p1p41242.bz2.

In [141]:
for data_path in partitions[0:4]:
    create_graph(data_path, save_dir, matches=matches, limit=100, save=True, tag_extractor=tag_extractor, 
                 back=False)

** Completed processing enwiki-20220420-pages-articles-multistream11.xml-p6899367p7054859.bz2.

## Combine redirects, id_to_title and wiki_content

In [145]:
def multiprocessor_1(func, tasks, num_process=10):
    pool = Pool(processes=num_process)
    results = {}
    for x in tqdm.tqdm( pool.imap(func, tasks), total=len(tasks)):
        results.update(x)
    return results


In [146]:
def extract_filetag(data_path, tag_extractor=None):
    file_tag = ''
    if isinstance(tag_extractor, re.Pattern):
        parts_tag = tag_extractor.match(os.path.basename(data_path))
        if parts_tag:
            try:
                file_tag = f'-{parts_tag.group(1)}-{parts_tag.group(2)}-{parts_tag.group(3)}'
            except:
                file_tag = ''
    elif isinstance(tag_extractor, str):
        file_tag = tag_extractor
    return file_tag


In [170]:
class WikiGraphCombine:
    
    def __init__(self, partition_files=None):
        self.partition_files = partition_files
        
        self.parameters = {}
        self.parameters['id_to_title'] = {}
        self.parameters['redirects'] = {}
        
        self.parameters['label_id_to_title'] = {}
        self.parameters['label_redirects'] = {}
        
        self.parameters['wiki_content'] = {}
        
        self.parameters['label_graph'] = {}
        self.parameters['article_graph'] = {}
        self.parameters['classification_graph'] = {}
    
    def load_param(self, data_path, save_dir, param='id_to_title', graph_resolved=True, tag_extractor=None):
        wikidataset = WikiGraphDataset()
        tag = extract_filetag(data_path, tag_extractor)
        
        if param == 'id_to_title':
            wikidataset.load_idtotitle(save_dir=save_dir, tag=tag, id_type="article")
            return wikidataset.id_to_title
        elif param == 'label_id_to_title':
            wikidataset.load_idtotitle(save_dir=save_dir, tag=tag, id_type="label")
            return wikidataset.label_id_to_title
        elif param == 'redirects':
            wikidataset.load_redirects(save_dir=save_dir, tag=tag, redirect_type="article")
            return wikidataset.redirects
        elif param == 'label_redirects':
            wikidataset.load_redirects(save_dir=save_dir, tag=tag, redirect_type="label")
            return wikidataset.label_redirects
        elif param == 'wiki_content':
            wikidataset.load_wikicontent(save_dir=save_dir, tag=tag)
            return wikidataset.wiki_content
        elif param == 'article_graph':
            if graph_resolved:
                tag = f'{tag}_resolved'
            wikidataset.load_graph(save_dir=save_dir, tag=tag, graph_type='article')
            self.convert_graph(wikidataset.article_graph.graph)
            return wikidataset.article_graph.graph
        elif param == 'label_graph':
            if graph_resolved:
                tag = f'{tag}_resolved'
            wikidataset.load_graph(save_dir=save_dir, tag=tag, graph_type='label')
            self.convert_graph(wikidataset.label_graph.graph)
            return wikidataset.label_graph.graph
        elif param == 'classification_graph':
            if graph_resolved:
                tag = f'{tag}_resolved'
            wikidataset.load_graph(save_dir=save_dir, tag=tag, graph_type='classification')
            self.convert_graph(wikidataset.classification_graph.graph)
            return wikidataset.classification_graph.graph
        else:
            raise Exception(f'Invalid value of param : {param}')
        
    def combine_param(self, save_dir, param='id_to_title', graph_resolved=True, tag_extractor=None):
        combine_helper = partial(self.load_param, param=param, save_dir=save_dir, graph_resolved=graph_resolved, 
                                 tag_extractor=tag_extractor)
        self.parameters[param] = multiprocessor_1(combine_helper, self.partition_files)
                    
    def convert_graph(self, graph):
        if len(graph) and isinstance(graph, dict):
            key = list(graph.keys())[0]
            if isinstance(graph[key], tuple):
                for doc, (edges, counts) in graph.items():
                    graph[doc] = {e:c for e, c in zip(edges, counts)}
            elif isinstance(graph[key], dict):
                for doc, edge_count in graph.items():
                    graph[doc] = (list(edge_count.keys()), list(edge_count.values()))
            else:
                raise Exception("Invalid graph format.")
                
    def save_idtotitle(self, save_dir, tag='', id_type="all"):
        if id_type == "article" or id_type == "all":
            os.makedirs(save_dir, exist_ok=True)
            map_file = f'{save_dir}/id_to_title{tag}.pickle'
            with open(map_file, 'wb') as f:
                pickle.dump(self.parameters['id_to_title'], f)
        if id_type == "label" or id_type == "all":
            os.makedirs(save_dir, exist_ok=True)
            map_file = f'{save_dir}/label_id_to_title{tag}.pickle'
            with open(map_file, 'wb') as f:
                pickle.dump(self.parameters['label_id_to_title'], f)
            
    def load_idtotitle(self, save_dir, tag='', id_type="all"):
        if id_type == "article" or id_type == "all":
            map_file = f'{save_dir}/id_to_title{tag}.pickle'
            with open(map_file, 'rb') as f:
                self.parameters['id_to_title'] = pickle.load(f)
        if id_type == "label" or id_type == "all":
            map_file = f'{save_dir}/label_id_to_title{tag}.pickle'
            with open(map_file, 'rb') as f:
                self.parameters['label_id_to_title'] = pickle.load(f)
                
    def save_wikicontent(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'wb') as f:
            pickle.dump(self.parameters['wiki_content'], f)
            
    def load_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'rb') as f:
            self.parameters['wiki_content'] = pickle.load(f)
            
    def save_redirects(self, save_dir, tag='', redirect_type="all"):
        if redirect_type == "article" or redirect_type == "all":
            os.makedirs(save_dir, exist_ok=True)
            redirect_file = f'{save_dir}/redirects{tag}.pickle'
            with open(redirect_file, 'wb') as f:
                pickle.dump(self.parameters['redirects'], f)
        if redirect_type == "label" or redirect_type == "all":
            os.makedirs(save_dir, exist_ok=True)
            redirect_file = f'{save_dir}/label_redirects{tag}.pickle'
            with open(redirect_file, 'wb') as f:
                pickle.dump(self.parameters['label_redirects'], f)
            
    def load_redirects(self, save_dir, tag='', redirect_type="all"):
        if redirect_type == "article" or redirect_type == "all":
            redirect_file = f'{save_dir}/redirects{tag}.pickle'
            with open(redirect_file, 'rb') as f:
                self.parameters['redirects'] = pickle.load(f)
        if redirect_type == "label" or redirect_type == "all":
            redirect_file = f'{save_dir}/label_redirects{tag}.pickle'
            with open(redirect_file, 'rb') as f:
                self.parameters['label_redirects'] = pickle.load(f)
            
    def save_graph(self, save_dir, tag='', graph_type='classification'):
        if graph_type != 'classification' and graph_type != 'article' and graph_type != "label":
            raise Exception("graph_type can only take values in ['classification', 'article', 'label']")
            
        os.makedirs(save_dir, exist_ok=True)
        graph_file = f'{save_dir}/link_graph_{graph_type}{tag}.pickle'
        with open(graph_file, 'wb') as f:
            pickle.dump(self.parameters[f'{graph_type}_graph'], f)
            
    def load_graph(self, save_dir, tag='', graph_type='classification'):
        if graph_type != 'classification' and graph_type != 'article' and graph_type != "label":
            raise Exception("graph_type can only take values in ['classification', 'article', 'label']")
            
        graph_file = f'{save_dir}/link_graph_{graph_type}{tag}.pickle'
        with open(graph_file, 'rb') as f:
            self.parameters[f'{graph_type}_graph'] = pickle.load(f)
            
    def remove_nodes_from_articles(self, nodes):
        articles = self.parameters['article_graph']
        if not articles:
            raise Exception('Article graph is empty.')
            
        for node in tqdm.tqdm( list(articles.keys()) ):
            if node in nodes:
                del articles[node]
            else:
                for edge in list(articles[node].keys()):
                    if edge in nodes:
                        del articles[node][edge]
 

In [192]:
tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')
results_dir = f'{dataset_home}/WikiCategory/results'

save_dir = f'{dataset_home}/WikiCategory/combined'

combine_graph = WikiGraphCombine(partitions)

In [214]:
combine_graph.combine_param(results_dir, param='id_to_title', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:03<00:00, 16.59it/s]


In [215]:
combine_graph.combine_param(results_dir, param='label_id_to_title', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [02:36<00:00,  2.52s/it]


In [243]:
combine_graph.combine_param(results_dir, param='redirects', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [03:18<00:00,  3.20s/it]


In [244]:
combine_graph.combine_param(results_dir, param='label_redirects', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [10:56<00:00, 10.58s/it]


In [245]:
combine_graph.save_idtotitle(save_dir, tag=f'-{project}-{dump_date}')
combine_graph.save_redirects(save_dir, tag=f'-{project}-{dump_date}')

In [246]:
len(combine_graph.parameters['id_to_title']), len(combine_graph.parameters['label_id_to_title'])

(6483885, 1682026)

In [247]:
len(combine_graph.parameters['redirects']), len(combine_graph.parameters['label_redirects'])

(10022173, 50)

### Testing

#### load_graph

In [174]:
id_to_title = combine_graph.load_param(sorted(partitions)[2], results_dir,
                                  param='id_to_title', graph_resolved=False, tag_extractor=tag_extractor)

In [175]:
article_graph = combine_graph.load_param(sorted(partitions)[2], results_dir,
                                  param='article_graph', graph_resolved=False, tag_extractor=tag_extractor)

In [178]:
id_to_title[5399388]

'james E. Hayes'

In [179]:
article_graph

{5399373: {'alexander Helios': 1,
  'caesarion': 1,
  'christopher Bulis': 1,
  'cleopatra': 1,
  'cleopatra Selene II': 1,
  'cybermen': 1,
  'doctor Who': 1,
  'mondas': 1,
  'peri Brown': 1,
  'science fiction on television': 1,
  'sixth Doctor': 1,
  'third Doctor': 1},
 5399379: {'1954 FIFA World Cup': 1,
  'aleksandar Tirnanić': 1,
  'andy Beattie': 1,
  'antonio López Herranz': 1,
  'doug Livingstone': 1,
  'gusztáv Sebes': 1,
  'juan López Fontana': 1,
  'karl Rappan': 1,
  'karol Borhy': 1,
  'kim Yong-sik': 1,
  'lajos Czeizler': 1,
  'pierre Pibarot': 1,
  'sandro Puppo': 1,
  'sepp Herberger': 1,
  'squad number': 1,
  'walter Nausch': 1,
  'walter Winterbottom': 1,
  'zezé Moreira': 1},
 5399386: {'2002–03 Ukrainian Premier League': 1,
  '2003–04 Ukrainian Cup': 1,
  '2003–04 Ukrainian Premier League': 1,
  '2004–05 Ukrainian Premier League': 2,
  '2005–06 Ukrainian Premier League': 3,
  '2006–07 Ukrainian Premier League': 1,
  '2007–08 Ukrainian Cup': 1,
  '2007–08 Ukrain

In [180]:
classification_graph = combine_graph.load_param(sorted(partitions)[2], results_dir,
                                  param='classification_graph', graph_resolved=False, tag_extractor=tag_extractor)

In [181]:
classification_graph[5399388]

{'1865 births': 1,
 '1898 deaths': 1,
 '19th-century American lawyers': 1,
 '19th-century American politicians': 1,
 'boston College people': 1,
 'catholics from Massachusetts': 1,
 'deaths from peritonitis': 1,
 'lawyers from Boston': 1,
 'massachusetts Democrats': 1,
 'massachusetts state senators': 1,
 'members of the Massachusetts House of Representatives': 1,
 'people from Suffolk County, Massachusetts': 1,
 'road incident deaths in Massachusetts': 1,
 'supreme Knights of the Knights of Columbus': 1}

#### convert_graph

In [182]:
g = {1: ([1, 2, 3],[11, 22, 33]), 2: ([4, 5], [44, 55]), 3: ([6], [66])}

In [189]:
g

{1: ([1, 2, 3], [11, 22, 33]), 2: ([4, 5], [44, 55]), 3: ([6], [66])}

In [190]:
combine_graph.convert_graph(g)

In [191]:
g

{1: {1: 11, 2: 22, 3: 33}, 2: {4: 44, 5: 55}, 3: {6: 66}}

#### loading graphs

In [581]:
combine_graph.load_idtotitle(save_dir, tag=f'-{project}-{dump_date}')
combine_graph.load_redirects(save_dir, tag=f'-{project}-{dump_date}')

In [600]:
article_title = set(combine_graph.parameters['id_to_title'].values())
label_title = set(combine_graph.parameters['label_id_to_title'].values())

In [603]:
article_title.intersection(label_title)

{'time loop',
 'sony Music Latin',
 'taxation in the United Arab Emirates',
 '1913 in Mexico',
 'murininae',
 'carl Wilson',
 'theni district',
 '2013 NCAA Division I FCS football season',
 'royal Yugoslav Army',
 'eratasthelys',
 'burlington County, New Jersey',
 'ponthieva',
 'hinche',
 'criticism of Christianity',
 'barranquitas, Puerto Rico',
 'telugu Desam Party',
 'judaean Desert',
 'dHB-Pokal',
 'missouri Baptist University',
 'colonial National Historical Park',
 'alternaria',
 'basti division',
 'telecommunications in Iceland',
 'jersey Airways',
 'batticaloa',
 'netball in Tonga',
 '2016 Tasmanian energy crisis',
 'hurricane Isaac (2012)',
 'quorum of the Twelve Apostles (LDS Church)',
 'drymini',
 'moriori',
 'danny Elfman',
 'article 6 of the European Convention on Human Rights',
 'costume design',
 'croatian First Football League',
 'world Federation of Engineering Organizations',
 'sepidan County',
 'organizational conflict',
 'water polo at the 1904 Summer Olympics',
 'm

This proves that there is intersection between `article_title` and `label_title`

In [582]:
dict_head_random(combine_graph.parameters['label_id_to_title'])

69374678 : comic strips by date
64760643 : disestablishments in Mongolia by year
60111416 : avalon Beach, New South Wales
52111754 : indian Jeet Kune Do practitioners
1485620 : thai films
57065483 : people from Cisco, Texas
37282838 : 1941 establishments in Portugal
69411131 : entertainment by former country
67989095 : 2001 anime ONAs
4196902 : scientology magazines


## Resolving Graph

In [198]:
def multiprocessor_2(func, tasks, num_process=10):
    pool = Pool(processes=num_process)
    results = []
    for x in tqdm.tqdm( pool.imap(func, tasks), total=len(tasks)):
        results.append(x)
    return results


In [291]:
class ResolveGraph:
    
    def __init__(self, partition_files):
        self.partition_files = partition_files
        self.id_to_title, self.redirects = None, None
    
    def change_maps(self, id_to_title, redirects): 
        self.id_to_title = id_to_title
        self.redirects = redirects
        
        self.title_to_id = {article_title:article_id for article_id, article_title in id_to_title.items()}
        
    def resolver(self, data_path, save_dir, graph_type='article', tag_extractor=None):
        if graph_type != 'classification' and graph_type != 'article' and graph_type != "label":
            raise Exception("graph_type can only take values in ['classification', 'article', 'label']")
            
        graph = WikilinkGraph()
        tag = extract_filetag(data_path, tag_extractor)
            
        tag=f'_{graph_type}{tag}'
        graph.load_data(save_dir, tag)
        
        graph.replace_redirects(self.redirects)
        graph.remove_dead(self.title_to_id)
        
        tag = f'{tag}_resolved'
        graph.save_data(save_dir, tag)
            
    def resolve(self, save_dir, graph_type='article', tag_extractor=None):
        resolve_helper = partial(self.resolver, save_dir=save_dir, graph_type=graph_type, 
                                 tag_extractor=tag_extractor)
        multiprocessor_2(resolve_helper, self.partition_files)
        

In [253]:
combined_dir = f'{dataset_home}/WikiCategory/combined'

combined_graph = WikiGraphCombine()
combined_graph.load_idtotitle(combined_dir, tag=f'-{project}-{dump_date}')
combined_graph.load_redirects(combined_dir, tag=f'-{project}-{dump_date}')

In [254]:
resolve_graph = ResolveGraph(partitions)

In [255]:
tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')
results_dir = f'{dataset_home}/WikiCategory/results'

In [256]:
resolve_graph.change_maps(combined_graph.parameters['id_to_title'], combined_graph.parameters['redirects'])
resolve_graph.resolve(results_dir, graph_type='article', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [12:19<00:00, 11.93s/it]


In [259]:
resolve_graph.change_maps(combined_graph.parameters['label_id_to_title'], 
                          combined_graph.parameters['label_redirects'])

resolve_graph.resolve(results_dir, graph_type='classification', tag_extractor=tag_extractor)
resolve_graph.resolve(results_dir, graph_type='label', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:04<00:00,  1.04s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [01:00<00:00,  1.03it/s]


### Testing

In [310]:
gs = [{1: (['a', 'b', 'c'],[11, 22, 33]), 2: (['d', 'e'], [44, 55]), 3: (['f'], [66])},
      {4: (['a', 'b'],[11, 22, 33]), 5: (['d'], [44, 55]), 6: (['f'], [66])}]

In [311]:
test_dir = f"{dataset_home}/data/testing"
os.makedirs(test_dir, exist_ok=True)

In [312]:
graph = WikilinkGraph()
for i, g in enumerate(gs):
    graph.graph = g
    graph.save_data(test_dir, tag=f'_article-{i}-{i}-{i}')

In [313]:
resolve_graph = ResolveGraph(['000', '111'])
tag_extractor = re.compile(r'([01])([01])([01])')

resolve_graph.change_maps({1:'aa', 2:'b', 3:'c', 4:'ee'}, {'a':'aa', 'e':'ee', 'f':'ff'})

resolve_graph.resolve(test_dir, graph_type='article', tag_extractor=tag_extractor)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.88it/s]


In [314]:
resolved_gs = []

graph = WikilinkGraph()
for i, g in enumerate(gs):
    graph.load_data(test_dir, tag=f'_article-{i}-{i}-{i}_resolved')
    resolved_gs.append(graph.graph)

In [315]:
resolved_gs

[{1: ([3, 2, 1], [33, 22, 11]), 2: ([4], [55]), 3: ([], [])},
 {4: ([2, 1], [33, 22]), 5: ([], []), 6: ([], [])}]

#### __Verification__

In [307]:
graph = WikilinkGraph()
graph.graph = gs[0]
graph.graph

{1: (['a', 'b', 'c'], [11, 22, 33]),
 2: (['d', 'e'], [44, 55]),
 3: (['f'], [66])}

In [308]:
graph.replace_redirects({'a':'aa', 'e':'ee', 'f':'ff'})
graph.graph

{1: (['aa', 'b', 'c'], [11, 22, 33]),
 2: (['d', 'ee'], [44, 55]),
 3: (['ff'], [66])}

In [309]:
graph.remove_dead({'aa':1, 'b':2, 'c':3, 'ee':4})
graph.graph

{1: ([3, 2, 1], [33, 22, 11]), 2: ([4], [55]), 3: ([], [])}

## Combine graphs

In [264]:
combine_graph = WikiGraphCombine(partitions)

In [265]:
combine_graph.combine_param(results_dir, param='classification_graph', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:08<00:00,  7.17it/s]


In [266]:
combine_graph.combine_param(results_dir, param='article_graph', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [05:45<00:00,  5.58s/it]


In [267]:
combine_graph.combine_param(results_dir, param='label_graph', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [21:51<00:00, 21.15s/it]


In [273]:
combine_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='classification')
combine_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='article')
combine_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='label')

### testing

#### test_1

In [316]:
combine_graph = WikiGraphCombine(['000', '111'])
tag_extractor = re.compile(r'([01])([01])([01])')
test_dir = f"{dataset_home}/data/testing"

combine_graph.combine_param(test_dir, param='article_graph', tag_extractor=tag_extractor)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 420.27it/s]


In [317]:
combine_graph.parameters['article_graph']

{1: {3: 33, 2: 22, 1: 11}, 2: {4: 55}, 3: {}, 4: {2: 33, 1: 22}, 5: {}, 6: {}}

#### test_2

In [327]:
combine_graph.load_idtotitle(save_dir, tag=f'-{project}-{dump_date}', id_type='article')

In [320]:
combine_graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='classification')
combine_graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='article')
combine_graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='label')

In [332]:
dict_head_random(combine_graph.parameters['id_to_title'], n=3)

3906764 : ian Johnson (writer)
36059385 : atrak Rural District (Golestan Province)
66334742 : robert Farris-Olsen


In [334]:
combine_graph.parameters['id_to_title'][266210]

'traditional Chinese characters'

In [322]:
dict_head_random(combine_graph.parameters['article_graph'], n=3)

51182710 : {266210: 1}
34449347 : {5667302: 1, 488385: 1, 488375: 1, 386244: 1, 55776255: 2}
6533043 : {39486898: 1, 3434750: 1, 74337: 1, 292161: 2, 288589: 1, 413785: 1, 622127: 1, 252981: 1, 6694705: 1, 7293: 1, 4699157: 1, 2103879: 1, 63429: 1, 2779: 1, 49827: 1, 576646: 1, 1916229: 1}


In [323]:
dict_head_random(combine_graph.parameters['classification_graph'], n=3)

26019812 : {66317279: 1, 48600105: 1, 33788527: 1, 3782398: 1, 32210340: 1, 3171852: 1}
53234107 : {52764134: 1, 30833314: 1}
10370203 : {3955033: 1, 35753495: 1, 10323481: 1}


In [324]:
dict_head_random(combine_graph.parameters['label_graph'], n=3)

60298633 : {65365387: 1, 50835116: 1}
697659 : {66230355: 1, 51801943: 1, 15466725: 1}
43427946 : {42235686: 1}


## Train-test split

In [543]:
def prune_map(mapping, idxs):
    rev_mapping = {value:key for key, value in mapping.items()}
    pruned_mapping = {}
    for i, idx in enumerate(idxs):
        pruned_mapping[rev_mapping[idx]] = i
    return pruned_mapping

def split_count(num_samples, perc=0.7):
    if num_samples == 1:
        return 1 if np.random.rand() > 0.3 else 0
    
    num_train = int( np.ceil(num_samples*perc) )
    if num_train == num_samples and num_samples > 1:
        num_train -= 1
    return num_train


In [569]:
class WikipediaSplit:
    
    def __init__(self, graph=None):
        self.graph = None
        self.labels = None
        self.doc_to_rowindex = None
        
        if graph:
            _ = self.to_matrix(graph)
        
        self.train, self.test = None, None
        self.trn_tst_labels = None
        self.train_doc_to_rowindex, self.test_doc_to_rowindex = None, None
        
    def to_matrix(self, graph):
        indptr = [0]
        indices = []
        data = []

        self.doc_to_rowindex = {}

        self.labels = {}
        for i, (doc, edge_count) in enumerate(graph.items()):
            self.doc_to_rowindex[doc] = i
            
            for link, cnt in edge_count.items():
                index = self.labels.setdefault(link, len(self.labels))
                indices.append(index)
                data.append(cnt)
            indptr.append(len(indices))

        self.graph = csr_matrix((data, indices, indptr), dtype=int)
        return self.graph, self.labels, self.doc_to_rowindex
    
    def clean_matrix(self, clean_type=0):
        if clean_type == 0:
            self.graph, self.labels, self.doc_to_rowindex = self.remove_single_labels(self.graph,
                                                                                      self.labels,
                                                                                      self.doc_to_rowindex)
        elif clean_type == 1:
            pruned_rows = self.get_pruned_row(self.graph)
            self.graph, self.doc_to_rowindex = self.prune_graph_rows(self.graph,
                                                                     self.doc_to_rowindex,
                                                                     pruned_rows)
        else:
            pruned_cols = self.get_pruned_cols(self.graph)
            self.graph, self.labels = self.prune_graph_cols(self.graph,
                                                           self.labels,
                                                           pruned_cols)
    
    """
    def get_split_idx(self):
        train_rowidx = []
        test_rowidx = []
        
        num_rows, num_cols = self.graph.shape
        row_available_flag = np.ones(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.sum(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        for lcnt in uni_label_cnt:
            if lcnt > upper_threshold:
                break
                
            label_pos = np.where(label_cnt == lcnt)[0]
            row_numlabels = np.array(self.graph[:, label_pos].sum(axis=1)).reshape(-1)
            
            row_idx = np.where(row_numlabels > 0)[0]
            row_available_flag[row_idx] = False
            num_train = split_count(row_idx.shape[0], perc=0.7)
            row_idx = list(np.random.permutation(row_idx))

            train_rowidx.extend(row_idx[:num_train])
            test_rowidx.extend(row_idx[num_train:])
        
        row_idx = np.where(row_available_flag == True)[0]
        num_train = split_count(row_idx.shape[0], perc=0.7)
        row_idx = list(np.random.permutation(row_idx))
        train_rowidx.extend(row_idx[:num_train])
        test_rowidx.extend(row_idx[num_train:])
        
        return train_rowidx, test_rowidx
        """
    
    def get_split_idx(self, upper_threshold=10, perc=0.7):
        train_rowidx = []
        test_rowidx = []
        
        row_idxs, col_idxs = self.graph.nonzero()
        sort_idx = np.argsort(col_idxs)
        row_idxs = row_idxs[sort_idx]
        col_idxs = col_idxs[sort_idx]
        
        num_rows = self.graph.shape[0]
        row_inserted_flag = np.zeros(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.getnnz(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        """
        print(label_cnt)
        print(uni_label_cnt)
        """
        
        cnt = 0
        for lcnt in uni_label_cnt:
            if cnt == num_rows or lcnt >= upper_threshold:
                break
                
            pos_ptr, col_ptr = 0, 0
            pos_idxs = np.where(label_cnt == lcnt)[0]
            pos_idxs.sort()

            while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and cnt < num_rows:
                if pos_idxs[pos_ptr] != col_idxs[col_ptr]:
                    col_ptr += 1
                else:
                    sample_row_idxs = []
                    while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and \
                    cnt < num_rows and pos_idxs[pos_ptr] == col_idxs[col_ptr]:
                        rn = row_idxs[col_ptr]
                        if not row_inserted_flag[rn]:
                            sample_row_idxs.append(rn)
                            row_inserted_flag[rn] = True
                            cnt += 1
                        col_ptr += 1
                    pos_ptr += 1
                    
                    num_train = split_count(len(sample_row_idxs), perc=0.7)
                    sample_row_idxs = list(np.random.permutation(sample_row_idxs))
                    train_rowidx.extend(sample_row_idxs[:num_train])
                    test_rowidx.extend(sample_row_idxs[num_train:])
                    """
                    print(f'Sample row : {sample_row_idxs}')
                    print(f'train : {train_rowidx}')
                    print(f'test  : {test_rowidx}')
                    """
                    
        sample_row_idxs = np.where(row_inserted_flag == False)[0]
        num_train = split_count(num_rows, perc=0.7)
        num_train -= len(train_rowidx)
        
        sample_row_idxs = list(np.random.permutation(sample_row_idxs))
        train_rowidx.extend(sample_row_idxs[:num_train])
        test_rowidx.extend(sample_row_idxs[num_train:])
        
        return train_rowidx, test_rowidx
    
    def get_split_bylabel(self, upper_threshold=10, perc=0.7):
        """
        splitting data into train-test
        """
        train_idx, test_idx = self.get_split_idx(upper_threshold, perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        self.trn_tst_labels = self.labels
        
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        """
        pruning the columns
        """
        trn_pruned_cols = self.get_pruned_cols(self.train)
        tst_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(trn_pruned_cols, tst_pruned_cols)
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.trn_tst_labels, pruned_cols)
        
        """
        prunning the rows
        """
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex, pruned_rows)
        
    def get_random_split_idx(self, perc=0.7):
        n_docs = self.graph.shape[0]
        n_trn = int(perc * n_docs)
        rand_idx = np.random.permutation(n_docs)
        return rand_idx[:n_trn], rand_idx[n_trn:]
    
    def get_pruned_cols(self, graph, count=0):
        label_cnt = np.array(graph.getnnz(axis=0)).reshape(-1)
        pruned_cols = np.where(label_cnt > count)[0]
        return pruned_cols
    
    def get_pruned_row(self, graph, count=0):
        pruned_rows = np.where( np.array(graph.getnnz(axis=1)).reshape(-1) > count )[0]
        return pruned_rows
    
    def prune_graph_cols(self, graph, labels, pruned_cols):
        graph = graph[:, pruned_cols]
        labels = prune_map(labels, pruned_cols)
        return graph, labels
    
    def prune_graph_rows(self, graph, doc_to_rowindex, pruned_rows):
        graph = graph[pruned_rows, :]
        doc_to_rowindex = prune_map(doc_to_rowindex, pruned_rows)
        return graph, doc_to_rowindex
    
    def remove_single_labels(self, graph, labels, doc_to_rowindex):
        pruned_cols = self.get_pruned_cols(graph, count=1)
        graph, labels = self.prune_graph_cols(graph, labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(graph)
        graph, doc_to_rowindex = self.prune_graph_rows(graph, doc_to_rowindex, pruned_rows)

        return graph, labels, doc_to_rowindex
        
    def get_split_byrandom(self, perc=0.7):
        train_idx, test_idx = self.get_random_split_idx(perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        train_pruned_cols = self.get_pruned_cols(self.train)
        test_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(train_pruned_cols, test_pruned_cols)
        
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, 
                                                                       pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex,
                                                                     pruned_rows)
        
    def save_data(self, save_dir, tag='category'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'wb') as fout:
            train = (self.trn_tst_labels, self.train_doc_to_rowindex, self.train)
            pickle.dump(train, fout)
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'wb') as fout:
            test = (self.trn_tst_labels, self.test_doc_to_rowindex, self.test)
            pickle.dump(test, fout)
            
    def load_data(self, save_dir, tag='category'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'rb') as fout:
            train = pickle.load(fout)
            self.trn_tst_labels, self.train_doc_to_rowindex, self.train = train
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'rb') as fout:
            test = pickle.load(fout)
            _, self.test_doc_to_rowindex, self.test = test
            

### Code

In [570]:
graph = WikiGraphCombine(partitions)
graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='classification')

In [571]:
classification_graph = graph.parameters['classification_graph']

In [572]:
dict_head_random(classification_graph, n=5)

24001090 : {69120618: 1, 14525713: 1, 2601865: 1, 1399788: 1, 2602036: 1, 2615923: 1, 24543100: 1, 740423: 1, 698698: 1, 10661449: 1, 2604014: 1, 21460042: 1, 26185725: 1}
16028312 : {69483762: 1, 24430172: 1}
4324148 : {3308254: 1, 18668365: 1, 27687677: 1, 10926462: 1}
53816396 : {11801442: 1, 1698784: 1}
11611474 : {8456049: 1, 1127028: 1, 61317860: 1, 27586728: 1, 57501298: 1, 19922908: 1, 62593050: 1, 69584432: 1}


In [573]:
data_splitter = WikipediaSplit(classification_graph)

data_splitter.clean_matrix()
data_splitter.get_split_bylabel(upper_threshold=10)

In [574]:
data_splitter.save_data(save_dir, tag='wiki_category')

In [575]:
save_dir

'/home/cse/phd/anz198717/scratch/suchith_data/wikipedia/wikipedia-data-science/WikiCategory/combined'

#### statistics

In [576]:
data_splitter.train, data_splitter.test

(<4245310x1080148 sparse matrix of type '<class 'numpy.int64'>'
 	with 19852712 stored elements in Compressed Sparse Row format>,
 <1820068x1080148 sparse matrix of type '<class 'numpy.int64'>'
 	with 8648479 stored elements in Compressed Sparse Row format>)

In [577]:
data_splitter.train.sum(axis=0).mean()

18.383777963760522

In [578]:
data_splitter.train.sum(axis=1).mean()

4.6774442855763185

### Testing

In [545]:
g = {1: {1:2, 2:3, 3:1, 4:1}, 
     2:{1:1, 2:1, 7:1, 8:1}, 
     3:{4:1, 5:1, 6:1, 7:1},
     4:{9:2},
     5:{2:1, 5:1, 8:1}
    }

In [546]:
ws = WikipediaSplit(graph=g)

#### __to_matrix__

In [525]:
ws.graph.todense(), ws.doc_to_rowindex, ws.labels

(matrix([[2, 3, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 1, 1, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 2],
         [0, 1, 0, 0, 0, 1, 1, 0, 0]]),
 {1: 0, 2: 1, 3: 2, 4: 3, 5: 4},
 {1: 0, 2: 1, 3: 2, 4: 3, 7: 4, 8: 5, 5: 6, 6: 7, 9: 8})

In [513]:
ws.get_pruned_cols(ws.graph, count=1)

array([0, 1, 3, 4, 5, 6])

In [514]:
ws.get_pruned_row(ws.graph, count=1)

array([0, 1, 2, 4])

#### __clean_matrix__

In [515]:
ws.clean_matrix()

In [562]:
ws.graph.todense(), ws.doc_to_rowindex, ws.labels

(matrix([[2, 3, 1, 1, 0, 0, 0, 0, 0],
         [1, 1, 0, 0, 1, 1, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 1, 1, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 2],
         [0, 1, 0, 0, 0, 1, 1, 0, 0]]),
 {1: 0, 2: 1, 3: 2, 4: 3, 5: 4},
 {1: 0, 2: 1, 3: 2, 4: 3, 7: 4, 8: 5, 5: 6, 6: 7, 9: 8})

In [373]:
rows, cols = ws.graph.nonzero()

idx = np.argsort(cols)

rows_idx = rows[idx]
cols_idx = cols[idx]

In [374]:
rows_idx, cols_idx

(array([0, 1, 0, 1, 3, 0, 2, 1, 2, 1, 3, 2, 3], dtype=int32),
 array([0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], dtype=int32))

#### get_split_idx

In [557]:
trn_idx, tst_idx = ws.get_split_idx()
trn_idx, tst_idx

([0, 1, 4], [2, 3])

## Abstract

In [593]:
class AbstractXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self):
        xml.sax.handler.ContentHandler.__init__(self)
        self._buffer = None
        self._values = {}
        self._current_tag = None
        self._article_count = 0
        
        self._abstract = {}
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title','abstract'):
            self._current_tag = name
            self._buffer = []
            
    def endElement(self, name):
        if name == self._current_tag:
            self._values[name] = ' '.join(self._buffer)
            self._current_tag = None
                
        elif name == 'doc':
            self._article_count += 1
            cat, title = self._values['title'].split(':', maxsplit=1)
            title = title.strip()
            title = title[0].lower() + title[1:]
            self._abstract[title.strip()] = self._values['abstract'].strip()
            

In [594]:
class WikiAbstract:
    
    def __init__(self, abstract_paths=None):
        self.abstract = {}
        self.abstract_paths = abstract_paths
        
    def parse_abstract(self, abstract_file):
        handler = AbstractXmlHandler()

        parser = xml.sax.make_parser()
        parser.setContentHandler(handler)

        for i, line in enumerate( subprocess.Popen(['gzip','-cd'],
                                                  stdin=open(abstract_file),
                                                  stdout=subprocess.PIPE ).stdout ):
            try:
                parser.feed(line)
            except StopIteration:
                break

        return handler._abstract
    
    def create_abstract(self, processes=15):
        pool = Pool(processes=processes)
        if self.abstract_paths is None:
            raise Exception("Abstract paths are empty.")
        for abstract in tqdm.tqdm(pool.imap_unordered(self.parse_abstract, self.abstract_paths), 
                                           total=len(self.abstract_paths)):
            self.abstract.update(abstract)
    
    def save_abstract(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        abstract_file = f'{save_dir}/abstract{tag}.pickle'
        with open(abstract_file, 'wb') as f:
            pickle.dump(self.abstract, f)
            
    def load_abstract(self, save_dir, tag=''):
        abstract_file = f'{save_dir}/abstract{tag}.pickle'
        if os.path.exists(abstract_file):
            with open(abstract_file, 'rb') as f:
                self.abstract = pickle.load(f)
        else:
            raise Exception(f"Unable to load abstract from '{abstract_file}'.")
            

In [595]:
save_dir = f'{dataset_home}/WikiCategory/combined'

In [596]:
abstract = WikiAbstract()
abstract.load_abstract(save_dir, tag=f'-{project}-{dump_date}')

## XC dataset format

In [626]:
class XCDataset:
    
    def __init__(self):
        self.id_to_title = {}
        self.abstract = None
        
        self.data_splitter = {}
        
    def load_idtotitle(self, save_dir, tag='', verbose=True):
        """
        Loading id_to_title
        """
        if verbose:
            print("Loading Article and Category 'id_to_title'.")
        graph_data = WikiGraphCombine()
        graph_data.load_idtotitle(save_dir, tag=tag)
        
        self.id_to_title['article'] = graph_data.parameters['id_to_title']
        self.id_to_title['label'] = graph_data.parameters['label_id_to_title']
        
    def load_abstract(self, save_dir, tag='', verbose=True):
        """
        Loading abstract
        """
        if verbose:
            print("Loading Article 'abstracts'.")
        abstract = WikiAbstract()
        abstract.load_abstract(save_dir, tag=tag)
        self.abstract = abstract.abstract
        
    def save_sparse_file(self, matrix, save_file):
        matrix.sort_indices()
        rows, cols = matrix.nonzero()
        data = matrix.data

        with open(save_file, 'w') as fout:
            fout.write(f'{matrix.shape[0]} {matrix.shape[1]}\n')

            row_ctr = -1
            line = ''
            for r, c, d in zip(rows, cols, data):
                if row_ctr == r:
                    line = line+f' {c}:{d}'
                elif row_ctr+1 == r:
                    if line:
                        fout.write(f'{line}\n')
                    line = f'{c}:{d}'
                    row_ctr += 1
                else:
                    raise Exception("Row is missing.")
        
    def save_XY_text(self, save_dir, doc_to_rowindex, tag='', id_type='article'):
        os.makedirs(save_dir, exist_ok=True)
        
        if id_type == 'article':
            idtocontent_file = f'{save_dir}/{tag}_id_to_text.txt'
            self.save_XY_content(idtocontent_file, self.id_to_title[id_type], doc_to_rowindex, self.abstract)

        idtotitle_file = f'{save_dir}/{tag}_id_to_title.txt'
        self.save_XY_title(idtotitle_file, self.id_to_title[id_type], doc_to_rowindex)
                    
    def save_XY_content(self, idtocontent_file, id_to_title, doc_to_rowindex, content):
        with open(idtocontent_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                if id_to_title[doc_id] in content:
                    fout.write(f'{doc_id}->{content[id_to_title[doc_id]]}\n')
                else:
                    fout.write(f'{doc_id}->\n')
                    
    def save_XY_title(self, idtotitle_file, id_to_title, doc_to_rowindex):
        with open(idtotitle_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                fout.write(f'{doc_id}->{id_to_title[doc_id]}\n')
                
    def load_classification_data(self, save_dir, verbose=True):
        """
        Loading Classification train-test split
        """
        if verbose:
            print("Loading Classification train-test split.")
        self.data_splitter['classification'] = WikipediaSplit()
        self.data_splitter['classification'].load_data(save_dir, tag='wiki_category')
                
    def save_XCClassification_text(self, xc_dir, verbose=True):
        """
        Saving Classification XY - title and content(text) 
        """
        if verbose:
            print("Saving Classification train X-article text")
        self.save_XY_text(xc_dir, self.data_splitter['classification'].train_doc_to_rowindex, 
                                    tag='classification_train_X', id_type='article')
        if verbose:
            print("Saving Classification test X-article text")
        self.save_XY_text(xc_dir, self.data_splitter['classification'].test_doc_to_rowindex, 
                                    tag='classification_test_X', id_type='article')
        if verbose:
            print("Saving Classification Y-label text")
        self.save_XY_text(xc_dir, self.data_splitter['classification'].trn_tst_labels, 
                                    tag='classification_Y', id_type='label')
                      
    def save_XCClassification_data(self, xc_dir, verbose=True):
        if verbose:
            print("Saving Classification train 'trn_X_Y.txt'.")
        train_file = f'{xc_dir}/trn_X_Y.txt'
        self.save_sparse_file(self.data_splitter['classification'].train, train_file)
        
        if verbose:
            print("Saving Classification test 'tst_X_Y.txt'.")
        test_file = f'{xc_dir}/tst_X_Y.txt'
        self.save_sparse_file(self.data_splitter['classification'].test, test_file)
        
    def save_XCClassification(self, save_dir, xc_dir, verbose=True):
        """
        XC Classification
        """
        self.load_classification_data(save_dir, verbose=verbose)
        self.save_XCClassification_text(xc_dir, verbose=verbose)
        self.save_XCClassification_data(xc_dir, verbose=verbose)
        
    def load_graph_data(self, save_dir, tag='', graph_type='article', verbose=True):
        """
        Loading Graph
        """
        if verbose:
            print(f"Loading {graph_type} graph.")
        graph_data = WikiGraphCombine()
        graph_data.load_graph(save_dir, tag=tag, graph_type=graph_type)
        
        self.data_splitter[graph_type] = WikipediaSplit(graph_data.parameters[f'{graph_type}_graph'])
        self.data_splitter[graph_type].clean_matrix(clean_type=1)
    
    def save_XCGraph_text(self, xc_dir, graph_type='article', verbose=True):
        """
        Saving XC-Graph text
        """
        if verbose:
            print(f"Saving {graph_type}_graph X-text.")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].doc_to_rowindex, 
                          tag=f'{graph_type}_graph_X', id_type=graph_type)
        
        if verbose:
            print(f"Saving {graph_type}_graph Y-text.")
        self.save_XY_text(xc_dir, self.data_splitter[graph_type].labels, 
                          tag=f'{graph_type}_graph_Y', id_type=graph_type)
        
    def save_XCGraph_data(self, xc_dir, graph_type='article', verbose=True):
        if verbose:
            print(f"Saving '{graph_type}_graph_trn_X_Y.txt'")
        graph_file = f'{xc_dir}/{graph_type}_graph_trn_X_Y.txt'
        self.save_sparse_file(self.data_splitter[graph_type].graph, graph_file)
        
    def save_XCGraph(self, save_dir, xc_dir, tag='', graph_type='article', verbose=True):
        """
        XC Graph
        """
        self.load_graph_data(save_dir, tag=tag, graph_type=graph_type, verbose=verbose)
        self.save_XCGraph_text(xc_dir, graph_type=graph_type, verbose=verbose)
        self.save_XCGraph_data(xc_dir, graph_type=graph_type, verbose=verbose)
        
    
    def create_XCData(self, save_dir, xc_dir, tag='', verbose=True):
        self.load_idtotitle(save_dir, tag=tag, verbose=verbose)
        self.load_abstract(save_dir, tag=tag, verbose=verbose)
        
        self.save_XCClassification(save_dir, xc_dir, verbose=verbose)
        self.save_XCGraph(save_dir, xc_dir, tag=tag, graph_type='article', verbose=verbose)
        self.save_XCGraph(save_dir, xc_dir, tag=tag, graph_type='label', verbose=verbose)
        

In [627]:
save_dir = f'{dataset_home}/WikiCategory/combined'
xc_dir = f'{dataset_home}/WikiCategory/XCData'

In [628]:
xc_data = XCDataset()
xc_data.create_XCData(save_dir, xc_dir, tag=f'-{project}-{dump_date}')

Loading Article and Category 'id_to_title'.
Loading Article 'abstracts'.
Loading Classification train-test split.
Saving Classification train X-article text
Saving Classification test X-article text
Saving Classification Y-label text
Saving Classification train 'trn_X_Y.txt'.
Saving Classification test 'tst_X_Y.txt'.
Loading article graph.
Saving article_graph X-text.
Saving article_graph Y-text.
Saving 'article_graph_trn_X_Y.txt'
Loading label graph.
Saving label_graph X-text.
Saving label_graph Y-text.
Saving 'label_graph_trn_X_Y.txt'
