# Imports

In [2]:
%load_ext autoreload
%autoreload 2

In [42]:
import os
import warnings
from copy import deepcopy
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch

from assembly import Assembly
from utils import load_yaml, num_params
from utils.models import load_model

In [32]:
sns.set(font_scale=1.25, style="whitegrid")

# Ignore known warnings that come when constructing subnets.
warnings.filterwarnings("ignore", message=".*The parameter 'pretrained' is deprecated.*")
warnings.filterwarnings("ignore", message=".*Arguments other than a weight enum.*")
warnings.filterwarnings("ignore", message=".*already erased node.*")

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(f"Using device = {device}")

NUM_CORES = os.cpu_count()
if hasattr(os, "sched_getaffinity"):
    # This function is only available on certain platforms. When running with Slurm, it can tell us the true
    # number of cores we have access to.
    NUM_CORES = len(os.sched_getaffinity(0))
print (f"Using {NUM_CORES} cores.")

Using 8 cores.


# Counting Compute

Figuring out how to count params and FLOPs, or whatever might be a good proxy for amount of compute.

In [36]:
from fvcore.nn import FlopCountAnalysis
from torchtnt.utils.flops import FlopTensorDispatchMode

def print_assembly_sizes(cfile):
    cfg = load_yaml(cfile)
    assembly = cfg.get("assembly", cfg.get("stages", cfg.get("src_stages")))
    model = Assembly(assembly, cfg.get("head"), input_shape=[3, 224, 224])
    print_sizes(model)


def print_sizes(model):
    model.train()
    print(f"Number of parameters: {num_params(model):.3e}")
    if hasattr(model, "parts"):
        for i, p in enumerate(model.parts):
            print(f"Num params in part {i+1} ({p.__class__.__name__}): {num_params(p):.3e}")

    dummy_img = torch.randn(1, 3, 224, 224)
    flops = FlopCountAnalysis(model, dummy_img)
    print(f"\nfvcore FLOPs: {flops.total():.3e}")
    # print(f"fvcore FLOPs by module: {flops.by_module()}")
    # print(f"fvcore FLOPs by operator: {flops.by_operator()}")
    with FlopTensorDispatchMode(model) as ftdm:
        _ = model(dummy_img)
        print(f"\nTorchTNT FLOPs: {sum(ftdm.flop_counts[''].values()):.3e}")
        # print(f"\nTorchTNT FLOPs breakdown:")
        # for k, v in ftdm.flop_counts.items():
        #     print(f"    {k}: {v:.3e}")

In [38]:
print_assembly_sizes(Path("../across-scales/mobilenet-v3.yml"))



Number of parameters: 5.483e+06
Num params in part 1 (Subnet): 4.368e+03
Num params in part 2 (Subnet): 4.440e+03
Num params in part 3 (Subnet): 1.033e+04
Num params in part 4 (Subnet): 4.198e+04
Num params in part 5 (Subnet): 3.208e+04
Num params in part 6 (Subnet): 6.993e+05
Num params in part 7 (Subnet): 4.292e+05
Num params in part 8 (Subnet): 4.261e+06

fvcore FLOPs: 2.386e+08

TorchTNT FLOPs: 2.166e+08


In [39]:
print_assembly_sizes(Path("../across-scales/resnet-50.yml"))



Number of parameters: 2.556e+07
Num params in part 1 (Subnet): 8.454e+04
Num params in part 2 (Subnet): 1.408e+05
Num params in part 3 (Subnet): 3.794e+05
Num params in part 4 (Subnet): 8.402e+05
Num params in part 5 (Subnet): 1.512e+06
Num params in part 6 (Subnet): 5.586e+06
Num params in part 7 (Subnet): 6.040e+06
Num params in part 8 (Subnet): 1.097e+07

fvcore FLOPs: 4.145e+09

TorchTNT FLOPs: 4.089e+09


In [40]:
print_assembly_sizes(Path("../across-scales/swin-t.yml"))

Number of parameters: 2.829e+07
Num params in part 1 (Subnet): 4.896e+03
Num params in part 2 (Subnet): 2.247e+05
Num params in part 3 (Subnet): 7.450e+04
Num params in part 4 (Subnet): 8.918e+05
Num params in part 5 (Subnet): 2.964e+05
Num params in part 6 (Subnet): 1.066e+07
Num params in part 7 (Subnet): 1.183e+06
Num params in part 8 (Subnet): 1.495e+07


parts.1.net.features.1.0.stochastic_depth



fvcore FLOPs: 4.509e+09

TorchTNT FLOPs: 4.491e+09


In [37]:
print_sizes(load_model("resnet18", "pytorch", pretrained=False))



Number of parameters: 1.169e+07

fvcore FLOPs: 1.827e+09

TorchTNT FLOPs: 1.814e+09


Below: demonstrating how the number of params and FLOPs in a ResNet BottleneckBlock are actually much smaller than a single conv3x3 layer.

In [58]:
# Assembly equivalent to a ResNet-50.
assembly_config = [
    {"Subnet": {
        "backend": "timm",
        "model_name": "resnet50.a1_in1k",
        "block_input": "x",
        "block_output": "layer2.3",
        "in_format": "img",  # layer1.0 input is [64, 56, 56].
        "out_format": ["img", [512, 28, 28]],
    }},
    {"Subnet": {  # Downsample block
        "backend": "timm",
        "model_name": "resnet50.a1_in1k",
        "block_input": "layer3.0",
        "block_output": "layer3.0",
        "in_format": ["img", [512, 28, 28]],
        "out_format": ["img", [1024, 14, 14]],
    }},
    {"Subnet": {
        "backend": "timm",
        "model_name": "resnet50.a1_in1k",
        "block_input": "layer3.1",
        "block_output": "fc",
        "in_format": ["img", [1024, 14, 14]],
        "out_format": "vector",
    }},
]

def newcfg():
    return deepcopy(assembly_config)


model = Assembly(newcfg(), input_shape=[3, 224, 224])
print_sizes(model)



Number of parameters: 2.556e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (Subnet): 1.512e+06
Num params in part 3 (Subnet): 2.260e+07

fvcore FLOPs: 4.145e+09

TorchTNT FLOPs: 4.089e+09


In [59]:
from launch_scaling_experiments import linear, stitch, conv3x3, bottleneck, stitch_no_downsample

gap = {
    "blocks_to_drop": [1, 1],
    "num_downsamples": 1,
}

In [60]:
conv3x3_assembly = stitch(newcfg(), newcfg(), gap, conv3x3)
model = Assembly(conv3x3_assembly, input_shape=[3, 224, 224])
print_sizes(model)



Number of parameters: 2.877e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (SimpleAdapter): 4.723e+06
Num params in part 3 (Subnet): 2.260e+07

fvcore FLOPs: 4.696e+09

TorchTNT FLOPs: 4.642e+09


In [61]:
bottleneck_assembly = stitch(newcfg(), newcfg(), gap, bottleneck)
model = Assembly(bottleneck_assembly, input_shape=[3, 224, 224])
print_sizes(model)



Number of parameters: 2.556e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (ResNetBottleneck): 1.512e+06
Num params in part 3 (Subnet): 2.260e+07

fvcore FLOPs: 4.067e+09

TorchTNT FLOPs: 4.012e+09


In [62]:
bottleneck_no_downsample_assembly = stitch_no_downsample(newcfg(), newcfg(), gap, bottleneck)
model = Assembly(bottleneck_no_downsample_assembly, input_shape=[3, 224, 224])
print_sizes(model)

Number of parameters: 2.556e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (ResNetBottleneck): 1.512e+06
Num params in part 3 (Subnet): 2.260e+07





fvcore FLOPs: 1.070e+10

TorchTNT FLOPs: 1.060e+10


In [63]:
linear_assembly = stitch(newcfg(), newcfg(), gap, linear)
model = Assembly(linear_assembly, input_shape=[3, 224, 224])
print_sizes(model)



Number of parameters: 2.457e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (SimpleAdapter): 5.284e+05
Num params in part 3 (Subnet): 2.260e+07

fvcore FLOPs: 3.906e+09

TorchTNT FLOPs: 3.851e+09


In [64]:
linear_no_downsample_assembly = stitch_no_downsample(newcfg(), newcfg(), gap, linear)
model = Assembly(linear_no_downsample_assembly, input_shape=[3, 224, 224])
print_sizes(model)

Number of parameters: 2.457e+07
Num params in part 1 (Subnet): 1.445e+06
Num params in part 2 (SimpleAdapter): 5.284e+05
Num params in part 3 (Subnet): 2.260e+07





fvcore FLOPs: 1.111e+10

TorchTNT FLOPs: 1.102e+10
