-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
Hi there!
I am implementing a U-Net using pytorch and bumped into the following error when calling my model:
RuntimeError: Input type (MPSByteType) and weight type (MPSFloatType) should be the same
Here is my code:
import os
import torch
import torch.nn as nn
import torch.utils.data as tdata
import torchvision
from torch.nn import functional as F
device = "mps"
torch.manual_seed(127)
#%%
# Dataset creation
class RoadsDataset(tdata.Dataset):
root: str
num_images: int
images: list[torch.Tensor]
gt_images: list[torch.Tensor]
# transform: torchvision.transforms.Tran
def __init__(self, root: str, num_images=20, transform=None, target_transform=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.num_images = num_images
assert 10 <= num_images <= 100
self.images = []
self.gt_images = []
for i in range(num_images):
image_path = os.path.join(self.root, "images/satImage_%.3d.png" % (i + 1))
self.images.append(torchvision.io.read_image(image_path).to(device))
gt_image_path = os.path.join(
self.root, "groundtruth/satImage_%.3d.png" % (i + 1)
)
self.gt_images.append(torchvision.io.read_image(gt_image_path).to(device))
print("Loaded {} images from {}".format(num_images, root))
def __len__(self):
return self.num_images
def __getitem__(self, item: int) -> tuple[torch.Tensor, torch.Tensor]:
if self.transform:
image = self.transform(self.images[item])
else:
image = self.images[item]
if self.target_transform:
gt_image = self.target_transform(self.gt_images[item])
else:
gt_image = self.gt_images[item]
return image, gt_image
#%%
training_data = RoadsDataset(root="data/training", num_images=10)
training_dataloader = tdata.DataLoader(training_data, batch_size=5, shuffle=True)
#%%
# U-Net implementation
class Block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class Encoder(nn.Module):
def __init__(self, chs=(3, 64, 128, 256, 512, 1024)):
super().__init__()
self.enc_blocks = nn.ModuleList(
[Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
)
self.pool = nn.MaxPool2d(2)
def forward(self, x: torch.Tensor):
ftrs = []
for block in self.enc_blocks:
x = block(x)
ftrs.append(x)
x = self.pool(x)
return ftrs
class Decoder(nn.Module):
def __init__(self, chs=(1024, 512, 256, 128, 64)):
super().__init__()
self.chs = chs
self.upconvs = nn.ModuleList(
[nn.ConvTranspose2d(chs[i], chs[i + 1], 2, 2) for i in range(len(chs) - 1)]
)
self.dec_blocks = nn.ModuleList(
[Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
)
def forward(self, x, encoder_features):
for i in range(len(self.chs) - 1):
x = self.upconvs[i](x)
enc_ftrs = self.crop(encoder_features[i], x)
x = torch.cat([x, enc_ftrs], dim=1)
x = self.dec_blocks[i](x)
return x
def crop(self, enc_ftrs, x):
_, _, H, W = x.shape
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
return enc_ftrs
class UNet(nn.Module):
def __init__(
self,
enc_chs=(3, 64, 128, 256, 512, 1024),
dec_chs=(1024, 512, 256, 128, 64),
num_class=1,
retain_dim=False,
out_sz=(400, 400),
):
super().__init__()
self.encoder = Encoder(enc_chs)
self.decoder = Decoder(dec_chs)
self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
self.retain_dim = retain_dim
self.out_sz = out_sz
def forward(self, x):
enc_ftrs = self.encoder(x)
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
out = self.head(out)
if self.retain_dim:
out = F.interpolate(out, self.out_sz)
return out
#%%
# train Unet on the training set
unet_model = UNet().to(device)
loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3)
def train(dataloader: tdata.DataLoader, model: nn.Module, loss_fun, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fun(pred, y)
loss.backwards()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader: tdata.DataLoader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
epochs = 5
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(training_dataloader, unet_model, loss_fun, optimizer)
test(training_dataloader, unet_model, loss_fun)
print("Done!")
I am working on a M2 MacBook Air using MPS as device. I made sure that everything is indeed stored on this device, and I do not understand what the types MPSByteType
and MPSFloatType
are in this context.
Can somebody please help me understand what is going on ?
Thanks
Tudor
Versions
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.24.2
Libc version: N/A
Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.12.1
[pip3] torchvision==0.13.0a0
[conda] numpy 1.23.4 py310h5d7c261_1 conda-forge
[conda] pytorch 1.12.1 cpu_py310h7410233_1 conda-forge
[conda] torchvision 0.13.0 cpu_py310he68663e_0 conda-forge
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev