# Library

In [7]:
# install missing libraries if failed to import libraries 
import torch
import torchvision
import timm
import torch.nn as nn

# Model

In [15]:
class ImageAndClinicalModel(torch.nn.Module):

    def __init__(self):
        super(ImageAndClinicalModel, self).__init__()
        
        self.clinical_dimension = 16
        self.image_model = timm.create_model('tf_efficientnetv2_s_in21ft1k', pretrained=True, in_chans= 2, num_classes= 2)

        self.clinical_model = torch.nn.Sequential(
                            torch.nn.Linear(self.clinical_dimension, 6836),
                            torch.nn.BatchNorm1d(6836),
                            nn.ReLU(),
                            nn.Dropout(0.012310474179685116),

                            torch.nn.Linear(6836, 5657),
                            torch.nn.BatchNorm1d(5657),
                            nn.ReLU(),
                            nn.Dropout(0.06788695883005857),

                            torch.nn.Linear(5657, 468),
                            torch.nn.BatchNorm1d(468),
                            nn.ReLU(),
                            nn.Dropout(0.6673239199444652),

                            torch.nn.Linear(468, 2) )
        
    def forward(self, image_feature, clinical_feature):

        image_output = self.image_model(image_feature)

        clinical_output = self.clinical_model(clinical_feature)

        final_output = image_output + clinical_output

        return final_output

In [16]:
model = ImageAndClinicalModel()

In [17]:
# batch size, channel, height, width
image_feature = torch.randn(2, 2, 500, 500)
clinical_feature = torch.randn(2, 16)

model( image_feature, clinical_feature )

tensor([[3.3108, 0.9135],
        [7.4740, 4.4265]], grad_fn=<AddBackward0>)