In [None]:
print("These two below should match")
import sys
print(sys.executable)

!which python3
print("These two above should match")

These two below should match
/ocean/projects/cis260031p/shared/temu_conda/bin/python3
/ocean/projects/cis260031p/shared/temu_conda/bin/python3


These two above should match


In [1]:
from PIL.Image import Image
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights

In [2]:
class BaselineTactileEncoder(nn.Module):
    def __init__(
        self,
        rgb_freq: int, 
        tactile_freq: int,
        ft_freq: int,
        vision_resnet_emb_size: int,
        tactile_resnet_emb_size: int,
        ft_emb_size: int,
        gripper_emb_size: int,
        model_emb_size: int,
        cell_state_size: int,
        hidden_state_size: int):
        
        super(BaselineTactileEncoder, self).__init__()
        self.vision_encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.tactile_encoder = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.vision_resnet_emb_size = vision_resnet_emb_size
        self.tactile_resnet_emb_size = tactile_resnet_emb_size
        self.model_emb_size = model_emb_size
        self.ft_emb_size = ft_emb_size
        self.gripper_emb_size = gripper_emb_size
        self.hidden_state_size = hidden_state_size
        
        self._rgb_freq = rgb_freq 
        self._tactile_freq = tactile_freq
        self._ft_freq = ft_freq
        
        self.resnet_weights = ResNet50_Weights.DEFAULT
        self.resnet_transforms = self.resnet_weights.transforms()
        # PIL IMAGE (B, C, H, W) -> 
        # resized to (B, C, 256, 256) -> 
        # normalized(values[0 to 1], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        ## Vision part
        self.vision_encoder.fc = nn.Identity() # type: ignore[assignment]


        ## Tactile part
        self.tactile_encoder.fc = nn.Identity() # type: ignore[assignment]
        
        self._pre_lstm_dim = (self._rgb_freq    * self.vision_resnet_emb_size + 
                            self._tactile_freq * self.tactile_resnet_emb_size + 
                            self._ft_freq      * self.ft_emb_size + 
                            self._gripper_freq * self.gripper_emb_size)

        ## pre-lstm projection
        self.fc1 = nn.Linear(in_features=self._pre_lstm_dim, 
                             out_features=self.model_emb_size)
        
        self.lstm = nn.LSTM(input_size=2*self.model_emb_size, 
                            hidden_size=self.hidden_state_size, 
                            num_layers=6,
                            bidirectional=False, 
                            batch_first=True)

        self.fc2 = nn.Linear(input_size=self.hidden_state_size, 
                            output_size=self.hidden_state_size//4)
        self.relu = nn.ReLU()
        self.bn_fc2 = nn.BatchNorm1d(self.hidden_state_size//4)
        self.fc3 = nn.Linear(input_size=self.hidden_state_size//4, 
                            output_size=1)

    def forward(self, 
                vision_img:Image,  #---------(f1*T,3,H,W), sampled at f1 fps
                tactile_img:Image, #---------(f4*T,3,H,W), sampled at f4 fps
                ft_data:torch.Tensor, #------(f2*T,6), sampled at f2 fps
                gripper_force:torch.Tensor #--(1, 1), sampled at f3 fps
                )->torch.Tensor: #---------->> predicted slip probability for time T+1
        
        assert vision_img.shape[0] == self._rgb_freq * vision_img.shape[1] // 3, f"Expected vision_img shape ({self._rgb_freq},3,H,W), got {vision_img.shape}"
        assert tactile_img.shape[0] == self._tactile_freq * tactile_img.shape[1] // 3, f"Expected tactile_img shape ({self._tactile_freq},3,H,W), got {tactile_img.shape}"
        assert ft_data.shape[0] == self._ft_freq * ft_data.shape[1] // 6, f"Expected ft_data shape ({self._ft_freq},6), got {ft_data.shape}"
        assert gripper_force.shape[0] == self._gripper_freq * gripper_force.shape[1] // 2, f"Expected gripper_force shape ({self._gripper_freq},2), got {gripper_force.shape}"
        
        preprocd_vision = self.resnet_transforms(vision_img)
        preprocd_tactile = self.resnet_transforms(tactile_img)

        vision_emb = self.vision_encoder(preprocd_vision) # (f1*T, vision_resnet_emb_size)
        vision_emb = torch.flatten(vision_emb, start_dim=1) # (f1*T*vision_resnet_emb_size)
        
        tactile_emb = self.tactile_encoder(preprocd_tactile) # (f4*T, tactile_resnet_emb_size)
        tactile_emb = torch.flatten(tactile_emb, start_dim=1) # (f4*T*tactile_resnet_emb_size)

        ft_emb = torch.flatten(ft_data, start_dim=1) # (f2*T, 6)
        gripper_emb = torch.repeat(gripper_force, (ft_data.size[0]))

        ## concatenate all the embeddings
        combined_emb = torch.cat((vision_emb, tactile_emb, ft_emb, gripper_emb), dim=1) # (f1*T*vision_resnet_emb_size + f4*T*tactile_resnet_emb_size + f2*T*6 + f3*T*2)

        assert combined_emb.shape[1] == self._pre_lstm_dim, f"Expected combined_emb shape ({self._pre_lstm_dim}), got {combined_emb.shape[1]}"
        
        fc1_out = self.fc1(combined_emb) # (model_emb_size)
        lstm_h, lstm_c = self.lstm(fc1_out)
        # lstm_h shape: (batch_size, NUM_LSTM_LAYERS, hidden_state_size)
        # lstm_c shape: (batch_size, NUM_LSTM_LAYERS, hidden_state_size)
        lstm_out = torch.concat(lstm_h[:, -1, :], lstm_c[:, -1, :]) # (batch_size, hidden_state_size)
        fc2_out = self.fc2(lstm_out) # (hidden_state_size//4)
        fc2_out = self.relu(fc2_out)
        fc2_out = self.bn_fc2(fc2_out)
        
        fc3_out = self.fc3(fc2_out) # (1)
        slip_prob = torch.sigmoid(fc3_out) # (1)
        return slip_prob
        
        
        
        # TODO ranais, For Feb 14th
        # 1. Need to add relu, dropout, bn, and complete forward
        # 2. Need to figure out what all the shapes will be and try to overfit with a limited set of images
        # 3. Does the current architecture make sense? is the flattening correct? Do we need to do some sort of pooling instead of flattening? Do we need SOO many fc layers?
        