### Demo code for running inference using trained model and extract features 

In [1]:
import os 
import numpy as np 
import torch 
import glob 
from PIL import Image
from types import SimpleNamespace
import json
from torchvision import transforms
import torch.nn as nn
import sys
root_code = os.path.dirname(os.getcwd())
sys.path.insert(0, root_code)

from src.models.generator import unet_translator


In [2]:
# update following as needed
checkpoint_path = 'model/pytorch_model.pt' # path to trained translator model
config_path = 'model/config.json' # path to corresponding config file from trained model
device = 'cuda:0'
img_path = None # path to he tile, if None then demo done on random numpy array
features = {} 

In [3]:
# config file
with open(config_path, "r") as f:
    config = json.load(f, object_hook=lambda d: SimpleNamespace(**d))

In [4]:
# initialize model
model = unet_translator(
    input_nc=config.input_nc,
    output_nc=config.output_nc,
    use_high_res=config.use_high_res,
    use_multiscale=config.use_multiscale,
    ngf=config.ngf,
    depth=config.depth,
    encoder_padding=config.encoder_padding,
    decoder_padding=config.decoder_padding, 
    device="cpu", 
    extra_feature_size=config.fm_feature_size
)


In [5]:
# load checkpoint and set model in eval model
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['trans_ema_state_dict']) # trans_state_dict
model.to(device)
model.eval()
print('model loaded')

model loaded


In [6]:
# defining hooks to get features from unet bottleneck
def get_features(name):
    def hook(model, input, output):
        print(f"Hook triggered for {name}!")
        print(f"Output shape: {output.shape}")
        features[name] = output.detach()
    return hook
# global avg pooling 
gap = nn.AdaptiveAvgPool2d(1) 
# translator hook in center block, can also do "translator_model.center_block[0].conv1"
model.center_block[0].register_forward_hook(get_features('feats_translator'))


<torch.utils.hooks.RemovableHandle at 0x7f091d4a56a0>

In [7]:
# transform function for img
def transform_np_to_tensor(np_img):
        ''' Construct torch tensor from a numpy array
        np_img: numpy array of shape [H,W,C]
        returns a torch tensor with shape [C,H,W]
        '''
        np_img = np_img.transpose((2, 0, 1))
        np_img = np.ascontiguousarray(np_img)
        np_img = np_img // 255 # use this if img in range 0-255
        torch_img = torch.from_numpy(np_img).float()
        torch_img = torch_img.unsqueeze(dim=0)
        return torch_img

In [8]:
# load image and apply transformation
if img_path: 
    image = Image.open(img_path)
else: 
    image = np.full((256,256,3), 255)
image = transform_np_to_tensor(image).to(device)
print(image.shape)

torch.Size([1, 3, 256, 256])


In [9]:
# inference using model; 3 pred imc images are generated with diff resolution -- use last one
_, _, pred_imc = model(image)
print(pred_imc.shape)

Hook triggered for feats_translator!
Output shape: torch.Size([1, 256, 1, 1])
torch.Size([1, 11, 64, 64])


In [10]:
pred_imc = pred_imc.detach().cpu().numpy().squeeze(0).transpose((1, 2, 0))
print(pred_imc.shape)

(64, 64, 11)


In [11]:
# extract features from encoder part of translator unet model 
pred_imc_feats = gap(features['feats_translator']).squeeze().cpu().numpy()
pred_imc_feats.shape

(256,)

In [None]:
model