### Import libraries

In [2]:
import torch
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import model
from dataset_class import PTBXL_Dataset

### Import some settings

In [4]:
# book keeping namings and code
from settings import base_architecture, img_size, prototype_shape, num_classes, \
                     prototype_activation_function, add_on_layers_type, experiment_run, \
                     train_batch_size, test_batch_size, train_push_batch_size, \
                     train_information, test_information, num_train_examples, num_test_examples

### Create the model

In [5]:
# construct the model
base_architecture = 'resnet18'
ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)
ppnet = ppnet.to('cuda')

In [6]:
# # Load the model
# ppnet.load_state_dict(torch.load(model_path))
ppnet.eval()

PPNet(
	features: resnet18_features,
	img_size: 224,
	prototype_shape: (160, 128, 1, 1),
	proto_layer_rf_info: [7, 32, 435, 0.5],
	num_classes: 5,
	epsilon: 0.0001
)

### Data loader (see the dataset_class to understand dataloader)

In [7]:
# Initialize dataset and dataloader for training
train_dataset = PTBXL_Dataset(train_information, reshape=True)
# Initialize dataset and dataloader for testing
test_dataset = PTBXL_Dataset(test_information, reshape=True)

# Create data loaders for the subsets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=4, pin_memory=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=4, pin_memory=False)


In [8]:
for waveforms, labels in train_loader:
    waveforms = waveforms.cuda()
    labels = labels.cuda()
    print(waveforms.shape, labels.shape)
    break

torch.Size([80, 3, 100, 40]) torch.Size([80])


### Inference
#### Note:
I loaded waveforms, reshaped each waveform to shape (3, x, x). This is because the backend i am using are still the conv. backends that expect images with 3 channels.

In [9]:
logits, min_distances = ppnet(waveforms)

print(f"Shape of logits is: {logits.shape}")
print(f"Shape of prototype layer activations is: {min_distances.shape}")

Shape of logits is: torch.Size([80, 5])
Shape of prototype layer activations is: torch.Size([80, 160])


## Note:
In this model, prototypes are for diagnostic labels. And for each of the 5 labels, there are 32 prototypes. Hence the shape of min_distances for each example is 32 * 5 = 160.