## Example

In this simple example, we load an image, pre-process it, and classify it with a pretrained EfficientNet.

In [1]:
import sys
sys.path.append('../../')

import json
from PIL import Image

import torch
from torchvision import transforms

from efficientnet_pytorch import EfficientNet
from efficientnet_pytorch.utils import *

In [2]:
model_name = 'efficientnet-b0'
image_size = EfficientNet.get_image_size(model_name) # 224
image_size

224

In [3]:
# Open image
import numpy as np
import imageio

file_name = '/Users/yanqixu/Documents/1.0.MasterCDS/project-summer/breast_img/0_L-CC.png'
loaded_image = np.array(imageio.imread(file_name)).astype(np.float32)

loaded_image = np.expand_dims(np.expand_dims(loaded_image, 0), 0).copy()
mammo = torch.Tensor(loaded_image)
# mammo = torch.cat([tensor_batch,tensor_batch],0)
# mammo.shape

In [4]:
# convert python 2D array into 4D torch tensor in N,C,H,W format
loaded_image = np.expand_dims(np.expand_dims(loaded_image, 0), 0).copy()
tensor_batch = torch.Tensor(loaded_image)#.to(device)

In [7]:
# Preprocess image
img = Image.open('img.jpg')
tfms = transforms.Compose([transforms.Resize(image_size), transforms.CenterCrop(image_size), 
                           transforms.ToTensor(),
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
img = tfms(img).unsqueeze(0)

In [8]:
img_new = torch.cat([torch.cat([img[:,:,:,:],torch.zeros(1,3,224,256-224)],3),torch.zeros(1,3,256-224,256)],2)
img_new.shape

torch.Size([1, 3, 256, 256])

In [7]:
# Load class names
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(1000)]

In [8]:
# Classify with EfficientNet
model = EfficientNet.from_pretrained(model_name,in_channels=1)
model.eval()
with torch.no_grad():
    logits = model(img_new)
preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()

print('-----')
for idx in preds:
    label = labels_map[idx]
    prob = torch.softmax(logits, dim=1)[0, idx].item()
    print('{:<75} ({:.2f}%)'.format(label, prob*100))

Loaded pretrained weights for efficientnet-b0
-----
window screen                                                               (33.92%)
fire screen, fireguard                                                      (18.37%)
hamper                                                                      (1.62%)
envelope                                                                    (1.31%)
radiator                                                                    (0.88%)


In [7]:
new_model = EfficientNet.from_name(model_name,in_channels=1,stem_filters=16)

In [9]:
output = new_model.extract_features(mammo)

stem torch.Size([1, 16, 1472, 960])
In MBConvBlock
torch.Size([1, 16, 737, 481])
torch.Size([1, 24, 368, 240])
torch.Size([1, 24, 368, 240])
torch.Size([1, 40, 184, 120])
torch.Size([1, 40, 184, 120])
torch.Size([1, 80, 92, 60])
torch.Size([1, 80, 92, 60])
torch.Size([1, 80, 92, 60])
torch.Size([1, 112, 92, 60])
torch.Size([1, 112, 92, 60])
torch.Size([1, 112, 92, 60])
torch.Size([1, 192, 46, 30])
torch.Size([1, 192, 46, 30])
torch.Size([1, 192, 46, 30])
torch.Size([1, 192, 46, 30])
torch.Size([1, 320, 46, 30])


In [9]:
local_model = EfficientNet.from_name(model_name,in_channels=3)
output2 = local_model.forward(img_new)

torch.Size([1, 32, 128, 128])
torch.Size([1, 32, 128, 128])
torch.Size([1, 32, 128, 128])
In MBConvBlock
torch.Size([1, 16, 128, 128])
torch.Size([1, 24, 64, 64])
torch.Size([1, 24, 64, 64])
torch.Size([1, 40, 32, 32])
torch.Size([1, 40, 32, 32])
torch.Size([1, 80, 16, 16])
torch.Size([1, 80, 16, 16])
torch.Size([1, 80, 16, 16])
torch.Size([1, 112, 16, 16])
torch.Size([1, 112, 16, 16])
torch.Size([1, 112, 16, 16])
torch.Size([1, 192, 8, 8])
torch.Size([1, 192, 8, 8])
torch.Size([1, 192, 8, 8])
torch.Size([1, 192, 8, 8])
torch.Size([1, 320, 8, 8])
torch.Size([1, 1280, 8, 8])


In [18]:
blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25',
        'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
blocks_args = BlockDecoder.decode(blocks_args)


In [28]:
blocks_args[0]._replace(num_repeat=6)

BlockArgs(num_repeat=6, kernel_size=3, stride=[1], expand_ratio=1, input_filters=32, output_filters=16, se_ratio=0.25, id_skip=True)

In [30]:
blocks_args

[BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=1, input_filters=32, output_filters=16, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=2, kernel_size=3, stride=[2], expand_ratio=6, input_filters=16, output_filters=24, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=2, kernel_size=5, stride=[2], expand_ratio=6, input_filters=24, output_filters=40, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=3, kernel_size=3, stride=[2], expand_ratio=6, input_filters=40, output_filters=80, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=3, kernel_size=5, stride=[1], expand_ratio=6, input_filters=80, output_filters=112, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=4, kernel_size=5, stride=[2], expand_ratio=6, input_filters=112, output_filters=192, se_ratio=0.25, id_skip=True),
 BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=6, input_filters=192, output_filters=320, se_ratio=0.25, id_skip=True)]

In [33]:
global_params = GlobalParams(
        width_coefficient=1,
        depth_coefficient=1,
        image_size=1,
        dropout_rate=1,

        num_classes=2,
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        drop_connect_rate=1,
        depth_divisor=8,
        min_depth=None,
        stem_filters = 32,
    )

In [35]:
global_params._replace(num_classes=10)

GlobalParams(width_coefficient=1, depth_coefficient=1, image_size=1, dropout_rate=1, num_classes=10, batch_norm_momentum=0.99, batch_norm_epsilon=0.001, drop_connect_rate=1, depth_divisor=8, min_depth=None, stem_filters=32)

In [36]:
global_params

GlobalParams(width_coefficient=1, depth_coefficient=1, image_size=1, dropout_rate=1, num_classes=2, batch_norm_momentum=0.99, batch_norm_epsilon=0.001, drop_connect_rate=1, depth_divisor=8, min_depth=None, stem_filters=32)