<a href="https://colab.research.google.com/github/rwong-current/colab/blob/main/utils/query.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.cloud import bigquery
from google.cloud import bigquery_storage
import google.auth

from google.colab import auth
auth.authenticate_user()

import pandas as pd
import numpy as np
import time, pytz
import re
import pprint

from datetime import datetime, timedelta

pd.options.display.max_columns = None
pd.options.display.float_format = '{:,.2f}'.format

class bq:
    def __init__(self, project_id: str):
      print("BQ V2")
      self.project_id = project_id
      self.bqclient = bigquery.Client(
          project=project_id
      )

      self.bq_to_pd_dtype = {
          'STRING': 'O',
          'INTEGER': 'Int64',
          'FLOAT': 'Float64',
          'BOOLEAN': pd.BooleanDtype(),
          'DATETIME': 'datetime64[ns]'
      }
      self.pd_to_bq_dtype = {pd: bq for bq, pd in self.bq_to_pd_dtype.items()}

    def generate_query(self, columns: list, table: str):
      return '''
      select {columns}
      from {table}
      '''.format(table=table,columns=', '.join(columns)).strip()

    def run_query(self, columns: list = None, table: str = None, override_query: str = None):
        '''BQ Query'''

        assert ((columns is not None) and (table is not None)) or (override_query is not None), "Missing arguments"

        print("Running Query")
        try:
          start = time.time()
          if override_query:
            override_query = override_query.strip()
            df = self.bqclient.query(override_query).result().to_dataframe()
            if table is None:
              table = next(arg for arg in override_query.split('\n') if re.search('from.*', arg)).split('`')[1]
          else:
            query = self.generate_query(columns, table)
            df = self.bqclient.query(query).result().to_dataframe()

          if len(self.get_table_schema(table)) == len(df.columns):
            df = self.match_dtypes(df, self.get_table_schema(table))

          print("Query Complete. Runtime:", round(time.time() - start,2), "s")
          print("df shape:", df.shape)
          return df
        except Exception as e:
            print(e)
        except:
            print("ruh roh")


    def get_table_schema(self, table: str):
        '''
        https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.schema.SchemaField
        '''
        print("Getting Table Schema")
        try:
            schema = {}
            table = self.bqclient.get_table(table)
            for col in table.schema:
                schema[col.name] = col.field_type
            return schema
        except:
            print("ruh roh")

    def is_date(self, data):
        '''
        Returns true if column or variable with dtype datetime64[ns] has type date (no hr, min, sec info)
        '''
        try: return (data.dt.hour == 0).all() & (data.dt.minute == 0).all() & (data.dt.second == 0).all()
        except: pass

        try: return (data.hour == 0) & (data.minute == 0) & (data.second == 0)
        except: pass

        return False

    def match_dtypes(self, df: pd.DataFrame, schema: dict):
        '''
        Change df dtypes to match schema specified (BQ) dtypes
        '''
        print("Matching dtypes")
        schema_new = {col: self.bq_to_pd_dtype.get(dtype, 'O') for col, dtype in schema.items()}
        df_type_match = df.astype(schema_new)

        for col, dtype in schema.items():
            if dtype == 'DATE':
                df_type_match[col] = pd.to_datetime(df_type_match[col], format='%Y-%m-%d')

        return df_type_match

    def build_temp_table(self, dataset: str, table_name: str, df: pd.DataFrame, expires={}) -> bigquery.Table:
        '''
        Populate temporary table for querying purposes.
        Expires: (days, seconds, microseconds, milliseconds, minutes, hours, weeks); Table lives for default 1hr
        '''
        start = time.time()

        # Set table expiration
        exp = {'days':0, 'seconds':0, 'microseconds':0, 'milliseconds':0, 'minutes':0, 'hours':1, 'weeks':0}
        accepted_keys = ['days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks']
        invalid_key = [key for key in expires.keys() if key not in accepted_keys]
        valid_key = [key for key in expires.keys() if key in accepted_keys]
        if len(invalid_key) > 0:
            print(invalid_key, 'not in', accepted_keys)
        if len(valid_key) > 0 and 'hours' not in valid_key:
            expires['hours'] = 0
        exp.update(expires)

        # Define table
        temp_table_name = f"{self.project_id}.{dataset}.temp_{table_name}"
        temp_table_def = bigquery.Table(temp_table_name)
        temp_table_def.expires = datetime.now(pytz.utc) + timedelta(
            days=exp['days'], seconds=exp['seconds'], microseconds=exp['microseconds'],
            milliseconds=exp['milliseconds'], minutes=exp['minutes'], hours=exp['hours'], weeks=exp['weeks']
        )

        schema_list = []
        for col, dtype in df.dtypes.to_dict().items():
            bq_type = 'DATE' if self.is_date(df[col]) else self.pd_to_bq_dtype.get(dtype, 'STRING')
            bq_mode = 'REPEATED' if isinstance(df[col][0], np.ndarray) else 'NULLABLE'
            schema = bigquery.SchemaField(col, bq_type, mode=bq_mode)
            schema_list.append(schema)

        schema_list.append(bigquery.SchemaField('inserted_at', 'DATETIME', mode='REQUIRED'))
        temp_table_def.schema = schema_list

        try:
            temp_table = self.bqclient.create_table(temp_table_def)
            print("Table created. Runtime:", round(time.time() - start,2), "s")
        except:
            print(temp_table_name, 'already exists.')
            temp_table = self.bqclient.get_table(temp_table_name)
            print("Accessed table. Runtime:", round(time.time() - start,2), "s")
            # temp_table.schema = schema_list
            # print(temp_table.schema)
            # temp_table = self.bqclient.update_table(temp_table, ['schema'])

        # Clean data
        df = df.replace({np.nan: None})
        df['inserted_at'] = datetime.now()
        data = df.to_dict('records')

        # Object of type x is not JSON serializable
        for row in data:
          for k, v in row.items():
            if isinstance(v, np.ndarray):
              row[k] = list(v)
            elif isinstance(v, datetime):
              if self.is_date(v):
                row[k] = v.strftime('%Y-%m-%d')
              else:
                row[k] = v.strftime('%Y-%m-%d %H:%M:%S')

        print("Data transformed. Runtime:", round(time.time() - start,2), "s")

        # Insert data - chunking
        def chunks(data, max_chunk_size=5000):
          for i in range(0, len(data), max_chunk_size):
            yield data[i:i+max_chunk_size]

        errors = []
        for chunk in list(chunks(data)):
          errors += self.bqclient.insert_rows(temp_table, chunk)

        print("Data inserted. Runtime:", round(time.time() - start,2), "s")
        print(f"Loaded {len(data)} rows into {temp_table.dataset_id}:{temp_table.table_id} with {len(errors)} errors")

        if len(errors) > 0:
            pprint.pprint(errors)

class excel:
    def __init__(self, filepath: str, sheet_names: list):
        self.filepath = filepath
        self.sheet_names = sheet_names
        self.result = {}

    def read_excel(self):
        print("Reading Excel File")
        xls = pd.ExcelFile(self.filepath)

        for sheet in self.sheet_names:
            self.result[sheet] = pd.read_excel(xls, sheet)

        return self.result

KeyboardInterrupt: 