# Plotting data from SQLite database

This notebook queries the databases to generate a matrix of plots for the number of articles/images per year, run separately for each category.

This code reproduces the average number of images per article by year top-16 plot.

For more on plots, see https://github.com/re-imaging/re-imaging/blob/master/sqlite-scripts/db_plots.ipynb

## Structure

- setup
- load list of categories
- pull specific data (and save as pickle)
- format data
- generate plot
- save image

Notebook is intended to be navigated and blocks to be run selectively, rather than the whole notebook being executed.

## Setup

Import required libraries, connect to SQLite database, create cursor, fetch table info

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import numpy as np
import sqlite3
import pickle
import copy
import json
import math
import pandas as pd
import os

In [None]:
# import the sqlite3 database and create a cursor
db_path = os.path.expanduser("~/data/db/arxiv_db_images.sqlite3")
db = sqlite3.connect(db_path)
c = db.cursor()

In [None]:
c.execute('PRAGMA TABLE_INFO({})'.format("metadata"))
info = c.fetchall()

print("\nColumn Info:\nID, Name, Type, NotNull, DefaultVal, PrimaryKey")
for col in info:
    print(col)

## Build category lists

First get a full list of all the primary categories by querying the SQLite database. Used to do later queries. Select a sort_mode first.

In [None]:
# `catlist` refers to categories by article

print("pulling all categories by total number of articles")
c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(metadata.identifier)
    FROM metadata
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY count(metadata.identifier) DESC    
    ''')
rows = c.fetchall()
articles_catlist = rows
# uncomment if you want to use article count to order categories
# catlist = rows

print("pulling all categories by total number of images")
c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(images.identifier)
    FROM images
    LEFT JOIN metadata ON images.identifier = metadata.identifier 
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY count(images.identifier) DESC    
''')   
rows = c.fetchall()
images_catlist = rows

print("pulling all categories in alphabetical order")
c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(metadata.identifier)
    FROM metadata
    WHERE strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1) ASC
''')
rows = c.fetchall()
alpha_catlist = rows
catlist = rows
print("done")

In [None]:
print(f'The list of categories totals {len(catlist)}. Here are the first entries:')
print(catlist[:8])

## Finding the change in rank between number of articles and number of images

In [None]:
# structure
# [category, article-rank, images-rank, rank-difference]

ordering = [[cat[0], count, 0, 0] for count, cat in enumerate(articles_catlist)]
print(ordering)

In [None]:
# find the difference in category list when ordering by article numbers vs image numbers
acount = 0
for ac in articles_catlist:
#     print("articles:",acount)
    icount = 0
    for ic in images_catlist:
        if ac[0] == ic[0]:
#             print("match:",ac[0],ic[0])
            ordering[acount][2] = icount
        icount += 1 
    acount += 1

In [None]:
# go through and modify the rank-difference
for count, row in enumerate(ordering):
    ordering[count][3] = row[1] - row[2]

In [None]:
# [category, article-rank, images-rank, change]
for cat in ordering:
    print(cat)

## Generating Plots

### Article data

Then use that list of primary categories to query the database for how many articles per year. Store it in the `articles_data` variable.

In [None]:
# by category
# number of articles for each year

sql = ('''
    SELECT count(metadata.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM metadata
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    AND strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    print(rows)

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

article_data = data

# WRITE PKL

filename = "articles_cat_year_data.pkl"
with open(filename, "wb") as write_file:
    pickle.dump(article_data, write_file)

In [None]:
print(data)

In [None]:
# READ PKL

filename = "articles_cat_year_data.pkl"
with open(filename, "rb") as read_file:
    article_data = pickle.load(read_file)

In [None]:
print(f'length of article_data: {len(article_data)}')
print(article_data[:])

### Image data

Pull number of images in each year. Store in `images_data`.

In [None]:
# by category
# total number of images for each year

image_pkl_filename = "images_cat_year.pkl"

sql = ('''
    SELECT count(images.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM images
    LEFT JOIN metadata on images.identifier = metadata.identifier
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    AND strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    print(rows)
    
    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

image_data = data

# WRITE
image_pkl_filename = "images_cat_year_data.pkl"
with open(image_pkl_filename, "wb") as write_file:
    pickle.dump(image_data, write_file)

In [None]:
print(len(image_data))

In [None]:
# READ PKL
image_pkl_filename = "images_cat_year_data.pkl"
with open(image_pkl_filename, "rb") as read_file:
    image_data = pickle.load(read_file)

In [None]:
print(f'length of image_data: {len(image_data)}')
print(image_data[:10])

### Get percentages

In [None]:
# this code copies the article data then modifies it by 
# dividing the total number of images by the number of articles

average_data = copy.deepcopy(article_data)

for count, article in enumerate(average_data):
    
    # set all values to zero
    for index, val in enumerate(average_data[count][2]):
        average_data[count][2][index] = 0

#     average_data[count][2]
    for i, year in enumerate(image_data[count][1]):
        try:
            print("*" * 20)
            print(f'count {count} | year {year} | i {i}')
            print(article_data[count][1])
            listindex = article_data[count][1].index(str(year)) 
            print("listindex:",listindex)
            print("count:",count)
            print("i:",i)
            print("no images:",image_data[count][2][i])
            print("no articles:",article_data[count][2][listindex])
            print("average:",image_data[count][2][i] / article_data[count][2][listindex])
            average_data[count][2][listindex] = image_data[count][2][i] / article_data[count][2][listindex]
        except ValueError:
            print("!" * 20)
            print("didn't find index")

In [None]:
print(average_data[:10])

In [None]:
# copy data over for plotting
data = copy.deepcopy(average_data)

## Clean data
- remove any entries of "2019" in the years and articles columns of data (don't have full data for this year).
- rewrite all entries as integers rather than strings (otherwise there will be problems when adjusting the axes)
- find the minimum and maximum for any entries, so that we can set our axes later as needed.

Data is saved as nested lists in the format
```
[
    [cat1, [year1, year2...yearX], [totat1, total2...totalY]
    [cat2, [year1, year2...yearX], [totat1, total2...totalY]
    ...
    [catZ, [year1, year2...yearX], [totat1, total2...totalY]
]
```

### Clean year entries - make sure to run this!

In [None]:
# fix year entries appearing as string by going through each and re-writing as integer

for cat in data:
    for k in enumerate(cat[1]):
#         print(cat[1][k[0]])
        cat[1][k[0]] = int(cat[1][k[0]])

#### Remove years that aren't being used in plots

In [None]:
# get the index and remove from both the year and no. article lists

temp_data = data

# list of years to remove
years_to_remove = [2020, 2019]

for i, data_row in enumerate(data):
    num_years = len(data_row[1])
    for ii, data_year in enumerate(data_row[1][::-1]):
#         print(f'data_year: {data_year}, format: {type(data_year)}')
        for y in years_to_remove:
#             print(f'y: {y}, format: {type(y)}')
            if data_year == y:
                index = num_years - ii - 1
                print(f'found entry, i: {i} | ii: {ii} | data_year: {data_year} | index {index}')
                del(temp_data[i][2][index])
                del(temp_data[i][1][index])
                print("*" * 20)
                
data[:] = temp_data

In [None]:
# test to make sure there is still a total for each year

for cat in data:
    if len(cat[1]) != len(cat[2]):
        print("problem with category: " + cat)

In [None]:
count = 0
for d in data:
    for y in d[1]:
        if y == 2020 or y == 2019:
            count += 1
            print(f'found year {y}')
print("-" * 20)
print(f'found a total of {count} entries with year 2019 or 2020')

### Save data 
Interim progress, to prevent having to run SQL queries again : )
Save as either json file or pickle for reloading.
First set filename

In [None]:
# save_mode = "json"
save_mode = "pkl"

# filename = "articles_cat_year." + save_mode 
# filename = "images_cat_year." + save_mode 
filename = "average_images_article_cat_year." + save_mode

print(filename)

#### JSON
saves as human-readable JSON format

In [None]:
# WRITE

# with open(filename, "w") as write_file:
#     json.dump(data, write_file)

In [None]:
# READ
if save_mode == "json":
    load_data = []

    with open(filename, "r") as read_file:
        load_data = json.load(read_file)

#### pickle
Save data as a serialized file using pickle

In [None]:
#READ

if save_mode == "pkl":
    load_data = []

    with open(filename, "rb") as read_file:
        load_data = pickle.load(read_file)
        read_file.close()

In [None]:
# WRITE
# with open(filename, "wb") as write_file:
#     pickle.dump(data, write_file)
#     write_file.close()

#### Check data

In [None]:
data = load_data
print(load_data[:10])

## Generate plots

Load in the imported data in the `data` variable

In [None]:
for row in data[:10]:
    print(row)

### Load data

#### For average number of images per article by year top-16 plot

Just get the first 16 categories from the catlist, e.g. if generating top-16 by image count.

In [None]:
data = []
selected_cats = []
for c in images_catlist[:16]:
    selected_cats.append(c[0])
for c in selected_cats:
    for d in load_data:
        if c == d[0]:
            data.append(d)
print(len(selected_cats))
print(selected_cats)

In [None]:
# full names for giving plot titles
selected_cats_full = ["Computer Science: Computer Vision",
                      "High Energy Physics - Phenomenology",
                      "Astrophysics",
                      "Astrophysics of Galaxies",
                      "Computer Science: Machine Learning",
                      "Solar and Stellar Astrophysics",
                      "Cosmology and Nongalactic Astrophysics",
                      "Quantum Physics",
                      "High Energy Physics - Theory",
                      "High Energy Astrophysical Phenomena",
                      "Mesoscale and Nanoscale Physics",
                      "Strongly Correlated Electrons",
                      "Mathematics: Numerical Analysis",
                      "High Energy Physics - Experiment",
                      "General Relativity", #  and Quantum Cosmology
                      "Condensed Matter: Statistical Mechanics"
                     ]

Or select specific categories to plot

In [None]:
# selected_cats = ["hep-ph", "astro-ph", "cs.CV", "astro-ph.GA", "astro-ph.CO", "astro-ph.SR",
#                 "quant-ph", "hep-th", "astro-ph.HE", "cond-mat.mes-hall", "cond-mat.str-el",
#                 "hep-ex", "cond-mat.stat-mech", "nucl-th", "gr-qc", "cs.LG"]

# data = []
# for c in selected_cats:
#     for d in load_data:
#         if c == d[0]:
#             data.append(d)
# print(len(data))
# print(selected_cats)

In [None]:
for d in data:
    print(d)

#### Find max and min

Go through each value in the data to find the maximum and minimums for plotting

In [None]:
# get the maximums and minimums of year and no. articles for figuring out axes
minY = math.inf
maxY = -(math.inf)
minA = math.inf
maxA = -(math.inf)

for cat in data:
#     print(len(cat))
    print(cat[0])
    if min(cat[1]) < minY: minY = min(cat[1])
    if max(cat[1]) > maxY: maxY = max(cat[1])
    if min(cat[2]) < minA: minA = min(cat[2])
    if max(cat[2]) > maxA: maxA = max(cat[2])
    print("min year: " + str(min(cat[1])))
    print("max year: " + str(max(cat[1])))
    print("min articles/images: " + str(min(cat[2])))
    print("max articles/images: " + str(max(cat[2])))
    print("*" * 20)
    
print("minY: " + str(minY))
print("maxY: " + str(maxY))
print("minA: " + str(minA))
print("maxA: " + str(maxA))
              
print("done")

#### Save data in org format
Use org-friendly table format. This can be printed to console or written to a file. For posting to Github and rendered in Github markdown.

In [None]:
# filename = "stats_images_cat_year.org"
filename = "stats_average_images_article_cat_year.org"

In [None]:
# write the data in an org-friendly format for posting on github
# with nested lists

with open(filename, "w") as write_file:
    for cat in data:
        print("* " + cat[0], file=write_file)
        joined = list(zip(cat[1], cat[2]))
        #     print(joined)
        print("|-|-|", file=write_file)
        for j in joined:
            print('|' + str(j[0]) + "|" + str(j[1]) + "|", file=write_file)
        print("|-|-|", file=write_file)
write_file.close()

#### Plotting matrix of scatterplots

Plot data in two formats
- with shared x and y axes, for comparison across data
- with individual x and y axes taken from min/max of each plot automatically, for individual trends
- finally, save as high resolution (300 dpi) image

In [None]:
bArticles = True
# bArticles = False
bAverage = True
# bAverage = False
bLog10 = False
# bLog10 = True

In [None]:
# manually set the maximum for the Y-axis to ignore large outliers
maxA = 32

### Plot for "Images of the arXiv" paper

#### Average number of images per article by year in each category.

Number of images published per year in each category.

In [None]:
# plot figures with shared x and y axes 
# using the min/max year/article numbers from the cleaning step

xdim = 4
ydim = 4

bLog10 = False
bArticles = False

fig, ax = plt.subplots(ydim, xdim, sharey='row') # sharex='col', 
fig.subplots_adjust(hspace=0.5, wspace=0.3)
fig.set_size_inches(16, 12)

# if bArticles: fig.suptitle("arXiv total articles per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)
# else: fig.suptitle("arXiv total images per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)
    
data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
            if bArticles:
                ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            else:
                ax[i, j].plot(data[idx][1], data[idx][2], '--k.')
#             ax[i, j].title.set_text(data[idx][0])
            title_string = f'{selected_cats_full[idx]}\ntotal: {(images_catlist[idx][1])}'
#             title_string = f'{selected_cats_full[idx]}'
            ax[i, j].title.set_text(title_string)
#             s = f'total: {(catlist[idx][1])}'
#             ax[i, j].text(0.025, 0.88, s, fontsize=12, transform=ax[i, j].transAxes)
            ax[i, j].axis([1991, 2018, 0, maxA])
            if bLog10: ax[i, j].set_yscale('log')

In [None]:
fig.savefig("plot_images_cat_year_indax_shareY_top16_v4.svg", dpi=300, bbox_inches='tight',
    pad_inches=0, transparent=False)
fig.savefig("plot_images_cat_year_indax_shareY_top16_v4.png", dpi=300, bbox_inches='tight',
    pad_inches=0 )