Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

summary does not work with the torch.device class #199

Open
Skaifai opened this issue Jan 18, 2024 · 0 comments
Open

summary does not work with the torch.device class #199

Skaifai opened this issue Jan 18, 2024 · 0 comments

Comments

@Skaifai
Copy link

Skaifai commented Jan 18, 2024

Pytorch summary does not work with the torch.device class.
Code to reproduce the error.

import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using ", device)

class CNN(nn.Module):
    def __init__(self, train_CNN=False, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = CNN().to(device)
summary(model, (3, 28, 28), device=device)

Error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_17928\2345870344.py in <module>
     27 
     28 model = CNN().to(device)
---> 29 summary(model, (3, 28, 28), device=device)

~\anaconda3\lib\site-packages\torchsummary\torchsummary.py in summary(model, input_size, batch_size, device)
     42             hooks.append(module.register_forward_hook(hook))
     43 
---> 44     device = device.lower()
     45     assert device in [
     46         "cuda",

AttributeError: 'torch.device' object has no attribute 'lower'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant