In [1]:
import argparse

from FastMETRO.model import FastMETRO_Body_Network

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', required=False, type=bool, default=False, help='Select whether to plot preview.')
parser.add_argument('--model_dim_1', required=True, type=str, default='resnet50', help='Specify path to weights.')
parser.add_argument('--model_dim_2', required=False, type=str, default=False, help='Specify path to weights.')
args = parser.parse_args()

In [None]:
model = FastMETRO_Body_Network()

In [3]:
from model import load_model

model = load_model('mobilenet', finetune=True)
param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
    buffer_size += buffer.nelement() * buffer.element_size()

size_all_mb = (param_size + buffer_size) / 1024**2

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model size: {size_all_mb:.3f}MB, \n{total_params:,} total parameters, \n{total_trainable_params:,} trainable parameters.")

DeepLabV3, mobilenet backbone
Model size: 42.146MB, 
11,020,337 total parameters, 
8,048,385 trainable parameters.


In [1]:
import torchvision.models.segmentation as models

deeplab = models.deeplabv3_resnet101(weight='DEFAULT')

In [2]:
import torch.nn as nn

class DeepLabV3Modified(nn.Module):
    def __init__(self, backbone):
        super(DeepLabV3Modified, self).__init__()
        self.backbone = backbone
        self.backbone.classifier[4] = nn.Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1))  # Remove the final classifier layer
        
    def forward(self, x):
        result = self.backbone(x)
        return {
            "res2": result['low_level'],  # Low-level features for fine-grained sampling
            "coarse": result['out']       # Coarse segmentation map
        }

deeplab_modified = DeepLabV3Modified(deeplab)

In [3]:
from model.pointrend import PointHead, PointRend

point_head = PointHead(in_c=532, num_classes=1)  # Adjust input channels if needed

pointrend_model = PointRend(deeplab_modified, point_head)

In [4]:
print(pointrend_model)

PointRend(
  (backbone): DeepLabV3Modified(
    (backbone): DeepLabV3(
      (backbone): IntermediateLayerGetter(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, 

In [5]:
print(deeplab_modified)

DeepLabV3Modified(
  (backbone): DeepLabV3(
    (backbone): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


In [7]:
import torch
from PIL import Image
import torchvision.transforms as T

def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = T.Compose([
        T.Resize((512, 512)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image).unsqueeze(0)  

def test_model(model, input_tensor):
    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
    return output

# Path to the image
image_path = 'data/segmentation_dataset/subfolder_1/images/csr0129a_front_view.png'

# Load and preprocess the image
input_tensor = load_image(image_path)

In [8]:
import matplotlib.pyplot as plt

# Load and preprocess the image
input_tensor = load_image(image_path)

# Perform inference
output = test_model(pointrend_model, input_tensor)

# Convert the output to a NumPy array and visualize it
output_mask = output["fine"].squeeze().cpu().numpy()

plt.imshow(output_mask, cmap='gray')
plt.title("Segmentation Output")
plt.show()

KeyError: 'low_level'