In [5]:
from PIL import Image
import torchvision.transforms as tt
import torch
from torchvision import models
import torch.nn as nn
import torch.nn.functional as F

In [6]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
valid_tfms = tt.Compose([tt.Resize([224,224]),tt.ToTensor(), tt.Normalize(*stats)])

In [7]:
class Flatten(nn.Module):
    def forward(self,x):
        return torch.flatten(x,1)

class FoodImageClassifer(nn.Module):
    def __init__(self):
        super().__init__()
        mobilenet = models.mobilenet_v2(pretrained=True)
        self.body = mobilenet.features
        self.head = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(1280,101))
    
    def forward(self,x):
        x = self.body(x)
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1) 
        return self.head(x)
    
    def freeze(self):
        for name,param in self.body.named_parameters():
            param.requires_grad = False

In [14]:
model = FoodImageClassifer()
model.load_state_dict(torch.load('food_classifier.pth'))
model.eval()

FoodImageClassifer(
  (body): Sequential(
    (0): ConvBNActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): ConvBNActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, mo

In [15]:
imgpath = 'food-101/images/apple_pie/3004621.jpg'
img = Image.open(imgpath)
img_ts=valid_tfms(img)
batch_t = torch.unsqueeze(img_ts, 0)
out = model(batch_t)
prob = torch.nn.functional.softmax(out, dim = 1)[0] * 100
_, indices = torch.sort(out, descending = True)

In [16]:
indices

tensor([[  0,  58,   6,  73,  16,   8,  17,  12,  31, 100,  98,  43,  57,  39,
          24,  22,  60,  94,  72,  23,  84,  30,  67,  47,   9,  18,  45,  89,
          32,  34,  83,  80,  14,  49,   5,  87,  21,  42,  82,  44,  27,  46,
          62,  36,  68,  97,  29,  52,   4,  74,  54,  55,  15,  85,  41,  69,
          37,  59,  91,  64,  65,  63,  10,  20,   2,  40,  19,  61,  93,  99,
          35,  95,  33,  26,  56,  70,  76,  79,  38,  48,  51,  96,  13,   1,
          28,  78,  71,  77,  66,  50,  53,  90,  11,  92,  88,   7,  25,  75,
          81,  86,   3]])

In [17]:
 with open('classes.txt') as f:
        classes = [line.strip() for line in f.readlines()]

In [18]:
val = [(classes[idx], prob[idx].item()) for idx in indices[0][:5]]

In [19]:
val

[('apple_pie', 32.59563064575195),
 ('ice_cream', 12.091999053955078),
 ('beignets', 11.654073715209961),
 ('panna_cotta', 11.581610679626465),
 ('cheese_plate', 4.838817596435547)]