In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image


class BaselineTactileEncoder(nn.Module):
    def __init__(
        self,
        vision_resnet_emb_size: int,
        tactile_resnet_emb_size: int,
        ft_emb_size: int,
        gripper_emb_size: int,
        model_emb_size: int,
        cell_state_size: int,
        hidden_state_size: int):
        
        super(BaselineTactileEncoder, self).__init__()
        self.vision_encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.tactile_encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.vision_resnet_emb_size = vision_resnet_emb_size
        self.tactile_resnet_emb_size = tactile_resnet_emb_size
        self.model_emb_size = model_emb_size
        self.ft_emb_size = ft_emb_size
        self.gripper_emb_size = gripper_emb_size
        self.hidden_state_size = hidden_state_size
        
        self.resnet_weights = ResNet50_Weights.DEFAULT
        self.resnet_transforms = self.resnet_weights.transforms()
        # PIL IMAGE (B, C, H, W) -> 
        # resized to (B, C, 256, 256) -> 
        # normalized(values[0 to 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        ## Vision part
        self.vision_encoder.fc = nn.Identity()

        ## Tactile part
        self.tactile_encoder.fc = nn.Identity()

        ## pre-lstm projection
        self.fc1 = nn.Linear(input_size=self.vision_resnet_emb_size + self.tactile_resnet_emb_size + self.ft_emb_size * 6 + self.gripper_emb_size * 2, 
                             output_size=self.model_emb_size)
        
        self.lstm = nn.LSTM(input_size=self.model_emb_size, 
                            hidden_size=self.hidden_state_size, 
                            num_layers=6,
                            bidirectional=True, 
                            batch_first=True)

        self.fc2 = nn.Linear(input_size=self.hidden_state_size, 
                            output_size=self.self.hidden_state_size//4)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(self.hidden_state_size//4)
        self.fc3 = nn.Linear(input_size=self.hidden_state_size//4, 
                            output_size=1)

    def forward(self, 
                vision_img:Image,  #---------(f1*T,3,H,W), sampled at f1 fps, for T seconds
                tactile_img:Image, #---------(f4*T,3,H,W), sampled at f4 fps, for T seconds
                ft_data:torch.Tensor, #------(f2*T,6), sampled at f2 fps, for T seconds
                gripper_data:torch.Tensor #--(f3*T,2), sampled at f3 fps, for T seconds
                )->torch.Tensor: #---------->> predicted slip probability for this timestep
        
        preprocd_vision = self.resnet_transforms(vision_img)
        preprocd_tactile = self.resnet_transforms(tactile_img)

        vision_emb = self.vision_encoder(preprocd_vision) # (f1*T, vision_resnet_emb_size)
        vision_emb = torch.flatten(vision_emb, start_dim=1) # (f1*T*vision_resnet_emb_size)
        
        tactile_emb = self.tactile_encoder(preprocd_tactile) # (f4*T, tactile_resnet_emb_size)
        tactile_emb = torch.flatten(tactile_emb, start_dim=1) # (f4*T*tactile_resnet_emb_size)

        ft_emb = torch.flatten(ft_data, start_dim=1) # (f2*T, 6)
        gripper_emb = torch.flatten(gripper_data, start_dim=1) # (f3*T, 2)

        ## concatenate all the embeddings
        combined_emb = torch.cat((vision_emb, tactile_emb, ft_emb, gripper_emb), dim=1) # (f1*T*vision_resnet_emb_size + f4*T*tactile_resnet_emb_size + f2*T*6 + f3*T*2)


