<a href="https://colab.research.google.com/github/timsetsfire/wandb-examples/blob/main/colab/Sarimax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%%capture
!pip install wandb

## Simple Example of Sweeps with SARIMAX

In [3]:
import wandb
import numpy as np
import pandas as pd
from scipy.stats import norm
import statsmodels.api as sm
import matplotlib.pyplot as plt
from datetime import datetime
import requests
from io import BytesIO
# Register converters to avoid warnings
pd.plotting.register_matplotlib_converters()
plt.rc("figure", figsize=(16,8))
plt.rc("font", size=14)
air2 = requests.get('https://www.stata-press.com/data/r12/air2.dta').content
data = pd.read_stata(BytesIO(air2))
data.index = pd.date_range(start=datetime(data.time[0], 1, 1), periods=len(data), freq='MS')
data['lnair'] = np.log(data['air'])


In [4]:
# Define sweep config
sweep_config = {
    'method': 'random',
    'name': 'sarimax-sweep',
    'metric': {'goal': 'minimize', 'name': 'aic'},
    'parameters': 
    {
        'trend': {'values': ['c', 't', 'ct']},
        'p': {'values': [0,1,2]},
        'd': {'values': [0,1,2]},
        'q': {'values': [0,1,2]},
        # 's_p': {'values':  [0,1,2]},
        # 's_d': {'values': [0,1,2]},
        # 's_q': {'values': [0,1,2]} 
     }
}

In [5]:
test_config = {"trend": "t", "p": 0, "q": 1, "d": 1, "s_p": 1, "s_d": 0, "s_q": 0}

In [6]:
def train_func():  

    with wandb.init() as run:

      config = wandb.config
      print("wandb.config:", config)

      air2 = requests.get('https://www.stata-press.com/data/r12/air2.dta').content
      data = pd.read_stata(BytesIO(air2))
      data.index = pd.date_range(start=datetime(data.time[0], 1, 1), periods=len(data), freq='MS')
      data['lnair'] = np.log(data['air'])

      
      order = [config.p, config.d, config.q]
      seasonal_order = None
      mod = sm.tsa.statespace.SARIMAX(data['lnair'], order=order, seasonal_order=seasonal_order, simple_differencing=True)
      res = mod.fit()
      model_summary_html = "<html>\n<plaintext>\n"+res.summary().as_text()    
      print(res.summary())     
      run.log({"aic": res.aic, "model_summary": wandb.Html(model_summary_html) })

In [7]:
## start the sweep with wandb 
sweep_id = wandb.sweep(sweep_config, project="sarimax-test")

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: hge7mn0t
Sweep URL: https://wandb.ai/tim-w/sarimax-test/sweeps/hge7mn0t


In [None]:
## start an agent locally 
wandb_agent = wandb.agent(sweep_id, function=train_func, count = 5)

## Multiprocessing

There are actually a few ways to do this.  The method illustrated below writes our code to `train.py` and we start the wandb sweep, and we launch multiple agents locally to complete the sweep.  

In [None]:
# Define sweep config
import wandb
sweep_config = {
    "program": "train.py",
    'method': 'random',
    'metric': {'goal': 'minimize', 'name': 'aic'},
    'parameters': 
    {
        'trend': {'values': ['c', 't', 'ct']},
        'p': {'values': [0,1,2]},
        'd': {'values': [0,1,2]},
        'q': {'values': [0,1,2]},
        # 's_p': {'values':  [0,1,2]},
        # 's_d': {'values': [0,1,2]},
        # 's_q': {'values': [0,1,2]} 
     }
}

In [None]:
## write train.py
code = """
import wandb
import numpy as np
import pandas as pd
from scipy.stats import norm
import statsmodels.api as sm
import matplotlib.pyplot as plt
from datetime import datetime
import requests
from io import BytesIO
# Register converters to avoid warnings
pd.plotting.register_matplotlib_converters()
plt.rc("figure", figsize=(16,8))
plt.rc("font", size=14)
air2 = requests.get('https://www.stata-press.com/data/r12/air2.dta').content
data = pd.read_stata(BytesIO(air2))
data.index = pd.date_range(start=datetime(data.time[0], 1, 1), periods=len(data), freq='MS')
data['lnair'] = np.log(data['air'])
order = [config.p, config.d, config.q]
seasonal_order = None
mod = sm.tsa.statespace.SARIMAX(data['lnair'], order=order, seasonal_order=seasonal_order, simple_differencing=True)
res = mod.fit()
model_summary_html = "<html>\n<plaintext>\n"+res.summary().as_text()    
print(res.summary())     
run.log({"aic": res.aic, "model_summary": wandb.Html(model_summary_html) })
run.finish()
"""
with open("train.py", "w") as f:
  f.write(code)

In [None]:
sweep_id = wandb.sweep(sweep_config, project = "sarimax-test-mp")

In [None]:
import subprocess
processes = []
for i in range(1,4):
    processes.append( subprocess.Popen( [
        "wandb", 
        "agent", 
        "--project", "sarimax-test-mp", 
        "--entity", "tim-w", 
        "--count", "5",
        sweep_id],  stdout=subprocess.PIPE))
