In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os
import numpy as np

class MultimodalDataset(Dataset):
    def __init__(self, data_dir):
        # create a list of image/depth/lidar paths
        self.data_dir = data_dir
        self.image_paths = []
        self.depth_paths = []
        self.lidar_paths = []
        self.image_paths, self.depth_paths, self.lidar_paths = self.load_data()
    
    def load_data(self):
        # get folders
        data_dir = self.data_dir
        for folder in os.listdir(data_dir):
            trajectory_folder_path = os.path.join(data_dir, folder)
            if not os.path.isdir(trajectory_folder_path):
                continue
            print(f'Processing folder: {trajectory_folder_path}')

            self.image_folder_name = 'image_lcam_fish'
            self.depth_folder_name = 'depth_lcam_fish'
            self.lidar_folder_name = 'lidar'
            
            image_folder_path = os.path.join(trajectory_folder_path, self.image_folder_name)
            depth_folder_path = os.path.join(trajectory_folder_path, self.depth_folder_name)
            lidar_folder_path = os.path.join(trajectory_folder_path, self.lidar_folder_name)

            if not os.path.exists(image_folder_path):
                continue
            if not os.path.exists(depth_folder_path):
                continue
            if not os.path.exists(lidar_folder_path):
                continue

            # get image/depth/lidar paths
            if len(os.listdir(image_folder_path)) != len(os.listdir(depth_folder_path)) \
                or len(os.listdir(image_folder_path)) != len(os.listdir(lidar_folder_path)) \
                or len(os.listdir(depth_folder_path)) != len(os.listdir(lidar_folder_path)):
                print(f'Number of images, depth, and lidar files do not match in folder: {trajectory_folder_path}')
                continue
            self.image_paths += [os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]
            self.depth_paths += [os.path.join(depth_folder_path, path) for path in os.listdir(depth_folder_path)]
            self.lidar_paths += [os.path.join(lidar_folder_path, path) for path in os.listdir(lidar_folder_path)]
        print(f'Number of images: {len(self.image_paths)}')
        print(f'Number of depth: {len(self.depth_paths)}')
        print(f'Number of lidar: {len(self.lidar_paths)}')
        return self.image_paths, self.depth_paths, self.lidar_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        # read the image from disk
        image_path = self.image_paths[index]
        image = process_image(image_path)

        # read the depth from disk
        depth_path = self.depth_paths[index]
        depth = process_depth(depth_path)

        # read the lidar from disk
        lidar_path = self.lidar_paths[index]
        lidar = process_lidar(lidar_path)

        return image, depth, lidar

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
def process_image(image_path):
    image = Image.open(image_path)
    transform_image = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform_image(image)
    return image
image_path = "/home/tyz/Desktop/11_777/Data_easy/P000/image_lcam_fish/000000_lcam_fish_image.png"
image = process_image(image_path)

In [4]:
from PIL import Image
import torchvision.transforms as transforms
depth_path = '/home/tyz/Desktop/11_777/Data_easy/P000/depth_lcam_fish/000000_lcam_fish_depth.png'
# print(depth.size)
def process_depth(depth_path):
    depth = Image.open(depth_path)
    transform_depth = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5/2])
    ])
    depth = transform_depth(depth)
    return depth[1:]
depth = process_depth(depth_path)
# print(depth.shape)

In [5]:
import torch
import numpy as np
import open3d as o3d

def process_lidar(filename):
  # Load point cloud data from file
  pcd = o3d.io.read_point_cloud(filename)
  points = np.asarray(pcd.points)

  # Set voxel size
  voxel_size = 0.1

  # Voxelization
  voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd,voxel_size=voxel_size)
  voxels = np.asarray(voxel_grid.get_voxels())
  # print(voxels.shape)

  # Extract voxel features
  features = []
  for voxel in voxels:
      voxel_indices = voxel.grid_index
      if len(voxel_indices) == 0:
          feature = np.zeros(6, dtype=np.float32)
      else:
          voxel_points = points[voxel_indices]
          feature = np.concatenate([np.mean(voxel_points[:, :3], axis=0), np.max(voxel_points[:, :3], axis=0)])
      features.append(feature)
  features = np.stack(features)

  # Normalize features
  features = (features - np.mean(features, axis=0)) / np.std(features, axis=0)

  # Convert features to tensor
  tensor = torch.from_numpy(features)
  tensor = tensor.permute(1, 0).reshape(-1)  # (batch_size=1, num_channels=6, height=num_voxels, width=1)
  padding=3*224*224-tensor.shape[-1]
  tensor = torch.nn.functional.pad(tensor, (0, padding), mode='constant', value=0).reshape((3,224,224)).float()
  # print(tensor)
#   print(tensor.shape)
  return tensor

lidar=process_lidar('/home/tyz/Desktop/11_777/Data_easy/P000/lidar/000000_lcam_front_lidar.ply')

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [6]:
import torch
from transformers import AutoImageProcessor, ViTMAEForPreTraining, ViTMAEConfig
from multimodal.custom_models.CustomViT import CustomViT
from multimodal.custom_models.CustomViTMAE import CustomViTMAE
import torch.utils.data
# call CustomViT
model_name = "facebook/vit-mae-base"
vit_config = ViTMAEConfig.from_pretrained(model_name)
vit_config.output_hidden_states=True
vit_model = CustomViT.from_pretrained(model_name,config=vit_config)

config = ViTMAEConfig.from_pretrained(model_name)
config.output_hidden_states=True

# load from pretrained model and replace the original encoder with custom encoder
custom_model = CustomViTMAE.from_pretrained("facebook/vit-mae-base",config=config)
custom_model.vit = vit_model
custom_model = custom_model.cuda()

Some weights of the model checkpoint at facebook/vit-mae-base were not used when initializing CustomViT: ['decoder.decoder_layers.4.layernorm_before.bias', 'decoder.decoder_layers.7.intermediate.dense.bias', 'decoder.decoder_layers.3.layernorm_before.bias', 'decoder.decoder_layers.1.layernorm_after.weight', 'decoder.decoder_layers.4.attention.attention.query.weight', 'decoder.decoder_layers.5.layernorm_after.weight', 'decoder.decoder_layers.5.attention.attention.value.weight', 'decoder.decoder_layers.3.output.dense.weight', 'decoder.decoder_layers.4.output.dense.bias', 'decoder.decoder_layers.6.attention.attention.query.bias', 'decoder.decoder_layers.3.attention.attention.key.weight', 'decoder.decoder_layers.6.intermediate.dense.weight', 'decoder.decoder_layers.4.output.dense.weight', 'decoder.decoder_layers.4.intermediate.dense.weight', 'decoder.decoder_layers.2.layernorm_after.weight', 'decoder.decoder_layers.2.layernorm_before.bias', 'decoder.decoder_layers.1.intermediate.dense.weig

In [7]:
# load the dataset
# specify the directory containing the images
data_dir = "/home/tyz/Desktop/11_777/Data_easy"
# create the dataloader
myDataset = MultimodalDataset(data_dir)
dataloader = DataLoader(myDataset, batch_size=64, shuffle=False, num_workers=4)

Processing folder: /home/tyz/Desktop/11_777/Data_easy/P003
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P000
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P002
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P006
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P007
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P001
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P005
Processing folder: /home/tyz/Desktop/11_777/Data_easy/P004
Number of images: 9011
Number of depth: 9011
Number of lidar: 9011


In [8]:
num_epochs = 1
import torch.optim as optim
optimizer = optim.Adam(custom_model.parameters(), lr=0.001)
for epoch in range(num_epochs):
    batch_count = 0
    for image_batch, depth_batch, lidar_batch in dataloader:
        # Zero the parameter gradients
        optimizer.zero_grad()
        print(image_batch.shape)
        # Forward pass
        image_batch = image_batch.cuda()
        depth_batch = depth_batch.cuda()
        lidar_batch = lidar_batch.cuda()
        outputs = custom_model(image_batch,depth_batch,lidar_batch)

        # Calculate loss
        loss = outputs.loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print statistics
        # running_loss += loss.item()
        # if epoch % 10 == 9:
        print(f'Epoch {epoch + 1}, Batch {batch_count + 1}: loss {loss / 1:.3f}')
        # running_loss = 0.0
        batch_count += 1

torch.Size([64, 3, 224, 224])
Epoch 1, Batch 1: loss 4.905
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 2: loss 4.804
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 3: loss 5.039
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 4: loss 4.084
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 5: loss 4.190
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 6: loss 3.656
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 7: loss 3.477
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 8: loss 2.995
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 9: loss 2.727
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 10: loss 3.298
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 11: loss 3.445
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 12: loss 4.401
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 13: loss 2.995
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 14: loss 2.860
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 15: loss 2.650
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 16: loss 2.604
torch.Size([64, 3, 224, 224])
Epoch 1, Batch 17: 

KeyboardInterrupt: 

In [10]:
outputs.keys()

odict_keys(['loss', 'logits', 'mask', 'ids_restore', 'hidden_states'])

In [14]:
outputs.hidden_states.shape

torch.Size([64, 50, 768])

: 

In [16]:
print("image device",image_batch.device)
print("model device",next(custom_model.parameters()).device)

image device cuda:0
model device cpu


In [9]:
torch.save(custom_model.state_dict(), 'model.pth')

In [11]:
import torch
from transformers import ViTMAEForPreTraining, ViTMAEConfig
from multimodal.custom_models.CustomViT import CustomViT
from multimodal.custom_models.CustomViTMAE import CustomViTMAE
import torch.utils.data
# call CustomViT
model_name = "facebook/vit-mae-base"
vit_config = ViTMAEConfig.from_pretrained(model_name)
vit_config.output_hidden_states=True
vit_model = CustomViT.from_pretrained(model_name,config=vit_config)

config = ViTMAEConfig.from_pretrained(model_name)
config.output_hidden_states=True

# load from pretrained model and replace the original encoder with custom encoder
custom_model = CustomViTMAE.from_pretrained("facebook/vit-mae-base",config=config)
custom_model.vit = vit_model
custom_model = custom_model.cuda()
custom_model.load_state_dict(torch.load('model.pth'))

Some weights of the model checkpoint at facebook/vit-mae-base were not used when initializing CustomViT: ['decoder.decoder_layers.4.layernorm_before.bias', 'decoder.decoder_layers.7.intermediate.dense.bias', 'decoder.decoder_layers.3.layernorm_before.bias', 'decoder.decoder_layers.1.layernorm_after.weight', 'decoder.decoder_layers.4.attention.attention.query.weight', 'decoder.decoder_layers.5.layernorm_after.weight', 'decoder.decoder_layers.5.attention.attention.value.weight', 'decoder.decoder_layers.3.output.dense.weight', 'decoder.decoder_layers.4.output.dense.bias', 'decoder.decoder_layers.6.attention.attention.query.bias', 'decoder.decoder_layers.3.attention.attention.key.weight', 'decoder.decoder_layers.6.intermediate.dense.weight', 'decoder.decoder_layers.4.output.dense.weight', 'decoder.decoder_layers.4.intermediate.dense.weight', 'decoder.decoder_layers.2.layernorm_after.weight', 'decoder.decoder_layers.2.layernorm_before.bias', 'decoder.decoder_layers.1.intermediate.dense.weig

<All keys matched successfully>