In [1]:
import os
import torch
import numpy as np
import pytorch_lightning as pl
import albumentations
from albumentations.pytorch import ToTensorV2
from torchvision import models
from torch import nn
from skimage import io
from captum.attr import LayerGradCam


class Classifier(pl.LightningModule):
    def __init__(self):

        super().__init__()
        
        self.image_size = 224

        base_model = models.resnet50(pretrained=True)
        base_model.fc = nn.Sequential(
            nn.Dropout(p=0.6),
            nn.Linear(in_features=base_model.fc.in_features, out_features=1),
        )
        self.loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1]))
        self.base_model = base_model

        self.layer_gc = LayerGradCam(self.base_model, self.base_model.layer4[-1])
        
        self.dx_threshold = 0.5045563742228787

        self.sigmoid = nn.Sigmoid()
        
        
        self.test_transform = albumentations.Compose([
            albumentations.Resize(self.image_size, self.image_size),
            albumentations.Normalize(),
            ToTensorV2()
        ])
        
    @staticmethod
    def inverse_normalize(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        for t, m, s in zip(tensor, mean, std):
            t.mul_(s).add_(m)
        return tensor
    
    def forward(self, x):
        output = self.sigmoid(self.base_model(x))
        return output

In [2]:
model = Classifier()

checkpoint = torch.load("models/baseline_model.pth")
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [3]:

model.eval()

raw_img = io.imread(os.path.join("/home/caduser/Tirtha/data/combined", "ISIC_0024306.jpg"))
image = model.test_transform(image=raw_img)['image'].unsqueeze(0)


predictions = model(image)
predictions

tensor([[0.2299]], grad_fn=<SigmoidBackward0>)