In [1]:
import torch
import torch.nn.functional as F
import torch.nn.init as init
import os
import numpy as np
import cv2
import h5py
import random
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn




In [2]:
def load_pairs_from_hdf5(hdf5_file_path, hdf5_folder):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        loaded_pairs = []

        # Function to convert a single relative path back to an absolute path
        def make_absolute(rel_path):
            # Decode if the path is a byte string
            if isinstance(rel_path, bytes):
                rel_path = rel_path.decode('utf-8')
            parts = rel_path.split('/')
            new_parts = []
            for part in parts:
              if part == 'sun3d_extracted' or part =='..':
                continue
              new_parts.append(part)
            corrected_path = '/'.join(new_parts)
            absolute_path = os.path.join(hdf5_folder, corrected_path)
            return absolute_path
            # rel_path = rel_path.replace('../sun3d_extracted', '')
            # return os.path.join(hdf5_folder, rel_path)

        # Function to process paths in pairs
        def process_paths(img_paths_array):
            # Ensure each path in the tuple is absolute
            return tuple(make_absolute(path) for path in img_paths_array)

        # Load pairs
        pairs_group = hdf['pairs']
        for pair_name in pairs_group:
            pair_group = pairs_group[pair_name]
            img_paths_array = pair_group['img_paths'][()]  # This will be a NumPy array
            img_paths = process_paths(img_paths_array)  # Process each path to be absolute
            points1 = torch.tensor(pair_group['points1'][()])
            pos_points2 = torch.tensor(pair_group['pos_points2'][()])
            neg_points2 = torch.tensor(pair_group['neg_points2'][()])
            loaded_pairs.append({
                'img_paths': img_paths,
                'points1': points1,
                'pos_points2': pos_points2,
                'neg_points2': neg_points2
            })

    return loaded_pairs

In [3]:
# Get the current working directory
current_directory = os.getcwd()
output_path = os.path.join(current_directory, os.pardir,  os.pardir,'datasets','sun3d_training')
# output_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training'
hdf5_file_path = os.path.join(output_path, 'pairs.hdf5')
# hdf5_file_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training/pairs.hdf5'
loaded_pairs= load_pairs_from_hdf5(hdf5_file_path,output_path)


In [4]:
class BinarizedActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(torch.sign(input))
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # grad_input[(input < -1) | (input > 1)] = 0
        grad_input[(input.abs() > 1)] = 0
        return grad_input


class GCNv2(nn.Module):
    def __init__(self):
        super(GCNv2, self).__init__()
        # self.elu = F.elu
        self.elu = torch.nn.ELU(inplace=True)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.convF_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convF_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.convD_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convD_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.binarized_activation = BinarizedActivation()
        self.pixel_shuffle = nn.PixelShuffle(16)
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')

    def forward(self, x):
        x = self.elu(self.conv1(x))
        x = self.elu(self.conv2(x))
        x = self.elu(self.conv3_1(x))
        x = self.elu(self.conv3_2(x))
        x = self.elu(self.conv4_1(x))
        x = self.elu(self.conv4_2(x))

        # Descriptor
        xF = self.elu(self.convF_1(x))
        desc = self.convF_2(xF)
        dn = torch.norm(desc, p=2, dim=1)
        # desc = BinarizedActivation.apply(desc.div(torch.unsqueeze(dn, 1)))
        desc = desc.div(torch.unsqueeze(dn, 1))
        desc = self.binarized_activation.apply(desc)
        # Detector
        xD = self.elu(self.convD_1(x))
        det = self.convD_2(xD).sigmoid()
        det = self.pixel_shuffle(det)
        return desc, det



In [5]:
class loss_calculator_3:
    def __init__(self, model):
        self.model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def remove_padding(self, pts):
        mask = (pts != torch.tensor([float('inf'), float('inf')]))
        non_padding_vectors = pts[mask.all(dim=1)]
        return non_padding_vectors

    def get_images_for_batch(self, img_paths):
      images = []
      # img_paths = batch_loaded_pairs['img_paths']
      for img_path in img_paths:
          img1, img2 = cv2.imread(img_path[0]), cv2.imread(img_path[1])
          gray_img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
          gray_img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
          img_tensor1 = torch.from_numpy(gray_img1).unsqueeze(0).float().to(self.device)
          img_tensor2 = torch.from_numpy(gray_img2).unsqueeze(0).float().to(self.device)
          images.append((img_tensor1, img_tensor2))

      return images

    def normalized_pts(self, batch, height=480, width=640):
      normalized_pts = {}
      scale_w = float(width) / 2.
      scale_h = float(height) / 2.
      pts_cur = batch['points1'].clone()
      pts_tar_pos = batch['pos_points2'].clone()
      pts_tar_neg = batch['neg_points2'].clone()

      normalized_pts['cur'] = torch.zeros_like(pts_cur)
      normalized_pts['tar_pos'] = torch.zeros_like(pts_tar_pos)
      normalized_pts['tar_neg'] = torch.zeros_like(pts_tar_neg)
      normalized_pts['cur'][:,:,0] = pts_cur[:,:,0] / scale_w - 1
      normalized_pts['cur'][:,:,1] = pts_cur[:,:,1] / scale_h - 1
      normalized_pts['tar_pos'][:,:,0] = pts_tar_pos[:,:,0] /scale_w - 1
      normalized_pts['tar_pos'][:,:,1] = pts_tar_pos[:,:,1] / scale_h - 1
      normalized_pts['tar_neg'][:,:,0] = pts_tar_neg[:,:,0] / scale_w - 1
      normalized_pts['tar_neg'][:,:,1] = pts_tar_neg[:,:,1] / scale_h - 1
      # for key in normalized_pts:
      #     if torch.isinf(normalized_pts[key]).any():
      #         normalized_pts[key] = self.remove_padding(normalized_pts[key])
      return normalized_pts

    def get_desc_pairs(self, batch, height=480, width=640):
      batchsize= len(batch.get('img_paths'))
      normal_pts = self.normalized_pts(batch, height, width)
      img_paths = batch.get('img_paths')
      gray_images = self.get_images_for_batch(img_paths) #### get img_paths
      desc_dict = {}
      for key, pts_group in normal_pts.items():
        desc = []
        for i in range(batchsize):
          inp_cur, inp_tar =  gray_images[i]
          inp_cur = inp_cur.unsqueeze(0).float()
          inp_tar = inp_tar.unsqueeze(0).float()
          desc_cur, desc_tar = self.model(inp_cur)[0], self.model(inp_tar)[0]
          desc_cur.requires_grad_(True)
          desc_tar.requires_grad_(True)
          desc_tensor = desc_cur if key == 'cur' else desc_tar
          if torch.isinf(pts_group[i]).any():
              pts = self.remove_padding(pts_group[i]).view(1, 1, -1, 2).float().to(self.device)
          else:
            pts = pts_group[i].view(1, 1, -1, 2).float().to(self.device)
          sample_desc = F.grid_sample(desc_tensor, pts, align_corners=False, mode='nearest').squeeze()
          desc.append(sample_desc)

        desc_dict[key] = desc
      return desc_dict
    
    def hamming_distance(self,tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
        """Calculate the Hamming distance between two tensors.

        Args:
            tensor1 (torch.Tensor): The first tensor.
            tensor2 (torch.Tensor): The second tensor, must be the same shape as tensor1.

        Returns:
            torch.Tensor: Tensor of Hamming distances.
        """
        if tensor1.dtype != torch.bool:
            tensor1 = tensor1.bool()
        if tensor2.dtype != torch.bool:
            tensor2 = tensor2.bool()
        
        differing_bits = tensor1 ^ tensor2
        distance = differing_bits.sum(dim=-1)
        return distance

    def batch_l_desc_loss(self, desc_batch, margin=1.0):
      ldesc = 0
      for cur_list, pos_list, neg_list in zip(desc_batch['cur'], desc_batch['tar_pos'], desc_batch['tar_neg']):
          # print("Shapes:", cur_list.shape, pos_list.shape, neg_list.shape)
          # pairwise_dist_pos = torch.sum((cur_list - pos_list) ** 2, dim=-1)
          # pairwise_dist_neg = torch.sum((cur_list - neg_list) ** 2, dim=-1)
          
          # Compute Hamming distances
          pairwise_dist_pos = self.hamming_distance(cur_list, pos_list)
          pairwise_dist_neg = self.hamming_distance(cur_list, neg_list)

          sample_loss = torch.sum(torch.max(torch.zeros_like(pairwise_dist_pos), pairwise_dist_pos - pairwise_dist_neg + margin))
          ldesc += sample_loss
      return ldesc.to(self.device)

    def get_det_pairs(self, batch):
      img_paths = batch.get('img_paths')
      batchsize= len(img_paths)
      gray_images = self.get_images_for_batch(img_paths)
      det_dict = {}
      det_cur_list = []
      det_tar_list = []
      for i in range(batchsize):
          inp_cur, inp_tar =  gray_images[i]
          inp_cur = inp_cur.unsqueeze(0).float()
          inp_tar = inp_tar.unsqueeze(0).float()
          det_cur, det_tar = self.model(inp_cur)[1], self.model(inp_tar)[1]
          det_cur.requires_grad_(True)
          det_tar.requires_grad_(True)
          det_cur_list.append(det_cur.squeeze())
          det_tar_list.append(det_tar.squeeze())
      det_cur_list = torch.stack(det_cur_list, dim=0)
      det_tar_list = torch.stack(det_tar_list, dim=0)
      det_dict = {'points1': det_cur_list, 'pos_points2': det_tar_list}
      return det_dict

    def generate_batch_mask(self, coords, height=480, width=640):

      mask = torch.zeros((coords.shape[0], height, width), dtype=torch.uint8)
      for batch_idx in range(coords.shape[0]):
          batch_coords = coords[batch_idx]
          for point_idx, (x, y) in enumerate(batch_coords):
              x = x.int()
              y = y.int()
              if 0 <= x < width and 0 <= y < height:
                  mask[batch_idx, y, x] = 1

      return mask


    def transform_target(self,batch,height=480,width=640):
        img_path_cur = batch.get('img_paths')[0]
        img_path_tar = batch.get('img_paths')[1]
        coord_cur = batch.get('points1')
        coord_tar = batch.get('pos_points2')
        cur_tensor = self.generate_batch_mask(coord_cur,height,width)
        tar_tensor = self.generate_batch_mask(coord_tar,height,width)
        return cur_tensor, tar_tensor

    def l_det_loss(self, o_cur, c_cur, o_tar, c_tar, alpha1=1.0, alpha2=1.0):
        Lce_cur = self.binary_cross_entropy(o_cur, c_cur)
        Lce_tar = self.binary_cross_entropy(o_tar, c_tar)
        Ldet = alpha1 * Lce_cur + alpha2 * Lce_tar
        return Ldet
    def binary_cross_entropy(self, o, c):
        c = c.to(o.dtype)
        bce_loss = F.binary_cross_entropy_with_logits(o, c, reduction='none')
        total_loss = torch.sum(bce_loss)
        return total_loss

    def batch_l_det_loss(self, batch, det_batch, margin=1.0):
        ldet = 0
        # for i in range(len(batch_loaded_pairs)):
        cur_det = det_batch['points1']
        tar_det = det_batch['pos_points2']
        trans_cur = self.transform_target(batch,height=480,width=640)[0].to(self.device)
        trans_tar_pos = self.transform_target(batch,height=480,width=640)[1].to(self.device)
        # trans_cur = batch_loaded_pairs[i]['trans_cur']
        # trans_tar_pos = batch_loaded_pairs[i]['trans_tar_pos']
        ldet += self.l_det_loss(cur_det, trans_cur, tar_det, trans_tar_pos)
        return ldet


    def loss(self, batch_loaded_pairs, batch_size, height=480, width=640, margin=1.0):

      desc_batch = self.get_desc_pairs(batch_loaded_pairs, height, width)
      det_batch = self.get_det_pairs(batch_loaded_pairs)
      # for key, desc_list in desc_batch.items():
      #   for desc_tensor in desc_list:
      #       print(f"Tensor name: {key}, requires_grad: {desc_tensor.requires_grad}")
      loss_desc = self.batch_l_desc_loss(desc_batch, margin)
      # loss_desc = loss_desc.requires_grad_(True)
      loss_det = self.batch_l_det_loss(batch_loaded_pairs, det_batch, margin)
      # loss_det = loss_det.requires_grad_(True)
      return (loss_desc +  0.001*loss_det)/batch_size


In [6]:
model = GCNv2()
loss_calculator = loss_calculator_3(model)

def collate_fn(batch):
    max_len_points = max(len(sample['points1']) for sample in batch)
    for sample in batch:
      n = max_len_points - len(sample['points1'])
      inf_tensor = torch.tensor([[float('inf'), float('inf')]] *n )
      sample['points1'] = torch.cat((sample['points1'],inf_tensor), dim=0)
      sample['pos_points2'] = torch.cat((sample['pos_points2'], inf_tensor), dim=0)
      sample['neg_points2'] = torch.cat((sample['neg_points2'], inf_tensor), dim=0)
    img_path = [sample['img_paths'] for sample in batch]
    points1 = torch.stack([sample['points1'] for sample in batch], dim=0)
    pos_points2 = torch.stack([sample['pos_points2'] for sample in batch], dim=0)
    neg_points2 = torch.stack([sample['neg_points2'] for sample in batch], dim=0)

    return {'img_paths':img_path,'points1': points1, 'pos_points2': pos_points2, 'neg_points2': neg_points2}
class Pairs(Dataset):
    def __init__(self, loaded_pairs):
        self.loaded_pairs = loaded_pairs

    def __len__(self):
        return len(self.loaded_pairs)

    def __getitem__(self, idx):
        pair = self.loaded_pairs[idx]
        # path_tensor = torch.tensor(pair['img_paths'])
        pts_tensor = pair['points1']
        pos_tensor = pair['pos_points2']
        neg_tensor = pair['neg_points2']
        return {'img_paths':pair['img_paths'],'points1':pts_tensor,'pos_points2':pos_tensor,'neg_points2':neg_tensor}

dataset = loaded_pairs

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# for batch in dataloader:
#   l_total = loss_calculator.loss(batch,batch_size)
#   print(l_total)
  # print(count_ones)


In [7]:
def adjust_learning_rate(optimizer: torch.optim.Optimizer, epoch: int) -> None:
    """Halves the learning rate of the optimizer every 40 epochs.

    Args:
    optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
    epoch (int): The current epoch number.

    """
    # Every 40 epochs, halve the learning rate
    if epoch % 40 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5

In [8]:
# for name, param in model.named_parameters():
#     if 'weight' in name:  # 只对权重进行梯度计算
#         param.requires_grad = True
#     else:
#         param.requires_grad = False

# optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)  # 仅传递需要梯度更新的参数给优化器
optimizer = optim.Adam( model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(dataloader):
        # batch_loaded_pairs = batch
        optimizer.zero_grad()
        loss = loss_calculator.loss(batch, batch_size)
        # print(loss)
        loss.backward()
        optimizer.step()
        # optimizer.zero_grad()

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')

    torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')


Epoch [1/10], Step [100/18689], Loss: 689.3131713867188
Epoch [1/10], Step [200/18689], Loss: 689.3078002929688
Epoch [1/10], Step [300/18689], Loss: 689.309814453125
Epoch [1/10], Step [400/18689], Loss: 689.3075561523438
Epoch [1/10], Step [500/18689], Loss: 689.3075561523438
Epoch [1/10], Step [600/18689], Loss: 689.3062744140625
Epoch [1/10], Step [700/18689], Loss: 689.3082885742188
Epoch [1/10], Step [800/18689], Loss: 689.308837890625
Epoch [1/10], Step [900/18689], Loss: 689.30859375
Epoch [1/10], Step [1000/18689], Loss: 689.309326171875
Epoch [1/10], Step [1100/18689], Loss: 689.30859375
Epoch [1/10], Step [1200/18689], Loss: 689.309326171875
Epoch [1/10], Step [1300/18689], Loss: 689.309326171875
Epoch [1/10], Step [1400/18689], Loss: 689.30908203125
Epoch [1/10], Step [1500/18689], Loss: 689.3080444335938
Epoch [1/10], Step [1600/18689], Loss: 689.309814453125
Epoch [1/10], Step [1700/18689], Loss: 689.3070068359375
Epoch [1/10], Step [1800/18689], Loss: 689.3078002929688
E

In [None]:
import h5py
import torch

def save_model_weights_to_hdf5(model, filepath):
    with h5py.File(filepath, 'w') as f:
        for name, param in model.named_parameters():
            param_value = param.data.numpy()
            f.create_dataset(name, data=param_value)

def load_model_weights_from_hdf5(model, filepath):
    
    
    
    
    
    
    
    
    
    
    
    with h5py.File(filepath, 'r') as f:
        for name in f.keys():
            param_value = torch.tensor(f[name][:])
            model._parameters[name] = torch.nn.Parameter(param_value)


save_model_weights_to_hdf5(model, '/content/drive/MyDrive/project_slam/model_weights.h5')

# load_model_weights_from_hdf5(model, 'model_weights.h5')
