# ***ToMe***

ViT converts image patches into “tokens,” then applies an attention mechanism in each layer that allows these tokens to collect information from one another, proportional to their similarity. To improve the speed of ViT while maintaining its accuracy, ToMe takes redundant tokens and merges them based on similarity, reducing the number of tokens without losing information.

During inference, ToMe lowers the number of tokens gradually over the course of the network, significantly reducing the overall time taken.


# ***Installing dependencies***

In [None]:
%pip install transformers datasets --quiet

In [None]:
%pip install git+https://github.com/facebookresearch/ToMe.git --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone


# ***Validation***

## Pre trained model without token merging

The pretrained ViT is loaded and an image of a dog is passed through the network and get predictions for the resulting classes.

In [None]:
import timm
import tome
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image


# Load a pretrained model
model = timm.create_model("vit_base_patch16_224", pretrained=True)

input_size = model.default_cfg["input_size"][1]

transform = transforms.Compose([
    transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize(model.default_cfg["mean"], model.default_cfg["std"]),
])

img = Image.open("images/husky.png")
img_tensor = transform(img)[None, ...]

model(img_tensor).topk(5).indices[0].tolist()

[248, 250, 249, 537, 174]

## Pre trained model with token merging


Here the ToMe mechanism is applied but with no reduction. As we can see the predicted classes remain the same.

In [None]:
# Patch the model with ToMe
tome.patch.timm(model)

# Run the model with no reduction (should be the same as before)
model.r = 0
model(img_tensor).topk(5).indices[0].tolist()


[248, 250, 249, 537, 174]

As we can see, reducing the number of token per layer, leaves unchanged the first three predicted classes. This for both r = 8 and r = 16.

In [None]:
# Run the model with some reduction
model.r = 8
model(img_tensor).topk(5).indices[0].tolist()

[248, 250, 249, 269, 537]

In [None]:
# Run the model with a lot of reduction
# Top-3 most applicable classes didn't change (husky, Siberian husky, Alaskan malamute)
# But model is 2x faster now! See benchmarking section.
model.r = 16
model(img_tensor).topk(5).indices[0].tolist()

[248, 250, 249, 269, 537]

# ***Benchmarking***

The followign is the benchmark of the vision transformers without token merging


In [None]:
import timm
import tome


model = timm.create_model("vit_base_patch16_224", pretrained=True)

device = "cuda:0"
runs = 50
batch_size = 256
input_size = model.default_cfg["input_size"]

baseline_throughput = tome.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)

Benchmarking: 100%|██████████| 50/50 [02:18<00:00,  2.77s/it]


Throughput: 84.07 im/s


The following is the benchmark of the vision transformer with merging layer. We merge tokens with a constant schedule, i.e. r= 8 and r = 16 per layer. This increases almost x2 the throughput of the network.

In [None]:
tome.patch.timm(model)

model.r = 8
tome_throughput = tome.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)
print(f"Throughput improvement: {tome_throughput / baseline_throughput:.2f}x")

Benchmarking: 100%|██████████| 50/50 [01:49<00:00,  2.19s/it]


Throughput: 113.29 im/s
Throughput improvement: 1.34x


In [None]:
tome.patch.timm(model)

model.r = 16
tome_throughput = tome.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)
print(f"Throughput improvement: {tome_throughput / baseline_throughput:.2f}x")

Benchmarking: 100%|██████████| 50/50 [01:14<00:00,  1.50s/it]

Throughput: 165.62 im/s
Throughput improvement: 1.96x





Merging layer with a decreasing schedule,  “decreasing” schedule that removes 2r tokens
in the first layer and 0 tokens in the last layer, linearly interpolating for the rest

In [None]:
# ToMe with r=16 and a decreasing schedule
model.r = (16, -1.0)
tome_decr_throughput = tome.utils.benchmark(
    model,
    device=device,
    verbose=True,
    runs=runs,
    batch_size=batch_size,
    input_size=input_size
)
print(f"Throughput improvement: {tome_decr_throughput / baseline_throughput:.2f}x")

Benchmarking: 100%|██████████| 50/50 [02:18<00:00,  2.77s/it]


Throughput: 83.89 im/s
Throughput improvement: 1.00x
