In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import grad
import numpy as np

from functools import partial
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Callable ,Union, Iterable

import flax
from flax.core import FrozenDict, freeze
from flax.traverse_util import unflatten_dict
from flax import linen as nn
from flax.linen import BatchNorm, Conv
from flax.core.frozen_dict import freeze, unfreeze
PyTorchTensor = Any


In [None]:
pip install torchvision==0.11.1  


In [None]:
import torch

### Resnet model structure in JAX ###

In [None]:
# resnet model structure in jax
STAGE_SIZES = {
    18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3],
    101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3],
    269: [3, 30, 48, 8],
}
ModuleDef = Callable[..., Callable]
InitFn = Callable[[Any, Iterable[int], Any], Any]
#conv 
class ConvBlock(nn.Module): 
        
        n_filters: int
        dilation: int  
        kernel_size: Tuple[int, int] = (3, 3)
        strides: Tuple[int, int] = (1, 1)
        
        activation: Callable = nn.relu
        padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
        is_last: bool = False
        groups: int = 1
        kernel_init: InitFn = nn.initializers.kaiming_normal()
        bias_init: InitFn = nn.initializers.zeros
        conv_cls: ModuleDef = nn.Conv
        norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9)
        force_conv_bias: bool = False

        @nn.compact
        def __call__(self, x):
            x = self.conv_cls(
                self.n_filters,
                self.kernel_size,
                self.strides,              
                use_bias=(not self.norm_cls or self.force_conv_bias),
                padding=self.padding,
                feature_group_count=self.groups,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
                
                kernel_dilation=(self.dilation,self.dilation) 
            )(x)
            if self.norm_cls:
                scale_init = (nn.initializers.zeros
                              if self.is_last else nn.initializers.ones)
                mutable = self.is_mutable_collection('batch_stats')
                x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x)

            if not self.is_last:
                x = self.activation(x)
            return x


class Sequential(nn.Module):
    layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

    @nn.compact
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def slice_variables(variables: Mapping[str, Any],
                    start: int = 0,
                    end: Optional[int] = None) -> flax.core.FrozenDict:
    """Returns variables dict correspond to a sliced model.
    """
    last_ind = max(int(s.split('_')[-1]) for s in variables['params'])
    if end is None:
        end = last_ind + 1
    elif end < 0:
        end += last_ind + 1

    sliced_variables: Dict[str, Any] = {}
    for k, var_dict in variables.items():  # usually params and batch_stats
        sliced_variables[k] = {
            f'layers_{i-start}': var_dict[f'layers_{i}']
            for i in range(start, end)
            if f'layers_{i}' in var_dict
        }

    return flax.core.freeze(sliced_variables)                

#softmax func
def rsoftmax(x, radix, cardinality):
    # (batch_size, features) -> (batch_size, features)
    batch = x.shape[0]
    if radix > 1:
        x = x.reshape((batch, cardinality, radix, -1)).swapaxes(1, 2)
        return nn.softmax(x, axis=1).reshape((batch, -1))
    else:
        return nn.sigmoid(x)

class SplAtConv2d(nn.Module):
    channels: int
    kernel_size: Tuple[int, int]
    strides: Tuple[int, int] = (1, 1)
    padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
    groups: int = 1
    radix: int = 2
    reduction_factor: int = 4
    conv_block_cls: ModuleDef = ConvBlock
    cardinality: int = groups
    match_reference: bool = False

    @nn.compact
    def __call__(self, x):
        inter_channels = max(x.shape[-1] * self.radix // self.reduction_factor, 32)

        conv_block = self.conv_block_cls(self.channels * self.radix,
                                         kernel_size=self.kernel_size,
                                         strides=self.strides,
                                         groups=self.groups * self.radix,
                                         padding=self.padding)
        conv_cls = conv_block.conv_cls  # type: ignore
        x = conv_block(x)

        if self.radix > 1:
            # torch split takes split_size: int(rchannel//self.radix)
            # jnp split takes num sections: self.radix
            split = jnp.split(x, self.radix, axis=-1)
            gap = sum(split)
        else:
            gap = x

        gap = gap.mean((1, 2), keepdims=True)  # type: ignore # global average pool
        # Remove force_conv_bias after resolving
        # github.com/zhanghang1989/ResNeSt/issues/125
        gap = self.conv_block_cls(inter_channels,
                                  kernel_size=(1, 1),
                                  groups=self.cardinality,
                                  force_conv_bias=self.match_reference)(gap)

        attn = conv_cls(self.channels * self.radix,
                        kernel_size=(1, 1),
                        feature_group_count=self.cardinality)(gap)  # n x 1 x 1 x c
        attn = attn.reshape((x.shape[0], -1))
        attn = rsoftmax(attn, self.radix, self.cardinality)
        attn = attn.reshape((x.shape[0], 1, 1, -1))

        if self.radix > 1:
            attns = jnp.split(attn, self.radix, axis=-1)
            out = sum(a * s for a, s in zip(attns, split))
        else:
            out = attn * x

        return out

class ResNetStem(nn.Module):
    conv_block_cls: ModuleDef = ConvBlock

    @nn.compact
    def __call__(self, x):
        return self.conv_block_cls(64,dilation=1,
                                   kernel_size=(7, 7),
                                   strides=(2, 2),
                                   padding=[(3, 3), (3, 3)])(x) 

class ResNetSkipConnection(nn.Module):
    strides: Tuple[int, int]
    conv_block_cls: ModuleDef = ConvBlock

    @nn.compact
    def __call__(self, x, out_shape):
        if x.shape != out_shape:
            x = self.conv_block_cls(out_shape[-1],dilation=1,
                                    kernel_size=(1, 1),
                                    strides=self.strides,
                                    activation=lambda y: y)(x)
        return x

class ResNetBlock(nn.Module):
    n_hidden: int
    strides: Tuple[int, int] = (1, 1)
    activation: Callable = nn.relu
    conv_block_cls: ModuleDef = ConvBlock
    skip_cls: ModuleDef = ResNetSkipConnection

    @nn.compact
    def __call__(self, x):
        skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
        y = self.conv_block_cls(self.n_hidden,
                                padding=[(1, 1), (1, 1)],
                                strides=self.strides)(x)
        y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)],
                                is_last=True)(y)
        return self.activation(y + skip_cls(self.strides)(x, y.shape))

class ResNetBottleneckBlock(nn.Module):
    
    dilation: int 
    n_hidden: int
    strides: Tuple[int, int] ##
    expansion: int = 4
    groups: int = 1  # cardinality
    base_width: int = 64
    activation: Callable = nn.relu
    conv_block_cls: ModuleDef = ConvBlock
    skip_cls: ModuleDef = ResNetSkipConnection
   
    @nn.compact
    def __call__(self, x):   
        skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
        group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups
        # Downsampling strides in 3x3 conv instead of 1x1 conv, which improves accuracy.    
        y = self.conv_block_cls(group_width,dilation=1, kernel_size=(1, 1))(x)   
        y = self.conv_block_cls(group_width, strides=self.strides,
                                groups =self.groups,
                                padding =((self.dilation, self.dilation), (self.dilation, self.dilation)) ,
                                dilation=self.dilation)(y) #as padding=dilation in conv3x3
       
        y = self.conv_block_cls(self.n_hidden * self.expansion, dilation=1,
                                kernel_size =(1, 1),
                                is_last=True)(y)   
        
        return self.activation(y + skip_cls(self.strides)(x, y.shape))
    
#=----------------------------------------------------------------------------------------------------
def ResNet(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)

    layers = [stem_cls(), pool_fn]

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))

    layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
    layers.append(nn.Dense(n_classes))
    return Sequential(layers)

#RESNET 101 model
ResNet101 = partial(ResNet, stage_sizes=STAGE_SIZES[101],stem_cls=ResNetStem, block_cls=ResNetBottleneckBlock)


In [None]:
def ResNet_layer4(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
    replace_stride_with_dilation: Optional[Sequence[bool]] = None  
    
) -> Sequential:
    
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)
    layers = [stem_cls(), pool_fn]
    if replace_stride_with_dilation is None: 
        # each element in the tuple indicates if we should replace
        # the 2x2 stride with a dilated convolution instead
        replace_stride_with_dilation = [False, False, False]
    if len(replace_stride_with_dilation) != 3:
        raise ValueError(
            "replace_stride_with_dilation should be None "
            "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
        )
        
    dilation=1
    previous_dilation=1
    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        b=0
        strides = (1, 1) if i == 0 or b != 0 else (2, 2)
        previous_dilation = dilation 
        if i== 0:
            dilate=False
        else:
            dilate= replace_stride_with_dilation[i-1]
        
        if dilate:
                dilation= dilation* strides[0]
                strides= (1,1)
                previous_dilation = int(dilation/2) #update prev dilation   
       
        if i==0:
                layers.append(block_cls(dilation=1, n_hidden=hsize, strides=strides))   
        else:
                layers.append(block_cls(dilation=previous_dilation, n_hidden=hsize, strides=strides ))  
        
        for b in range(1,n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            if i==0:
                layers.append(block_cls(dilation =1, n_hidden=hsize, strides=strides)) 
            else:                          
                layers.append(block_cls(dilation=dilation, n_hidden=hsize, strides=strides ))    

    return Sequential(layers)

ResNet_layer4 = partial(ResNet_layer4, stage_sizes=STAGE_SIZES[101],stem_cls=ResNetStem,
                        block_cls=ResNetBottleneckBlock, replace_stride_with_dilation=[False, True, True])


In [None]:
def _pytorch_to_jax_params(pt2jax, state_dict, fc_keys):
    variables = {}
   # print(pt2jax.items())
    for pt_name, jax_key in pt2jax.items():
        w = state_dict[pt_name].detach().numpy()
        if w.ndim == 4:
            w = w.transpose((2, 3, 1, 0))
        elif pt_name in fc_keys:
            w = w.transpose()
        variables[jax_key] = w 
    return variables

def _get_add_bn(pt2jax):
    def add_bn(pname, jprefix):
        pt2jax[f'{pname}.weight'] = ('params', *jprefix, 'scale')
        pt2jax[f'{pname}.bias'] = ('params', *jprefix, 'bias')
        pt2jax[f'{pname}.running_mean'] = ('batch_stats', *jprefix, 'mean')
        pt2jax[f'{pname}.running_var'] = ('batch_stats', *jprefix, 'var')
    return add_bn


In [None]:
def pretrained_resnet(
    size: int,
    state_dict: Optional[Mapping[str, PyTorchTensor]] = None
) -> Tuple[ModuleDef, FrozenDict]:
    """Returns pretrained variables for ResNet ported from torch.hub.
    Returns:Module Class and variables dictionary for Flax ResNet.
    """
    PATH =  "/notebooks/storage/resnet101-63fe2227.pth"   
    state_dict= torch.load(PATH)
    pt2jax: Dict[str, Sequence[str]] = {}
    add_bn = _get_add_bn(pt2jax)
    pt2jax['conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
    add_bn('bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))
    lyr = 2  # block_ind
    for b, n_blocks in enumerate (STAGE_SIZES[101], 1):
        for i in range(n_blocks):
            for j in range(2 + (size >= 50)):
                pt2jax[f'layer{b}.{i}.conv{j+1}.weight'] = ('params', f'layers_{lyr}',
                                                            f'ConvBlock_{j}', 'Conv_0',
                                                            'kernel')
                add_bn(f'layer{b}.{i}.bn{j+1}',
                       (f'layers_{lyr}', f'ConvBlock_{j}', 'BatchNorm_0'))

            if f'layer{b}.{i}.downsample.0.weight' in state_dict:
                pt2jax[f'layer{b}.{i}.downsample.0.weight'] = ('params',
                                                               f'layers_{lyr}',
                                                               'ResNetSkipConnection_0',
                                                               'ConvBlock_0', 'Conv_0',
                                                               'kernel')
                add_bn(f'layer{b}.{i}.downsample.1',(f'layers_{lyr}', 'ResNetSkipConnection_0', 
                        'ConvBlock_0','BatchNorm_0'))
            lyr += 1

    lyr += 1
    pt2jax['fc.weight'] = ('params', f'layers_{lyr}', 'kernel')
    pt2jax['fc.bias'] = ('params', f'layers_{lyr}', 'bias')
    variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',))
    model_cls = partial(ResNet101, n_classes=1000) 
    
    return model_cls, freeze(unflatten_dict(variables))


In [None]:
def pretrained_customized_resnet(
    
    size: int,
    state_dict: Optional[Mapping[str, PyTorchTensor]] = None
) -> Tuple[ModuleDef, FrozenDict]:

    PATH =  "/notebooks/storage/resnet101-63fe2227.pth"   
    state_dict= torch.load(PATH)  
    pt2jax: Dict[str, Sequence[str]] = {}
    add_bn = _get_add_bn(pt2jax)
    pt2jax['conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
    add_bn('bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))
    lyr = 2  # block_ind  
    for b, n_blocks in enumerate (STAGE_SIZES[101], 1):
        for i in range(n_blocks):
            for j in range(2 + (size >= 50)):
                pt2jax[f'layer{b}.{i}.conv{j+1}.weight'] = ('params', f'layers_{lyr}',
                                                            f'ConvBlock_{j}', 'Conv_0',
                                                            'kernel')
                add_bn(f'layer{b}.{i}.bn{j+1}',
                       (f'layers_{lyr}', f'ConvBlock_{j}', 'BatchNorm_0'))

            if f'layer{b}.{i}.downsample.0.weight' in state_dict:
                pt2jax[f'layer{b}.{i}.downsample.0.weight'] = ('params',
                                                               f'layers_{lyr}',
                                                               'ResNetSkipConnection_0',
                                                               'ConvBlock_0', 'Conv_0',
                                                               'kernel')
                add_bn(f'layer{b}.{i}.downsample.1',
                       (f'layers_{lyr}', 'ResNetSkipConnection_0', 'ConvBlock_0',
                        'BatchNorm_0'))
            lyr += 1

    lyr += 1
    pt2jax['fc.weight'] = ('params', f'layers_{lyr}', 'kernel')
    pt2jax['fc.bias'] = ('params', f'layers_{lyr}', 'bias')

    variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',))
    model_cls = partial(ResNet_layer4, n_classes=1000) # customized model from layer4

    return model_cls, freeze(unflatten_dict(variables))


In [None]:
#FCNHead classifier in FLAX  
class FCNHead(nn.Module):    
    
    @nn.compact
    def __call__(self, x):
        in_channels=2048 #num_classes = 21
        channels= 21
        inter_channels = in_channels // 4  
        x= nn.Conv(inter_channels ,kernel_size = (3,3), padding=[(1,1),(1,1)], use_bias =False)(x) 
        x= nn.BatchNorm(use_running_average=True, momentum=0.9, epsilon=1e-5, dtype=jnp.float32)(x)
        x= nn.relu(x)
        x= nn.Dropout(0.1,deterministic= True)(x) #but through (training) set deterministic False
        x= nn.Conv(channels,kernel_size= (1,1))(x)
         
        return x   


In [None]:
def FCN_classifier(
    size: int,
    state_dict: Optional[Mapping[str, PyTorchTensor]] = None
) -> Tuple[ModuleDef, FrozenDict]:
    
        PATH= "/notebooks/storage/fcn_resnet101_coco-7ecb50ca.pth" #classifier state dictionary 
        state_dict= torch.load(PATH)
        pt2jax: Dict[str, Sequence[str]] = {}
        pt2jax['classifier.0.weight'] = ('params', 'Conv_0', 'kernel')

        pt2jax['classifier.1.weight'] = ('params',  'BatchNorm_0', 'scale')
        pt2jax['classifier.1.bias'] = ('params', 'BatchNorm_0', 'bias')
        
        pt2jax['classifier.1.running_mean'] = ('batch_stats','BatchNorm_0' , 'mean')
        pt2jax['classifier.1.running_var'] = ('batch_stats', 'BatchNorm_0', 'var')
        
        pt2jax['classifier.4.weight'] = ('params', 'Conv_1', 'kernel')
        pt2jax['classifier.4.bias'] = ('params', 'Conv_1', 'bias')

        variables = _pytorch_to_jax_params(pt2jax, state_dict, ())
        model_cls = FCNHead        # get the model structure from FCNHead
        
        return model_cls, freeze(unflatten_dict(variables))

In [None]:
#for customized resnet layer4- backbone
def backbone_pretrained_customized_resnet(
    
    size: int,
    state_dict: Optional[Mapping[str, PyTorchTensor]] = None
) -> Tuple[ModuleDef, FrozenDict]:  
    PATH= "/notebooks/storage/fcn_resnet101_coco-7ecb50ca.pth" #classifier state dictionary
    state_dict= torch.load(PATH)
    pt2jax: Dict[str, Sequence[str]] = {}
    add_bn = _get_add_bn(pt2jax)
    
    pt2jax['backbone.conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
    add_bn('backbone.bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))
    
    lyr = 2 #10 # block_ind
    for b, n_blocks in enumerate (STAGE_SIZES[101], 1):
        for i in range(n_blocks):
            for j in range(2 + (size >= 50)):
                pt2jax[f'backbone.layer{b}.{i}.conv{j+1}.weight'] = ('params', f'layers_{lyr}',
                                                            f'ConvBlock_{j}', 'Conv_0',
                                                            'kernel')
                add_bn(f'backbone.layer{b}.{i}.bn{j+1}',
                       (f'layers_{lyr}', f'ConvBlock_{j}', 'BatchNorm_0'))

            if f'backbone.layer{b}.{i}.downsample.0.weight' in state_dict:
                pt2jax[f'backbone.layer{b}.{i}.downsample.0.weight'] = ('params',
                                                               f'layers_{lyr}',
                                                               'ResNetSkipConnection_0',
                                                               'ConvBlock_0', 'Conv_0',
                                                               'kernel')
                add_bn(f'backbone.layer{b}.{i}.downsample.1',
                       (f'layers_{lyr}', 'ResNetSkipConnection_0', 'ConvBlock_0',
                        'BatchNorm_0'))
            lyr += 1
   
    lyr += 1
    pt2jax['backbone.bn1.running_mean']= ('batch_stats', f'layers_0' ,'ConvBlock_0','BatchNorm_0', 'mean')
    pt2jax['backbone.bn1.running_var'] = ('batch_stats', f'layers_0' , 'ConvBlock_0','BatchNorm_0', 'var')

    variables = _pytorch_to_jax_params(pt2jax, state_dict, ())
    model_cls = partial(ResNet_layer4, n_classes=1000) 
    
    return model_cls, freeze(unflatten_dict(variables))


In [None]:
def _fcn_resnet():
    
    fcn = FCNHead()   
    FCNClass_weights=FCN_classifier(101)[1] # catch the loaded weights (model_cls,wights)
  
    #backbone_pretrained_customized model layer4
    bck_RESNET100, bck_variables = backbone_pretrained_customized_resnet(101)
    backbone = bck_RESNET100()
    backbone_out=backbone.apply(bck_variables, jnp.ones((1,224, 224,3)) ,mutable=False) #initialization   
    init_variables= fcn.init(jax.random.PRNGKey(0),backbone_out) #random initialization 
    class_variables = unfreeze(init_variables)
    class_variables.update(FCNClass_weights)
    class_variables = freeze(class_variables)
   
    return (backbone, bck_variables, fcn ,class_variables)
    
def fcn_resnet101(pretrained: bool = False):
    model = _fcn_resnet()  
    return model


### Image Transformation ###

In [None]:
#img transforms
import torchvision.transforms as T
img_res= Image.open('/notebooks/storage/data/bird.jpg')   #.resize((224,224))
preprocess = T.Compose([T.Resize((224,224)),
            T.ToTensor(),T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])])
inp_img_tensor = preprocess(img_res)
inp_batch= torch.unsqueeze(inp_img_tensor,0).permute(0,2,3,1) # Input nchw(0,1,2,3)--- o/p nhwc(0,2,3,1) 
print(inp_batch.shape)

plt.imshow(img_res); plt.show()


### Running the Model ###

In [None]:
#Running model
fcn_model = fcn_resnet101(pretrained=True)
backbone= fcn_model[0]
backbone_variables= fcn_model[1]
classifier= fcn_model[2]
classifier_variables= fcn_model[3]

Backbone_out=backbone.apply(backbone_variables, inp_batch, mutable= False) #inp_batch ([1, 3, 224, 224])
print("Bckbone_out:", jax.tree_map(lambda x: x.shape, Backbone_out))
#print("Bck_out layer:", jax.tree_map(lambda x: x, Backbone_out))

classifier_out=classifier.apply(classifier_variables, Backbone_out , mutable= False) #Backbone_out: (1, 28, 28, 2048)
print("Classifier output:",classifier_out.shape) #(1, 28, 28, 21)


In [None]:
# Define the helper function
def decode_segmap(image, nc=21):
    
    label_colors = np.array([(0, 0, 0),  # 0=background
                    # 1=aeroplane, 2=bicycle, 3=bird, 4=boat, 5=bottle
                    (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128),
                    # 6=bus, 7=car, 8=cat, 9=chair, 10=cow
                    (0, 128, 128), (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0),
                    # 11=dining table, 12=dog, 13=horse, 14=motorbike, 15=person
                    (192, 128, 0), (64, 0, 128), (192, 0, 128), (64, 128, 128), (192, 128, 128),
                    # 16=potted plant, 17=sheep, 18=sofa, 19=train, 20=tv/monitor
                    (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128)])
    
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    for l in range(0, nc):
        idx = image == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]

    rgb = np.stack([r, g, b], axis=2)
    return rgb

### Segmentation ###

In [None]:
#Segmentation
om= jax.image.resize(classifier_out.squeeze(), (224,224,21),'bilinear')
#print(om.shape)
om2= jax.numpy.argmax(om, axis=2, out=None)
#print(om2.shape)
rgb = decode_segmap(om2)
plt.imshow(rgb); plt.show()
     