# import library

In [1]:
import torch
import torchvision
import torch.nn as nn
from torchviz import make_dot, make_dot_from_trace
from torch.autograd import Variable
import torchvision.models as models
from model.edsr import *
from model.vdsr import *
from model.srcnn import *
from model.espcn import *
from model.edsr_mobile import *
import numpy as np
from graphviz import Digraph

In [2]:
import easydict
args = easydict.EasyDict({
         "n_GPUs": 1,
        "scale": (2,),
    "n_colors": 3,
    "act": "Relu",
    "n_resblocks": 32,
    "n_feats": 256,
    "shift_mean": True,
    'res_scale':0.1,
        "rgb_range": 255,
    'model':'VDSR'
 })

In [3]:
def print_model_parm_nums(my_models=VDSR()):
    model = my_models#EDSR(args)
    trainable = filter(lambda x: x.requires_grad, model.parameters())
    num_params = sum([np.prod(p.size()) for p in filter(lambda x: x.requires_grad, model.parameters())])
    print('num of parameters: %.2fM ' % (int(num_params) / 1e6) +'\n')
    print('estimated of size %.2fM ' % (int(num_params) / 1e6*4) +'\n')
print_model_parm_nums()

num of parameters: 0.66M 

estimated of size 2.66M 



In [5]:
def visual(_model=models.AlexNet(),_input=torch.randn(1, 3, 224, 224).requires_grad_(True),filename=args.model):
    model =_model#VDSR() 
    
    x = _input
    y = model(x)
    make_dot(y, params=dict(list(model.named_parameters()) + [('x', x)])).render(filename=filename)
visual()

In [63]:
def print_model_parm_flops(my_models=models.resnet50(),_input=torch.rand(1,3,224,224)):
    multiply_adds = False
    list_conv=[]
    def conv_hook(self, input, output):
        Tt=lambda x:torch.tensor(x)
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
        kernel_ops = Tt(self.kernel_size[0]) * Tt(self.kernel_size[1]) * Tt(int(self.in_channels / self.groups))* Tt(2 if multiply_adds else 1)
        bias_ops = Tt(1 if self.bias is not None else 0)
        params = Tt(output_channels) * Tt(kernel_ops + bias_ops)
        flops = Tt(batch_size) * Tt(params) * Tt(output_height) * Tt(output_width)
        list_conv.append(flops)
    list_linear=[]
    def linear_hook(self, input, output):
        
        Tt=lambda x:torch.tensor(x)
        batch_size = Tt(input[0].size(0) if input[0].dim() == 2 else 1)

        weight_ops = Tt(self.weight.nelement()) * Tt(2 if multiply_adds else 1)
        bias_ops = Tt(self.bias.nelement())
        flops = Tt(batch_size) * Tt(weight_ops + bias_ops)
        list_linear.append(flops)
    list_bn=[]
    def bn_hook(self, input, output):
        list_bn.append(input[0].nelement())
        

    list_relu=[]
    def relu_hook(self, input, output):
        list_relu.append(input[0].nelement())

    list_pooling=[]
    def pooling_hook(self, input, output):
        Tt=lambda x:torch.tensor(x)
        batch_size, input_channels, input_height, input_width = input[0].size()
        output_channels, output_height, output_width = output[0].size()
        kernel_ops =Tt(self.kernel_size) * Tt(self.kernel_size)
        bias_ops = 0
        params = Tt(output_channels) * Tt(kernel_ops + bias_ops)
        flops = Tt(batch_size) * Tt(params) * Tt(output_height) * Tt(output_width)
        list_pooling.append(flops)
           
    def register_hook(net):
        childrens = list(net.children())
        if not childrens:
            if isinstance(net, torch.nn.Conv2d):
                net.register_forward_hook(conv_hook)
            if isinstance(net, torch.nn.Linear):
                net.register_forward_hook(linear_hook)
            if isinstance(net, torch.nn.BatchNorm2d):
                net.register_forward_hook(bn_hook)
            if isinstance(net, torch.nn.ReLU):
                net.register_forward_hook(relu_hook)
            if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
                net.register_forward_hook(pooling_hook)
            return
        for c in childrens:
                register_hook(c)

    model =my_models
     #EDSR(args)
    register_hook(model)
    with torch.no_grad():
        input = _input
        out = model(input)

    total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling))
    print('Total Number of FLOPs: %.3fG' % (total_flops.item() / 1e9)+'\n')
    for i in range(len(list_conv)):
        print(' Conv  layer `s '+str(i+1) + 'f FLOPs: %.3fG' % (list_conv[i].item() / 1e9)+'\n')
    for i in range(len(list_linear)):
        print(' linear layer `s '+str(i+1) +' of FLOPs: %.3fG' % (list_linear[i].item() / 1e9)+'\n')
    for i in range(len(list_bn)):
        print(' bn layer `s '+str(i+1) +' of FLOPs: %.3fG' % (list_bn[i].item() / 1e9)+'\n')
    for i in range(len(list_relu)):
        print(' activation layer `s '+str(i+1) +' of FLOPs: %.4fG' % (list_relu[i] / 1e9)+'\n')
    for i in range(len(list_pooling)):
        print(' pooling layer `s '+str(i+1) +' of FLOPs: %.3fG' % (list_pooling[i].item() / 1e9)+'\n')

mo=models.vgg16()
print_model_parm_flops(mo)
print_model_parm_nums(mo)

Total Number of FLOPs: 15.503G

 Conv  layer `s 1f FLOPs: 0.090G

 Conv  layer `s 2f FLOPs: 1.853G

 Conv  layer `s 3f FLOPs: 0.926G

 Conv  layer `s 4f FLOPs: 1.851G

 Conv  layer `s 5f FLOPs: 0.926G

 Conv  layer `s 6f FLOPs: 1.850G

 Conv  layer `s 7f FLOPs: 1.850G

 Conv  layer `s 8f FLOPs: 0.925G

 Conv  layer `s 9f FLOPs: 1.850G

 Conv  layer `s 10f FLOPs: 1.850G

 Conv  layer `s 11f FLOPs: 0.463G

 Conv  layer `s 12f FLOPs: 0.463G

 Conv  layer `s 13f FLOPs: 0.463G

 linear layer `s 1 of FLOPs: 0.103G

 linear layer `s 2 of FLOPs: 0.017G

 linear layer `s 3 of FLOPs: 0.004G

 activation layer `s 1 of FLOPs: 0.0032G

 activation layer `s 2 of FLOPs: 0.0032G

 activation layer `s 3 of FLOPs: 0.0016G

 activation layer `s 4 of FLOPs: 0.0016G

 activation layer `s 5 of FLOPs: 0.0008G

 activation layer `s 6 of FLOPs: 0.0008G

 activation layer `s 7 of FLOPs: 0.0008G

 activation layer `s 8 of FLOPs: 0.0004G

 activation layer `s 9 of FLOPs: 0.0004G

 activation layer `s 10 of FLOPs: