In [2]:
import torch

from liptrf.models.moderate import MNIST_4C3F_ReLU, CIFAR10_4C3F_ReLUx, CIFAR10_6C2F_ReLUx, CIFAR100_8C2F_ReLUx, TinyImageNet_8C2F_ReLUx
from liptrf.models.vit import ViT

from liptrf.models.layers.linear import LinearX
from liptrf.models.layers.conv import Conv2dX

from thop import profile, clever_format
from thop.vision.basic_hooks import calculate_conv2d_flops, calculate_linear

from timm import create_model

In [3]:
def count_linearx(m, x, y):
    # per output element
    total_mul = m.input
    # total_add = m.in_features - 1
    # total_add += 1 if m.bias is not None else 0
    num_elements = y.numel()

    m.total_ops += calculate_linear(total_mul, num_elements)
    
def count_conv2d(m, x, y):
    x = x[0]

    kernel_ops = torch.zeros(m.weight.size()[2:]).numel()  # Kw x Kh
    bias_ops = 1 if m.bias is not None else 0

    m.total_ops += calculate_conv2d_flops(
        input_size = list(x.shape),
        output_size = list(y.shape),
        kernel_size = list(m.weight.shape),
        groups = 1,
        bias = None
    )

In [48]:
model = MNIST_4C3F_ReLU(lmbda=0.1, power_iter=10)
inp = torch.randn(1, 1, 28, 28)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 100 * 50000, params], "%.3f"))
print ("BCP", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("GloRo", clever_format([macs * 500 * 50000, params], "%.3f"))
print ("Local-Lip", clever_format([macs * 290 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 90 * 50000, params], "%.3f"))

1973536
[INFO] Customize rule count_conv2d() <class 'liptrf.models.layers.conv.Conv2dX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
Standard ('60.669T', '1.974M')
BCP ('182.008T', '1.974M')
GloRo ('303.347T', '1.974M')
Local-Lip ('175.941T', '1.974M')
CertViT ('54.602T', '1.974M')


In [49]:
model = ViT(image_size=28, patch_size=7, num_classes=10, channels=1,
                dim=128, depth=6, heads=8, mlp_ratio=4, 
                attention_type='L2')
inp = torch.randn(1, 1, 28, 28)
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 100 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 90 * 50000, params], "%.3f"))

[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
Standard ('92.947T', '1.092M')
CertViT ('83.652T', '1.092M')


In [51]:
model = CIFAR10_4C3F_ReLUx(lmbda=0.1, power_iter=10)
inp = torch.randn(1, 3, 32, 32)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 100 * 50000, params], "%.3f"))
print ("BCP", clever_format([macs * 200 * 50000, params], "%.3f"))
print ("GloRo", clever_format([macs * 600 * 50000, params], "%.3f"))
print ("Local-Lip", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 90 * 50000, params], "%.3f"))

2528096
[INFO] Customize rule count_conv2d() <class 'liptrf.models.layers.conv.Conv2dX'>.
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
Standard ('81.782T', '2.466M')
BCP ('163.564T', '2.466M')
GloRo ('490.691T', '2.466M')
Local-Lip ('204.454T', '2.466M')
CertViT ('73.604T', '2.466M')


In [53]:
model = CIFAR10_6C2F_ReLUx(lmbda=0.1, power_iter=10)
inp = torch.randn(1, 3, 32, 32)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 100 * 50000, params], "%.3f"))
print ("BCP", clever_format([macs * 200 * 50000, params], "%.3f"))
print ("GloRo", clever_format([macs * 800 * 50000, params], "%.3f"))
print ("Local-Lip", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 110 * 50000, params], "%.3f"))

2360672
[INFO] Customize rule count_conv2d() <class 'liptrf.models.layers.conv.Conv2dX'>.
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
Standard ('174.843T', '2.250M')
BCP ('349.686T', '2.250M')
GloRo ('1398.743T', '2.250M')
Local-Lip ('437.107T', '2.250M')
CertViT ('192.327T', '2.250M')


In [57]:
model = CIFAR100_8C2F_ReLUx(lmbda=0.1, power_iter=10)
inp = torch.randn(1, 3, 32, 32)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 200 * 50000, params], "%.3f"))
print ("GloRo", clever_format([macs * 800 * 50000, params], "%.3f"))
print ("Local-Lip", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

2436864
[INFO] Customize rule count_conv2d() <class 'liptrf.models.layers.conv.Conv2dX'>.
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
Standard ('1285.663T', '2.219M')
GloRo ('5142.651T', '2.219M')
Local-Lip ('1607.078T', '2.219M')
CertViT ('1414.229T', '2.219M')


In [66]:
model = TinyImageNet_8C2F_ReLUx(lmbda=0.1, power_iter=10)
inp = torch.randn(1, 3, 64, 64)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 200 * 50000, params], "%.3f"))
print ("GloRo", clever_format([macs * 800 * 50000, params], "%.3f"))
print ("Local-Lip", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

5257984
[INFO] Customize rule count_conv2d() <class 'liptrf.models.layers.conv.Conv2dX'>.
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
Standard ('5829.530T', '4.341M')
GloRo ('23318.118T', '4.341M')
Local-Lip ('7286.912T', '4.341M')
CertViT ('6412.483T', '4.341M')


In [67]:
model = ViT(image_size=32, patch_size=4, num_classes=10, channels=3,
                dim=192, depth=10, heads=3, mlp_ratio=4, 
                attention_type='L2', 
                dropout=0.1)
inp = torch.randn(1, 3, 32, 32)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 270 * 50000, params], "%.3f"))

4086912
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
Standard ('3314.606T', '4.074M')
CertViT ('3579.775T', '4.074M')


In [68]:
model = ViT(image_size=32, patch_size=4, num_classes=100, channels=3,
                dim=192, depth=12, heads=3, mlp_ratio=4, 
                attention_type='L2',  dropout=0.1)
inp = torch.randn(1, 3, 32, 32)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 270 * 50000, params], "%.3f"))

4916736
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
Standard ('3976.262T', '4.904M')
CertViT ('4294.363T', '4.904M')


In [69]:
model = ViT(image_size=64, patch_size=4, num_classes=200, channels=3,
                dim=384, depth=12, heads=12, mlp_ratio=1, 
                attention_type='L2', dropout=0.1)
inp = torch.randn(1, 3, 64, 64)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 250 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 270 * 50000, params], "%.3f"))

9060864
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
Standard ('28600.531T', '8.962M')
CertViT ('30888.574T', '8.962M')


In [4]:
model = create_model('vit_tiny_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

5717416
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('16179.506T', '5.679M')
CertViT ('11864.971T', '5.679M')


In [6]:
model = create_model('vit_small_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

22050664
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('63731.750T', '21.975M')
CertViT ('46736.617T', '21.975M')


In [7]:
model = create_model('vit_small_patch32_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

22878952
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('16826.849T', '22.859M')
CertViT ('12339.689T', '22.859M')


In [8]:
model = create_model('vit_base_patch8_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

86576872
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('1002755.497T', '85.973M')
CertViT ('735354.031T', '85.973M')


In [9]:
model = create_model('vit_base_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

86567656
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('252954.455T', '86.416M')
CertViT ('185499.934T', '86.416M')


In [10]:
model = create_model('vit_base_patch32_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

88224232
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('65504.195T', '88.185M')
CertViT ('48036.409T', '88.185M')


In [11]:
model = create_model('vit_large_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

304326632
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('895300.669T', '304.124M')
CertViT ('656553.824T', '304.124M')


In [12]:
model = create_model('deit_tiny_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth" to /home/versag/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth


5717416
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('16179.506T', '5.679M')
CertViT ('11864.971T', '5.679M')


In [13]:
model = create_model('deit_small_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" to /home/versag/.cache/torch/hub/checkpoints/deit_small_patch16_224-cd65a155.pth


22050664
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('63731.750T', '21.975M')
CertViT ('46736.617T', '21.975M')


In [14]:
model = create_model('deit_base_patch16_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/versag/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


86567656
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
Standard ('252954.455T', '86.416M')
CertViT ('185499.934T', '86.416M')


In [16]:
model = create_model('swin_tiny_patch4_window7_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth" to /home/versag/.cache/torch/hub/checkpoints/swin_tiny_patch4_window7_224.pth


28288354
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
Standard ('65578.353T', '28.265M')
CertViT ('48090.792T', '28.265M')


In [17]:
model = create_model('swin_small_patch4_window7_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth" to /home/versag/.cache/torch/hub/checkpoints/swin_small_patch4_window7_224.pth


49606258
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
Standard ('128175.516T', '49.559M')
CertViT ('93995.378T', '49.559M')


In [19]:
model = create_model('swin_base_patch4_window7_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

87768224
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
Standard ('227547.385T', '87.705M')
CertViT ('166868.082T', '87.705M')


In [18]:
model = create_model('swin_large_patch4_window7_224', pretrained=True)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

Downloading: "https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth" to /home/versag/.cache/torch/hub/checkpoints/swin_large_patch4_window7_224_22kto1k.pth


196532476
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool1d'>.
Standard ('511261.168T', '196.437M')
CertViT ('374924.857T', '196.437M')


In [24]:
model = ViT(patch_size=16, dim=192, depth=12, heads=3, attention_type='L2', image_size=224, num_classes=10)
model.eval()
inp = torch.randn(1, 3, 224, 224)
print (sum(p.numel() for p in model.parameters())) 
macs, params = profile(model, inputs=(inp, ), custom_ops={LinearX: count_linearx, Conv2dX: count_conv2d})
print ("Standard", clever_format([macs * 300 * 50000, params], "%.3f"))
print ("CertViT", clever_format([macs * 220 * 50000, params], "%.3f"))

5063040
[INFO] Customize rule count_linearx() <class 'liptrf.models.layers.linear.LinearX'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
Standard ('14867.199T', '5.025M')
CertViT ('10902.613T', '5.025M')
