In [41]:
# perf_files = [
#     "output/20230313-133243-pr.pkl",
#     "output/20230313-134520-nightly.pkl",
# ]

perf_files = [
    "output/20230315-161003-pr.pkl",
    "output/20230315-134541-nightly.pkl",
]

In [42]:
import pickle
from pathlib import Path
import torch
import torch.utils.benchmark as benchmark

ab_results = []
for perf_filepath in perf_files:
    assert Path(perf_filepath).exists(), f"{perf_filepath} is not found"
    with open(perf_filepath, "rb") as handler:
        output = pickle.load(handler)
        ab_results.extend(output["test_results"])

    
compare = benchmark.Compare(ab_results)

In [43]:
# compare.colorize()
compare.print()

[---------------------------------------------------------------------------- Resize ---------------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+git0968a5d) PR  |  torch (2.1.0a0+git5309c44) nightly
1 threads: ------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |          39.0          |                56.6             |                 133.2              
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |                36.9             |                 112.8              
      3 torch.uint8 channels_last bilinear 256 -> 224 aa=True    |         128.1          |               152.5             |                 305.4              
      3 torch.uint8 channels

In [44]:
from torch.utils.benchmark.utils import common

In [45]:
results = common.Measurement.merge(compare._results)

In [46]:
grouped_results = compare._group_by_label(results)

In [47]:
# compare._render??
groups_iter = iter(grouped_results.values())
group = next(groups_iter)

In [50]:
c1 = "torch (2.1.0a0+git0968a5d) PR"
c2 = "torch (2.1.0a0+git5309c44) nightly"
# description = f"Speed-up: {c1} vs {c2}"
description = f"Speed-up: PR vs nightly"

In [51]:
class Value(common.Measurement): pass
#     @property
#     def times(self):
#         assert len(self.raw_times) == 1, self.raw_times
#         return self.raw_times[0]
    
#     def _lazy_init(self):
#         pass

In [52]:
# for measurement in group:
#     print(measurement.task_spec.description, measurement.task_spec.sub_label)

In [53]:
updated_group = []

sub_label = None
v1 = None
v2 = None
r = None

_, scale = common.select_unit(min([r.median for r in group]))

for measurement in group:
    # print(measurement.task_spec.description)
        
    if measurement.task_spec.description == c1:
        v1 = measurement.median
        sub_label = measurement.task_spec.sub_label
        # print(c1, v1, sub_label)

    measurement2 = None
    for m2 in group:
        d2 = m2.task_spec.description
        sl2 = m2.task_spec.sub_label
        if d2 == c2 and sl2 == sub_label:
            v2 = m2.median
            # print(c2, v2)
            measurement2 = m2            
            break
    
    if measurement not in updated_group:
        updated_group.append(measurement)
    if v1 is not None and v2 is not None:
        if measurement2 not in updated_group:
            updated_group.append(measurement2)
        r = v2 / v1 * scale
        # print("->", r)
        v1 = None
        v2 = None
        sub_label = None
        speedup_task = common.TaskSpec(
            "", 
            setup="", 
            label=measurement.label,
            sub_label=measurement.sub_label,
            num_threads=measurement.num_threads,
            env=measurement.env,
            description=description
        )
        speedup_measurement = Value(1, [r, ], speedup_task)
        r = None
        updated_group.append(speedup_measurement)

In [54]:
from torch.utils.benchmark.utils.compare import Table


class CustomizedTable(Table):
    
    def __init__(self, results, colorize, trim_significant_figures, highlight_warnings):
        assert len(set(r.label for r in results)) == 1

        self.results = results
        self._colorize = colorize
        self._trim_significant_figures = trim_significant_figures
        self._highlight_warnings = highlight_warnings
        self.label = results[0].label
        self.time_unit, self.time_scale = common.select_unit(
            min(r.median for r in results if not isinstance(r, Value))
        )

        self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
        self.row_keys.sort(key=lambda args: args[:2])  # preserve stmt order
        self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
        self.rows, self.columns = self.populate_rows_and_columns()

In [55]:
table = CustomizedTable(
    updated_group,
    compare._colorize,
    compare._trim_significant_figures,
    compare._highlight_warnings
)
print(table.render())

[------------------------------------------------------------------------------------------ Resize -----------------------------------------------------------------------------------------]
                                                                 |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+git0968a5d) PR  |  torch (2.1.0a0+git5309c44) nightly  |  Speed-up: PR vs nightly
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=True     |          39.0          |                56.6             |                 133.2                |            2.4          
      3 torch.uint8 channels_last bilinear 256 -> 32 aa=False    |                        |                36.9             |                 112.8                |            3.1          
      3 torch.uint8 channels_last bilinear 256 -> 

```
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |                        |               193.8             |               197.8             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |                        |                72.9             |                76.0             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=True    |                        |               422.5             |               426.1             |           1.0          
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=False   |                        |               253.3             |               277.3             |           1.1          

      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |                        |               321.5             |               320.2             |           1.0          
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |                        |               198.8             |               347.9             |           1.7          
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=True   |                        |               637.6             |               788.2             |           1.2          
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=False  |                        |               466.5             |               492.2             |           1.1          

```

```
PIL version:  9.0.0.post1
[-------------------------------------------- Resize -------------------------------------------]
                                                                 |  torch (2.1.0a0+git0c58e8a) PR
1 threads: --------------------------------------------------------------------------------------
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=True     |        193.671 (+-6.284)      
      4 torch.uint8 channels_last bilinear 712 -> 32 aa=False    |         73.858 (+-0.251)      
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=True    |        424.211 (+-2.159)      
      4 torch.uint8 channels_last bilinear 712 -> 224 aa=False   |        275.816 (+-2.406)      

      4 torch.uint8 channels_first bilinear 712 -> 32 aa=True    |        319.124 (+-2.168)      
      4 torch.uint8 channels_first bilinear 712 -> 32 aa=False   |        199.970 (+-1.239)      
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=True   |        640.515 (+-2.885)      
      4 torch.uint8 channels_first bilinear 712 -> 224 aa=False  |        491.073 (+-2.918)
```

In [56]:
import torch
from torchvision.models.video import swin3d_t

video = torch.rand(1, 3, 32, 800, 600)
# or swin3d_b, swin3d_s
model = swin3d_t(weights="DEFAULT")
model.eval()
with torch.inference_mode():
    prediction = model(video)
print(prediction)

tensor([[ 0.5948,  0.3137,  0.7556, -0.8223, -0.6089,  0.5881, -0.9447, -2.0455,
         -1.5483, -0.7381, -0.2854, -0.1242,  0.7795, -0.5873, -1.3449, -2.2452,
          1.7047, -1.8982,  0.2163, -0.6561,  1.0124, -0.4089, -1.1075, -0.1892,
         -1.0364, -1.1145,  0.2619, -0.4389, -0.4747, -0.3441, -0.9742, -1.1090,
         -0.3340, -2.2728,  0.3145,  0.6618,  0.1702,  0.2511, -0.0185, -1.3574,
          3.1635, -0.5269, -0.1911, -0.0828,  0.3138,  1.2309, -0.6196, -0.8707,
          0.9189,  1.9639, -0.6416,  0.0385, -1.3630, -1.4469, -1.1519, -0.9806,
         -1.3827,  0.0831, -1.0821, -0.7155,  1.0057, -0.9480,  0.6128, -0.1704,
         -0.3204,  0.9887, -1.1291,  0.3477, -0.8860,  0.8537, -1.3530, -0.8751,
         -2.0728, -1.5736,  0.0971, -1.6477,  0.5872,  0.6412, -1.6775,  0.7837,
          0.5288,  0.4561, -0.8195, -1.4059,  0.5379, -0.9413, -1.2652, -0.7151,
         -1.2707, -0.9139, -0.6774, -1.0966, -0.0820,  0.6377, -0.8144, -1.4586,
         -0.6194,  1.0986,  