In [24]:
import torch
from isegm.model.ifss_pfenet_model import PFENetModel

model = PFENetModel()
B = 2
H = 473
W = 473
query_image = torch.randn(B, 3, H, W)
prev_query_mask = torch.randint(0, 2, (B, 1, H, W)).float()
query_mask = torch.randint(0, 2, (B, 1, H, W)).float()

support_image = torch.randn(B, 3, H, W)
support_mask = torch.randint(0, 2, (B, 1, H, W)).float()

support_output = model.support_forward(support_image, s_gt=support_mask)
helpers = support_output["query_helpers"]
helpers["q_gt"] = query_mask
query_output = model.query_forward(query_image, prev_query_mask, helpers)

print("Support Output:", support_output['instances'][0].shape)
print("Query Output:", query_output['masks'].shape)

Support Output: torch.Size([2, 473, 473])
Query Output: torch.Size([2, 473, 473])


In [25]:
print(torch.min(query_image), torch.max(query_image))
print(torch.min(support_image), torch.max(support_image))

tensor(-4.8016) tensor(4.7887)
tensor(-4.7099) tensor(4.7189)


In [3]:
from pathlib import Path
from isegm.utils.exp import load_config

from albumentations import *
from isegm.data.transforms import *
from isegm.data.datasets.fss_sbd import iFSS_SBD_Dataset

cfg = load_config(Path("config.yml"))
cfg.debug = False

crop_size = (473, 473)

train_augmentator = Compose(
    [
        UniformRandomResize(scale_range=(0.75, 1.25)),
        Flip(),
        RandomRotate90(),
        ShiftScaleRotate(border_mode=0, p=0.75),
        PadIfNeeded(min_height=crop_size[0], min_width=crop_size[1], border_mode=0),
        RandomCrop(*crop_size),
        RandomBrightnessContrast(),
        RGBShift(
            r_shift_limit=(-10, 10),
            g_shift_limit=(-10, 10),
            b_shift_limit=(-10, 10),
        ),
        # HueSaturationValue(p=0.25),
        # GaussianBlur(p=0.25),
        Normalize(
            mean=(0.485, 0.456, 0.406), 
            std=(0.229, 0.224, 0.225), 
            max_pixel_value=1.0
        ),
    ], 
    p=1.0
)

trainset = iFSS_SBD_Dataset(
    cfg,  # FIXME: Find a better way of doing this
    data_root=cfg.SBD_TRAIN_PATH,
    data_list=cfg.SBD_TRAIN_LIST,
    mode="train",
    split=0,
    use_coco=False,
    use_split_coco=False,
    augmentator=train_augmentator,
    min_object_area=80,
    keep_background_prob=0.01,
)

Loading from saved dump


In [5]:
trainset[0].keys()

dict_keys(['s_images', 's_instances', 's_points', 'q_images', 'q_masks'])

In [12]:
import torch

query_image = torch.cat((trainset[0]['q_images'].unsqueeze(0), trainset[1]['q_images'].unsqueeze(0)), dim=0)
query_image.shape

torch.Size([2, 3, 473, 473])

In [37]:

from isegm.model.ifss_pfenet_model import PFENetModel

model = PFENetModel()
B = 2
H = 473
W = 473

for i in range(0, len(trainset), 2):
# for i in [0]:
    query_image = torch.cat(
        (trainset[i + 0]['q_images'].unsqueeze(0), 
        trainset[i + 1]['q_images'].unsqueeze(0)), dim=0)
    prev_query_mask = torch.randint(0, 2, (B, 1, H, W)).float()
    query_mask = torch.cat(
        (torch.tensor(trainset[i + 0]['q_masks']).unsqueeze(0), 
        torch.tensor(trainset[i + 1]['q_masks']).unsqueeze(0)), dim=0).float()

    support_image = torch.cat(
        (trainset[i + 0]['s_images'].unsqueeze(0), 
        trainset[i + 1]['s_images'].unsqueeze(0)), dim=0)
    support_mask = torch.cat(
        (torch.tensor(trainset[i + 0]['s_instances']).unsqueeze(0), 
        torch.tensor(trainset[i + 1]['s_instances']).unsqueeze(0)), dim=0).float()

    support_output = model.support_forward(support_image, s_gt=support_mask)
    helpers = support_output["query_helpers"]
    helpers["q_gt"] = query_mask
    query_output = model.query_forward(query_image, prev_query_mask, helpers)
    
    supp_feat_list = support_output['query_helpers']['supp_feat_list']
    nan_check = [torch.isnan(feat).any().item() for feat in supp_feat_list]
    print("NaN values in supp_feat_list:", nan_check)

  model.load_state_dict(torch.load(model_path), strict=False)


NaN values in supp_feat_list: [False]
NaN values in supp_feat_list: [False]
NaN values in supp_feat_list: [False]
NaN values in supp_feat_list: [False]


KeyboardInterrupt: 

In [31]:
support_output.keys()

dict_keys(['instances', 'query_helpers'])

In [32]:
support_output['query_helpers'].keys()

dict_keys(['supp_feat_list', 'final_supp_list', 'mask_list', 'q_gt'])

In [35]:
supp_feat_list = support_output['query_helpers']['supp_feat_list']
nan_check = [torch.isnan(feat).any().item() for feat in supp_feat_list]
print("NaN values in supp_feat_list:", nan_check)

NaN values in supp_feat_list: [False]


In [36]:
supp_feat_list = support_output['query_helpers']['final_supp_list']
nan_check = [torch.isnan(feat).any().item() for feat in supp_feat_list]
print("NaN values in supp_feat_list:", nan_check)

NaN values in supp_feat_list: [False]
