In [1]:
# Inference
MEAN_CHANNEL_VALUES = (0.07730, 0.05958, 0.07135)  # RGB
CHANNEL_STD_DEV = (0.12032, 0.08593, 0.14364)

In [2]:
import torch
from torch import nn
from torchvision.models import resnet18
import tez
import albumentations as A
import numpy as np

class ResNet18(tez.Model):
    '''Model class to facilitate transfer learning 
    from a resnet-18 model
    '''
    NUM_CLASSES = 19
    DROPOUT_RATE = 0.1
    IMG_DIR = 'D:/HPA_comp/single_cells'
    
    def __init__(self, train_df=None, valid_df=None, batch_size=16, train_aug=None, valid_aug=None, pretrained=True):
        # Initialise pretrained net and final layers for cell classification
        super().__init__()
        self.convolutions = nn.Sequential(*(list(resnet18(pretrained).children())[0:-1]))
        self.dropout = nn.Dropout(self.DROPOUT_RATE)
        self.dense = nn.Linear(512, self.NUM_CLASSES)
        self.out = nn.Sigmoid()
        self.loss_fn = nn.BCELoss()
                
    def forward(self, image, target=None):
        batch_size = image.shape[0]
        
        # Extracts 512x1 feature vector from pretrained resnet18 conv layers
        x = self.convolutions(image).reshape(batch_size, -1)
        # Fully connected dense layer to 19 class output
        output = self.dense(self.dropout(x))
        # Sigmoid activations on output to infer class probabilities
        output_probs = self.out(output)
        
        if target is not None:
            loss = self.loss_fn(output_probs, target.to(torch.float32))  # why to float32???
            metrics = self.monitor_metrics(output_probs, target)
            return output_probs, loss, metrics
        return output_probs, None, None
    
    def fetch_optimizer(self):
        opt = torch.optim.Adam(self.parameters(), lr=3e-4)
        return opt
    
    def fetch_scheduler(self):
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=1, eta_min=1e-6, last_epoch=-1
        )
        return sch

inf_aug = A.Compose([
    A.Normalize(
        mean=MEAN_CHANNEL_VALUES,
        std=CHANNEL_STD_DEV,
        max_pixel_value=1.0,
        p=1.0
    )
])

In [4]:
def infer(img, model, inf_aug):
    'In: np.array (244, 244, 3); out: np.array (1, 19)'
    # Augment; re-shuffle channels; reshape; send to gpu
    X = inf_aug(image=img)['image']
    X = np.transpose(img, (2, 0, 1)).astype(np.float32)
    X = X.reshape((1, 3, 224, 224))
    X = torch.tensor(X, dtype=torch.float32).to('cuda')
    with torch.no_grad():
        out = model(X)[0]
    return out.cpu().detach().numpy()

In [3]:
model = ResNet18()
model.load('../models/test_trained_model.bin')

In [5]:
img = np.zeros([224, 224, 3], dtype='float32')

infer(img)

array([[4.0226492e-01, 1.2569706e-02, 4.9887609e-02, 1.2750092e-02,
        1.2971192e-02, 2.3765745e-02, 1.9739717e-02, 1.9892044e-02,
        1.4863893e-02, 1.0636472e-02, 2.0367175e-03, 3.3339605e-04,
        3.7639823e-02, 7.0945047e-02, 3.2744169e-02, 4.9520819e-03,
        1.7193443e-01, 5.3259977e-03, 6.2999275e-04]], dtype=float32)