# Basic engine time profiling

In [168]:
from collections import OrderedDict

import torch

from ignite.engine import Engine, Events
from ignite.handlers import Timer


def remove_handler(engine, handler, event_name):
    assert event_name in engine._event_handlers    
    engine._event_handlers[event_name] = [(h, args, kwargs) 
                                          for h, args, kwargs in engine._event_handlers[event_name] 
                                            if h != handler]

class BasicTimeProfiler(object):

    def __init__(self, output_path=None):
        self.output_path = output_path
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()

    def _reset(self, num_iters, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)        
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters)
        }

    def _as_first_started(self, engine):
        num_iters = engine.state.max_epochs * len(engine.state.dataloader)
        self._reset(len(engine.state.dataloader), engine.state.max_epochs, num_iters)
        
        self.event_handlers_names = {
            e: [h.__name__ for (h, _, _) in engine._event_handlers[e]]
            for e in Events if e != Events.EXCEPTION_RAISED
        }
        
        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append((self._as_last_started, (), {}))
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            engine._event_handlers[e].insert(0, (m, (), {}))

        for e, m in zip(events, lmethods):
            engine._event_handlers[e].append((m, (), {}))

        # Let's go
        self._event_handlers_timer.reset()

    def _as_last_started(self, engine):        
        self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value()

    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t

        self._dataflow_timer.reset()

    def _as_first_iter_started(self, engine):        
        t = self._dataflow_timer.value()
        i = engine.state.iteration - 1
        self.dataflow_times[i] = t

        self._event_handlers_timer.reset()

    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t

        self._processing_timer.reset()

    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t

        self._event_handlers_timer.reset()        

    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t

        self._dataflow_timer.reset()

    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t

    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()

    def _as_last_completed(self, engine):
        self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()

        # Display stats
        self.print_results(self.get_results())

        # Write results
        if self.output_path is not None:
            self.write_results(self.output_path)
        
        # Remove added handlers:
        remove_handler(engine, self._as_last_started, Events.STARTED)
        
        #  - add the first handlers 
        events = [Events.EPOCH_STARTED, Events.EPOCH_COMPLETED, 
                  Events.ITERATION_STARTED, Events.ITERATION_COMPLETED, 
                  Events.COMPLETED]
        fmethods = [self._as_first_epoch_started, self._as_first_epoch_completed, 
                   self._as_first_iter_started, self._as_first_iter_completed,
                   self._as_first_completed]
        lmethods = [self._as_last_epoch_started, self._as_last_epoch_completed, 
                   self._as_last_iter_started, self._as_last_iter_completed,
                   self._as_last_completed]

        for e, m in zip(events, fmethods):
            remove_handler(engine, m, e)          

        for e, m in zip(events, lmethods):
            remove_handler(engine, m, e)

    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, "
                            "but given {}".format(type(engine)))
        
        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (), {}))
    
    @staticmethod
    def _compute_basic_stats(data):    
        return OrderedDict([
            ('min/index', (torch.min(data).item(), torch.argmin(data).item())),
            ('max/index', (torch.max(data).item(), torch.argmax(data).item())),
            ('mean', torch.mean(data).item()),
            ('std', torch.std(data).item()),
            ('total', torch.sum(data).item())
        ])
            
    def get_results(self):
        total_eh_time = sum([sum(self.event_handlers_times[e]) for e in Events if e != Events.EXCEPTION_RAISED])
        return OrderedDict([
            ("processing_stats", self._compute_basic_stats(self.processing_times)),
            ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
            ("event_handlers_stats",
                dict([(str(e).replace(".","_"), self._compute_basic_stats(self.event_handlers_times[e])) for e in Events 
                      if e != Events.EXCEPTION_RAISED] + 
                     [("total_time", total_eh_time)])
            ),
            ("event_handlers_names", {str(e).replace(".","_") + "_names": v 
                                      for e, v in self.event_handlers_names.items()})
        ])

    @staticmethod
    def odict_to_str(d):
        out = ""
        for k, v in d.items():
            out += "\t{}: {}\n".format(k, v)
        return out

    @staticmethod
    def print_results(results):
        
        others = {k: BasicTimeProfiler.odict_to_str(v) if isinstance(v, OrderedDict) else v 
                  for k, v in results['event_handlers_stats'].items()}
        
        others.update(results['event_handlers_names'])
        
        output_message = """
--------------------------------------------
- Time profiling results:
--------------------------------------------
              
Processing function time stats (in seconds):
{processing_stats}

Dataflow time stats (in seconds):
{dataflow_stats}
        
Time stats of event handlers (in seconds):
- Total time spent:
\t{total_time}

- Events.STARTED:
{Events_STARTED}
Handlers names:
{Events_STARTED_names}

- Events.EPOCH_STARTED:
{Events_EPOCH_STARTED}
Handlers names:
{Events_EPOCH_STARTED_names}

- Events.ITERATION_STARTED:
{Events_ITERATION_STARTED}
Handlers names:
{Events_ITERATION_STARTED_names}

- Events.ITERATION_COMPLETED:
{Events_ITERATION_COMPLETED}
Handlers names:
{Events_ITERATION_COMPLETED_names}

- Events.EPOCH_COMPLETED:
{Events_EPOCH_COMPLETED}
Handlers names:
{Events_EPOCH_COMPLETED_names}

- Events.COMPLETED:
{Events_COMPLETED}
Handlers names:
{Events_COMPLETED_names}

""".format(processing_stats=BasicTimeProfiler.odict_to_str(results['processing_stats']), 
           dataflow_stats=BasicTimeProfiler.odict_to_str(results['dataflow_stats']), 
           **others)
        print(output_message)
    
    @staticmethod
    def write_results(output_path):
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return
        
        raise NotImplementedError("")

In [169]:
from ignite.engine import Engine, Events

import time

In [170]:
def processing_function(engine, batch):
    time.sleep(0.12345)
    return 

In [171]:
engine = Engine(processing_function)

In [172]:
@engine.on(Events.STARTED)
def f1(engine):
    print("f1 - Events.STARTED")
    time.sleep(0.1)
    
@engine.on(Events.STARTED)
def f2(engine):
    print("f2 - Events.STARTED")
    time.sleep(0.2)    

In [173]:
@engine.on(Events.COMPLETED)
def f3(engine):
    print("f3 - Events.COMPLETED")    
    time.sleep(0.11)
    
@engine.on(Events.COMPLETED)
def f4(engine):
    print("f4 - Events.COMPLETED")    
    time.sleep(0.22)    

In [174]:
profiler = BasicTimeProfiler()

profiler.attach(engine)

In [175]:
@engine.on(Events.EPOCH_STARTED)
def f5(engine):
    print("f5 - Events.EPOCH_STARTED")
    time.sleep(0.23)
        
@engine.on(Events.EPOCH_COMPLETED)
def f6(engine):
    print("f6 - Events.EPOCH_COMPLETED")
    time.sleep(0.12)    

In [176]:
@engine.on(Events.ITERATION_STARTED)
def f7(engine):
    time.sleep(0.0123)
        
@engine.on(Events.ITERATION_COMPLETED)
def f8(engine):
    time.sleep(0.5333)    

In [177]:
class DataLoader():
    
    def __init__(self):
        self.data = list(range(10))

    def __getitem__(self, i):
        time.sleep(0.0111)
        return self.data[i]

    def __len__(self):
        return len(self.data)



data = DataLoader()

engine.run(data, max_epochs=2)

f1 - Events.STARTED
f2 - Events.STARTED
f5 - Events.EPOCH_STARTED
f6 - Events.EPOCH_COMPLETED
f5 - Events.EPOCH_STARTED
f6 - Events.EPOCH_COMPLETED
f3 - Events.COMPLETED
f4 - Events.COMPLETED

--------------------------------------------
- Time profiling results:
--------------------------------------------
              
Processing function time stats (in seconds):
	min/index: (0.12363585084676743, 8)
	max/index: (0.12369796633720398, 6)
	mean: 0.12366489320993423
	std: 1.790072383300867e-05
	total: 2.4732978343963623


Dataflow time stats (in seconds):
	min/index: (0.011218300089240074, 3)
	max/index: (0.011265149340033531, 7)
	mean: 0.011234122328460217
	std: 1.2680082363658585e-05
	total: 0.22468245029449463

        
Time stats of event handlers (in seconds):
- Total time spent:
	22.268726348876953

- Events.STARTED:
	min/index: (0.30058592557907104, 0)
	max/index: (0.30058592557907104, 0)
	mean: 0.30058592557907104
	std: nan
	total: 0.30058592557907104

Handlers names:
['_as_firs

<ignite.engine.engine.State at 0x7fd1a609f7f0>

In [120]:
res = profiler.get_results()

In [121]:
d = res['event_handlers_stats']

In [125]:
{k: type(v) for k, v in res['event_handlers_stats'].items()}

{'Events_COMPLETED': collections.OrderedDict,
 'Events_EPOCH_COMPLETED': collections.OrderedDict,
 'Events_EPOCH_STARTED': collections.OrderedDict,
 'Events_ITERATION_COMPLETED': collections.OrderedDict,
 'Events_ITERATION_STARTED': collections.OrderedDict,
 'Events_STARTED': collections.OrderedDict,
 'total_time': torch.Tensor}

In [108]:
def odict_to_str(d):
    out = ""
    for k, v in d.items():
        out += "\t{}: {}\n".format(k, v)
    return out

In [109]:
print(odict_to_str(d))

	min/index: (0.12362832576036453, 17)
	max/index: (0.12369902431964874, 10)
	mean: 0.12365998327732086
	std: 1.8820035620592535e-05
	total: 2.4731996059417725



In [84]:
res['event_handlers_stats']

{'Events.COMPLETED': OrderedDict([('min/index', (0.3306533098220825, 0)),
              ('max/index', (0.3306533098220825, 0)),
              ('mean', 0.3306533098220825),
              ('std', nan),
              ('total', 0.3306533098220825)]),
 'Events.EPOCH_COMPLETED': OrderedDict([('min/index',
               (0.12030424177646637, 0)),
              ('max/index', (0.12030671536922455, 1)),
              ('mean', 0.12030547857284546),
              ('std', 1.7490941672804183e-06),
              ('total', 0.24061095714569092)]),
 'Events.EPOCH_STARTED': OrderedDict([('min/index', (0.23040813207626343, 1)),
              ('max/index', (0.23040902614593506, 0)),
              ('mean', 0.23040857911109924),
              ('std', 6.322027275018627e-07),
              ('total', 0.4608171582221985)]),
 'Events.ITERATION_COMPLETED': OrderedDict([('min/index',
               (0.03337053954601288, 13)),
              ('max/index', (0.03342927619814873, 17)),
              ('mean', 0.03340839

In [57]:
profiler._compute_basic_stats(profiler.processing_times)

{'max': (0.12424758076667786, 233),
 'mean': 0.12366418540477753,
 'min': (0.12356055527925491, 222),
 'std': 4.8520789277972654e-05,
 'total': 30.916046142578125}

In [58]:
profiler._compute_basic_stats??

[0;31mSignature:[0m [0mprofiler[0m[0;34m.[0m[0m_compute_basic_stats[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
    [0;34m@[0m[0mstaticmethod[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m_compute_basic_stats[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m:[0m    [0;34m[0m
[0;34m[0m        [0;32mreturn[0m [0mOrderedDict[0m[0;34m([0m[0;34m[0m
[0;34m[0m            [0;34m([0m[0;34m'min/index'[0m[0;34m,[0m [0;34m([0m[0mtorch[0m[0;34m.[0m[0mmin[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0margmin[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m            [0;34m([0m[0;34m'max/index'[0m[0;34m,[0m [0;34m([0m[0mtorch[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem

In [59]:
compute_basic_stats??

[0;31mSignature:[0m [0mcompute_basic_stats[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mcompute_basic_stats[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m:[0m    [0;34m[0m
[0;34m[0m    [0;32mreturn[0m [0;34m{[0m[0;34m[0m
[0;34m[0m        [0;34m'min'[0m[0;34m:[0m [0;34m([0m[0mtorch[0m[0;34m.[0m[0mmin[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0margmin[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m,[0m[0;34m[0m
[0;34m[0m        [0;34m'max'[0m[0;34m:[0m [0;34m([0m[0mtorch[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0margmax[0m[0;34m([0m[0mdata[0m[0;34m)[0m[0;34m.[0m[0mitem[0m[0;34m([0m[0;34m)[0m[0;34m)

In [None]:
engine._event_handlers

In [None]:
profiler.event_handlers_times

In [52]:
# profiler.dataflow_times

In [53]:
profiler.processing_times

tensor([0.1237, 0.1236, 0.1236, 0.1237, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237,
        0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1237, 0.1237, 0.1237, 0.1236,
        0.1236, 0.1237, 0.1236, 0.1236, 0.1237, 0.1236, 0.1237, 0.1237, 0.1236,
        0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1236, 0.1236, 0.1237, 0.1236,
        0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237,
        0.1237, 0.1236, 0.1237, 0.1237, 0.1237, 0.1237, 0.1237, 0.1237, 0.1237,
        0.1237, 0.1236, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1236, 0.1237,
        0.1237, 0.1236, 0.1237, 0.1237, 0.1237, 0.1237, 0.1236, 0.1237, 0.1236,
        0.1236, 0.1237, 0.1237, 0.1237, 0.1237, 0.1236, 0.1236, 0.1237, 0.1236,
        0.1237, 0.1237, 0.1236, 0.1236, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237,
        0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1237,
        0.1236, 0.1237, 0.1237, 0.1236, 0.1237, 0.1237, 0.1237, 0.1237, 0.1237,
        0.1237, 0.1237, 0.1237, 0.1237, 

In [27]:
0.12345 * 5 * 50

30.862500000000004

In [25]:
def compute_basic_stats(data):    
    return {
        'min': (torch.min(data).item(), torch.argmin(data).item()),
        'max': (torch.max(data).item(), torch.argmax(data).item()),
        'mean': torch.mean(data).item(),
        'std': torch.std(data).item(),
        'total': torch.sum(data).item()
    }

In [54]:
compute_basic_stats(profiler.processing_times)

{'max': (0.12424758076667786, 233),
 'mean': 0.12366418540477753,
 'min': (0.12356055527925491, 222),
 'std': 4.8520789277972654e-05,
 'total': 30.916046142578125}

In [28]:
compute_basic_stats(profiler.dataflow_times)

{'max': (0.011513255536556244, 247),
 'mean': 0.011294993571937084,
 'min': (0.011203711852431297, 244),
 'std': 3.814018418779597e-05,
 'total': 2.8237483501434326}

In [29]:
compute_basic_stats(profiler.event_handlers_times[Events.EPOCH_STARTED])

{'max': (0.23074786365032196, 0),
 'mean': 0.23057159781455994,
 'min': (0.23048798739910126, 2),
 'std': 0.00010309197386959568,
 'total': 1.152858018875122}

In [30]:
compute_basic_stats(profiler.event_handlers_times[Events.EPOCH_COMPLETED])

{'max': (0.1205575242638588, 2),
 'mean': 0.12043551355600357,
 'min': (0.12034600228071213, 0),
 'std': 8.179348515113816e-05,
 'total': 0.6021775603294373}

In [38]:
sum([sum(profiler.event_handlers_times[e]) for e in Events if e != Events.EXCEPTION_RAISED])

TypeError: 'float' object is not iterable

In [40]:
a = [1, 2]
b = [3,]
a + b

[1, 2, 3]

In [71]:
[str(e) for e in Events]

['Events.EPOCH_STARTED',
 'Events.EPOCH_COMPLETED',
 'Events.STARTED',
 'Events.COMPLETED',
 'Events.ITERATION_STARTED',
 'Events.ITERATION_COMPLETED',
 'Events.EXCEPTION_RAISED']