## Benchmarking tor
We provide some benchmarking code in order to benchmark tor's throughput.

**Note**: notebooks have an overhead. To properly benchmark, use a standalone script.

In [1]:
import timm
import tor

In [2]:
# Use any ViT model here (see timm.models.vision_transformer)
model_name = "vit_base_patch16_224"

# Load a pretrained model
model = timm.create_model(model_name, pretrained=True)

In [3]:
# Set this to be whatever device you want to benchmark on
# If you don't have a GPU, you can use "cpu" but you probably want to set the # runs to be lower
device = "cuda:0"
runs = 50
batch_size = 256  # Lower this if you don't have that much memory
input_size = model.default_cfg["input_size"]

In [4]:
# Baseline benchmark
baseline_throughput = tor.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)

Benchmarking: 100%|██████████| 50/50 [00:27<00:00,  1.80it/s]


Throughput: 420.78 im/s


### Applying tor
Simply patch the model after initialization to enable tor.

In [5]:
# Apply tor
tor.patch.timm(model)

In [6]:
model.keep_rate = 0.5
model.drop_loc=[3, 6, 9]
model.token_fusion = True

In [7]:
# tor with r=16
model.r = 16
tor_throughput = tor.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)
print(f"Throughput improvement: {tor_throughput / baseline_throughput:.2f}x")

Benchmarking: 100%|██████████| 50/50 [00:10<00:00,  4.99it/s]


Throughput: 1256.18 im/s
Throughput improvement: 2.99x


In [8]:
# tor with r=16 and a decreasing schedule
model.r = (16, -1.0)
tor_decr_throughput = tor.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)
print(f"Throughput improvement: {tor_decr_throughput / baseline_throughput:.2f}x")

Benchmarking: 100%|██████████| 50/50 [00:07<00:00,  6.37it/s]


Throughput: 1586.95 im/s
Throughput improvement: 3.77x
