## Import necessary modules

In [None]:
import sys 
import swat
import matplotlib
import matplotlib.pyplot as plt 
from matplotlib import dates as mpldates
import argparse
import os
import inspect
import re

## Set global variables
You may need to change these based on the CAS configuration you are using.

In [None]:
# Set the CAS port  
DEFAULT_PORT = 5570  
# Host running CAS 
DEFAULT_HOST = "localhost" 
#DEFAULT_HOST = "sasserver.demo.sas.com" 
# Name of caslib to use 
DEFAULT_CASLIB = "mysess" 
# Set the protocol (None == autodetect) 
PROTOCOL = None 
# Set path to data
DATA_PATH = "/opt/sas/viya/config/data/cas/default/public/"

## Define wrapper functions to create ESM and ARIMA models

NOTE: These functions use algorithms from the statsmodels package. The functions will not be called from this Python program but from the Python program created and run by the EXTLANG package. They are defined here to simplify the embedded CMP code defined below. This code could also reside in a separate file that can be pushed into the generated Python code using the PushCodeFile method, which is the preferred method since inline code cannot contain certain characters, such as quotation marks, since they would be interpolated into the SAS code. 

In [None]:
def _local_esm_fun(y): 
    from statsmodels.tsa.holtwinters import ExponentialSmoothing 
    model = ExponentialSmoothing(y, trend="additive") 
    model_fit = model.fit() 
    yhat = model_fit.predict(0, int(NFOR) - 1) 
    return yhat 

def _local_arima_fun(y, fcst_length, dolog=True, order=(0,1,1),  
                     seasonal_order=(0,1,1,12)): 
    import statsmodels.api as sm 
    if dolog: 
        y = np.log(y) 
    model=sm.tsa.statespace.SARIMAX(y, order=order, 
                                    seasonal_order=seasonal_order, 
                                    enforce_invertibility=False) 
    model_fit=model.fit() 
    predicted = model_fit.predict(0, int(fcst_length) -1)  
    if dolog: 
        predicted = np.exp(predicted) 
    return predicted 

## Add code to call the Timedata.RunTimeCode action

This includes creating the connection and loading the action set. 

In [None]:
if __name__ == "__main__": 
    # Parse arguments 
    port, protocol, host, caslib = DEFAULT_PORT, PROTOCOL, DEFAULT_HOST,\
                                             DEFAULT_CASLIB 

    # Create connection 
    conn = swat.CAS(host, port, caslib=caslib, 
                    protocol=protocol) 


    # Load needed action sets 
    conn.loadactionset(actionset="timedata") 


## Load the dataset into a CASLIB

In [None]:
    # Specify path to data 
    indatadir = DATA_PATH 
    indata = "skinproduct" 
    if not conn.table.tableExists(table=indata).exists: 
        path_no_ext = os.path.join(indatadir, indata)
        if os.path.exists(path_no_ext + ".sashdat"):
            tbl = conn.upload_file(path_no_ext+".sashdat", casout={"name":indata})
        elif os.path.exists(path_no_ext + ".sas7bdat"):
            tbl = conn.upload_file(path_no_ext+".sas7bdat", casout={"name":indata})
        else: 
            tbl = conn.upload_file(path_no_ext+".csv",  
                                               casout={"name":indata}) 

## Specify the CMP code that is to be executed by the Timedata.RunTimeCode action 
Note how we use the inspect module to read the code from the _local_arima_fun_ function and remove all "#" comments from the text, since that will conflict with the SAS parser. In addition to the function definition, the generated Python code sets the PREDICT array by calling _local_arima_fun()_ with the input data set (Y). Besides running the code, objects are created to collect logs, variable status, and to store the code into a table. The stored code will be read in the next example of this tutorial.
    

In [None]:
    # trim comments from local functions and get their text 
    hash_comment = re.compile(r"#.*") 
    remove_comments = lambda s: re.sub(hash_comment, r"", s) 
    local_esm_text = remove_comments(inspect.getsource(_local_esm_fun)) 
    local_arima_text = remove_comments(inspect.getsource(_local_arima_fun)) 
    
    # Create string of PushCodeLine - SAS will hang if you just give it a bunch of newlines
    
    def makeLinePushCalls(s, objName):
        out = []
        for line in s.splitlines():
            assert line.count("'") == 0
            out.append("  rc = {obj}.PushCodeLine('{line}');".format(obj=objName, line=line))
        return "\n".join(out)
    
    local_esm_text = makeLinePushCalls(local_esm_text, "py")
    local_arima_text = makeLinePushCalls(local_arima_text, "py")
    
    cmpcode = """
       declare object py (PYTHON3);
       rc = py.Initialize();   
       rc = py.AddVariable(revenue, 'ALIAS', 'Y');
       rc = py.AddVariable(pyPred, 'ALIAS', 'PREDICT', 'READONLY', 'FALSE');
       rc = py.AddVariable(_LENGTH_, 'ALIAS', 'NFOR');

      {arimaFunc}
      {esmFunc}
      
      rc = py.PushCodeLine('import numpy');
      rc = py.PushCodeLine('PREDICT = _local_arima_fun(Y, NFOR, seasonal_order=(0,0,0,0))') ;
  
      rc = py.Run();
  
      declare object pylog(OUTEXTLOG) ;
      rc = pylog.Collect(py, 'EXECUTION') ;

      declare object pyvars(OUTEXTVARSTATUS) ;
      rc = pyvars.collect(py) ;

      /* Collect the code into a table to reuse it in the next example of the tutorial*/
      declare object outcode(OUTEXTCODE) ;
      rc = outcode.setOption("RUNID", "saspedia_extlang_ex2") ;
      rc = outcode.collect(py) ;
  """.format(esmFunc=local_esm_text, arimaFunc=local_arima_text) 

In [None]:
print(cmpcode)


## Specify the parameters of the action call 
Note that we promote the outcode table. This table will contain source code to be used in the next example. Therefore, you must use the same CAS session for that example. 

In [None]:
    # runTimecode declaration gets hard to read with all  
    # the nested dicts, so use this shorthand for outer  
    # dicts and the {} syntax for the innermost dicts 
    d = dict 
    # Shorthand for dicts that only have a key for "name" 
    dname = lambda name: dict(name=name) 
    # Define the timedata.runTimecode object 
    res = conn.timedata.runtimecode( 
                  table={ 
                      'name':indata, 
                      'groupby':[  
                                 dname("DistributionCenter")] 
                        },   
                  series=[d(accumulate='SUM', name='Revenue')], 
                  interval='Week', 
                  require=d(pkg="extlang"), 
                  timeid=d(name='date'), 
                  lead=12, 
                  arrayout={'arrays':[dname("pyPred")], 
                            'table':d(name="outarray", replace=True)}, 
                  objout=[ 
                    d(table=dname("outobj_pylog"), objRef="pylog"), 
                    d(table=dname("outobj_pyvars"), objRef="pyvars"), 
                     d(table=swat.CASTable("outcode", promote=True), objref="outcode"), 
                      ], 
                  code=cmpcode) 
    del(d) 
    print(res)
    print(conn)

## Print Python log
This should just print some output from the statsmodels package. 

In [None]:
    outlog_tbl = conn.CASTable("outobj_pylog") 
    loglen = sum(outlog_tbl["_LOGLEN_"].values) 
    # Print log if it's not empty
    if loglen > 0: 
        text = "".join(outlog_tbl["_LOGTEXT_"].values) 
        print("LOG:") 
        print(text) 
        print() 

## Check writable variables' status
If any of the writable variables were not updated (i.e. data was not brought back to the SAS program), print them out. 
This code should not produce any output.

In [None]:
    varstats_tbl = conn.CASTable("outobj_pyvars") 
    varnames = varstats_tbl["_NAME_"].values 
    updated = varstats_tbl["UPDATED"].values 
    for varname,is_upd in zip(varnames,updated): 
        if int(is_upd) != 1: 
            print("WARNING :: Variable {0} was NOT updated".format(varname)) 

## Tabulate a summary of the original and predicted values 


In [None]:
    outarray_tbl = conn.CASTable("OUTARRAY")
    
    column_names = ["DATE", "DistributionCenter", "Revenue", "pypred"] 
    dates, distr_centers, orig_values, predicted_values = \
        [outarray_tbl[col].values for col in column_names] 
    print("------------------------- RESULTS --------------------------") 
    print(  "{0:^12}|{1:^20}|{2:^15}|{3:^15}".format("Date", "DistributionCenter", "Revenue", "PREDICTED")) 
    print("-"*62) 
    for i, (date, distr_center, air, pypred) in \
        enumerate(zip( dates, distr_centers, orig_values, predicted_values)): 
        print("{0:<12}|{1:^20}|{2:<15}|{3:<15}" 
              .format(str(date), distr_center, air, pypred)) 
    print("\n") 

## Plot the predicted and actual revenue for Miami, Chicago, and Atlanta 

In [None]:
    from pandas.plotting import register_matplotlib_converters
    register_matplotlib_converters()
    for city,color in zip(["Miami", "Chicago", "Atlanta"],["green", "purple", "red"]): 
        idc = [i for i,ctr in enumerate(distr_centers) if ctr==city] 
        ctr_orig = orig_values[idc] 
        ctr_predicted = predicted_values[idc] 
        dates_1x = dates[idc] 
        ax = plt.gca() 
        plt.ylabel("Revenue") 
        plt.scatter(dates_1x, ctr_orig, color="black", marker="o", lw=0.5,  
                    facecolors="none", edgecolors="black") 
        plt.plot(dates_1x, ctr_predicted, color=color, ls="-", lw=0.5,  
                    label=city) 
    handles, labels = ax.get_legend_handles_labels() 
    ax.legend(handles, labels) 
    # Set x ticks to be less frequent so text doesn't get clobbered 
    ax.xaxis.set_major_locator(mpldates.YearLocator()) 
    # The statsmodels ARIMA function produces a big outlier, so set the limits 
    ax.set_ylim(1000000, 10000000) 
    plt.savefig("extlang_ex2_skincare.png") 