# Download Pushshift Reddit Data
See this on [Github](https://github.com/yinleon/doppler_tutorials/blob/master/1-download-data.ipynb), [NbViewer](https://nbviewer.jupyter.org/github/yinleon/doppler_tutorials/blob/master/0-download-data.ipynb)<br>
By Leon Yin 2019-04-05<br>
This Notebook collects subreddit metadata from PushShift's REST API, and downloads images from Reddit using requests.

## API Endpoints
A few helpful notes about PushShift's API.:

To get a subset of files here:<br>
`http://api.pushshift.io/reddit/submission/search/?subreddit=politics&size=500&sort=desc`

Latest data is accessible here here:<br>
`http://api.pushshift.io/reddit/submission/search/?subreddit=politics&after=2h&size=500`

Get everything under the sun here:<br>
`files.pushshift.io/reddit/submissions`

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import json
import shutil
import requests
import itertools
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

import pandas as pd
import imagehash
from PIL import Image
from tqdm import tqdm

from config import data_dir

In [3]:
s = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[ 502, 503, 504 ])
s.mount('http://', HTTPAdapter(max_retries=retries))

A trick in data engineering is creating a dictionary with variables such as the destination of the output file. We will use functions like `get_context` throughout this article.

In [4]:
def get_context(subreddit):
    '''
    Where will data be saved?
    '''
    sub_dir = os.path.join(data_dir, subreddit)
    media_dir =  os.path.join(data_dir, 'media')
    file_subreddit = os.path.join(sub_dir, 'posts.json.gz')
    file_subreddit_media = os.path.join(sub_dir, 'media.json.gz')
    
    for _dir in [data_dir, sub_dir, media_dir]:
        os.makedirs(_dir, exist_ok=True)
        
    context = {
        'data_dir' : data_dir,
        'sub_dir' : sub_dir,
        'media_dir' : media_dir,
        'file_subreddit' : file_subreddit,
        'file_subreddit_media' : file_subreddit_media
    }
    
    return context

## Query PushShift
These next two functions can be re-used for other purposes where you may need data from PushShift.

In [None]:
def build_api_endpoint(subreddit, size, ascending, last_record_date, verbose):
    '''
    An easy API endpoint builder for PushShift's Reddit API.
    '''
    url_base = ('http://api.pushshift.io/reddit/submission/search/'
               f'?subreddit={ subreddit }&size={ size }')
    if not last_record_date:
        url = url_base
    else:
        if ascending:
            url = url_base + f'&after={ last_record_date + 1 }&sort=asc'
        else:
            url = url_base + f'&before={ last_record_date - 1 }&sort=desc'
    if verbose:
        print(url)
    return url


def download_subreddit_posts(subreddit, size=5000, ascending=False, 
                             start_date=False, seen_ids=set(), 
                             display_every_x_iterations=20,
                             verbose=True):
    '''
    Downloads subreddit data.
    To go back in time from `start_date`, set ascending to False.
    To go forward from time `state_date`, set ascneding to True.
    Skipps seen_ids
    '''
    if not isinstance(seen_ids, set):
        raise "seen_ids needs to be a set!"
    i = 0
    records = []
    last_record_date = start_date
    try:
        while True:
            # Buld the url
            url = build_api_endpoint(subreddit, size, ascending, last_record_date, 
                                     verbose if i % display_every_x_iterations == 0 else False)

            # make the HTTP request to the API
            r = s.get(url)
            resp = r.json()
            data = resp.get('data')

            # check which records were returned by the API and which are new?
            # if there are no new IDs, then we're done!
            paginated_ids = {row.get("id") for row in data}
            new_ids = paginated_ids - seen_ids
            if len(new_ids) == 0:
                break

            # add new records to existing records. 
            new_records = [row for row in data if row['id'] in new_ids]
            records.extend(new_records)

            # collect all records created before the last record's date.
            last_record = data[-1]
            last_record_date = last_record.get('created_utc')
            i += 1
    
    except KeyboardInterrupt:
        if verbose:
            print("Cancelled early")
        
    return records

Here we call these functions to collect data.

In [None]:
verbose = True
subreddit = config.subreddit # change this in config.py
context = get_context(subreddit)

if os.path.exists(context['file_subreddit']):
    print('File Exists')
    df = pd.read_json(context['file_subreddit'], 
                      lines=True, orient='records',
                      compression='gzip')
    
    print(f'{ len(df) } Records exist')
    min_date = df.created_utc.min()
    max_date = df.created_utc.max()
    seen_ids = set(df.id.unique())
    
    newer_records = download_subreddit_posts(subreddit, verbose=verbose, 
                                             seen_ids=seen_ids,
                                             ascending=True, 
                                             start_date=max_date)
    older_records = download_subreddit_posts(subreddit, verbose=verbose, 
                                             seen_ids=seen_ids,
                                             ascending=False, 
                                             start_date=min_date)
    newer_records.extend(older_records)
    _df = pd.DataFrame(newer_records)
    if verbose:
        print(f"collected { len(_df) } records")
    df = df.append(_df, sort=False)
    df.drop_duplicates(subset=['id'], inplace=True)
    df.sort_values(by=['created_utc'], ascending=False, inplace=True)
    df.to_json(context['file_subreddit'], lines=True, orient='records', compression='gzip')

else:
    print("New Subreddit")
    records = download_subreddit_posts(subreddit, verbose=verbose, size=5000)
    if verbose:
        print(f"collected { len(records) } records")
    df = pd.DataFrame(records)
    df.to_json(context['file_subreddit'], lines=True, orient='records', compression='gzip')

if verbose:
    # Summary stats
    print('\n****************')
    df['created_at'] = pd.to_datetime(df['created_utc'], unit='s')
    print(f"N = { len(df) }\n"
          f"Start Date = { df['created_at'].min() }\n"
          f"End Date = { df['created_at'].max() }")

File Exists
294611 Records exist
http://api.pushshift.io/reddit/submission/search/?subreddit=pewdiepiesubmissions&size=5000&after=1555134672&sort=asc
http://api.pushshift.io/reddit/submission/search/?subreddit=pewdiepiesubmissions&size=5000&before=1551823502&sort=desc
http://api.pushshift.io/reddit/submission/search/?subreddit=pewdiepiesubmissions&size=5000&before=1551618603&sort=desc
Cancelled early
collected 40263 records

****************
N = 334874
Start Date = 2019-02-28 22:47:45
End Date = 2019-04-13 15:05:37


## Collect Images

In [None]:
def download_media(url, f):
    '''
    Downloads an image from the net and calcualtes the dhash
    '''
    if os.path.exists(f):
        # is th image exists, don't download it again.
        # calculate the size and hash
        img_size = os.path.getsize(f)
        if img_size != 0:
            # read the image and calculate the hash
            img = Image.open(f)
            dhash = str(imagehash.dhash(img, hash_size=8))
            
            return dhash, img_size

    # Download the image
    r = s.get(url, stream=True)
    if not r.status_code == 200:
        return 'NOHASH', 0
    
    # download the image locally
    with open(f, 'wb') as file:
        r.raw.decode_content = True
        shutil.copyfileobj(r.raw, file)
    
    # calculate the hash
    img = Image.open(f)
    img_size = os.path.getsize(f)
    dhash = str(imagehash.dhash(img, hash_size=8))
    
    return dhash, img_size

In [None]:
def get_media_context(image, context):
    '''
    Establishes where media files will be saved.
    '''
    image_id = image['id']
    pos_images = image.get('resolutions')
    if pos_images:
        largest_image = pos_images[-1]
    else:
        # no images...
        return None, None
    
    # where is the image to be downlaoded?
    img_url = largest_image.get('url')
    img_url = img_url.replace('&amp;', '&')
    
    # what is the file  extension?
    _, ext = os.path.splitext(img_url.split('?')[0])
    ext = ext.replace('jpeg', 'jpg')
    
    # where will the images be downloaded locally?
    dir_img = os.path.join(context['media_dir'], 
                           image_id[:2].lower(), 
                           image_id[2:4].lower())
    
    f_img = os.path.join(dir_img, image_id + ext)
    os.makedirs(dir_img, exist_ok=True)
    
    return img_url, f_img

In [None]:
img_meta = []
for _, row in tqdm(df.iterrows()):
    preview = row.get('preview')
    if isinstance(preview, dict):
        images = preview.get('images')
        if not images:
            continue
        for img in images:
            r = row.copy()
            img_url, f_img = get_media_context(img, context)
            if not img_url:
                continue
            d_hash, img_size = download_media(img_url, f_img)
            if img_size != 0:
                r['deleted'] = False
                r['d_hash'] = d_hash
                r['f_img'] = f_img 
                r['img_size'] = img_size
            else:
                r['deleted'] = True
                r['d_hash'] = d_hash
                r['f_img'] = f_img 
                r['img_size'] = img_size
            img_meta.append(r.to_dict())
            
                
df_img_meta = pd.DataFrame(img_meta)
df_img_meta.to_json(context['file_subreddit_media'], 
                    lines=True, orient='records',
                    compression='gzip')       

  ' expressed in bytes should be converted ' +
71337it [10:58, 122.27it/s]

In [None]:
len(df_img_meta)

In [None]:
Image.open(f_img)