In [None]:
import geopandas as gpd 
import os 
import shutil 
import requests
import rasterio
import numpy as np 
import datetime as dt
import time
import random
import json
import multiprocessing as mp
import tensorflow as tf

from collections import Counter
from json.decoder import JSONDecodeError
from tqdm.notebook import tqdm 
from requests.auth import HTTPBasicAuth
from PIL import Image
from calendar import monthrange

In [None]:
from google.colab import drive
drive.mount('/gdrive', force_remount=True)

# Helper Functions

In [None]:
MAX_ATTEMPTS = 20
URL = '' # PLANET URL, e.g., 'https://api.planet.com/compute/ops/orders/v2' 

def gen_box_coords(lat, lon, height=0.00450, width=0.00592):
    """
    Args:
        lat (float): latitude in decimal degrees
        lon (float): longitude in decimal degrees
        height (float): height of image in decimal degrees [default = 0.00450, appx 500m at IL latitude]
        width (float): width of image in decimal degrees [default = 0.00592, appx 500m at IL latitude]

    Returns:
        box polygon coordinates
    """

    w = width / 2
    h = height / 2

    # format is [[l, b], [r, b], [r, t], [l, t], [l, b]]

    box_coords = [[lon - w, lat - h],
                  [lon + w, lat - h],
                  [lon + w, lat + h],
                  [lon - w, lat + h],
                  [lon - w, lat - h]]

    return box_coords

def search_api(coordinates, start_date, end_date, item_type, clear_percent=95, cloud_cover=.05):
    """
    Args:
        coordinates (list): output of gen_box_coords(), a list of lat/lon pairs
        start_date (string): RFC 3339 date
        end_date (string): RFC 3339 date
        item_type (string): either 'PSScene3Band' or 'PSScene4Band'
        clear_percent (int, 0-100): filter for images at least this clear (for use with PSScene4Band imagery)
        cloud_cover (double, 0-1): filter for images at most this cloudy (for use with PSScene3Band imagery)

    Returns:
        A list of item IDs that matched the search filters
    """
    # needs update to handle multiple images
    if len(coordinates) > 1:
        coordinates = [coordinates]

    geo_json_geometry = {
        "type": "Polygon",
        "coordinates": coordinates
    }

    # filter for items the overlap with our chosen geometry
    geometry_filter = {
      "type": "GeometryFilter",
      "field_name": "geometry",
      "config": geo_json_geometry
    }

    # filter images acquired in a certain date range
    date_range_filter = {
      "type": "DateRangeFilter",
      "field_name": "acquired",
      "config": {
        "gte": start_date,
        "lte": end_date
      }
    }

    # filter images based on cloud tolerance
    cloud_cover_filter = {
      "type": "RangeFilter",
      "field_name": "cloud_cover",
      "config": {
        "lte": cloud_cover
      }
    }

    # filter based on total image clarity
    clear_percent_filter = {
        "type": "RangeFilter",
        "field_name": "clear_percent",
        "config": {
            "gte": clear_percent
        }
    }

    permission_filter = {
       "type":"PermissionFilter",
       "config": [
          "assets:download"
       ]
    }

    usable_data_filter = cloud_cover_filter

    # create a filter that combines our geo and date filters
    combined_filter = {
      "type": "AndFilter",
      #"config": [geometry_filter, date_range_filter, permission_filter]
      "config": [geometry_filter, date_range_filter, usable_data_filter, permission_filter]
    }

    # Search API request object
    search_endpoint_request = {
      "item_types": [item_type],
      "filter": combined_filter
    }

    attempts = 0

    while attempts < MAX_ATTEMPTS:
        result = \
          requests.post(
            'https://api.planet.com/data/v1/quick-search',
            auth=HTTPBasicAuth(os.environ['PL_API_KEY'], ''),
            json=search_endpoint_request)
        if result.status_code != 429:
            if result.status_code != 200:
                raise Exception(result)
            break

        # If rate limited, wait and try again
        time.sleep((2 ** attempts) + random.random())
        attempts = attempts + 1

    if 'json' not in result.headers.get('Content-Type'):
        raise Exception(f"{result} in search_api()")

    ids = []
    for result in result.json()['features']:
        ids.append(result['id'])

    return ids



In [None]:
def create_date(year, month, day):
    """Convert given year-month-day into ISO-8601 date representation. 
     E.g., 2015-03-04T00:00:00.000Z is a valid ISO-8601 date representation."""
    date = str(year) + '-' + str(month).zfill(2) + '-' + str(day).zfill(2) + "T00:00:00.000Z"
    return date

def year_month(year, month):
    return str(year) + '-' + str(month).zfill(2)

def n_days(year, month):
    """Get number of days in given month of given year"""
    return monthrange(year, month)[1]

def create_jobs(df, year=2020, start_month=1, start_day=1, n_months=12, item_type='PSScene4Band', 
               lat_name='Latitude', lon_name='Longitude', state_name='State', 
              save_path='./'):

    jobs = []

    for i, location in df.iterrows():
        loc_dir_name = 'loc_' + location.key


        for j in range(0, n_months):
            start_date = create_date(year, start_month + j, start_day)
            end_date = create_date(year, start_month + j, start_day + n_days(year, start_month + j)-1)
            
            lat, lon = location[lat_name], location[lon_name]
            coords = gen_box_coords(lat, lon, height=.0181, width=.0253) #should be around 2km x 2km

            state = location[state_name] 
            state_dir_name = os.path.join(save_path, f'planet_{state.lower()}_{year}')

            directory = os.path.join(state_dir_name, loc_dir_name)

            config = {
                'start_date': start_date,
                'end_date': end_date,
                'item_type': item_type,
                'out_dir': os.path.join(
                    directory, f'{year_month(year, start_month + j)}'
                ),
                'coordinates': coords
            }

            jobs += [config]

    return jobs

def create_clip(coordinates, item_ids, item_type='PSScene4Band'):

    bundle = ""
    if item_type == "PSScene3Band":
        bundle = "visual"
    elif item_type == "PSScene4Band":
        bundle = "analytic"

    clip = {
      "name": "clip",
      "order_type": "partial",
      "products": [
        {
          "item_ids": item_ids,
          "item_type": item_type,
          "product_bundle": bundle
        }
      ],
      "tools": [
        {
          "clip": {
            "aoi": {
              "type": "Polygon",
              "coordinates": [coordinates]
            }
          }
        }
      ]
    }

    return clip

def remove_existing(job, ids): 
    """
    Remove all jobs which already have corresponding files (e.g., 
    if you've run the download script multiple times on overlapping sets
    of images). 
    """

    if not os.path.exists(job['out_dir']): 
        return ids 

    to_remove = []
    for i, id in enumerate(ids): 
        if any([f.startswith(id) for f in os.listdir(job['out_dir'])]):
            to_remove.append(i)

    return list(np.delete(np.array(ids), to_remove))


def create_orders(jobs, clear_percent=100):
    """Create list of orders to send to planet. Remove 
    orders which have a corresponding image already downloaded.
    """

    order_ids = [] # list of arrays 

    start = time.time()
    for job in tqdm(jobs, desc='creating orders'): 
        job_ids = search_api(
            job['coordinates'], 
            job['start_date'], 
            job['end_date'], 
            job['item_type'], 
            clear_percent=clear_percent
        )
        job_ids = remove_existing(job, job_ids) # remove completed jobs
        order_ids.append(job_ids) 

    print(f'Orders created. ({time.time() - start:0.1f}s)')
    return order_ids 

def submit_requests(jobs, order_ids): 
    """Submit requests to Planet."""
  
    responses = {}
    start = time.time()
    desc = 'submitting jobs'
    for i, (job, ids) in tqdm(enumerate(zip(jobs, order_ids)), desc=desc):

        clip = create_clip(job['coordinates'], ids)
        response_order = requests.post(
            URL,
            auth=HTTPBasicAuth(os.environ['PL_API_KEY'], ''),
            json=clip
        )
        responses[i] = {
          'status': 'requested', 
          'order': response_order, 
          'out_dir': job['out_dir']
        }

    print(f'Submitted {len(jobs)} requests. ({time.time() - start:0.1f}s)')
    return responses 



def handle_download(jobs, session, max_wait_time=256, verbose=True):
    """In theory, call this on the list of jobs to download them all. This 
    should call the appropriate functions with the appropriate wait times, 
    backing off expoenentially if the server is receiving too many requests. 
    
    In practice, I just do all these steps manually so I can ensure that
    everything is working. Welcome to the jungle. 
    """

    # Create orders by querying db, then submit requests 
    order_ids = create_orders(jobs)
    responses = submit_requests(jobs, order_ids)

    # Alternate between updating requests and downloading 
    # ready resources. If queries to API are rate limited, 
    # apply exponential backoff to request times. 
    wait_time = 1
    rate_limited = False 
    while not all_jobs_downloaded(responses): 
        if rate_limited: 
            time.sleep(wait_time)
        if wait_time < max_wait_time:
            wait_time *= 2
    responses, rl1 = check_requested(responses)
    responses, rl2 = check_accepted(responses, session)
    responses, rl3 = download_successes(responses, session)
    rate_limited = rl1 or rl2 or rl3
    if verbose: 
        print(status_counter(responses))

    return responses

def status_counter(responses): 
    """Print the status of each response"""

    stati = [] # that a word?
    for v in responses.values():
        stati.append(v['status'])

    return Counter(stati)



def check_requested(responses):
    """Check responses which have been requested but not yet accepted. 
    Update status to either accepted or failed if warranted. 
    """

    rate_limited = False
    for k, v in responses.items():

    if v['status'] != 'requested':
        continue 

    if v['order'].status_code == 429: 
        rate_limited = True 
        break 
    elif v['order'].ok: 
        v['status'] = 'accepted'
        v['id'] = v['order'].json()['id']
      # print(f'Order {v["id"]} accepted.')
    else: 
        v['status'] = 'failed'
        print(f'Failed with code {v["order"].status_code}. \n{v["order"].content}')

    return responses, rate_limited

def check_all_running(responses, session): 
    """Returns true iff all responses are in 'running' state"""

    for v in responses.values():
        r = session.get(
            os.path.join('https://api.planet.com/compute/ops/orders/v2', f'{v["id"]}')
        )
        if r.json()['state'] != 'running':
            return False 

    return True 


def extract_json_results(json_content):
    results = []
    for result in json_content:
        if result['name'].endswith('.json'):
            results.append(result)
        if result['name'].endswith('AnalyticMS_clip.tif'):
            results.append(result)

    return results 



def check_accepted(responses, session):
    """Check all responses with `accepted' status. Update to success 
    if warranted"""

    rate_limited = False 
    for v in responses.values():
    if v['status'] != 'accepted': 
        continue 

    r = session.get(
        os.path.join('https://api.planet.com/compute/ops/orders/v2', f'{v["id"]}')
    )
    try: 
        if r.status_code == 429:
            rate_limited = True
            break 
        elif r.json()['state'] in ['success', 'partial']: 
            results = extract_json_results(r.json()['_links']['results'])
            v['media'] = {}
            for result in results: 
                v['media'][result['name']] = {'result': result, 'response': None}
            v['status'] = 'success'
        elif r.json()['state'] == 'failed':
            v['status'] = 'failed'
    except JSONDecodeError:
        rate_limited = True 

    return responses, rate_limited


def all_jobs_downloaded(responses): 
    """return true iff all jobs are downloaded"""

    for v in responses.values():
        if not v['status'] == 'downloaded':
            return False 
    return True


def all_results_downloaded(results, dir):
    """returns true iff all the results are downloaded for a 
        given request"""

    for filename in results.keys():
        if not os.path.exists(
            os.path.join(dir, filename.split(os.sep)[-1])
        ): 
            return False
    return True
  

def download_successes(responses, session):
    """Download the responses with status `success'
    """

  rate_limited = False
  for k, v in responses.items():

    if v['status'] != 'success': 
      continue 

    for name, info_dict in v['media'].items(): 

      if os.path.exists(
          os.path.join(v['out_dir'], name.split(os.sep)[-1])
      ): 
        continue 

      if info_dict['response'] is None: 
        # submit request 
        token = info_dict['result']['location'].partition('?token=')[2]
        params = (
          ('token', token),
        )
        download_response = session.get(
          'https://api.planet.com/compute/ops/download/', 
          params=params, stream=True
        )
        info_dict['response'] = download_response
#        v['media'][name]['response'] = download_response # set info_dict['response']

      if info_dict['response'].status_code == 200:
          if not os.path.exists(v['out_dir']): 
            os.makedirs(v['out_dir'])
          with open(os.path.join(v['out_dir'], name.split(os.sep)[-1]), 'wb') as f:
              download_response.raw.decode_content = True
              shutil.copyfileobj(download_response.raw, f)

      elif info_dict['response'].status_code == 429: 
        rate_limited = True  

    if all_results_downloaded(v['media'], v['out_dir']):
      v['status'] = 'downloaded'

  return responses, rate_limited


    

# Main Download Script

In [None]:
# Grab locations 

%env PL_API_KEY= # YOUR API KEY HERE
save_path = ''
locations_df = gpd.read_file(os.path.join(save_path, 'SHAPEFILE.shp'))


In [None]:
jobs = create_jobs(
    locations_df, year=2022, start_month=2, start_day=1, n_months=1, save_path=save_path
)

In [None]:
order_ids = create_orders(sub_jobs)
responses = submit_requests(sub_jobs, order_ids)

In [None]:
# Create session
session = requests.Session()
session.auth = (os.environ['PL_API_KEY'], '')

In [None]:
responses, rl = check_requested(responses)
print(status_counter(responses))

In [None]:
responses, rl = check_accepted(responses, session)
print(status_counter(responses))

In [None]:
responses, rl = download_successes(responses, session)
print(status_counter(responses))

# Create RGBs

The downloaded planet imagery is in GeoTiff format. The following lines of code convert these tiffs to RGB. 

In [None]:
parent_dir = 'DOWNLOAD-DIR'
year_month = "2022-02"
blank_tol = 0.2 # Only 20% of image can be blank 
cloud_tol = 0.2 # image can only be 20% cloudy 

for root, dirs, files in os.walk(parent_dir):
  
  loc = root.split(os.sep)[-2]
  date = root.split(os.sep)[-1]

  if "loc_" not in loc or year_month != date:
    continue

  print(loc)

  for f in files: 
    if not f.endswith('.tif'): 
      continue 
    metadata = f.replace('3B_AnalyticMS_clip.tif', 'metadata.json')
    if not os.path.exists(os.path.join(root, metadata)):
      continue 
    
    # Remove if too cloudy 
    info = json.load(open(os.path.join(root, metadata)))
    if info['properties']['cloud_cover'] > cloud_tol: 
      continue 

    # Remove if too many blank pixels 
    with rasterio.open(os.path.join(root, f)) as src:
        b,g,r,n = src.read()

    rgb = np.stack((r,g,b), axis=2)   
    rgb = rgb / rgb.max()
    blank = (rgb == 0).sum() / rgb.size
    if blank > blank_tol: 
      continue
    save_path = os.path.join(root, f.replace('.tif', '.png'))
    plt.imsave(save_path, rgb)

    # remove image if not snowy  
    if not check_snow(save_path): 
      os.remove(save_path)

    