<a href="https://colab.research.google.com/github/rdkworld/AIPND-2022/blob/main/Generalized/Predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [105]:
#Download image files
!wget https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
!tar -xf 102flowers.tgz

--2022-08-30 19:46:29--  https://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 344862509 (329M) [application/x-gzip]
Saving to: ‘102flowers.tgz.1’


2022-08-30 19:46:48 (18.7 MB/s) - ‘102flowers.tgz.1’ saved [344862509/344862509]



In [106]:
#Clone Model Inference Repository
!git lfs install
!git clone  <huggigface_url>n%cd which-flower
!git init
%cd ..

Updated git hooks.
Git LFS initialized.
fatal: destination path 'which-flower' already exists and is not an empty directory.
/content/which-flower
Reinitialized existing Git repository in /content/which-flower/.git/
/content


In [143]:
import torch
import torchvision
from PIL import Image
from torchvision import models
from torch import nn 
from typing import List
import json

In [148]:
#Read labels file
with open('which-flower/cat_to_name.json','r') as f:
    cat_to_name = json.load(f)
cat_to_name    

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 

In [108]:
#Update last layer of model
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def update_last_layer_pretrained_model(pretrained_model, num_classes, feature_extract):
    set_parameter_requires_grad(pretrained_model, feature_extract)
    if hasattr(pretrained_model, 'fc') and 'resnet' in pretrained_model.__class__.__name__.lower(): #resnet
        num_ftrs = pretrained_model.fc.in_features
        pretrained_model.fc = nn.Linear(num_ftrs, num_classes, bias = True)
    elif hasattr(pretrained_model, 'classifier') and ('alexnet' in pretrained_model.__class__.__name__.lower() or 'vgg' in pretrained_model.__class__.__name__.lower()): #alexNet, vgg
        num_ftrs = pretrained_model.classifier[6].in_features
        pretrained_model.classifier[6] = nn.Linear(num_ftrs, num_classes, bias = True)
    elif hasattr(pretrained_model, 'classifier') and 'squeezenet' in pretrained_model.__class__.__name__.lower(): #squeezenet
        pretrained_model.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        pretrained_model.num_classes = num_classes
    elif hasattr(pretrained_model, 'classifier') and ('efficientnet' in pretrained_model.__class__.__name__.lower() or 'mobilenet' in pretrained_model.__class__.__name__.lower()): #efficientnet, mobilenet
        num_ftrs = pretrained_model.classifier[1].in_features
        pretrained_model.classifier[1] = nn.Linear(num_ftrs, num_classes, bias = True)
    elif hasattr(pretrained_model, 'AuxLogits') and 'inception' in pretrained_model.__class__.__name__.lower(): #inception
        num_ftrs = pretrained_model.AuxLogits.fc.in_features 
        pretrained_model.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) #Auxilary net
        num_ftrs = pretrained_model.fc.in_features
        pretrained_model.fc = nn.Linear(num_ftrs,num_classes) #Primary net
    elif hasattr(pretrained_model, 'classifier') and 'densenet' in pretrained_model.__class__.__name__.lower(): #densenet
        num_ftrs = pretrained_model.classifier.in_features
        pretrained_model.classifier = nn.Linear(num_ftrs, num_classes, bias = True)
    elif hasattr(pretrained_model, 'heads') and 'visiontransformer' in pretrained_model.__class__.__name__.lower(): #vit transformer
        num_ftrs = pretrained_model.heads.head.in_features
        pretrained_model.heads.head = nn.Linear(num_ftrs, num_classes, bias = True)
    elif hasattr(pretrained_model, 'head') and 'swin' in pretrained_model.__class__.__name__.lower(): #swin transformer
        num_ftrs = pretrained_model.head.in_features
        pretrained_model.head = nn.Linear(num_ftrs, num_classes, bias = True)
    return pretrained_model

In [198]:
#Load Model
model_name, model_weights, model_path = ('efficientnet_b2','EfficientNet_B2_Weights','which-flower/flowers_efficientnet_b2_model.pth')
#model_name, model_weights, model_path = ('alexnet','AlexNet_Weights','which-flower/flowers_alexnet_model.pth')
checkpoint = torch.load(model_path, map_location='cpu')
pretrained_weights = eval(f"models.{model_weights}.DEFAULT")
auto_transforms = pretrained_weights.transforms()
#pretrained_model = eval(f"torchvision.models.{model_name}(weights = pretrained_weights)")
pretrained_model = eval(f"models.{model_name}(pretrained = True)")
pretrained_model = update_last_layer_pretrained_model(pretrained_model, 102, True)
pretrained_model.class_to_idx = checkpoint['class_to_idx']
pretrained_model.class_names = checkpoint['class_names']    
pretrained_model.load_state_dict(checkpoint['state_dict'])
pretrained_model.to('cpu')

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-bcdf34b7.pth


  0%|          | 0.00/35.2M [00:00<?, ?B/s]

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [199]:
#pred_and_plot_image
def pred_image(model, image_path, class_names = None, transform=None, device: torch.device = "cuda" if torch.cuda.is_available() else "cpu"):

    target_image = Image.open(image_path)
    if transform:
        target_image = transform(target_image)
    model.to(device)
    model.eval()
    with torch.inference_mode():
        target_image = target_image.unsqueeze(dim=0)
        target_image_pred = model(target_image.to(device))

    target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
    ps = target_image_pred_probs.topk(3)
    ps_numpy = ps[0].cpu().numpy()[0]
    idxs = [class_names[i] for i in ps[1].numpy()[0]] if class_names else ps[1].numpy()[0]

    return (ps_numpy, idxs)

In [203]:
#Predict
image_path = 'which-flower/80_image_02020.jpg'
probs, idxs = pred_image(model=pretrained_model, image_path=image_path, class_names=pretrained_model.class_names, transform=auto_transforms)
names = [cat_to_name[i] for i in idxs]

#Display or return to main function
print({names[i]: float(probs[i]) for i in range(len(names))})
print({names[i]: float(probs[i]) for i in range(len(names))})
print({names[i]: float(probs[i]) for i in range(len(names))})


{'anthurium': 0.985121488571167, 'lotus lotus': 0.006184564903378487, 'magnolia': 0.002347745466977358}
{'anthurium': 0.985121488571167, 'lotus lotus': 0.006184564903378487, 'magnolia': 0.002347745466977358}
{'anthurium': 0.985121488571167, 'lotus lotus': 0.006184564903378487, 'magnolia': 0.002347745466977358}


In [206]:
list_of_models_and_weights = [
    ('mobilenet_v2','MobileNet_V2_Weights','flowers_mobilenet_v2_model.pth'),
    ('densenet121','DenseNet121_Weights','flowers_densenet121_model.pth'),
    ('inception_v3','Inception_V3_Weights','flowers_inception_v3_model.pth'),
    ('efficientnet_b2','EfficientNet_B2_Weights','flowers_efficientnet_b2_model.pth'),
    ('squeezenet1_1','SqueezeNet1_1_Weights','flowers_squeezenet1_1_model.pth'),
    ('vgg16','VGG16_Weights','flowers_vgg16_model.pth'),
    ('alexnet','AlexNet_Weights','flowers_alexnet_model.pth'),
    ('resnet18','ResNet18_Weights','flowers_resnet18_model.pth'),
    ('swin_b','Swin_B_Weights','flowers_swin_b_model.pth'),
    ('vit_b_16', 'ViT_B_16_Weights','flowers_vit_b_16_model.pth')
                             ]

In [209]:
for model_name, model_weights, model_path in list_of_models_and_weights:
  print(model_weights)

MobileNet_V2_Weights
DenseNet121_Weights
Inception_V3_Weights
EfficientNet_B2_Weights
SqueezeNet1_1_Weights
VGG16_Weights
AlexNet_Weights
ResNet18_Weights
Swin_B_Weights
ViT_B_16_Weights


In [None]:
iface = gr.Interface(fn=process_input, inputs=gr.inputs.Image(type='filepath'), outputs=gr.outputs.Label, examples = examples,
                     title=title, description=description,article=article,interpretation=interpretation, enable_queue=enable_queue
                    )

In [207]:
!ls which-flower/*.pth
list_of_models_and_weights

which-flower/checkpoint-densenet121.pth
which-flower/flowers_alexnet_model.pth
which-flower/flowers_densenet121_model.pth
which-flower/flowers_efficientnet_b2_model.pth
which-flower/flowers_inception_v3_model.pth
which-flower/flowers_mobilenet_v2_model.pth
which-flower/flowers_resnet18_model.pth
which-flower/flowers_squeezenet1_1_model.pth
which-flower/flowers_swin_b_model.pth
which-flower/flowers_vgg16_model.pth
which-flower/flowers_vit_b_16_model.pth


[('mobilenet_v2', 'MobileNet_V2_Weights', 'flowers_mobilenet_v2_model.pth'),
 ('densenet121', 'DenseNet121_Weights', 'flowers_densenet121_model.pth'),
 ('inception_v3', 'Inception_V3_Weights', 'flowers_inception_v3_model.pth'),
 ('efficientnet_b2',
  'EfficientNet_B2_Weights',
  'flowers_efficientnet_b2_model.pth'),
 ('squeezenet1_1', 'SqueezeNet1_1_Weights', 'flowers_squeezenet1_1_model.pth'),
 ('vgg16', 'VGG16_Weights', 'flowers_vgg16_model.pth'),
 ('alexnet', 'AlexNet_Weights', 'flowers_alexnet_model.pth'),
 ('resnet18', 'ResNet18_Weights', 'flowers_resnet18_model.pth'),
 ('swin_b', 'Swin_B_Weights', 'flowers_swin_b_model.pth'),
 ('vit_b_16', 'ViT_B_16_Weights', 'flowers_vit_b_16_model.pth')]