Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

not suport two input image network? #14

Open
wikiwen opened this issue Aug 5, 2019 · 9 comments
Open

not suport two input image network? #14

wikiwen opened this issue Aug 5, 2019 · 9 comments
Labels
question Further information is requested

Comments

@wikiwen
Copy link

wikiwen commented Aug 5, 2019

No description provided.

@sovrasov
Copy link
Owner

sovrasov commented Aug 5, 2019

Try to use the input_constructor argument:

class Siamese(nn.Module):
    def __init__(self):
        super(Siamese, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 3, 1)
        self.conv2 = nn.Conv2d(1, 10, 3, 1)

    def forward(self, x):
        # assume x is a list
        return self.conv1(x[0]) + self.conv2(x[1])

def prepare_input(resolution):
    x1 = torch.FloatTensor(1, *resolution)
    x2 = torch.FloatTensor(1, *resolution)
    return dict(x = [x1, x2])

if __name__ == '__main__':
    model = Siamese()
    flops, params = get_model_complexity_info(model, input_res=(1, 224, 224), 
                                              input_constructor=prepare_input,
                                              as_strings=True, print_per_layer_stat=False)
    print('      - Flops:  ' + flops)
    print('      - Params: ' + params)

@wikiwen
Copy link
Author

wikiwen commented Aug 6, 2019

Try to use the input_constructor argument:

class Siamese(nn.Module):
    def __init__(self):
        super(Siamese, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 3, 1)
        self.conv2 = nn.Conv2d(1, 10, 3, 1)

    def forward(self, x):
        # assume x is a list
        return self.conv1(x[0]) + self.conv2(x[1])

def prepare_input(resolution):
    x1 = torch.FloatTensor(1, *resolution)
    x2 = torch.FloatTensor(1, *resolution)
    return dict(x = [x1, x2])

if __name__ == '__main__':
    model = Siamese()
    flops, params = get_model_complexity_info(model, input_res=(1, 224, 224), 
                                              input_constructor=prepare_input,
                                              as_strings=True, print_per_layer_stat=False)
    print('      - Flops:  ' + flops)
    print('      - Params: ' + params)

Thanks~~it works!!

@chyohoo
Copy link

chyohoo commented Sep 4, 2020

what if two inputs have different sizes?

@sovrasov
Copy link
Owner

sovrasov commented Sep 7, 2020

@chyohoo in that case you can ignore the resolution parameter and use custom shapes:

def prepare_input(resolution):
    x1 = torch.FloatTensor(1, 3, 224, 224)
    x2 = torch.FloatTensor(1, 3, 128, 128)
    return dict(x = [x1, x2])

@JamesLee789
Copy link

JamesLee789 commented Apr 13, 2021

@chyohoo in that case you can ignore the resolution parameter and use custom shapes:

def prepare_input(resolution):
    x1 = torch.FloatTensor(1, 3, 224, 224)
    x2 = torch.FloatTensor(1, 3, 128, 128)
    return dict(x = [x1, x2])

Hi. I tried to implement the calculation following your advice,

def prepare_input(resolution):
x = torch.FloatTensor(1, 3, 224, 224)
depth = torch.FloatTensor(1, 1, 224, 224)
return dict(x=[x, depth])

...

flops, macs, params = get_model_complexity_info(model, input_res=((1, 3, 224, 224),(1, 1, 224, 224)),input_constructor=prepare_input,as_strings=True,print_per_layer_stat=True, verbose=True)

However,I got the error.

Warning: module Softmax is treated as a zero-op.
Warning: module TestNet is treated as a zero-op.
Traceback (most recent call last):
File "get_model_complexity_info.py", line 18, in
flops, macs, params = get_model_complexity_info(model, input_res=((1, 3, 384, 384),(1, 1, 384, 384)),input_constructor=prepare_input,
File "/home/user/software/anaconda3/lib/python3.8/site-packages/ptflops/flops_counter.py", line 34, in get_model_complexity_info
_ = flops_model(**input)
File "/home/user/software/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'depth'

How could I fix it? Many thanks.

@sovrasov
Copy link
Owner

sovrasov commented Apr 13, 2021

Hi! Take a look at the forward method of my test model: it assumes x to be a list, whereas your model may require a different input format as well as different arguments besides x. General suggestion here is to preserve input args and their layout in prepare_input.
My guess is your forward method looks like this: def forward(self, x, depth). If so, you have to add depth argument also:

def prepare_input(resolution):
  x = torch.FloatTensor(1, 3, 224, 224)
  depth = torch.FloatTensor(1, 1, 224, 224)
  return dict(x=x, depth=depth)

@JamesLee789
Copy link

@chyohoo in that case you can ignore the resolution parameter and use custom shapes:

def prepare_input(resolution):
    x1 = torch.FloatTensor(1, 3, 224, 224)
    x2 = torch.FloatTensor(1, 3, 128, 128)
    return dict(x = [x1, x2])

Hi. I tried to implement the calculation following your advice,

def prepare_input(resolution):
x = torch.FloatTensor(1, 3, 224, 224)
depth = torch.FloatTensor(1, 1, 224, 224)
return dict(x=[x, depth])

...

flops, macs, params = get_model_complexity_info(model, input_res=((1, 3, 224, 224),(1, 1, 224, 224)),input_constructor=prepare_input,as_strings=True,print_per_layer_stat=True, verbose=True)

However,I got the error.

Warning: module Softmax is treated as a zero-op.
Warning: module TestNet is treated as a zero-op.
Traceback (most recent call last):
File "get_model_complexity_info.py", line 18, in
flops, macs, params = get_model_complexity_info(model, input_res=((1, 3, 384, 384),(1, 1, 384, 384)),input_constructor=prepare_input,
File "/home/user/software/anaconda3/lib/python3.8/site-packages/ptflops/flops_counter.py", line 34, in get_model_complexity_info
_ = flops_model(**input)
File "/home/user/software/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'depth'

How could I fix it? Many thanks.

Thanks for your prompt reply! It works now.

@hj611
Copy link

hj611 commented Apr 5, 2023

Does batchsize have to be 1? Can you customize the batch size? I tried your prepare_input and got the following error:
input = input_constructor(input_res)
TypeError: 'dict' object is not callable

@sovrasov
Copy link
Owner

sovrasov commented Apr 17, 2023

Hi! If you have an ordinary model that consumes only one input tensor x, the following would work for you:

bs = 2
input_constructor = lambda _: {"x":  torch.FloatTensor(bs, 3, 224, 224)}
macs, params = get_model_complexity_info(net, (3, 224, 224),
                                             as_strings=True,
                                             input_constructor=input_constructor,
                                             print_per_layer_stat=True,
                                             ost=ost)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants