In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


Deep Lab V3 Encoder for Few Shot Segmentation backbone

In [2]:
"""
Encoder for few shot segmentation (DeepLabv3)
"""
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        #load pretrained model
        self.features = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet101', pretrained=True)
        self.features.aux_classifier=Identity()

    def forward(self, x):
        y=self.features(x)['out']
        return y

Implementation of FewShot Segmentation Core Logic

In [None]:
class FewShotSeg(nn.Module):
    """
    Fewshot Segmentation model

    Args:
        in_channels:
            number of input channels
        pretrained_path:
            path of the model for initialization
        cfg:
            model configurations
    """
    def __init__(self, in_channels=3):
        super().__init__()
        # Encoder
        self.encoder = Encoder()

    def forward(self,Si,fg_mask,bg_mask,Qi):

        #get one episode (Si,Qi)
        #Si = support images way x shot x [B x 3 x H x W]
        #Qi = query images way x shot x [B x 3 x H x W]
        #number of classes is the number of ways
        c_ways=len(Si)
        #number of shots in each class
        k_shots=len(Si[0])
        batch_size=Si[0][0].shape[0]
        img_size=Si[0][0].shape[1:]


        #concatenate support and query into a single large tensor

        support= [torch.cat(classes,dim=0) for classes in Si]
        query  = [torch.cat(Qi,dim=0),]

        model_ip= torch.cat(support+query,dim=0)

        features=self.encoder(model_ip)



