In [6]:
from models.ShaSpec import  ShaSpec
import torch
import yaml
from math import ceil
import numpy as np

config_file = open('configs/model.yaml', mode='r')
config = yaml.load(config_file, Loader=yaml.FullLoader)["shaspec"]

# Load the configuration
filter_num = config['filter_num']
filter_size = config['filter_size']
sa_div = config['sa_div']

# Parameter for the model
activation = "ReLU"
shared_encoder_type = "concatenated"

# Craft exemplary dataset for one subject (similar to DSADS)
num_modalities = 5
num_classes = 19
num_of_sensor_channels = 9
miss_rate = 0.6

modalities_to_omit = ceil(miss_rate * num_modalities)
missing_indices = np.random.choice(num_modalities, modalities_to_omit, replace=False)

print(missing_indices)

ablate_shared_encoder = False
ablate_missing_modality_features = False

# Dummy input for each modality
# B F T C
input = (64, 1, 125, num_of_sensor_channels)  
dummy_inputs = [torch.randn(input) for _ in range(num_modalities)]

shaspec_model = ShaSpec(input, 
                        num_modalities,
                        miss_rate,
                        num_classes, 
                        activation, 
                        shared_encoder_type,
                        ablate_shared_encoder,
                        ablate_missing_modality_features,
                        config
                        )

# Forward pass with the dummy inputs and missing_indices
output = shaspec_model(dummy_inputs, missing_indices)


[0 1 2]
Number of total modalities:  5
Selected miss rate:  0.6
Number of available modalities:  2
Ablate shared encoder:  False
Ablate missing modality features:  False
PREDICTION HERE
tensor([[-0.0396, -0.0113,  0.0155,  ...,  0.0159,  0.0610, -0.0075],
        [-0.0279, -0.0027,  0.0063,  ...,  0.0104,  0.0623, -0.0192],
        [-0.0312, -0.0081,  0.0055,  ...,  0.0152,  0.0502, -0.0109],
        ...,
        [-0.0325, -0.0057,  0.0090,  ...,  0.0153,  0.0596, -0.0051],
        [-0.0351, -0.0102,  0.0032,  ...,  0.0061,  0.0551, -0.0162],
        [-0.0373, -0.0103,  0.0032,  ...,  0.0099,  0.0682, -0.0149]],
       grad_fn=<AddmmBackward0>)
torch.Size([64, 19])
