<a href="https://colab.research.google.com/github/timsetsfire/wandb-examples/blob/main/colab/sarimax_%26_pyspark_%2B_W%26B_Sweeps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q findspark pyspark wandb

[K     |████████████████████████████████| 1.9 MB 5.1 MB/s 
[K     |████████████████████████████████| 182 kB 50.6 MB/s 
[K     |████████████████████████████████| 166 kB 39.8 MB/s 
[K     |████████████████████████████████| 63 kB 567 kB/s 
[K     |████████████████████████████████| 166 kB 15.6 MB/s 
[K     |████████████████████████████████| 162 kB 33.7 MB/s 
[K     |████████████████████████████████| 162 kB 38.7 MB/s 
[K     |████████████████████████████████| 158 kB 48.4 MB/s 
[K     |████████████████████████████████| 157 kB 52.1 MB/s 
[K     |████████████████████████████████| 157 kB 60.3 MB/s 
[K     |████████████████████████████████| 157 kB 38.8 MB/s 
[K     |████████████████████████████████| 157 kB 48.2 MB/s 
[K     |████████████████████████████████| 157 kB 44.5 MB/s 
[K     |████████████████████████████████| 157 kB 36.0 MB/s 
[K     |████████████████████████████████| 157 kB 23.5 MB/s 
[K     |████████████████████████████████| 156 kB 35.7 MB/s 
[?25h  Building wheel for 

In [None]:
##
import numpy as np
import pandas as pd
from scipy.stats import norm
import statsmodels.api as sm
import matplotlib.pyplot as plt
from datetime import datetime
import requests
from io import BytesIO

##
import wandb
import findspark
from pyspark.sql import SparkSession
import string
import random
findspark.init()

## Toy Example

In [None]:
# spark = SparkSession.builder.master("local[*]").getOrCreate()
spark = SparkSession.builder\
        .master("local")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .getOrCreate()

In [None]:
## create dataset

air2 = requests.get('https://www.stata-press.com/data/r12/air2.dta').content
data = pd.read_stata(BytesIO(air2))
data.index = pd.date_range(start=datetime(data.time[0], 1, 1), periods=len(data), freq='MS')
data['lnair'] = np.log(data['air'])
# creating a dataset to mimic a panel data set
# generating random strings
def create_house_id(N = 7):
    house = ''.join(random.choices(string.ascii_uppercase +
                             string.digits, k=N))
    return house 
houses = [str(create_house_id(7)) for i in range(5)]
mdf = []
for house in houses:
    tdf = data.copy()
    tdf["house_id"] = house
    mdf.append(tdf)
mdf = pd.concat(mdf)
mdf["lnair"] = mdf["lnair"] + np.random.rand(mdf.shape[0])
mdf["house_id_copy"] = mdf["house_id"].copy()
sparkDF=spark.createDataFrame(mdf) 
sparkDF

The generated random string : FRM9GVP


DataFrame[air: bigint, time: double, t: double, lnair: double, house_id: string, house_id_copy: string]

In [None]:
import wandb

## Set up

If you decide to run this the `wandb.login` will need access to your API Token.  If running this in databricks, you should be able to set environment variables in cluster configuration.  

In [None]:
def sweep_udf(data):
  import wandb
  def train_func():  

    wandb.login(key = "")
    run = wandb.init()

    config = run.config
    print("wandb.config:", config)

    order = [config.p, config.d, config.q]
    seasonal_order = None
    print( "order:", order)
    print( "seasonal order:", seasonal_order)
    mod = sm.tsa.statespace.SARIMAX(data['lnair'], order=order, seasonal_order=seasonal_order, simple_differencing=True)
    res = mod.fit()
    model_summary_html = "<html>\n<plaintext>\n"+res.summary().as_text()    
    print("mod order:", mod.order)
    print("mod seasonal order:", mod.seasonal_order)
    print(res.summary())    
    run.log({"aic": res.aic, "model_summary": wandb.Html(model_summary_html) })
    house_id = data["house_id"].unique()[0]
    run.log({"house_id": house_id})
    run.finish()

  house_id = data['house_id_copy'].unique()[0]
  
  sweep_config = {
      'method': 'random',
      'name': f"House ID {house_id}",
      'early_terminate': {
        'type': 'hyperband',
        'min_iter': 5, 
        'max_iter': 10
      },
      'metric': {'goal': 'minimize', 'name': 'aic'},
      'parameters': 
      {
          'trend': {'values': ['c', 't', 'ct']},
          'p': {'values': [0,1,2]},
          'd': {'values': [0,1]},
          'q': {'values': [0,1,2]},
      }
  }

  sweep_id = wandb.sweep(sweep_config, project="sarimax-spark")
  wandb_agent = wandb.agent(sweep_id, function=train_func)

  return data




In [None]:
groupedSparkDF = sparkDF.groupBy("house_id")

In [None]:
out = groupedSparkDF.applyInPandas(sweep_udf, sparkDF.schema)

In [None]:
out.count()
## nothing will be printed to stdout here.  head into the sarimax-spark project