In [68]:
import os
import re
import gc
import json
import pickle
import xml.sax
import requests
import pandas as pd

import bz2
import subprocess
import numpy as np
from IPython import display
from scipy.sparse import csr_matrix
from matplotlib import pyplot as plt
from timeit import default_timer as timer

In [2]:
import tqdm
from functools import partial
from multiprocessing import Pool

In [3]:
import mwparserfromhell

In [4]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

## Helper functions

In [5]:
def dict_head_random(dictionary, n=10):
    keys = np.random.choice(list(dictionary.keys()), size=n)
    for k in keys:
        print(f'{k} : {dictionary[k]}')
        

## Arguement

In [6]:
project = 'enwiki'
dump_date = "20220420"
dataset_home = '/home/cse/phd/anz198717/scratch/suchith_data/wikipedia/wikipedia-data-science'

In [7]:
prog = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')

partitions = sorted([f'{dataset_home}/datasets/{file}' for file in os.listdir(f'{dataset_home}/datasets') 
                     if prog.match(file)])

print(f'Total number of partitions of the wikipedia dump : {len(partitions)}')

Total number of partitions of the wikipedia dump : 62


## Page Handler

In [21]:
class WikiXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self, matches=None):
        xml.sax.handler.ContentHandler.__init__(self)
        
        """
        PARSING VARIABLES: these variables will be useful
        while parsing the wikipedia dumps.
        """
        #regrex to select wikipedia section
        #TODO: remove this from here.
        self.matches = matches
        
        #basic storage for on-the-fly processing
        self._buffer = None
        self._values = {}
        self._current_tag = None
        
        #flags for handling special cases.
        self._add_page = True
        self._is_pageid = True
        
        
        
        
        
        #from here improvements can be made
        self.pages = []
        self._article_count = 0
        self._total_edges = 0
        
        """
        output information
        """
        #self.article_mat = WikilinkGraph()
        #self.seealso_mat = WikilinkGraph()
        
        self.redirects = {}
        self.id_to_title = {}
        self.page_content = {}
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title', 'text', 'ns'):
            self._current_tag = name
            self._buffer = []
        elif name == 'id' and self._is_pageid:
            self._is_pageid = False
            self._current_tag = name
            self._buffer = []    
        elif name == 'redirect':
            self.redirects[self._values['article_title'].strip()] = attrs.getValue('title').strip()
            self._add_page = False
                   
    def endElement(self, name):
        if name == self._current_tag:
            self._values[f'article_{name}'] = ' '.join(self._buffer)
            self._current_tag = None
        elif name == 'page':
            if int(self._values['article_ns']):
                self._add_page = False
                
            if self._add_page:
                self._article_count += 1
                self.pages.append(self._values.copy())
                
            self._add_page = True
            self._is_pageid = True
            

In [22]:
data_path = sorted(partitions)[0]
matches = r'^([Ss]ee[ ]*|[Ss]ee[ ]*([Aa]lso|[Mm]ore|[Aa]ll)|[Ss]ee[ ]*[Aa]lso[ ]*\(.+\))$'

In [23]:
start = timer()

handler = WikiXmlHandler(matches)

parser = xml.sax.make_parser()
parser.setContentHandler(handler)

for i, line in enumerate(subprocess.Popen(['bzcat'], 
                         stdin = open(data_path), 
                         stdout = subprocess.PIPE).stdout):
    parser.feed(line)
    
    if handler._article_count > 10:
        break
        
end = timer()
        
print(f'Searched through {handler._article_count} articles.')
print(f'Processing time is {round(end - start)} secs.')

Searched through 11 articles.
Processing time is 0 secs.


In [24]:
handler._article_count, len(handler.pages)

(11, 11)

## Filtering links

In [8]:
class FilterWikilinks:
    
    def __init__(self):
        self.subject_namespaces = {'user', 'wikipedia', 'wp', 'project', 'file', 'image', 'mediawiki', 
                                   'template', 't', 'help', 'h', 'category', 'cat', 'portal', 'p', 
                                   'draft', 'timedtext', 'module', 'special', 'media'}
        
        self.talk_namespaces = {'talk', 'user talk', 'wikipedia talk', 'wt', 'project talk', 'file talk', 
                                'image talk','mediawiki talk', 'template talk', 'help talk', 'category talk', 
                                'portal talk', 'draft talk', 'timedtext talk', 'module talk'}
        
        self.interwiki_links = {'wiktionary', 'wikt', 'wikinews', 'n', 'wikibooks', 'b', 'wikiquote','q', 
                                'wikisource','s', 'oldwikisource', 'wikispecies', 'species', 
                                'wikiversity', 'v', 'wikivoyage', 'voy', 'wikimedia','foundation', 'wmf', 
                                'commons', 'c', 'metawiki', 'metawikimedia', 'metawikipedia', 'meta' , 'm', 
                                'incubator', 'strategy', 'mediawikiwiki', 'mw', 'mediazilla', 'bugzilla'}
        
        self.language_code = re.compile(r'^[a-z][a-z]$')
        
    def remove_section_tags(self, link):
        hash_parts = link.split('#')
        if len(hash_parts) > 1:
            link = hash_parts[0].strip()  
        return link
    
    def filter_special_tags(self, link):
        colon_parts = link.split(':')
        part_num = 0
        for part in colon_parts:
            part = part.lower()
            if (part == "w") or (part == "en") or (part_num == 0 and not part):
                part_num += 1
            elif (part in self.subject_namespaces) or (part in self.talk_namespaces) \
            or (part in self.interwiki_links) or self.language_code.match(part):
                return ''
            else:
                break
        return ':'.join(colon_parts[part_num:])
    
    def lower_wikilink(self, link):
        if link: link = link[0].lower() + link[1:]
        return link
        
    def remove_underscore(self, link):
        return link.replace('_', ' ')
    
    def fix_single_quotes(self, link):
        if '"' in link:
            link_split = link.split('"')
            for i in range(len(link_split)):
                if i == 0:
                    link_split[i] = link_split[i][:-1]
                elif i == len(link_split)-1:
                    link_split[i] = link_split[i][1:]
                else:
                    link_split[i] = link_split[i][1:-1]
            return '"'.join(link_split)
        return link
    
    def process_wikilink(self, link):
        link = self.remove_section_tags(link)
        link = self.filter_special_tags(link)
        link = self.lower_wikilink(link)
        link = self.remove_underscore(link)
        link = self.fix_single_quotes(link)
        return link
    

In [145]:
link_filter = FilterWikilinks()
link_filter.process_wikilink('" Page "  Title')

'"Page" Title'

## Datastructure

In [9]:
class WikilinkGraph:

    def __init__(self):
        self.graph = {}
        
    def add_article(self, article_id, wikilinks):
        links, count = np.unique(wikilinks, return_counts=True)
        self.graph[article_id] = (links.tolist(), count.tolist())
        
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        filename = f'{save_dir}/link_graph{tag}.pickle'
        with open(filename, 'wb') as f:
            pickle.dump(self.graph, f)
            
        del self.graph
        gc.collect()
            
    def load_data(self, save_dir, tag=''):
        filename = f'{save_dir}/link_graph{tag}.pickle'
        if os.path.exists(filename):
            with open(filename, 'rb') as f:
                self.graph = pickle.load(f)
            return True
        print(f"ERROR:: Unable to load the graph at '{filename}'")
        return False
    
    def replace_redirects(self, redirects):
        for edges, _ in self.graph.values():
            for i, edge in enumerate(edges):
                if edge:
                    edge = edge[0].lower()+edge[1:]
                    if edge in redirects:
                        edges[i] = redirects[edge]
        return None
    
    def remove_dead(self, title_to_id):
        for article_id, (edges, counts) in self.graph.items():
            active_edges = list()
            active_counts = list()

            while edges:
                edge = edges.pop()
                count = counts.pop()

                if edge:
                    edge = edge[0].lower()+edge[1:]
                    if edge in title_to_id:
                        if isinstance(edge, str):
                            active_edges.append(title_to_id[edge])
                        else:
                            active_edges.append(edge)
                        active_counts.append(count)
            self.graph[article_id] = (active_edges, active_counts)  
        return None
    

In [10]:
class WikiGraphDataset:
    
    def __init__(self, matches=None):
        self.matches = matches
        
        self.seealso_graph = WikilinkGraph()
        self.article_graph = WikilinkGraph()
        
        self.id_to_title = {}
        self.wiki_content = {}
        self.redirects = {}
        
        self.filter_links = FilterWikilinks()
    
    """
    Add articles to graph dataset.
    """
    def extract_article_info(self, article_text):
        wikicode = mwparserfromhell.parse(article_text, skip_style_tags=True)
        wikicode.remove_nodetype(inplace=True)

        match_sections, rest_sections = wikicode.split_sections(matches=self.matches, include_lead=True, flat=True)

        match_wikilinks = self.extract_section_wikilinks(match_sections)
        rest_wikilinks = self.extract_section_wikilinks(rest_sections)

        article_content = wikicode.strip_code().strip()

        return match_wikilinks, rest_wikilinks, article_content
    
    def extract_section_wikilinks(self, sections):
        wikilinks = list()

        for section in sections:
            links = list( map(lambda link: (link.title).strip_code().strip(), section.filter_wikilinks()) )
            for link in links:
                processed_link = self.filter_links.process_wikilink(link)
                if processed_link:
                    wikilinks.append(processed_link)

        return wikilinks
    
    def add_article(self, article_title, article_text, article_id, article_ns):
        
        article_title, article_id = article_title.strip(), int(article_id.strip())
        article_title = self.filter_links.process_wikilink(article_title)
        
        if not article_title or article_title in self.article_graph.graph or article_title in self.seealso_graph.graph:
            return
        seealso_wikilinks, article_wikilinks, article_content = self.extract_article_info(article_text)
        
        if len(seealso_wikilinks) or len(article_wikilinks):
            self.id_to_title[article_id] = article_title
            
            if len(seealso_wikilinks):
                self.seealso_graph.add_article(article_id, seealso_wikilinks)
            
            if len(article_wikilinks):
                self.article_graph.add_article(article_id, article_wikilinks)
            
            if article_content:
                self.wiki_content[article_id] = article_content
    
    """
    Add redirect
    """
    def add_redirect(self, article_title, target_title):
        article_title = self.filter_links.process_wikilink(article_title)
        target_title = self.filter_links.process_wikilink(target_title)
        
        if article_title and target_title:
            self.redirects[article_title] = target_title
        
    """
    Save content
    """
    def save_graph(self, save_dir, tag='', graph_type='both'):
        if graph_type == 'seealso' or graph_type == 'both':
            self.seealso_graph.save_data(save_dir, tag=f'_seealso{tag}')
            
        if graph_type == 'articles' or graph_type == 'both':
            self.article_graph.save_data(save_dir, tag=f'_articles{tag}')
                
    def save_data(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        
        self.save_graph(save_dir, tag)
        self.save_idtotitle(save_dir, tag)
        self.save_wikicontent(save_dir, tag)
        self.save_redirects(save_dir, tag)
        
    def save_idtotitle(self, save_dir, tag=''):
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        with open(map_file, 'wb') as f:
            pickle.dump(self.id_to_title, f)   
        del self.id_to_title
        gc.collect()
            
    def save_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'wb') as f:
            pickle.dump(self.wiki_content, f)
        del self.wiki_content
        gc.collect()
            
    def save_redirects(self, save_dir, tag=''):
        redirect_file = f'{save_dir}/redirects{tag}.pickle'
        with open(redirect_file, 'wb') as f:
            pickle.dump(self.redirects, f)
        del self.redirects
        gc.collect()
            
    def load_graph(self, save_dir, tag='', graph_type='both'):
        if graph_type == 'seealso' or graph_type == 'both':
            if not self.seealso_graph.load_data(save_dir, tag=f'_seealso{tag}'):
                raise Exception("Unable to load 'seealso graph'.")
                
        if graph_type == 'articles' or graph_type == 'both':
            if not self.article_graph.load_data(save_dir, tag=f'_articles{tag}'):
                raise Exception("Unable to load 'article graph'.")

    def load_data(self, save_dir, tag=''):
        self.load_graph(save_dir, tag)
        self.load_idtotitle(save_dir, tag)
        self.load_wikicontent(save_dir, tag)
        self.load_redirects(save_dir, tag)
        
    def load_idtotitle(self, save_dir, tag=''):
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        if os.path.exists(map_file):
            with open(map_file, 'rb') as f:
                self.id_to_title = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'id_to_title' from {map_file}.")
            
    def load_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        if os.path.exists(content_file):
            with open(content_file, 'rb') as f:
                self.wiki_content = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'wiki_content' from {content_file}.")
            
    def load_redirects(self, save_dir, tag=''):
        redirect_file = f'{save_dir}/redirects{tag}.pickle'
        if os.path.exists(redirect_file):
            with open(redirect_file, 'rb') as f:
                self.redirects = pickle.load(f)
        else:
            raise Exception(f"Unable to load 'redirects' from {redirect_file}.")
            

### Testing

#### Adding a single article

In [125]:
page_num = 1
handler.pages[page_num]['article_title']

'Autism'

In [126]:
wikidataset = WikiGraphDataset(matches=matches)
wikidataset.add_article(**handler.pages[page_num])

In [127]:
list(wikidataset.article_graph.graph.values())[0][0][:10]

['aDHD',
 'alexithymia',
 'alternative therapies for developmental and learning disabilities',
 'anticonvulsant',
 'antidepressant',
 'antipsychotic',
 'anxiety disorder',
 'applied behavior analysis',
 'apraxia',
 'aripiprazole']

In [129]:
handler.pages[page_num]['article_text']



#### Adding multiple article

In [130]:
for i in range(len(handler.pages)):
    wikidataset.add_article(**handler.pages[i])

In [131]:
wikidataset.id_to_title

{25: 'autism',
 12: 'anarchism',
 39: 'albedo',
 290: 'a',
 303: 'alabama',
 305: 'achilles',
 307: 'abraham Lincoln',
 308: 'aristotle',
 309: 'an American in Paris',
 316: 'academy Award for Best Production Design',
 324: 'academy Awards'}

In [132]:
wikidataset.article_graph.graph

{25: (['aDHD',
   'alexithymia',
   'alternative therapies for developmental and learning disabilities',
   'anticonvulsant',
   'antidepressant',
   'antipsychotic',
   'anxiety disorder',
   'applied behavior analysis',
   'apraxia',
   'aripiprazole',
   'asperger syndrome',
   'attachment (psychology)',
   'attachment in children',
   'attention deficit hyperactivity disorder',
   'atypical antipsychotic',
   'autism Diagnostic Interview-Revised',
   'autism Diagnostic Observation Schedule',
   'autism Rights Movement',
   'autism Sunday',
   'autism and LGBT identities',
   'autism rights movement',
   'autism spectrum disorder',
   'autistic Pride Day',
   'autistic rights',
   'autoimmune disease',
   'autonomic nervous system',
   'autreat',
   'babbling',
   'birth defect',
   'blindism',
   'brain scan',
   'brett Abrahams',
   'brominated flame retardant',
   'cEASE therapy',
   'calcium',
   'casein-free diet',
   'causes of autism',
   'cell adhesion',
   'checklist for Au

#### Running the generation script

In [74]:
wikidataset = WikiGraphDataset(matches=matches)
wikidataset.load_data(results_dir, tag='-enwiki-20220420-16-p20460153p20570392')

In [87]:
dict_head_random(wikidataset.seealso_graph.graph, n=2)

20487019 : (['national Forests Office (France)'], [1])
20517555 : (['charles Reade (disambiguation)', 'charles Reed (disambiguation)', 'charles Reid (disambiguation)'], [1, 1, 1])


In [86]:
dict_head_random(wikidataset.article_graph.graph, n=2)

20531172 : (['amazons!', 'andre Norton', 'andrew J. Offutt', 'c. J. Cherryh', 'dAW Books', 'flashing Swords! 2', 'flashing Swords! 3: Warriors and Wizards', 'garan the Eternal', 'gerald W. Page', 'hank Reinhardt', 'heroic Fantasy (anthology)', 'jack Gaughan', 'jessica Amanda Salmonson', 'lin Carter', 'michael Whelan', 'science fantasy', 'swords Against Darkness II', 'witch World'], [1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1])
20521626 : (['4669 Høder', 'hoder (Marvel Comics)', 'hossein Derakhshan', 'susanne Hoder'], [1, 1, 1, 1])


In [89]:
dict_head_random(wikidataset.id_to_title, n=3)

20570082 : 1999 Lipton Championships – Men's singles
20468360 : valentin Miculescu
20473962 : linda Wang


In [90]:
dict_head_random(wikidataset.redirects, n=3)

rKPK : gimhae International Airport
wenham, Massachusetts (Plymouth County) : carver, Massachusetts
cystourethroscopy : cystoscopy


## Graph XML handler

In [11]:
class WikiXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self, matches=None):
        xml.sax.handler.ContentHandler.__init__(self)
        
        """
        PARSING VARIABLES: these variables will be useful
        while parsing the wikipedia dumps.
        """
        #basic storage for on-the-fly processing
        self._buffer = None
        self._values = {}
        self._current_tag = None
        
        #flags for handling special cases.
        self._add_page = True
        self._is_pageid = True
        
        """
        STORAGE VARIABLES: these variable will be used for
        storing the graph and content of the wikipedia dump.
        """
        self.wikidataset = WikiGraphDataset(matches=matches)
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title', 'text', 'ns'):
            self._current_tag = name
            self._buffer = []
            
        elif name == 'id' and self._is_pageid:
            self._is_pageid = False
            self._current_tag = name
            self._buffer = []
            
        elif name == 'redirect':
            article_title = self._values['article_title'].strip()
            target_title = attrs.getValue('title').strip()
            self.wikidataset.add_redirect(article_title, target_title)
            self._add_page = False
                   
    def endElement(self, name):
        if name == self._current_tag:
            self._values[f'article_{name}'] = ' '.join(self._buffer)
            self._current_tag = None
            
        elif name == 'page':
            if int(self._values['article_ns']):
                self._add_page = False
                
            """
            EXTRACT_DATA : The following code stores the data
            """
            if self._add_page:
                self.wikidataset.add_article(**self._values)
            """
            EXTRACT_DATA
            """
            
            self._add_page = True
            self._is_pageid = True
            

## Create graph

In [11]:
def create_graph(data_path, save_dir, matches=None, limit=None, save=True, tag_extractor=None, 
                 back=True, verbose=True):
    
    handler = WikiXmlHandler(matches=matches)
    
    parser = xml.sax.make_parser()
    parser.setContentHandler(handler)
    
    for i, line in enumerate(subprocess.Popen(['bzcat'],
                                             stdin = open(data_path),
                                             stdout = subprocess.PIPE).stdout):
        try:
            parser.feed(line)
        except StopIteration:
            break
            
        if limit is not None and len(handler.wikidataset.id_to_title) >= limit:
            if back:
                return handler.wikidataset
            else:
                break
    
    if save:
        file_tag = ''
        if isinstance(tag_extractor, re.Pattern):
            parts_tag = tag_extractor.match(os.path.basename(data_path))
            if parts_tag:
                try:
                    file_tag = f'-{parts_tag.group(1)}-{parts_tag.group(2)}-{parts_tag.group(3)}'
                except:
                    file_tag = ''
        elif isinstance(tag_extractor, str):
            file_tag = tag_extractor
            
        handler.wikidataset.save_data(save_dir, tag=file_tag)
        if verbose:
            print(f"** Completed processing {os.path.basename(data_path)}.", end='\r')

    del handler
    del parser
    gc.collect()
    
    return None

In [12]:
save_dir = f'{dataset_home}/test_data'

tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')

matches = r'^([Ss]ee[ ]*|[Ss]ee[ ]*([Aa]lso|[Mm]ore|[Aa]ll)|[Ss]ee[ ]*[Aa]lso[ ]*\(.+\))$'

In [136]:
data_path = partitions[0]
create_graph(data_path, save_dir, matches=matches, limit=100, save=True, tag_extractor=tag_extractor, back=False)

** Completed processing enwiki-20220420-pages-articles-multistream1.xml-p1p41242.bz2.

In [151]:
for data_path in partitions[0:4]:
    create_graph(data_path, save_dir, matches=matches, limit=100, save=True, tag_extractor=tag_extractor, 
                 back=False)

** Completed processing enwiki-20220420-pages-articles-multistream11.xml-p6899367p7054859.bz2.

## Extract Abstract

In [331]:
class AbstractXmlHandler(xml.sax.handler.ContentHandler):
    
    def __init__(self):
        xml.sax.handler.ContentHandler.__init__(self)
        self._buffer = None
        self._values = {}
        self._current_tag = None
        self._article_count = 0
        
        self._abstract = {}
        
    def characters(self, content):
        if self._current_tag:
            self._buffer.append(content)
            
    def startElement(self, name, attrs):
        if name in ('title','abstract'):
            self._current_tag = name
            self._buffer = []
            
    def endElement(self, name):
        if name == self._current_tag:
            self._values[name] = ' '.join(self._buffer)
            self._current_tag = None
                
        elif name == 'doc':
            self._article_count += 1
            cat, title = self._values['title'].split(':', maxsplit=1)
            title = title.strip()
            title = title[0].lower() + title[1:]
            self._abstract[title.strip()] = self._values['abstract'].strip()
            

In [332]:
class WikiAbstract:
    
    def __init__(self, abstract_paths=None):
        self.abstract = {}
        self.abstract_paths = abstract_paths
        
    def parse_abstract(self, abstract_file):
        handler = AbstractXmlHandler()

        parser = xml.sax.make_parser()
        parser.setContentHandler(handler)

        for i, line in enumerate( subprocess.Popen(['gzip','-cd'],
                                                  stdin=open(abstract_file),
                                                  stdout=subprocess.PIPE ).stdout ):
            try:
                parser.feed(line)
            except StopIteration:
                break

        return handler._abstract
    
    def create_abstract(self, processes=15):
        pool = Pool(processes=processes)
        if self.abstract_paths is None:
            raise Exception("Abstract paths are empty.")
        for abstract in tqdm.tqdm(pool.imap_unordered(self.parse_abstract, self.abstract_paths), 
                                           total=len(self.abstract_paths)):
            self.abstract.update(abstract)
    
    def save_abstract(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        abstract_file = f'{save_dir}/abstract{tag}.pickle'
        with open(abstract_file, 'wb') as f:
            pickle.dump(self.abstract, f)
            
    def load_abstract(self, save_dir, tag=''):
        abstract_file = f'{save_dir}/abstract{tag}.pickle'
        if os.path.exists(abstract_file):
            with open(abstract_file, 'rb') as f:
                self.abstract = pickle.load(f)
        else:
            raise Exception(f"Unable to load abstract from '{abstract_file}'.")
            

In [335]:
prog = re.compile(f'({project}-{dump_date})-'+r'abstract([0-9]{1,2}).xml.gz')

abstract_paths = sorted([f'{dataset_home}/datasets/{file}' for file in os.listdir(f'{dataset_home}/datasets') 
                     if prog.match(file)])

print(f'Total number of partitions of the wikipedia dump : {len(abstracts)}')

Total number of partitions of the wikipedia dump : 27


In [336]:
abstract = WikiAbstract(abstract_paths)
abstract.create_abstract()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27/27 [01:29<00:00,  3.32s/it]


In [337]:
abstract.save_abstract(save_dir, tag=f'-{project}-{dump_date}')

## Combine redirects, id_to_title and wiki_content

In [191]:
def multiprocessor_1(func, tasks, num_process=10):
    pool = Pool(processes=num_process)
    results = {}
    for x in tqdm.tqdm( pool.imap(func, tasks), total=len(tasks)):
        results.update(x)
    return results


In [192]:
def extract_filetag(data_path, tag_extractor=None):
    file_tag = ''
    if isinstance(tag_extractor, re.Pattern):
        parts_tag = tag_extractor.match(os.path.basename(data_path))
        if parts_tag:
            try:
                file_tag = f'-{parts_tag.group(1)}-{parts_tag.group(2)}-{parts_tag.group(3)}'
            except:
                file_tag = ''
    elif isinstance(tag_extractor, str):
        file_tag = tag_extractor
    return file_tag


In [263]:
class WikiGraphCombine:
    
    def __init__(self, partition_files=None):
        self.partition_files = partition_files
        
        self.parameters = {}
        self.parameters['id_to_title'] = {}
        self.parameters['redirects'] = {}
        self.parameters['wiki_content'] = {}
        
        self.parameters['seealso_graph'] = {}
        self.parameters['articles_graph'] = {}
    
    def load_param(self, data_path, save_dir, param='id_to_title', graph_resolved=True, tag_extractor=None):
        wikidataset = WikiGraphDataset()
        tag = extract_filetag(data_path, tag_extractor)
        
        if param == 'id_to_title':
            wikidataset.load_idtotitle(save_dir=save_dir, tag=tag)
            return wikidataset.id_to_title
        elif param == 'redirects':
            wikidataset.load_redirects(save_dir=save_dir, tag=tag)
            return wikidataset.redirects
        elif param == 'wiki_content':
            wikidataset.load_wikicontent(save_dir=save_dir, tag=tag)
            return wikidataset.wiki_content
        elif param == 'seealso_graph':
            if graph_resolved:
                tag = f'{tag}_resolved'
            wikidataset.load_graph(save_dir=save_dir, tag=tag, graph_type='seealso')
            self.convert_graph(wikidataset.seealso_graph.graph)
            return wikidataset.seealso_graph.graph
        elif param == 'articles_graph':
            if graph_resolved:
                tag = f'{tag}_resolved'
            wikidataset.load_graph(save_dir=save_dir, tag=tag, graph_type='articles')
            self.convert_graph(wikidataset.article_graph.graph)
            return wikidataset.article_graph.graph
        else:
            raise Exception(f'Invalid value of param : {param}')
        
    def combine_param(self, save_dir, param='id_to_title', tag_extractor=None):
        combine_helper = partial(self.load_param, param=param, save_dir=save_dir, tag_extractor=tag_extractor)
        self.parameters[param] = multiprocessor_1(combine_helper, self.partition_files)
        
    def remove_seealso_edges(self):
        seealso = self.parameters['seealso_graph']
        articles = self.parameters['articles_graph']
        
        if seealso and articles:
            for node_1, edge_count in seealso.items():
                for node_2 in edge_count:
                    if node_1 in articles and node_2 in articles[node_1]:
                        del articles[node_1][node_2]
                        if len(articles[node_1]) == 0:
                            del articles[node_1]
                    if node_2 in articles and node_1 in articles[node_2]:
                        del articles[node_2][node_1]
                        if len(articles[node_2]) == 0:
                            del articles[node_2]
                    
    def convert_graph(self, graph):
        if len(graph) and isinstance(graph, dict):
            key = list(graph.keys())[0]
            if isinstance(graph[key], tuple):
                for doc, (edges, counts) in graph.items():
                    graph[doc] = {e:c for e, c in zip(edges, counts)}
            elif isinstance(graph[key], dict):
                for doc, edge_count in graph.items():
                    graph[doc] = (list(edge_count.keys()), list(edge_count.values()))
            else:
                raise Exception("Invalid graph format.")
                
    def save_idtotitle(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        with open(map_file, 'wb') as f:
            pickle.dump(self.parameters['id_to_title'], f)
            
    def load_idtotitle(self, save_dir, tag=''):
        map_file = f'{save_dir}/id_to_title{tag}.pickle'
        with open(map_file, 'rb') as f:
            self.parameters['id_to_title'] = pickle.load(f)
            
    def save_wikicontent(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'wb') as f:
            pickle.dump(self.parameters['wiki_content'], f)
            
    def load_wikicontent(self, save_dir, tag=''):
        content_file = f'{save_dir}/wiki_content{tag}.pickle'
        with open(content_file, 'rb') as f:
            self.parameters['wiki_content'] = pickle.load(f)
            
    def save_redirects(self, save_dir, tag=''):
        os.makedirs(save_dir, exist_ok=True)
        redirect_file = f'{save_dir}/redirects{tag}.pickle'
        with open(redirect_file, 'wb') as f:
            pickle.dump(self.parameters['redirects'], f)
            
    def load_redirects(self, save_dir, tag=''):
        redirect_file = f'{save_dir}/redirects{tag}.pickle'
        with open(redirect_file, 'rb') as f:
            self.parameters['redirects'] = pickle.load(f)
            
    def save_graph(self, save_dir, tag='', graph_type='seealso'):
        if graph_type != 'seealso' and graph_type != 'articles':
            raise Exception("graph_type can not take in ['seealso', 'articles']")
            
        os.makedirs(save_dir, exist_ok=True)
        graph_file = f'{save_dir}/link_graph_{graph_type}{tag}.pickle'
        with open(graph_file, 'wb') as f:
            pickle.dump(self.parameters[f'{graph_type}_graph'], f)
            
    def load_graph(self, save_dir, tag='', graph_type='seealso'):
        if graph_type != 'seealso' and graph_type != 'articles':
            raise Exception("graph_type can not take in ['seealso', 'articles']")
            
        graph_file = f'{save_dir}/link_graph_{graph_type}{tag}.pickle'
        with open(graph_file, 'rb') as f:
            self.parameters[f'{graph_type}_graph'] = pickle.load(f)
            
    def remove_nodes_from_articles(self, nodes):
        articles = self.parameters['articles_graph']
        if not articles:
            raise Exception('Article graph is empty.')
            
        for node in tqdm.tqdm( list(articles.keys()) ):
            if node in nodes:
                del articles[node]
            else:
                for edge in list(articles[node].keys()):
                    if edge in nodes:
                        del articles[node][edge]
                        

In [359]:
tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')
results_dir = f'{dataset_home}/results'

save_dir = f'{dataset_home}/combined'

combine_graph = WikiGraphCombine(partitions)

In [360]:
combine_graph.combine_param(results_dir, param='id_to_title', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:05<00:00, 12.39it/s]


In [17]:
combine_graph.combine_param(results_dir, param='redirects', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [02:38<00:00,  2.55s/it]


In [361]:
combine_graph.save_idtotitle(save_dir, tag=f'-{project}-{dump_date}')
combine_graph.save_redirects(save_dir, tag=f'-{project}-{dump_date}')

In [19]:
len(combine_graph.parameters['id_to_title']), len(combine_graph.parameters['redirects'])

(6472153, 9997267)

### Testing

In [182]:
combine_graph = WikiGraphCombine(partitions)

In [130]:
graph = combine_graph.load_param(partitions[0], results_dir, param='articles_graph', tag_extractor=tag_extractor)

In [132]:
dict_head_random(graph, n=2)

2722 : {31801: 1, 156665: 2, 31784: 1, 30653: 1, 20647050: 1, 29380512: 1, 26741: 1, 210815: 1, 223272: 1, 3720257: 1, 196221: 1, 170691: 1, 1021843: 1, 9313566: 1, 21287: 1, 3641632: 1, 17616: 1, 191035: 1, 14627: 1, 191032: 1, 9903331: 1, 395934: 1, 10303: 1, 20611356: 1, 145381: 1, 9334: 1, 26554345: 1, 19593040: 2, 5233: 1, 49658: 1, 580: 1, 16107068: 1, 51428781: 1}
26079 : {145439: 1, 19666611: 2, 19817390: 1, 844674: 1, 195149: 1, 31769: 1, 2288455: 1, 573136: 1, 30635: 1, 6193480: 1, 1374719: 1, 39565408: 1, 48990: 1, 26961: 1, 7959: 1, 33703952: 1, 7980471: 1, 566942: 1, 66230: 1, 11600: 1, 38355446: 1, 113248: 1, 159284: 1, 178755: 1, 1795147: 1, 1108867: 3, 707374: 1, 16826: 3, 38420: 1, 15245: 1, 15012: 1, 178778: 1, 291717: 1, 14653: 1, 68483773: 1, 332667: 1, 737: 2, 280746: 1, 13831: 2, 46530: 1, 14128: 1, 80949: 2, 68866725: 1, 967570: 2, 57979: 1, 66468: 1, 50465: 2, 18716728: 1, 411381: 1, 4509682: 1, 3768: 1, 1570983: 1, 2963: 1, 994610: 1, 2851029: 1, 178759: 1, 605

In [187]:
a = {1: {1:1, 2:2, 3:3}, 2 : {4:4, 5:5, 6:6}}
s = {1: {1:1, 2:1}, 2: {7:1, 8:1, 9:1}}

In [188]:
combine_graph.parameters['seealso_graph'] = s
combine_graph.parameters['articles_graph'] = a

In [189]:
combine_graph.remove_seealso_edges()

In [190]:
a

{1: {3: 3}, 2: {4: 4, 5: 5, 6: 6}}

## Resolving Graph

In [81]:
def multiprocessor_2(func, tasks, num_process=10):
    pool = Pool(processes=num_process)
    results = []
    for x in tqdm.tqdm( pool.imap(func, tasks), total=len(tasks)):
        results.append(x)
    return results


In [82]:
class ResolveGraph:
    
    def __init__(self, partition_files, id_to_title, redirects):
        self.partition_files = partition_files
        self.id_to_title = id_to_title
        self.redirects = redirects
        
        self.title_to_id = {article_title:article_id for article_id, article_title in id_to_title.items()}
        
    def resolver(self, data_path, save_dir, graph_type='seealso', tag_extractor=None):
        graph = WikilinkGraph()
        tag = extract_filetag(data_path, tag_extractor)
            
        tag=f'_{graph_type}{tag}'
        graph.load_data(save_dir, tag)

        graph.replace_redirects(self.redirects)
        graph.remove_dead(self.title_to_id)

        tag = f'{tag}_resolved'
        graph.save_data(save_dir, tag)
            
    def resolve(self, save_dir, graph_type='seealso', tag_extractor=None):
        resolve_helper = partial(self.resolver, save_dir=save_dir, graph_type=graph_type, 
                                 tag_extractor=tag_extractor)
        multiprocessor_2(resolve_helper, self.partition_files)
        

In [22]:
combined_dir = f'{dataset_home}/combined'

combined_graph = WikiGraphCombine()
combined_graph.load_idtotitle(combined_dir, tag=f'-{project}-{dump_date}')
combined_graph.load_redirects(combined_dir, tag=f'-{project}-{dump_date}')

In [23]:
resolve_graph = ResolveGraph(partitions, 
                             combined_graph.parameters['id_to_title'], 
                             combined_graph.parameters['redirects'])

In [24]:
tag_extractor = re.compile(f'({project}-{dump_date})-'+r'pages-articles-multistream([0-9]{1,2}).xml-(p[0-9]+p[0-9]+).bz2')
results_dir = f'{dataset_home}/results'

In [25]:
resolve_graph.resolve(results_dir, graph_type='seealso', tag_extractor=tag_extractor)

IOStream.flush timed out
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [11:04<00:00, 10.72s/it]


In [26]:
resolve_graph.resolve(results_dir, graph_type='articles', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [12:17<00:00, 11.89s/it]


### Testing

In [130]:
resolve_graph = ResolveGraph(partitions, 
                             combined_graph.parameters['id_to_title'], 
                             combined_graph.parameters['redirects'])

resolve_graph.resolver(partitions[0], results_dir, tag_extractor=tag_extractor)

_seealso-enwiki-20220420-1-p1p41242_resolved


In [131]:
g = WikilinkGraph()
g.load_data(results_dir, tag='_seealso-enwiki-20220420-1-p1p41242_resolved')

True

In [134]:
dict_head_random(g.graph, n=5)

7521 : ([10226249, 20606419], [1, 1])
10989 : ([62036856], [1])
21840 : ([39352291, 24176661, 7980534, 4830379, 16714118, 11127], [1, 1, 1, 1, 1, 1])
27917 : ([9448227], [1])
12244 : ([], [])


## Combine seealso and articles graphs

In [195]:
combine_graph = WikiGraphCombine(partitions)

In [196]:
combine_graph.combine_param(results_dir, param='seealso_graph', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [00:23<00:00,  2.67it/s]


In [197]:
combine_graph.combine_param(results_dir, param='articles_graph', tag_extractor=tag_extractor)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 62/62 [33:23<00:00, 32.31s/it]


In [198]:
combine_graph.remove_seealso_edges()

In [199]:
combine_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='seealso')
combine_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='articles')

### Testing

In [40]:
dict_head_random(combine_graph.parameters['seealso_graph'], n=4)

47708265 : ([66701603], [1])
639449 : ([922759], [1])
594590 : ([503179, 870352, 9776805, 22318958, 201457, 12898, 5630, 1789184], [1, 1, 1, 1, 1, 1, 1, 1])
270061 : ([58345806, 755393], [1, 1])


In [38]:
dict_head_random(combine_graph.parameters['articles_graph'], n=4)

44922622 : ([167774, 60919, 23592200, 209935, 6853403, 582311, 6852678, 1025821, 24116, 7141683, 18403271, 1336242, 19391, 158483, 9137495, 1036865, 34758894, 420561, 577429, 1739558, 4237986, 13345289, 50375706], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
40682609 : ([7348713, 173533, 14653, 530348, 33490080, 40681711], [1, 1, 1, 1, 1, 1])
50012208 : ([31222, 28436, 606848, 226651, 240090, 22428, 84349, 12966567, 15786696, 2947544, 180283, 46381336, 1196118, 14024916, 4092, 6594080], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
18589767 : ([469047, 72155, 7293, 4015, 26757970, 510215], [1, 1, 1, 1, 1, 1])


## Creating Test-Train split

In [205]:
def prune_map(mapping, idxs):
    rev_mapping = {value:key for key, value in mapping.items()}
    pruned_mapping = {}
    for i, idx in enumerate(idxs):
        pruned_mapping[rev_mapping[idx]] = i
    return pruned_mapping

def split_count(num_samples, perc=0.7):
    num_train = int( np.ceil(num_samples*perc) )
    if num_train == num_samples and num_samples > 1:
        num_train -= 1
    return num_train


In [287]:
class WikipediaSplit:
    
    def __init__(self, graph=None):
        self.graph = None
        self.labels = None
        self.doc_to_rowindex = None
        
        if graph:
            _ = self.to_matrix(graph)
        
        self.train, self.test = None, None
        self.trn_tst_labels = None
        self.train_doc_to_rowindex, self.test_doc_to_rowindex = None, None
        
    def to_matrix(self, graph):
        indptr = [0]
        indices = []
        data = []

        self.doc_to_rowindex = {}

        self.labels = {}
        for i, (doc, edge_count) in enumerate(graph.items()):
            self.doc_to_rowindex[doc] = i
            
            for link, cnt in edge_count.items():
                index = self.labels.setdefault(link, len(self.labels))
                indices.append(index)
                data.append(cnt)
            indptr.append(len(indices))

        self.graph = csr_matrix((data, indices, indptr), dtype=int)
        return self.graph, self.labels, self.doc_to_rowindex
    
    def clean_matrix(self, clean_type=0):
        if clean_type == 0:
            self.graph, self.labels, self.doc_to_rowindex = self.remove_single_labels(self.graph,
                                                                                      self.labels,
                                                                                      self.doc_to_rowindex)
        elif clean_type == 1:
            pruned_rows = self.get_pruned_row(self.graph)
            self.graph, self.doc_to_rowindex = self.prune_graph_rows(self.graph,
                                                                     self.doc_to_rowindex,
                                                                     pruned_rows)
        else:
            pruned_cols = self.get_pruned_cols(self.graph)
            self.graph, self.labels = self.prune_graph_cols(self.graph,
                                                           self.labels,
                                                           pruned_cols)
    
    """
    def get_split_idx(self):
        train_rowidx = []
        test_rowidx = []
        
        num_rows, num_cols = self.graph.shape
        row_available_flag = np.ones(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.sum(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        for lcnt in uni_label_cnt:
            if lcnt > upper_threshold:
                break
                
            label_pos = np.where(label_cnt == lcnt)[0]
            row_numlabels = np.array(self.graph[:, label_pos].sum(axis=1)).reshape(-1)
            
            row_idx = np.where(row_numlabels > 0)[0]
            row_available_flag[row_idx] = False
            num_train = split_count(row_idx.shape[0], perc=0.7)
            row_idx = list(np.random.permutation(row_idx))

            train_rowidx.extend(row_idx[:num_train])
            test_rowidx.extend(row_idx[num_train:])
        
        row_idx = np.where(row_available_flag == True)[0]
        num_train = split_count(row_idx.shape[0], perc=0.7)
        row_idx = list(np.random.permutation(row_idx))
        train_rowidx.extend(row_idx[:num_train])
        test_rowidx.extend(row_idx[num_train:])
        
        return train_rowidx, test_rowidx
        """
    
    def get_split_idx(self, upper_threshold=10, perc=0.7):
        train_rowidx = []
        test_rowidx = []
        
        row_idxs, col_idxs = self.graph.nonzero()
        sort_idx = np.argsort(col_idxs)
        row_idxs = row_idxs[sort_idx]
        col_idxs = col_idxs[sort_idx]
        
        num_rows = self.graph.shape[0]
        row_inserted_flag = np.zeros(num_rows, dtype=bool)
        
        label_cnt = np.array(self.graph.sum(axis=0)).reshape(-1)
        uni_label_cnt = np.unique(label_cnt)
        
        cnt = 0
        for lcnt in uni_label_cnt:
            if cnt == num_rows or lcnt >= upper_threshold:
                break
                
            pos_ptr, col_ptr = 0, 0
            pos_idxs = np.where(label_cnt == lcnt)[0]
            pos_idxs.sort()

            while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and cnt < num_rows:
                if pos_idxs[pos_ptr] != col_idxs[col_ptr]:
                    col_ptr += 1
                else:
                    sample_row_idxs = []
                    while col_ptr < len(col_idxs) and pos_ptr < len(pos_idxs) and cnt < num_rows and pos_idxs[pos_ptr] == col_idxs[col_ptr]:
                        rn = row_idxs[col_ptr]
                        if not row_inserted_flag[rn]:
                            sample_row_idxs.append(rn)
                            row_inserted_flag[rn] = True
                            cnt += 1
                        col_ptr += 1
                    pos_ptr += 1

                    num_train = split_count(len(sample_row_idxs), perc=0.7)
                    sample_row_idxs = list(np.random.permutation(sample_row_idxs))
                    train_rowidx.extend(sample_row_idxs[:num_train])
                    test_rowidx.extend(sample_row_idxs[num_train:])
                    
        sample_row_idxs = np.where(row_inserted_flag == False)[0]
        num_train = split_count(num_rows, perc=0.7)
        num_train -= len(train_rowidx)
        
        sample_row_idxs = list(np.random.permutation(sample_row_idxs))
        train_rowidx.extend(sample_row_idxs[:num_train])
        test_rowidx.extend(sample_row_idxs[num_train:])
        
        return train_rowidx, test_rowidx
    
    def get_split_bylabel(self, upper_threshold=10, perc=0.7):
        """
        splitting data into train-test
        """
        train_idx, test_idx = self.get_split_idx(upper_threshold, perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        self.trn_tst_labels = self.labels
        
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        """
        pruning the columns
        """
        trn_pruned_cols = self.get_pruned_cols(self.train)
        tst_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(trn_pruned_cols, tst_pruned_cols)
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.trn_tst_labels, pruned_cols)
        
        """
        prunning the rows
        """
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex, pruned_rows)
        
    def get_random_split_idx(self, perc=0.7):
        n_docs = self.graph.shape[0]
        n_trn = int(perc * n_docs)
        rand_idx = np.random.permutation(n_docs)
        return rand_idx[:n_trn], rand_idx[n_trn:]
    
    def get_pruned_cols(self, graph, count=0):
        label_cnt = np.array(graph.sum(axis=0)).reshape(-1)
        pruned_cols = np.where(label_cnt > count)[0]
        return pruned_cols
    
    def get_pruned_row(self, graph, count=0):
        pruned_rows = np.where( np.array(graph.sum(axis=1)).reshape(-1) > count )[0]
        return pruned_rows
    
    def prune_graph_cols(self, graph, labels, pruned_cols):
        graph = graph[:, pruned_cols]
        labels = prune_map(labels, pruned_cols)
        return graph, labels
    
    def prune_graph_rows(self, graph, doc_to_rowindex, pruned_rows):
        graph = graph[pruned_rows, :]
        doc_to_rowindex = prune_map(doc_to_rowindex, pruned_rows)
        return graph, doc_to_rowindex
    
    def remove_single_labels(self, graph, labels, doc_to_rowindex):
        pruned_cols = self.get_pruned_cols(graph, count=1)
        graph, labels = self.prune_graph_cols(graph, labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(graph)
        graph, doc_to_rowindex = self.prune_graph_rows(graph, doc_to_rowindex, pruned_rows)

        return graph, labels, doc_to_rowindex
        
    def get_split_byrandom(self, perc=0.7):
        train_idx, test_idx = self.get_random_split_idx(perc)
        
        self.train = self.graph[train_idx, :]
        self.test = self.graph[test_idx, :]
        rowindex_to_doc = {row_idx:doc for doc, row_idx in self.doc_to_rowindex.items()}
        self.train_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(train_idx)}
        self.test_doc_to_rowindex = {rowindex_to_doc[idx]:i for i, idx in enumerate(test_idx)}
        
        train_pruned_cols = self.get_pruned_cols(self.train)
        test_pruned_cols = self.get_pruned_cols(self.test)
        pruned_cols = np.intersect1d(train_pruned_cols, test_pruned_cols)
        
        self.train = self.train[:, pruned_cols]
        self.test = self.test[:, pruned_cols]
        self.trn_tst_labels = prune_map(self.labels, pruned_cols)
        
        pruned_rows = self.get_pruned_row(self.train)
        self.train, self.train_doc_to_rowindex = self.prune_graph_rows(self.train, 
                                                                       self.train_doc_to_rowindex, 
                                                                       pruned_rows)
        pruned_rows = self.get_pruned_row(self.test)
        self.test, self.test_doc_to_rowindex = self.prune_graph_rows(self.test,
                                                                     self.test_doc_to_rowindex,
                                                                     pruned_rows)
        
    def save_data(self, save_dir, tag='seealso'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'wb') as fout:
            train = (self.trn_tst_labels, self.train_doc_to_rowindex, self.train)
            pickle.dump(train, fout)
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'wb') as fout:
            test = (self.trn_tst_labels, self.test_doc_to_rowindex, self.test)
            pickle.dump(test, fout)
            
    def load_data(self, save_dir, tag='seealso'):
        train_file = f'{save_dir}/{tag}_train.pkl'
        with open(train_file, 'rb') as fout:
            train = pickle.load(fout)
            self.trn_tst_labels, self.train_doc_to_rowindex, self.train = train
            
        test_file = f'{save_dir}/{tag}_test.pkl'
        with open(test_file, 'rb') as fout:
            test = pickle.load(fout)
            _, self.test_doc_to_rowindex, self.test = test
            

In [249]:
graph = WikiGraphCombine(partitions)
graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='seealso')

In [250]:
seealso_graph = graph.parameters['seealso_graph']

In [253]:
dict_head_random(seealso_graph, n=5)

26856928 : {13288203: 1, 9889883: 1, 12032646: 1, 26418726: 1, 29344952: 1}
2164617 : {5073318: 1, 42235: 1, 4678878: 1}
37510133 : {4449080: 1}
59690370 : {650912: 1, 41916954: 1, 58969921: 1, 748825: 1, 1979577: 1, 15324748: 1, 213739: 1}
35889821 : {13908785: 1, 30609033: 1}


In [254]:
data_splitter = WikipediaSplit(seealso_graph)

data_splitter.clean_matrix()
data_splitter.get_split_bylabel(upper_threshold=10)

In [255]:
data_splitter.train, data_splitter.test

(<668552x333243 sparse matrix of type '<class 'numpy.int64'>'
 	with 1401886 stored elements in Compressed Sparse Row format>,
 <293023x333243 sparse matrix of type '<class 'numpy.int64'>'
 	with 712681 stored elements in Compressed Sparse Row format>)

In [288]:
data_splitter.save_data(save_dir, tag='seealso')

### Testing

In [258]:
trn_lbl_cnt = np.array(data_splitter.train.sum(axis=0)).reshape(-1)
trn_lbl_cnt_pos = np.where(trn_lbl_cnt == 0)[0]
len(trn_lbl_cnt_pos)

0

In [259]:
tst_lbl_cnt = np.array(data_splitter.test.sum(axis=0)).reshape(-1)
tst_lbl_cnt_pos = np.where( tst_lbl_cnt == 0 )[0]
len(tst_lbl_cnt_pos)

0

## Delete test nodes from the graph

In [264]:
save_dir = f'{dataset_home}/combined'

articles_graph = WikiGraphCombine(partitions)
articles_graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='articles')

In [266]:
len(articles_graph.parameters['articles_graph'])

6468771

In [269]:
articles_graph.remove_nodes_from_articles(data_splitter.test_doc_to_rowindex)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 6468771/6468771 [00:21<00:00, 299493.30it/s]


In [270]:
articles_graph.save_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='articles')

## XC dataset format

In [368]:
class XCDataset:
    
    def __init__(self, id_to_title, abstract):
        self.id_to_title = id_to_title
        self.abstract = abstract
        
    def save_text(self, save_dir, doc_to_rowindex, tag='train'):
        os.makedirs(save_dir, exist_ok=True)
        map_file = f'{save_dir}/{tag}_map.txt'
        with open(map_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                fout.write(f'{doc_id}->{self.id_to_title[doc_id]}\n')
        
        raw_file = f'{save_dir}/{tag}_raw_text.txt'
        with open(raw_file, 'w') as fout:
            for doc_id in sorted(doc_to_rowindex, key=lambda x : doc_to_rowindex[x]):
                if self.id_to_title[doc_id] in self.abstract:
                    fout.write(f'{doc_id}->{self.abstract[self.id_to_title[doc_id]]}\n')
                else:
                    fout.write(f'{doc_id}->\n')
                    
    def save_sparse_file(self, matrix, save_file):
        matrix.sort_indices()
        rows, cols = matrix.nonzero()
        data = matrix.data

        with open(save_file, 'w') as fout:
            fout.write(f'{matrix.shape[0]} {matrix.shape[1]}\n')

            row_ctr = -1
            line = ''
            for r, c, d in zip(rows, cols, data):
                if row_ctr == r:
                    line = line+f' {c}:{d}'
                elif row_ctr+1 == r:
                    if line:
                        fout.write(f'{line}\n')
                    line = f'{c}:{d}'
                    row_ctr += 1
                else:
                    raise Exception("Row is missing.")
                    

In [369]:
save_dir = f'{dataset_home}/combined'
xc_dir = f'{dataset_home}/XCData'

combined_graph = WikiGraphCombine()
combined_graph.load_idtotitle(save_dir, tag=f'-{project}-{dump_date}')

abstract = WikiAbstract()
abstract.load_abstract(save_dir, tag=f'-{project}-{dump_date}')

In [370]:
xc_data = XCDataset(combine_graph.parameters['id_to_title'], abstract.abstract)

In [371]:
data_splitter = WikipediaSplit()
data_splitter.load_data(save_dir, tag='seealso')

In [372]:
data_splitter.train, data_splitter.test

(<668552x333243 sparse matrix of type '<class 'numpy.int64'>'
 	with 1401886 stored elements in Compressed Sparse Row format>,
 <293023x333243 sparse matrix of type '<class 'numpy.int64'>'
 	with 712681 stored elements in Compressed Sparse Row format>)

In [373]:
xc_data.save_text(xc_dir, data_splitter.train_doc_to_rowindex, tag='train')
xc_data.save_text(xc_dir, data_splitter.test_doc_to_rowindex, tag='test')
xc_data.save_text(xc_dir, data_splitter.trn_tst_labels, tag='label')

In [377]:
train_file = f'{xc_dir}/trn_X_Y.txt'
xc_data.save_sparse_file(data_splitter.train, train_file)

In [378]:
test_file = f'{xc_dir}/tst_X_Y.txt'
xc_data.save_sparse_file(data_splitter.test, test_file)

In [385]:
combined_graph.load_graph(save_dir, tag=f'-{project}-{dump_date}', graph_type='articles')

In [388]:
graph = WikipediaSplit(combined_graph.parameters['articles_graph'])

In [389]:
graph.graph

<6176501x5471236 sparse matrix of type '<class 'numpy.int64'>'
	with 131076968 stored elements in Compressed Sparse Row format>

In [None]:
xc_data.save_text(xc_dir, graph.doc_to_rowindex , tag='graph_train')
xc_data.save_text(xc_dir, graph.labels, tag='graph_label')