# Plotting data from SQLite database

This notebook queries the databases to generate a matrix of plots for the number of articles/images per year, run separately for each category.

## Setup

Import required libraries, connect to SQLite database, create cursor, fetch table info

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import numpy as np
import sqlite3
import pickle
import copy
import json
import math
import pandas as pd
import os

In [None]:
# import the sqlite3 database and create a cursor
db_path = os.path.expanduser("~/data/db/arxiv_db_images.sqlite3")
db = sqlite3.connect(db_path)
c = db.cursor()

In [None]:
c.execute('PRAGMA TABLE_INFO({})'.format("metadata"))
info = c.fetchall()

print("\nColumn Info:\nID, Name, Type, NotNull, DefaultVal, PrimaryKey")
for col in info:
    print(col)

In [None]:
c.execute('PRAGMA TABLE_INFO({})'.format("images"))
info = c.fetchall()

print("\nColumn Info:\nID, Name, Type, NotNull, DefaultVal, PrimaryKey")
for col in info:
    print(col)

## Build category list

First get a full list of all the primary categories by querying the SQLite database. Used to do later queries. Select a sort_mode first.

In [None]:
sort_mode = "articles"
# sort_mode = "images"
# sort_mode = "alpha"

In [None]:
# list primary categories by articles

if sort_mode == "articles":
    c.execute('''
        SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(metadata.identifier)
        FROM metadata
        GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
        ORDER BY count(metadata.identifier) DESC    
        ''')
    rows = c.fetchall()
elif sort_mode == "images":
    c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(images.identifier)
    FROM images
    LEFT JOIN metadata ON images.identifier = metadata.identifier 
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY count(images.identifier) DESC    
    ''')    
elif sort_mode == "alpha":
    c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(metadata.identifier)
    FROM metadata
    WHERE strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1) ASC
    ''')

print(f'Pulled a list of categories sorted by {sort_mode}, length: {len(rows)}')
for row in rows:
    print(row)

In [None]:
catlist = []

if sort_mode == "articles":
    catlist = rows
    articles_catlist = rows
elif sort_mode == "images":
    images_catlist = rows
    catlist = images_catlist
elif sort_mode == "alpha":
    catlist = rows

In [None]:
# optionally only use category if it contains a certain number of articles or images
articles_list = []
for cat, n in rows:
    if n > 10000:
        articles_list.append((cat, n))
print(len(articles_list))

In [None]:
# or grab only the first X categories
catlist[:16]

In [None]:
print(f'The list of categories is {len(catlist)}. Here are the first entries:')
print(catlist[:8])

## Finding the change in rank between number of articles and number of images

In [None]:
# structure
# [category, article-rank, images-rank, rank-difference]

ordering = [[cat[0], count, 0, 0] for count, cat in enumerate(articles_catlist)]
print(ordering)

In [None]:
# find the difference in category list when ordering by article numbers vs image numbers
acount = 0
for ac in articles_catlist:
#     print("articles:",acount)
    icount = 0
    for ic in images_catlist:
        if ac[0] == ic[0]:
#             print("match:",ac[0],ic[0])
            ordering[acount][2] = icount
        icount += 1 
    acount += 1

In [None]:
# go through and modify the rank-difference
for count, row in enumerate(ordering):
    ordering[count][3] = row[1] - row[2]

In [None]:
#print out (could add formatting here)
for cat in ordering:
    print(cat)

## Generating Plots
### Query DB for data
Then use that list of primary categories to query the db for how many articles per year. Store it in the `data` variable. The first command is the number of articles only so only requires the `metadata` table. The second block below searches by number of images so also requires the `images` table in the SQLite database.

In [None]:
# total number of articles for each year by category

sql = ('''
    SELECT count(metadata.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM metadata
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    
#     print("total number of images found: " + str(len(rows)))
    print(rows)
#     print("total number of articles: " + rows[0][0])

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

In [None]:
# total number of articles for each year by category

sql = ('''
    SELECT count(metadata.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM metadata
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    
#     print("total number of images found: " + str(len(rows)))
    print(rows)
#     print("total number of articles: " + rows[0][0])

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

In [None]:
print(data)

In [None]:
article_data = data

In [None]:
# WRITE PKL

filename = "articles_cat_year_data.pkl"
with open(filename, "wb") as write_file:
    pickle.dump(article_data, write_file)

In [None]:
# READ PKL

filename = "articles_cat_year_data.pkl"
with open(filename, "rb") as read_file:
    article_data = pickle.load(read_file)

In [None]:
print(article_data[:10])

In [None]:
# total number of images for each year by category

sql = ('''
    SELECT count(images.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM images
    LEFT JOIN metadata on images.identifier = metadata.identifier
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    
#     print("total number of images found: " + str(len(rows)))
    print(rows)
#     print("total number of articles: " + rows[0][0])

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

In [None]:
# get the image total data and use it to modify the article data

sql = ('''
    SELECT count(images.identifier), strftime("%Y", metadata.created) as 'Y'
    FROM images
    LEFT JOIN metadata on images.identifier = metadata.identifier
    WHERE substr(trim(cat),1,instr(trim(cat)||' ',' ')-1) = ?
    GROUP BY strftime("%Y", metadata.created)
    ORDER BY strftime("%Y", metadata.created) ASC
    ''')

image_data = []

for count, cat in enumerate(catlist):
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    
#     print("total number of images found: " + str(len(rows)))
    print(rows)
#     print("total number of articles: " + rows[0][0])

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    image_data.append(newdata)

    print(data[count])
#     for i, d in enumerate(data[count][2]):
#         data[count][2][i] = rows[i][0] / data[count][2][i]
#     print("new data",data[count])


print("*" * 20)
print("done")

In [None]:
print(len(image_data))

In [None]:
image_pkl_filename = "images_cat_year_data.pkl"
# image_pkl_filename = "images_cat_year.pkl"

In [None]:
# WRITE
# with open(image_pkl_filename, "wb") as write_file:
#     pickle.dump(image_data, write_file)

In [None]:
# READ PKL
with open(image_pkl_filename, "rb") as read_file:
    image_data = pickle.load(read_file)

In [None]:
# READ PKL

filename = "articles_cat_year_data.pkl"
with open(filename, "rb") as read_file:
    article_data = pickle.load(read_file)

In [None]:
print(image_data[:10])

In [None]:
print(article_data[:10])

In [None]:
# this code copies the article data then modifies it by 
# dividing the total number of images by the number of articles

average_data = copy.deepcopy(article_data)

for count, article in enumerate(average_data):
    
    # set all values to zero
    for index, val in enumerate(average_data[count][2]):
        average_data[count][2][index] = 0

#     average_data[count][2]
    for i, year in enumerate(image_data[count][1]):
        try:
            print("*" * 20)
            print(f'count {count} | year {year} | i {i}')
            print(article_data[count][1])
            listindex = article_data[count][1].index(str(year)) 
            print("listindex:",listindex)
            print("count:",count)
            print("i:",i)
            print("no images:",image_data[count][2][i])
            print("no articles:",article_data[count][2][listindex])
            print("average:",image_data[count][2][i] / article_data[count][2][listindex])
            average_data[count][2][listindex] = image_data[count][2][i] / article_data[count][2][listindex]
        except ValueError:
            print("!" * 20)
            print("didn't find index")

In [None]:
print(average_data[:10])

In [None]:
data = average_data

In [None]:
for row in data:
    print(row)

In [None]:
# total number of images for each year by extension

db.create_function("reverse", 1, lambda s: s[::-1])
sql = ('''
    SELECT COUNT(reverse(substr(reverse(filename),1,instr(reverse(filename),'.')-1))), reverse(substr(reverse(filename),1,instr(reverse(filename),'.')-1)) AS extension
    FROM images
    WHERE x is not null and x != ''
    AND y is not null and y != ''
    AND imageformat is not null and imageformat != ''
    GROUP BY extension
    ''')

data = []

for cat in catlist:
    print("querying for category: " + str(cat[0]))
    c.execute(sql, (cat[0], ))
    rows = c.fetchall()
    
#     print("total number of images found: " + str(len(rows)))
    print(rows)
#     print("total number of articles: " + rows[0][0])

    years = []
    totals = []
    
    for row in rows:
        years.append(row[1])
        totals.append(row[0])
        
    newdata = [cat[0], years, totals]
    data.append(newdata)

print("*" * 20)
print("done")

#### Clean data
- remove any entries of "2019" in the years and articles columns of data (don't have full data for this year).
- rewrite all entries as integers rather than strings (otherwise there will be problems when adjusting the axes)
- find the minimum and maximum for any entries, so that we can set our axes later as needed.

Data is saved as nested lists in the format
```
[
    [cat1, [year1, year2...yearX], [totat1, total2...totalY]
    [cat2, [year1, year2...yearX], [totat1, total2...totalY]
    ...
    [catZ, [year1, year2...yearX], [totat1, total2...totalY]
]
```

In [None]:
# if there is "2019" in the list of years, get the index and remove from both the year and no. article lists

for cat in data:
    while "2019" in cat[1]:
        index = cat[1].index("2019")
        print(cat[0])
        print(index)
        del cat[2][index]
        del cat[1][index]
        print("*" * 20)

In [None]:
# test to make sure there is still a total for each year

for cat in data:
    if len(cat[1]) != len(cat[2]):
        print("problem with category: " + cat)

In [None]:
# fix year entries appearing as string by going through each and re-writing as integer

for cat in data:
    for k in enumerate(cat[1]):
        print(cat[1][k[0]])
        cat[1][k[0]] = int(cat[1][k[0]])

In [None]:
print(data)

#### Save data 
Interim progress, to prevent having to run SQL queries again : )
Save as either json file or pickle for reloading

#### json
saves as human-readable JSON format

In [None]:
import json

In [None]:
filename = "articles_cat_year.json" 

In [None]:
filename = "images_cat_year.json"

In [None]:
filename = "average_images_article_cat_year.json"

In [None]:
# READ

load_json = []

with open(filename, "r") as read_file:
    load_json = json.load(read_file)

In [None]:
# WRITE

# with open(filename, "w") as write_file:
#     json.dump(data, write_file)

In [None]:
print(load_json)

In [None]:
data = load_json

In [None]:
data

#### pickle
Save data as a serialized file using pickle

In [None]:
filename = "articles_cat_year.pickle"
bArticles = True

In [None]:
filename = "images_cat_year.pkl"
bArticles = False

In [None]:
filename = "average_images_article_cat_year.pkl"

In [None]:
#READ

load_data = []

with open(filename, "rb") as read_file:
    load_data = pickle.load(read_file)
    read_file.close()

In [None]:
# WRITE
# with open(filename, "wb") as write_file:
#     pickle.dump(data, write_file)
#     write_file.close()

Load in the imported data in the `data` variable

In [None]:
data = load_data

In [None]:
for row in data[:10]:
    print(row)

Set bArticles 

In [None]:
print(filename)

In [None]:
bArticles = True

In [None]:
bArticles = False

Testing loaded data

In [None]:
print(load_data == data)

In [None]:
print(load_data)

In [None]:
print(data)

In [None]:
selected_cats = ["hep-ph", "astro-ph", "cs.CV", "astro-ph.GA", "astro-ph.CO", "astro-ph.SR",
                "quant-ph", "hep-th", "astro-ph.HE", "cond-mat.mes-hall", "cond-mat.str-el",
                "hep-ex", "cond-mat.stat-mech", "nucl-th", "gr-qc", "cs.LG"]
# data = [x for x in load_data if x[0] in selected_cats]

In [None]:
selected_cats = []
for c in catlist[:16]:
    selected_cats.append(c[0])
len(selected_cats)

In [None]:
data = []
for c in selected_cats:
    for d in load_data:
        if c == d[0]:
            data.append(d)
print(len(data))

In [None]:
for d in data:
    print(d)

In [None]:
selected_cats_full = ["Computer Science: Computer Vision",
                      "High Energy Physics - Phenomenology",
                      "Astrophysics",
                      "Astrophysics of Galaxies",
                      "Computer Science: Machine Learning",
                      "Solar and Stellar Astrophysics",
                      "Cosmology and Nongalactic Astrophysics",
                      "Quantum Physics",
                      "High Energy Physics - Theory",
                      "High Energy Astrophysical Phenomena",
                      "Mesoscale and Nanoscale Physics",
                      "Strongly Correlated Electrons",
                      "Mathematics: Numerical Analysis",
                      "High Energy Physics - Experiment",
                      "General Relativity", #  and Quantum Cosmology
                      "Condensed Matter: Statistical Mechanics"
                     ]

In [None]:
selected_cats_full

In [None]:
catlist[:16]

In [None]:
print(selected_cats)

#### Recalculate with log10

In [None]:
# DON'T DO THIS! just use set_yscale('log') in plot instead

# for each item in totals list, replace with log10(d)

# list comprehension on totals list
# for cat in data:
#     cat[2][:] = [math.log10(x) for x in cat[2]]
#     print(cat[2])

#### Find max and min

Go through each value in the data to find the maximum and minimums for plotting

In [None]:
# get the maximums and minimums of year and no. articles for figuring out axes
minY = math.inf
maxY = -(math.inf)
minA = math.inf
maxA = -(math.inf)

for cat in data:
#     print(len(cat))
    print(cat[0])
    if min(cat[1]) < minY: minY = min(cat[1])
    if max(cat[1]) > maxY: maxY = max(cat[1])
    if min(cat[2]) < minA: minA = min(cat[2])
    if max(cat[2]) > maxA: maxA = max(cat[2])
    print("min year: " + str(min(cat[1])))
    print("max year: " + str(max(cat[1])))
    print("min articles/images: " + str(min(cat[2])))
    print("max articles/images: " + str(max(cat[2])))
    print("*" * 20)
    
print("minY: " + str(minY))
print("maxY: " + str(maxY))
print("minA: " + str(minA))
print("maxA: " + str(maxA))
              
print("done")

In [None]:
print(data)

#### Save data in org format
Use org-friendly table format. This can be printed to console or written to a file. For posting to Github and rendered in Github markdown.

In [None]:
def write_to_org(data, _write_file):
    with open(_write_file, "w+") as write_file:
        for row in data:
            print("|", file=write_file, end = '')
            for item in row:
                print(str(item).replace("\n", " "), file=write_file, end = '')
                print("|", file=write_file, end = '')
            print("\n", file=write_file, end = '')
        write_file.close()

In [None]:
# write the data in an org-friendly format for posting on github
for cat in data:
    print("* " + cat[0])
    joined = list(zip(cat[1], cat[2]))
    #     print(joined)
    print("|-|-|")
    for j in joined:
        print('|' + str(j[0]) + "|" + str(j[1]) + "|")
    print("|-|-|")

In [None]:
filename = "stats_images_cat_year.org"

In [None]:
filename = "stats_average_images_article_cat_year.org"

In [None]:
# write the data to a file
with open(filename, "w") as write_file:
    for cat in data:
        print("* " + cat[0], file=write_file)
        joined = list(zip(cat[1], cat[2]))
        #     print(joined)
        print("|-|-|", file=write_file)
        for j in joined:
            print('|' + str(j[0]) + "|" + str(j[1]) + "|", file=write_file)
        print("|-|-|", file=write_file)
write_file.close()

#### Plotting matrix of scatterplots

Plot data in two formats
- with shared x and y axes, for comparison across data
- with individual x and y axes taken from min/max of each plot automatically, for individual trends
- finally, save as high resolution (300 dpi) image

In [None]:
bArticles = True

In [None]:
bArticles = False

In [None]:
bAverage = True

In [None]:
# manually set the maximum for the Y-axis to ignore large outliers
maxA = 32

#### Plot number of images per article per year, subplots for each category

In [None]:
# plot figures with shared x and y axes using the min/max year/article numbers from the cleaning step

xdim = 15
ydim = 12

fig, ax = plt.subplots(ydim, xdim, sharex='col', sharey='row')
fig.set_size_inches(40, 30)

data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
#             ax[i, j].plot(data[idx][1], data[idx][2], '--.')
#             ax[i, j].plot(data[idx][1], data[idx][2], '--r.')
            ax[i, j].plot(data[idx][1], data[idx][2], '--k.')
            ax[i, j].title.set_text(data[idx][0])
            # add one to the maximum year for alignment
            ax[i, j].axis([minY, maxY+1, minA, maxA])

In [None]:
# fig.savefig("plot_articles_cat_year_04.png", dpi=300)
fig.savefig("plot_images_cat_year_03.png", dpi=300)

In [None]:
# fig.savefig("plot_average_images_articles_cat_year_01.svg", dpi=300)
fig.savefig("plot_average_images_articles_cat_year_max32_01.svg", dpi=300)

In [None]:
# plot figures with individual x and y axes for the year and article/image totals
xdim = 15
ydim = 12

fig, ax = plt.subplots(ydim, xdim)
fig.subplots_adjust(hspace=0.4, wspace=0.4)
fig.set_size_inches(40, 30)

data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
#             ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            ax[i, j].plot(data[idx][1], data[idx][2], '--r.')
            ax[i, j].title.set_text(data[idx][0])
#             ax[i, j].axis([minY, maxY+1, minA, maxA])

In [None]:
# fig.savefig("plot_articles_cat_year_indax_01.png", dpi=300)
fig.savefig("plot_images_cat_year_indax_03.png", dpi=300)

### Additional plots

- Plot data with shared X axis from 1991-2018 but individual Y axes
- Log10 of Y axis
- Plot by individual categories

##### fixed time range, relative totals

In [None]:
# plot figures with shared x and y axes using the min/max year/article numbers from the cleaning step
# articles

bLog10 = False
xdim = 15
ydim = 12

fig, ax = plt.subplots(ydim, xdim, sharex='col')
fig.subplots_adjust(hspace=0.4, wspace=0.4)
fig.set_size_inches(40, 30)

if bArticles: fig.suptitle("arXiv relative number of articles per year between 1991 and 2018", x=0.5, y=0.92, size=28)
else: fig.suptitle("arXiv relative number of images per year between 1991 and 2018", x=0.5, y=0.92, size=28)
    
data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
            if bArticles: ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            else: ax[i, j].plot(data[idx][1], data[idx][2], '--r.')
            ax[i, j].title.set_text(data[idx][0])
            ax[i, j].axis([1991, 2018, 0, max(data[idx][2])])
            if bLog10: ax[i, j].set_yscale('log')

In [None]:
if bArticles: fig.savefig("plot_articles_cat_year_fixedtime.svg", dpi=300)
else: fig.savefig("plot_images_cat_year_fixedtime.svg", dpi=300)

##### absolute totals

In [None]:
bLog10 = True

### Plot for methods paper
Number of images published per year in each category.

In [None]:
# plot figures with shared x and y axes using the min/max year/article numbers from the cleaning step
# articles

xdim = 4
ydim = 4

bLog10 = False
bArticles = False

fig, ax = plt.subplots(ydim, xdim, sharey='row') # sharex='col', 
fig.subplots_adjust(hspace=0.5, wspace=0.3)
fig.set_size_inches(16, 12)

# if bArticles: fig.suptitle("arXiv total articles per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)
# else: fig.suptitle("arXiv total images per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)
    
data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
            if bArticles:
                ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            else:
                ax[i, j].plot(data[idx][1], data[idx][2], '--k.')
#             ax[i, j].title.set_text(data[idx][0])
#             title_string = f'{selected_cats_full[idx]}\ntotal: {(catlist[idx][1])}'
            title_string = f'{selected_cats_full[idx]}'
            ax[i, j].title.set_text(title_string)
#             s = f'total: {(catlist[idx][1])}'
#             ax[i, j].text(0.025, 0.88, s, fontsize=12, transform=ax[i, j].transAxes)
            ax[i, j].axis([1991, 2018, 0, maxA])
            if bLog10: ax[i, j].set_yscale('log')

In [None]:
data

In [None]:
cat_list = [x[0] for x in data]
# for dd in data:
#     print(dd[0])
#     for d in dd:
#         print(d)
print(cat_list)

In [None]:
fig.savefig("plot_images_cat_year_indax_shareY_top16_v5.svg", dpi=300, bbox_inches='tight',
    pad_inches=0, transparent=False)
fig.savefig("plot_images_cat_year_indax_shareY_top16_v5.png", dpi=300, bbox_inches='tight',
    pad_inches=0 )

In [None]:
if bArticles: fig.savefig("plot_articles_cat_year_fixedtime_shareY.svg", dpi=300)
else: fig.savefig("plot_images_cat_year_indax_shareY.svg", dpi=300)

In [None]:
if bArticles: fig.savefig("plot_articles_cat_year_fixedtime_log10.svg", dpi=300)
else: fig.savefig("plot_images_cat_year_fixedtime_log10.svg", dpi=300)

##### log10

In [None]:
bArticles = True

In [None]:
bArticles = False

In [None]:
# plot figures with shared x and y axes using the min/max year/article numbers from the cleaning step

xdim = 15
ydim = 12

fig, ax = plt.subplots(ydim, xdim, sharex='col', sharey='row')
fig.subplots_adjust(hspace=0.4, wspace=0.4)
fig.set_size_inches(40, 30)

if bArticles: fig.suptitle("arXiv log10 of articles per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)
else: fig.suptitle("arXiv log10 of images per year between 1991 and 2018\nShared Axes", x=0.5, y=0.92, size=28)

data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
            if bArticles:
                ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            else:
                ax[i, j].plot(data[idx][1], data[idx][2], '--r.')
            ax[i, j].title.set_text(data[idx][0])
            ax[i, j].axis([1991, 2018, 0, maxA])

In [None]:
if bArticles: fig.savefig("plot_articles_cat_year_fixedtime_log10.svg", dpi=300)
else: fig.savefig("plot_images_cat_year_indax_log10.svg", dpi=300)

##### categories

- physics (including astro-ph, cond-mat)
- cs
- math
- q-bio
- q-fin
- stat

In [None]:
# testing for primary category search
article_count = 0
for cat in data:
    if "stat." in cat[0]:
        print(cat[0])
        article_count += 1
print(article_count)

In [None]:
print(data)

#### Grab data from only some categories

In [None]:
# get only computer science
select_data = []
for cat in data:
    if "cs." in cat[0] and "physics" not in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

In [None]:
# get only maths
select_data = []
for cat in data:
    if "math." in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

In [None]:
# get all physics related categories
select_data = []
for cat in data:
    if "ph" in cat[0] or "physics." in cat[0] or "cond-mat" in cat[0] or "nlin" in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

In [None]:
# get only quantitative biology
select_data = []
for cat in data:
    if "q-bio." in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

In [None]:
# get only quantitative finance
select_data = []
for cat in data:
    if "q-fin." in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

In [None]:
# get only statistics
select_data = []
for cat in data:
    if "stat." in cat[0]:
        select_data.append(cat)
print(select_data)
print(len(select_data))

data = select_data

#### set log10, category and find factors

In [None]:
bLog10 = False

In [None]:
bLog10 = True

In [None]:
# get the two factors closest to the square root

input = len(data)

test = int(math.sqrt(input))
# print(test)
while input % test != 0:
    test -= 1

xdim = max(test, int(input/test))
ydim = min(test, int(input/test))

print(xdim)
print(ydim)

In [None]:
# category = "computer science"
# category = "math"
# category = "physics"
# category = "q-bio"
# category = "q-fin"
category = "stats"

In [None]:
print(len(data))

#### Plot data

In [None]:
# plot figures with shared x and y axes using the min/max year/article numbers from the cleaning step

xdim = 3
ydim = 2

fig, ax = plt.subplots(ydim, xdim, sharex='col', sharey='row')
fig.subplots_adjust(hspace=0.4, wspace=0.4)
fig.set_size_inches(40, 30)

if bArticles: fig.suptitle("arXiv " + category + " articles per year between 1991 and 2018", x=0.5, y=0.92, size=28)
else: fig.suptitle("arXiv " + category + " images per year between 1991 and 2018", x=0.5, y=0.92, size=28)

data_size = len(data)

for i in range(ydim):
    for j in range(xdim):
        idx = (i * xdim) + j
        if idx < data_size:
            if bArticles:
                ax[i, j].plot(data[idx][1], data[idx][2], '--.')
            else:
                ax[i, j].plot(data[idx][1], data[idx][2], '--r.')
            ax[i, j].title.set_text(data[idx][0])
            ax[i, j].axis([1991, 2018, 0, maxA])
            if bLog10: ax[i, j].set_yscale('log')

In [None]:
if bArticles: fig.savefig("plot_cs_articles_year_fixedtime.svg", dpi=300)
else: fig.savefig("plot_cs_images_year_fixedtime.svg", dpi=300)

In [None]:
if bArticles: fig.savefig("plot_cs_articles_year_fixedtime_log10.svg", dpi=300)
else: fig.savefig("plot_cs_images_year_log10.svg", dpi=300)

In [None]:
if bArticles: fig.savefig("plot_" + category + "_articles_year_fixedtime.svg", dpi=300)
else: fig.savefig("plot_" + category + "_images_year.svg", dpi=300)

In [None]:
if bArticles: fig.savefig("plot_" + category + "_articles_year_fixedtime_log10.svg", dpi=300)
else: fig.savefig("plot_" + category + "_images_year_log10.svg", dpi=300)

### Generate stackplot of image formats by year

In [None]:
# list primary categories by associated images

c.execute('''
    SELECT images.filename, strftime("%Y", metadata.created) 
    FROM images
    LEFT JOIN metadata ON images.identifier = metadata.identifier
    WHERE strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    ''')
rows = c.fetchall()
# for row in rows:
#     print(row)
print(len(rows))
print("sample:\n",rows[:3])

In [None]:
# make lists of the years and extensions
years = []
exts = []

# get years and extensions
for filename, year in rows[:]:
    if year not in years:
        years.append(year)
    fileext = filename.rsplit(".", 1)[1].lower()
    if fileext not in exts:
        exts.append(fileext)
years.sort()
print(years)
# print(exts)
exts.sort()
print(exts)

In [None]:
del exts[9] # delete svg
del exts[8] # delete pstex
del exts[3] # delete jpeg
del exts[1] # delete epsf
del years[0] # delete 1998

In [None]:
print(len(years))
print(years)
print(len(exts))
print(exts)

In [None]:
# make empty array

ext_data = np.zeros((len(exts), len(years)))
print(ext_data)

In [None]:
for filename, year in rows[:]:
    if year is not '1988':
#         print(filename, year)
        fileext = filename.rsplit(".", 1)[1].lower()
        if fileext == "jpeg":
            fileext = "jpg"
        if fileext == "epsf":
            fileext = "eps"
        if fileext == "pstex":
            fileext = "ps"
#         print(fileext)
        iyear = int(year) - 1990
#         print(iyear)
        if fileext in exts:
            iext = exts.index(fileext)
#             print(iext)
            ext_data[iext][iyear] += 1

In [None]:
print(ext_data)

In [None]:
# get sum for each year
sums = []
for i, y in enumerate(years):
    sum = 0
    for j, e in enumerate(exts):
#         print(i, j)
        sum += ext_data[j][i]
    sums.append(sum)

In [None]:
print(sums)

In [None]:
# get percentages
ext_data = np.array(ext_data)
sums = np.array(sums)
# ext_data_per = (ext_data / sums)
ext_data_per = np.divide(ext_data, sums)
ext_data_per = ext_data_per * 100

In [None]:
ext_data

In [None]:
ext_data_per

In [None]:
# ind = np.arange(len(years))
fig, ax = plt.subplots(1, 1, sharex='col', sharey='row')
fig.set_size_inches(10, 8)
width = 1

pal = sns.color_palette("deep", 7)
# pal = sns.diverging_palette(10, 220, sep=80, n=7)
plt.stackplot(years, ext_data_per, labels=exts, colors=pal, alpha=1)
plt.margins(0, 0)
plt.ylabel("percentage image file extensions per year")
# plt.xticks(np.arange(0, 1, step=0.2) + 20)
plt.xticks(years, years, rotation=300)
# plt.title("File extension percentages by year")

# ax.legend(loc='upper left', 
#           bbox_to_anchor=(1.02, 0.98),
#           fontsize=9.0,
#           frameon=True,
#           handlelength=2)

# labelspacing=-2.5
#           prop={'size':15})

ax.legend(reversed(ax.legend().legendHandles), reversed(exts), loc='upper left')

for label in ax.xaxis.get_ticklabels()[1::2]:
    label.set_visible(False)

In [None]:
fig.savefig("extensions_stackplot_smaller_v5_legend_nosvg.png", bbox_inches='tight',
    pad_inches=0, transparent=False, dpi=300)

### Generating figures for dataset methods paper

In [None]:
# list primary categories by alphabetical order

c.execute('''
    SELECT substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1), count(metadata.identifier)
    FROM metadata
    WHERE strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    GROUP BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1)
    ORDER BY substr(trim(metadata.cat),1,instr(trim(metadata.cat)||' ',' ')-1) ASC
    ''')
rows = c.fetchall()
for row in rows:
    print(row)
print(len(rows))

In [None]:
# testing that the date check works
c.execute('''
    SELECT count(metadata.identifier)
    FROM metadata
    WHERE strftime("%Y", metadata.created) != '2019'
    AND strftime("%Y", metadata.created) != '2020'
    ''')
rows = c.fetchall()
for row in rows:
    print(row)

In [None]:
# testing that the date check works
c.execute('''
    SELECT count(metadata.identifier)
    FROM metadata
    ''')
rows = c.fetchall()
for row in rows:
    print(row)

In [None]:
# store list of categories with condition
catlist = []
for cat, n in rows:
    if n > 5000:
        catlist.append((cat, n))
print(len(catlist))

In [None]:
def take_second(elem):
    return elem[1]
rows.sort(key=take_second, reverse=True)

In [None]:
catlist = rows[:16]
catlist.sort()

In [None]:
categories = [x[0] for x in catlist]
values = [x[1] for x in catlist]
y_pos = np.arange(len(categories))

fig, ax = plt.subplots()
fig.set_size_inches(10, 12)

ax.barh(y_pos, values, align='center')
ax.set_yticks(y_pos)
ax.set_yticklabels(categories)
ax.invert_yaxis()
ax.set_xlabel('No. articles')

plt.tight_layout()
plt.show()