In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os


In [5]:
class PSPNet(nn.Module):
    def __init__(self, num_classes=150):
        super(PSPNet, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.layer0 = nn.Sequential(self.backbone.conv1, self.backbone.bn1, self.backbone.relu, self.backbone.maxpool)
        self.layer1 = self.backbone.layer1
        self.layer2 = self.backbone.layer2
        self.layer3 = self.backbone.layer3
        self.layer4 = self.backbone.layer4

        self.ppm = nn.ModuleList([
            nn.AdaptiveAvgPool2d(1),
            nn.AdaptiveAvgPool2d(2),
            nn.AdaptiveAvgPool2d(3),
            nn.AdaptiveAvgPool2d(6)
        ])
        self.conv = nn.Conv2d(2048 + 2048 // 4 * 4, 512, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(512)
        self.final = nn.Conv2d(512, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        input_size = x.size()
        ppm_out = [x]
        for pool in self.ppm:
            pooled = pool(x)
            pooled = F.interpolate(pooled, size=input_size[2:], mode='bilinear', align_corners=False)
            ppm_out.append(pooled)
        x = torch.cat(ppm_out, 1)
        x = self.conv(x)
        x = self.bn(x)
        x = self.final(x)
        return x

In [15]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Simplified PSPNet model
class PSPNet(nn.Module):
    def __init__(self, num_classes=19):
        super(PSPNet, self).__init__()
        self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        self.layer0 = nn.Sequential(self.backbone.conv1, self.backbone.bn1, self.backbone.relu, self.backbone.maxpool)
        self.layer1 = self.backbone.layer1
        self.layer2 = self.backbone.layer2
        self.layer3 = self.backbone.layer3
        self.layer4 = self.backbone.layer4

        # Adjusting ppm layers to ensure proper channel concatenation
        self.ppm = nn.ModuleList([
            nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(2048, 512, 1)),
            nn.Sequential(nn.AdaptiveAvgPool2d(2), nn.Conv2d(2048, 512, 1)),
            nn.Sequential(nn.AdaptiveAvgPool2d(3), nn.Conv2d(2048, 512, 1)),
            nn.Sequential(nn.AdaptiveAvgPool2d(6), nn.Conv2d(2048, 512, 1)),
        ])
        self.conv = nn.Conv2d(2048 + 512 * 4, 512, kernel_size=3, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(512)
        self.final = nn.Conv2d(512, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        input_size = x.size()
        ppm_out = [x]
        for pool in self.ppm:
            pooled = pool(x)
            pooled = F.interpolate(pooled, size=input_size[2:], mode='bilinear', align_corners=False)
            ppm_out.append(pooled)
        x = torch.cat(ppm_out, 1)
        x = self.conv(x)
        x = self.bn(x)
        x = self.final(x)
        return x

# Load the PSPNet model
num_classes = 19  # For Cityscapes dataset
model = PSPNet(num_classes=num_classes)

# Load the checkpoint
checkpoint_file = os.path.join('/Users/sarahbanat/Desktop/segmentation/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth')
checkpoint = torch.load(checkpoint_file, map_location='cpu')

# Extract the state_dict
state_dict = checkpoint['state_dict']

# Fix the keys if needed
new_state_dict = {}
for k, v in state_dict.items():
    new_key = k.replace('backbone.', '')
    new_state_dict[new_key] = v

# Load the state_dict into the model
model.load_state_dict(new_state_dict, strict=False)
model.eval()

# Transformations
preprocess = transforms.Compose([
    transforms.Resize((512, 1024)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and preprocess image
imagepath = os.path.join("/Users/sarahbanat/Desktop/segmentation/data/pic.png")

input_image = Image.open(image_path)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)  # Create a mini-batch as expected by the model

# Check if a GPU is available and use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_batch = input_batch.to(device)

# Forward pass
with torch.no_grad():
    output = model(input_batch)

# Post-process the output
output_predictions = output.argmax(1).cpu().numpy()[0]

# Visualize the result
plt.imshow(output_predictions)
plt.axis('off')
plt.show()

  checkpoint = torch.load(checkpoint_file, map_location='cpu')


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/data/pic.png'