In [None]:
# Create your tasks here
#from __future__ import absolute_import, unicode_literals
from celery.task.schedules import crontab
from celery.decorators import periodic_task, task
from celery.utils.log import get_task_logger
#from celery import shared_task
from celery.canvas import subtask
from celery.result import AsyncResult
from model_engine import ModelRunnerSet
from orm.models.htmmodel import ModelSet, HtmModel, Prediction
from .search import getData
import time
#from orm.models.project import Project

logger = get_task_logger(__name__)

#DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"

# Create your tasks here
@periodic_task(
    run_every=(crontab(minute='*')),
    name="run_all_model_predictions",
)
def run_all_model_predictions ():
    # get all active model sets
    active_sets = ModelSet.objects.filter(active=True).values("id","task_id")
    logger.info("active sets: {}".format(active_sets))

    for key in active_sets:
        if key["task_id"] == "not running":
            task_id = subtask('run_model_predictions', args=(key["id"], )).apply_async().id
            ModelSet.objects.filter(id=key["id"]).update(task_id=task_id)
        else:
            logger.info("already running")


@task (name="run_model_predictions")
def run_model_predictions(modelset_id, n=None, defaultValue=0):
    """
    run model predictions
    """
    mset = ModelSet.objects.filter(id=modelset_id).values("newly_created",
        "aggregation_interval", 'project__name', 'project__identifier')

    newly_created = mset[0]["newly_created"]
    ival = mset[0]["aggregation_interval"]
    index = '%s_%s' %(mset[0]['project__name'].replace(' ', '-').lower(),
                                    mset[0]['project__identifier'])

    logger.info("loading data")
    if newly_created:
        data = getData(index, ival=ival)
        ModelSet.objects.filter(id=modelset_id).update(newly_created=False)

    else:
        # Get last timestamps that have been processed
        last_timestamps = HtmModel.objects.filter(
            model_set_id=modelset_id).values("last_timestamp_processed")

        # get all data since last timestamp
        lts = last_timestamps[0]["last_timestamp_processed"]
        print lts
        data = getData(index, timestamp=lts, ival=ival)

    logger.info("number of input fields {}".format(len(data[0])))

    if n is None or len(data) < n:
        n = len(data)
    else:
        data = data[:n]

    #print "number of data rows {}".format(n)
    #print data

    if n>=1:

        logger.info("running predictions")
        run_model_predictions.update_state(state="LOADING")
        #input_chans = data[0].keys()
        modelrunners = ModelRunnerSet (
            modelset_id = modelset_id,
            pretrained = not newly_created
            #input_chans = [x for x in input_chans if x is not "timestamp"]
            )
        predFields = modelrunners.getPredFields()

        for step, row in enumerate(data):
            #logger.info(step)

            # stop_running = HtmModel.objects.filter(
            #         model_set_id=modelset_id).values("stop_running")[0]["stop_running"]
            stop_running = ModelSet.objects.filter(
                    id=modelset_id).values("stop_running")[0]["stop_running"]

            #logger.info(stop_running)
            if stop_running:
                break

            #print "PROCESSING {}%".format(step*float(100)/n)
            #if (step % 1000) == 0 and step>0:
            if ((step+1) % 1000) == 0:
                logger.info('no. of rows processed: {}'.format(step))
                run_model_predictions.update_state(state="SAVING")
                modelrunners.save()

            run_model_predictions.update_state(state="PROCESSING {}%".format(step*float(100)/n))

            for field in predFields:
                if field not in row:
                    row[field]=defaultValue
            modelrunners.run(row)

        run_model_predictions.update_state(state="SAVING")
        #save models
        modelrunners.save()
        # Allow to start predictions again
        ModelSet.objects.filter(id=modelset_id).update(stop_running=False)
        run_model_predictions.update_state(state="SUCCESS")
        ModelSet.objects.filter(id=modelset_id).update(task_id="not running")

    return 'finished model predictions'