Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于使用PSA模块时报错 #83

Open
nlper01 opened this issue Sep 12, 2022 · 9 comments
Open

关于使用PSA模块时报错 #83

nlper01 opened this issue Sep 12, 2022 · 9 comments

Comments

@nlper01
Copy link

nlper01 commented Sep 12, 2022

大佬,我在使用PSA.py这个模块时,老是报错
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same,
我已经把这个模块.cuda()

    input=torch.randn(50,512,7,7).cuda()
    psa = PSA(channel=512,reduction=8).cuda()
    output=psa(input)
    a=output.view(-1).sum()
    a.backward()
    print(output.shape)

还是报错,能解决一下吗?

@XiaoyuWant
Copy link

我也有这个问题,不知道怎么解决

@nlper01
Copy link
Author

nlper01 commented Oct 21, 2022

我也有这个问题,不知道怎么解决

不能直接使用list,改成nn.ModuleList([])

@nlper01
Copy link
Author

nlper01 commented Oct 21, 2022

issue有大佬改了,连带就地操作问题也解决了,找不到了,搬运一下

import numpy as np
import torch
from torch import nn
from torch.nn import init
from attention.SpaceAtt import SpatialAttention

class PSA(nn.Module):

    def __init__(self, channel=552, reduction=4, S=3):
        super().__init__()
        self.S = S

        self.convs = nn.ModuleList([])
        for i in range(S):
            # Add groups
            self.convs.append(
                nn.Conv2d(channel // S, channel // S, kernel_size=2 * (i + 1) + 1, padding=i + 1, groups=2 ** i))

        self.se_blocks = nn.ModuleList([])
        for i in range(S):
            self.se_blocks.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(channel // S, channel // (S * reduction), kernel_size=1, bias=False),
                nn.ReLU(inplace=False),
                nn.Conv2d(channel // (S * reduction), channel // S, kernel_size=1, bias=False),
                nn.Sigmoid()
                ))

        self.softmax = nn.Softmax(dim=1)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h, w = x.size()

        # Step1:SPC module
        PSA_input = x.view(b, self.S, c // self.S, h, w)  # bs,s,ci,h,w
        outs = []
        for idx, conv in enumerate(self.convs):
            SPC_input = PSA_input[:, idx, :, :, :]
            # SPC_out[:,idx,:,:,:]=se(SPC_input)
            outs.append(conv(SPC_input))
        SPC_out = torch.stack([out for out in outs], dim=1)

        # Step2:SE weight
        SE_out = torch.zeros_like(SPC_out)
        outs = []
        for idx, se in enumerate(self.se_blocks):
            SE_input = SPC_out[:, idx, :, :, :]
            # SE_out[:,idx,:,:,:]=se(SE_input)
            outs.append(se(SE_input))

        SE_out = torch.stack([out for out in outs], dim=1)

        # Step3:Softmax
        softmax_out = self.softmax(SE_out)

        # Step4:SPA
        PSA_out = SPC_out * softmax_out
        PSA_out = PSA_out.view(b, -1, h, w)

        return PSA_out

if __name__ == '__main__':
    device = torch.device('cuda')
    input = torch.randn(8, 552, 7, 7).to(device)
    psa = PSA(channel=552, reduction=8).to(device)
    output = psa(input)
    a = output.reshape(-1).sum()
    a.backward()
    print(output.shape)

@zeng-cy
Copy link

zeng-cy commented Dec 1, 2022

大佬,我在使用PSA.py这个模块时,老是报错 RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same, 我已经把这个模块.cuda()

    input=torch.randn(50,512,7,7).cuda()
    psa = PSA(channel=512,reduction=8).cuda()
    output=psa(input)
    a=output.view(-1).sum()
    a.backward()
    print(output.shape)

还是报错,能解决一下吗?

这个问题您解决了吗?

@zeng-cy
Copy link

zeng-cy commented Dec 1, 2022

我也有这个问题,不知道怎么解决

不能直接使用list,改成nn.ModuleList([])

好像还是不行

@zeng-cy
Copy link

zeng-cy commented Dec 2, 2022

,我在使用PSA.py这个模块时,老是报错 RuntimeError

这个还是用不了,from attention.SpaceAtt import SpatialAttention这个引进不来

@zeng-cy
Copy link

zeng-cy commented Dec 2, 2022

我也有这个问题,不知道怎么解决

不能直接使用list,改成nn.ModuleList([])

只改这个nn.ModuleList,确实出现您说的问题:one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 64, 20, 20]], which is output 0 of ReluBackward1, is at version 5; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

@SuperFDX
Copy link

我是在48行添加conv = conv.to(x.device)和54行添加se = se.to(x.device)全部转移到和输入的x同设备上解决的。

@18373233193
Copy link

我也有这个问题,不知道怎么解决

不能直接使用list,改成nn.ModuleList([])

只改这个nn.ModuleList,确实出现您说的问题:one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [4, 64, 20, 20]], which is output 0 of ReluBackward1, is at version 5; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

请问这个问题您解决了吗 我现在也遇到了这个问题

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants