In [None]:
import jax
import jax.numpy as jnp
from jax import grad
import numpy as np
import flax
from flax.core import FrozenDict, freeze
from flax.traverse_util import unflatten_dict
from flax import linen as nn
from flax.linen import BatchNorm, Conv

from PIL import Image
import matplotlib.pyplot as plt

In [None]:
pip install torchvision==0.11.1  


In [None]:
import torch
from functools import partial
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union)


In [None]:
# resnet model structure in JAX
STAGE_SIZES = {
    18: [2, 2, 2, 2],
    34: [3, 4, 6, 3],
    50: [3, 4, 6, 3],
    101: [3, 4, 23, 3],
    152: [3, 8, 36, 3],
    200: [3, 24, 36, 3],
    269: [3, 30, 48, 8],
}
ModuleDef = Callable[..., Callable]
# InitFn = Callable[[PRNGKey, Shape, DType], Array]
InitFn = Callable[[Any, Iterable[int], Any], Any]
#conv 
 
class ConvBlock(nn.Module): 
        n_filters: int
        kernel_size: Tuple[int, int] = (3, 3)
        strides: Tuple[int, int] = (1, 1)
        activation: Callable = nn.relu
        padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
        is_last: bool = False
        groups: int = 1
        kernel_init: InitFn = nn.initializers.kaiming_normal()
        bias_init: InitFn = nn.initializers.zeros

        conv_cls: ModuleDef = nn.Conv
        norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9)

        force_conv_bias: bool = False

        @nn.compact
        def __call__(self, x):
            x = self.conv_cls(
                self.n_filters,
                self.kernel_size,
                self.strides,
                use_bias=(not self.norm_cls or self.force_conv_bias),
                padding=self.padding,
                feature_group_count=self.groups,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            )(x)
            if self.norm_cls:
                scale_init = (nn.initializers.zeros
                              if self.is_last else nn.initializers.ones)
                mutable = self.is_mutable_collection('batch_stats')
                x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x)

            if not self.is_last:
                x = self.activation(x)
            return x


class Sequential(nn.Module):
    layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

    @nn.compact
    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def slice_variables(variables: Mapping[str, Any],
                    start: int = 0,
                    end: Optional[int] = None) -> flax.core.FrozenDict:
    """Returns variables dict correspond to a sliced model.
    Args:
        variables: A dict (typically a flax.core.FrozenDict) containing the
            model parameters and state.
        start: integer indicating the first layer to keep.
        end: integer indicating the first layer to exclude (can be negative,
            has the same semantics as negative list indexing).
    Returns:
        A flax.core.FrozenDict with the subset of parameters/state requested.
    """
    last_ind = max(int(s.split('_')[-1]) for s in variables['params'])
    if end is None:
        end = last_ind + 1
    elif end < 0:
        end += last_ind + 1

    sliced_variables: Dict[str, Any] = {}
    for k, var_dict in variables.items():  # usually params and batch_stats
        sliced_variables[k] = {
            f'layers_{i-start}': var_dict[f'layers_{i}']
            for i in range(start, end)
            if f'layers_{i}' in var_dict
        }

    return flax.core.freeze(sliced_variables)    
#---------
#softmax func
def rsoftmax(x, radix, cardinality):
    # (batch_size, features) -> (batch_size, features)
    batch = x.shape[0]
    if radix > 1:
        x = x.reshape((batch, cardinality, radix, -1)).swapaxes(1, 2)
        return nn.softmax(x, axis=1).reshape((batch, -1))
    else:
        return nn.sigmoid(x)

#conv2d
class SplAtConv2d(nn.Module):
    channels: int
    kernel_size: Tuple[int, int]
    strides: Tuple[int, int] = (1, 1)
    padding: Union[str, Iterable[Tuple[int, int]]] = ((0, 0), (0, 0))
    groups: int = 1
    radix: int = 2
    reduction_factor: int = 4
    conv_block_cls: ModuleDef = ConvBlock
    cardinality: int = groups
    match_reference: bool = False

    @nn.compact
    def __call__(self, x):
        inter_channels = max(x.shape[-1] * self.radix // self.reduction_factor, 32)

        conv_block = self.conv_block_cls(self.channels * self.radix,
                                         kernel_size=self.kernel_size,
                                         strides=self.strides,
                                         groups=self.groups * self.radix,
                                         padding=self.padding)
        conv_cls = conv_block.conv_cls  # type: ignore
        x = conv_block(x)

        if self.radix > 1:
            # torch split takes split_size: int(rchannel//self.radix)
            # jnp split takes num sections: self.radix
            split = jnp.split(x, self.radix, axis=-1)
            gap = sum(split)
        else:
            gap = x

        gap = gap.mean((1, 2), keepdims=True)  # type: ignore # global average pool

        # Remove force_conv_bias after resolving
        # github.com/zhanghang1989/ResNeSt/issues/125
        gap = self.conv_block_cls(inter_channels,
                                  kernel_size=(1, 1),
                                  groups=self.cardinality,
                                  force_conv_bias=self.match_reference)(gap)

        attn = conv_cls(self.channels * self.radix,
                        kernel_size=(1, 1),
                        feature_group_count=self.cardinality)(gap)  # n x 1 x 1 x c
        attn = attn.reshape((x.shape[0], -1))
        attn = rsoftmax(attn, self.radix, self.cardinality)
        attn = attn.reshape((x.shape[0], 1, 1, -1))

        if self.radix > 1:
            attns = jnp.split(attn, self.radix, axis=-1)
            out = sum(a * s for a, s in zip(attns, split))
        else:
            out = attn * x

        return out
#------------------------------------------------------------------
class ResNetStem(nn.Module):
    conv_block_cls: ModuleDef = ConvBlock

    @nn.compact
    def __call__(self, x):
        return self.conv_block_cls(64,
                                   kernel_size=(7, 7),
                                   strides=(2, 2),
                                   padding=[(3, 3), (3, 3)])(x) #################

class ResNetDStem(nn.Module):
    conv_block_cls: ModuleDef = ConvBlock
    stem_width: int = 32

    # If True, n_filters for first conv is (input_channels + 1) * 8
    adaptive_first_width: bool = False

    @nn.compact
    def __call__(self, x):
        cls = partial(self.conv_block_cls, kernel_size=(3, 3), padding=((1, 1), (1, 1)))
        first_width = (8 * (x.shape[-1] + 1)
                       if self.adaptive_first_width else self.stem_width)
        x = cls(first_width, strides=(2, 2))(x)
        x = cls(self.stem_width, strides=(1, 1))(x)
        x = cls(self.stem_width * 2, strides=(1, 1))(x)
        return x


class ResNetSkipConnection(nn.Module):
    strides: Tuple[int, int]
    conv_block_cls: ModuleDef = ConvBlock

    @nn.compact
    def __call__(self, x, out_shape):
        if x.shape != out_shape:
            x = self.conv_block_cls(out_shape[-1],
                                    kernel_size=(1, 1),
                                    strides=self.strides,
                                    activation=lambda y: y)(x)
        return x

class ResNetDSkipConnection(nn.Module):
    strides: Tuple[int, int]
    conv_block_cls: ModuleDef = ConvBlock

    @nn.compact
    def __call__(self, x, out_shape):
        if self.strides != (1, 1):
            x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding=((0, 0), (0, 0)))
        if x.shape[-1] != out_shape[-1]:
            x = self.conv_block_cls(out_shape[-1], (1, 1), activation=lambda y: y)(x)
        return x

class ResNeStSkipConnection(ResNetDSkipConnection):
    # Inheritance to ensures our variables dict has the right names.
    pass

class ResNetBlock(nn.Module):
    n_hidden: int
    strides: Tuple[int, int] = (1, 1)

    activation: Callable = nn.relu
    conv_block_cls: ModuleDef = ConvBlock
    skip_cls: ModuleDef = ResNetSkipConnection

    @nn.compact
    def __call__(self, x):
        skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
        y = self.conv_block_cls(self.n_hidden,
                                padding=[(1, 1), (1, 1)],
                                strides=self.strides)(x)
        y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)],
                                is_last=True)(y)
        return self.activation(y + skip_cls(self.strides)(x, y.shape))

class ResNetBottleneckBlock(nn.Module):
    n_hidden: int
    strides: Tuple[int, int] = (1, 1)
    expansion: int = 4
    groups: int = 1  # cardinality
    base_width: int = 64

    activation: Callable = nn.relu
    conv_block_cls: ModuleDef = ConvBlock
    skip_cls: ModuleDef = ResNetSkipConnection

    @nn.compact
    def __call__(self, x):
        skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
        group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups

        # Downsampling strides in 3x3 conv instead of 1x1 conv, which improves accuracy.
        # This variant is called ResNet V1.5 (matches torchvision).
        y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x)
        y = self.conv_block_cls(group_width,
                                strides=self.strides,
                                groups=self.groups,
                                padding=((1, 1), (1, 1)))(y)
        y = self.conv_block_cls(self.n_hidden * self.expansion,
                                kernel_size=(1, 1),
                                is_last=True)(y)
        return self.activation(y + skip_cls(self.strides)(x, y.shape))


class ResNetDBlock(ResNetBlock):
    skip_cls: ModuleDef = ResNetDSkipConnection


class ResNetDBottleneckBlock(ResNetBottleneckBlock):
    skip_cls: ModuleDef = ResNetDSkipConnection

class ResNeStBottleneckBlock(ResNetBottleneckBlock):
    skip_cls: ModuleDef = ResNeStSkipConnection
    avg_pool_first: bool = False
    radix: int = 2

    splat_cls: ModuleDef = SplAtConv2d

    @nn.compact
    def __call__(self, x):
        assert self.radix == 2  # TODO: implement radix != 2

        skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls)
        group_width = int(self.n_hidden * (self.base_width / 64.)) * self.groups

        y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x)

        if self.strides != (1, 1) and self.avg_pool_first:
            y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)])

        y = self.splat_cls(group_width,
                           kernel_size=(3, 3),
                           strides=(1, 1),
                           padding=[(1, 1), (1, 1)],
                           groups=self.groups,
                           radix=self.radix)(y)

        if self.strides != (1, 1) and not self.avg_pool_first:
            y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)])

        y = self.conv_block_cls(self.n_hidden * self.expansion,
                                kernel_size=(1, 1),
                                is_last=True)(y)

        return self.activation(y + skip_cls(self.strides)(x, y.shape))

def ResNet(
    block_cls: ModuleDef,
    *,
    stage_sizes: Sequence[int],
    n_classes: int,
    hidden_sizes: Sequence[int] = (64, 128, 256, 512),
    conv_cls: ModuleDef = nn.Conv,
    norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9),
    conv_block_cls: ModuleDef = ConvBlock,
    stem_cls: ModuleDef = ResNetStem,
    pool_fn: Callable = partial(nn.max_pool,
                                window_shape=(3, 3),
                                strides=(2, 2),
                                padding=((1, 1), (1, 1))),
) -> Sequential:
    conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls)
    stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls)
    block_cls = partial(block_cls, conv_block_cls=conv_block_cls)

    layers = [stem_cls(), pool_fn]

    for i, (hsize, n_blocks) in enumerate(zip(hidden_sizes, stage_sizes)):
        for b in range(n_blocks):
            strides = (1, 1) if i == 0 or b != 0 else (2, 2)
            layers.append(block_cls(n_hidden=hsize, strides=strides))

    layers.append(partial(jnp.mean, axis=(1, 2)))  # global average pool
    layers.append(nn.Dense(n_classes))
    return Sequential(layers)

#RESNET 101 model
ResNet101 = partial(ResNet, stage_sizes=STAGE_SIZES[101],stem_cls=ResNetStem, block_cls=ResNetBottleneckBlock)


In [None]:

PyTorchTensor = Any
def pretrained_resnet(
    size: int,
    state_dict: Optional[Mapping[str, PyTorchTensor]] = None
) -> Tuple[ModuleDef, FrozenDict]:
    """Returns pretrained variables for ResNet ported from torch.hub.
    """
    PATH =  "/notebooks/storage/resnet101-63fe2227.pth"  
    state_dict= torch.load(PATH)
    pt2jax: Dict[str, Sequence[str]] = {}
    add_bn = _get_add_bn(pt2jax)
    pt2jax['conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
    add_bn('bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))

    lyr = 2  # block_ind
    for b, n_blocks in enumerate (STAGE_SIZES[101], 1):
        for i in range(n_blocks):
            for j in range(2 + (size >= 50)):
                pt2jax[f'layer{b}.{i}.conv{j+1}.weight'] = ('params', f'layers_{lyr}',
                                                            f'ConvBlock_{j}', 'Conv_0',
                                                            'kernel')
                add_bn(f'layer{b}.{i}.bn{j+1}',
                       (f'layers_{lyr}', f'ConvBlock_{j}', 'BatchNorm_0'))

            if f'layer{b}.{i}.downsample.0.weight' in state_dict:
                pt2jax[f'layer{b}.{i}.downsample.0.weight'] = ('params',
                                                               f'layers_{lyr}',
                                                               'ResNetSkipConnection_0',
                                                               'ConvBlock_0', 'Conv_0',
                                                               'kernel')
                add_bn(f'layer{b}.{i}.downsample.1',
                       (f'layers_{lyr}', 'ResNetSkipConnection_0', 'ConvBlock_0',
                        'BatchNorm_0'))

            lyr += 1

    lyr += 1
    pt2jax['fc.weight'] = ('params', f'layers_{lyr}', 'kernel')
    pt2jax['fc.bias'] = ('params', f'layers_{lyr}', 'bias')
    variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',))
    model_cls = partial(ResNet101, n_classes=1000)
   
    return model_cls, freeze(unflatten_dict(variables))

def _pytorch_to_jax_params(pt2jax, state_dict, fc_keys):
    variables = {}
    for pt_name, jax_key in pt2jax.items():
        w = state_dict[pt_name].detach().numpy()
        if w.ndim == 4:
            w = w.transpose((2, 3, 1, 0))
        elif pt_name in fc_keys:
            w = w.transpose()
        variables[jax_key] = w

    return variables

def _get_add_bn(pt2jax):
    def add_bn(pname, jprefix):
        pt2jax[f'{pname}.weight'] = ('params', *jprefix, 'scale')
        pt2jax[f'{pname}.bias'] = ('params', *jprefix, 'bias')
        pt2jax[f'{pname}.running_mean'] = ('batch_stats', *jprefix, 'mean')
        pt2jax[f'{pname}.running_var'] = ('batch_stats', *jprefix, 'var')

    return add_bn

In [None]:
# Load resnet100 model
RESNET100, variables = pretrained_resnet(101)
model = RESNET100()
model_out=model.apply(variables, jnp.ones((1, 224, 224, 3)) ,mutable=False)
print(np.shape(model_out))



In [None]:
#-----------Image Classification through resnet101-------------

#image transforms
import torchvision.transforms as T
img_res= Image.open('/notebooks/storage/data/bird.jpg')
plt.imshow(img_res)
preprocess = T.Compose([T.Resize((224,224)),
            T.ToTensor(),T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])])
inp_img_tensor = preprocess(img_res)
inp_batch= torch.unsqueeze(inp_img_tensor,0).view(1,224,224,3)
out=model.apply(variables, inp_batch, mutable= False)

#-------classify-----------

with open('/notebooks/storage/imagenet_classes.txt') as f: # path to the ImageNet labels
    classes = [line.strip() for line in f.readlines()]

max_index=jax.numpy.argmax(out,1) 
percentage= jax.nn.softmax(out, axis=1) *100
print("The image is classified as (",classes[max_index[0]],") with percentage ", percentage[0][max_index[0]].item() ,"%")
