In [4]:
from models.ShaSpec import  ShaSpec
import torch
import yaml
from math import ceil
import numpy as np

config_file = open('configs/model.yaml', mode='r')
config = yaml.load(config_file, Loader=yaml.FullLoader)["shaspec"]

# Load the configuration
filter_num = config['filter_num']
filter_size = config['filter_size']
sa_div = config['sa_div']

# Parameter for the model
activation = "ReLU"
shared_encoder_type = "concatenated"

# Craft exemplary dataset for one subject (similar to DSADS)
num_modalities = 5
num_classes = 19
num_of_sensor_channels = 9
miss_rate = 0.6

modalities_to_omit = ceil(miss_rate * num_modalities)
missing_indices = np.random.choice(num_modalities, modalities_to_omit, replace=False)

print(missing_indices)

ablate_shared_encoder = False
ablate_missing_modality_features = False

# Dummy input for each modality
# B F T C
input = (64, 1, 125, num_of_sensor_channels)  
dummy_inputs = [torch.randn(input) for _ in range(num_modalities)]

shaspec_model = ShaSpec(input, 
                        num_modalities,
                        miss_rate,
                        num_classes, 
                        activation, 
                        shared_encoder_type,
                        ablate_shared_encoder,
                        ablate_missing_modality_features,
                        config
                        )

# Forward pass with the dummy inputs and missing_indices
output = shaspec_model(dummy_inputs, missing_indices)


[1 2 0]
Number of total modalities:  5
Selected miss rate:  0.6
Number of available modalities:  2
Ablate shared encoder:  False
Ablate missing modality features:  False
PREDICTION HERE
tensor([[ 0.0263, -0.0242, -0.0382,  ..., -0.0016, -0.0026,  0.0090],
        [ 0.0284, -0.0247, -0.0423,  ..., -0.0024, -0.0016,  0.0053],
        [ 0.0269, -0.0224, -0.0332,  ...,  0.0059, -0.0031,  0.0073],
        ...,
        [ 0.0251, -0.0201, -0.0385,  ..., -0.0005, -0.0068,  0.0064],
        [ 0.0297, -0.0269, -0.0362,  ...,  0.0016, -0.0080,  0.0115],
        [ 0.0262, -0.0204, -0.0331,  ...,  0.0002, -0.0015,  0.0076]],
       grad_fn=<AddmmBackward0>)
torch.Size([64, 19])
