In [1]:
# perf_files = [
#     "output/20230313-133243-pr.pkl",
#     "output/20230313-134520-nightly.pkl",
# ]

perf_files = [
    "output/20230313-133243-pr.pkl",
    "output/20230315-011856-pr.pkl",
]

In [2]:
import pickle
from pathlib import Path
import torch
import torch.utils.benchmark as benchmark

ab_results = []
for perf_filepath in perf_files:
    assert Path(perf_filepath).exists(), f"{perf_filepath} is not found"
    with open(perf_filepath, "rb") as handler:
        output = pickle.load(handler)
        ab_results.extend(output["test_results"])

    
compare = benchmark.Compare(ab_results)

No CUDA runtime is found, using CUDA_HOME='/opt/conda'


In [3]:
# compare.colorize()
compare.print()

[------------------------------------------------------------------------- Resize -------------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+git1d3a939) PR  |  torch (2.1.0a0+git0c58e8a) PR
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |          38.3          |                56.3             |                58.3           
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |                36.2             |                39.5           
      3 torch.uint8 channels_last bilinear 256 -> 224 aa=True    |         127.6          |               149.9             |               159.8           
      3 torch.uint8 channels_last bilinear 256 -> 224 aa=F

In [4]:
from torch.utils.benchmark.utils import common

In [5]:
results = common.Measurement.merge(compare._results)

In [6]:
grouped_results = compare._group_by_label(results)

In [7]:
# compare._render??
groups_iter = iter(grouped_results.values())
group = next(groups_iter)

In [9]:
c1 = "torch (2.1.0a0+git1d3a939) PR"
c2 = "torch (2.1.0a0+git0c58e8a) PR"
# description = f"Speed-up: {c1} vs {c2}"
description = f"Speed-up: PR 1 vs PR 2"

In [10]:
class Value(common.Measurement): pass
#     @property
#     def times(self):
#         assert len(self.raw_times) == 1, self.raw_times
#         return self.raw_times[0]
    
#     def _lazy_init(self):
#         pass

In [11]:
# for measurement in group:
#     print(measurement.task_spec.description, measurement.task_spec.sub_label)

In [12]:
updated_group = []

sub_label = None
v1 = None
v2 = None
r = None

_, scale = common.select_unit(min([r.median for r in group]))

for measurement in group:
    # print(measurement.task_spec.description)
        
    if measurement.task_spec.description == c1:
        v1 = measurement.median
        sub_label = measurement.task_spec.sub_label
        # print(c1, v1, sub_label)

    measurement2 = None
    for m2 in group:
        d2 = m2.task_spec.description
        sl2 = m2.task_spec.sub_label
        if d2 == c2 and sl2 == sub_label:
            v2 = m2.median
            # print(c2, v2)
            measurement2 = m2            
            break
    
    if measurement not in updated_group:
        updated_group.append(measurement)
    if v1 is not None and v2 is not None:
        if measurement2 not in updated_group:
            updated_group.append(measurement2)
        r = v2 / v1 * scale
        # print("->", r)
        v1 = None
        v2 = None
        sub_label = None
        speedup_task = common.TaskSpec(
            "", 
            setup="", 
            label=measurement.label,
            sub_label=measurement.sub_label,
            num_threads=measurement.num_threads,
            env=measurement.env,
            description=description
        )
        speedup_measurement = Value(1, [r, ], speedup_task)
        r = None
        updated_group.append(speedup_measurement)

In [13]:
from torch.utils.benchmark.utils.compare import Table


class CustomizedTable(Table):
    
    def __init__(self, results, colorize, trim_significant_figures, highlight_warnings):
        assert len(set(r.label for r in results)) == 1

        self.results = results
        self._colorize = colorize
        self._trim_significant_figures = trim_significant_figures
        self._highlight_warnings = highlight_warnings
        self.label = results[0].label
        self.time_unit, self.time_scale = common.select_unit(
            min(r.median for r in results if not isinstance(r, Value))
        )

        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
        self.row_keys.sort(key=lambda args: args[:2])  # preserve stmt order
        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
        self.rows, self.columns = self.populate_rows_and_columns()

In [14]:
table = CustomizedTable(
    updated_group,
    compare._colorize,
    compare._trim_significant_figures,
    compare._highlight_warnings
)
print(table.render())

[--------------------------------------------------------------------------------------- Resize --------------------------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+git1d3a939) PR  |  torch (2.1.0a0+git0c58e8a) PR  |  Speed-up: PR 1 vs PR 2
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |          38.3          |                56.3             |                58.3             |           1.0          
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |                36.2             |                39.5             |           1.1          
      3 torch.uint8 channels_last bilinear 256 -> 224 aa=True    |         127.6

```
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |                        |               193.8             |               197.8             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |                72.9             |                76.0             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=True    |                        |               422.5             |               426.1             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=False   |                        |               253.3             |               277.3             |           1.1          

      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |                        |               321.5             |               320.2             |           1.0          
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |               198.8             |               347.9             |           1.7          
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=True   |                        |               637.6             |               788.2             |           1.2          
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=False  |                        |               466.5             |               492.2             |           1.1          

```

```
PIL version:  9.0.0.post1
[-------------------------------------------- Resize -------------------------------------------]
                                                                 |  torch (2.1.0a0+git0c58e8a) PR
1 threads: --------------------------------------------------------------------------------------
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |        193.671 (+-6.284)      
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |         73.858 (+-0.251)      
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=True    |        424.211 (+-2.159)      
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=False   |        275.816 (+-2.406)      

      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |        319.124 (+-2.168)      
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |        199.970 (+-1.239)      
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=True   |        640.515 (+-2.885)      
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=False  |        491.073 (+-2.918)
```