1.optimizer  2.print gradient 3. binarized layer deactivate no need, can be called  3.print weight

In [1]:
import torch
import torch.nn.functional as F
import torch.nn.init as init
import os
import numpy as np
import cv2
import h5py
import random
import types
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from tensorboardX import SummaryWriter

In [2]:
# %pip install tensorboardX
# %pip install tensorboard

In [3]:
seed = 1
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fc3108fb990>

In [4]:
def check_tensor(x, name):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print(f"NaN or Inf found in {name}")

In [5]:
def load_pairs_from_hdf5(hdf5_file_path, hdf5_folder):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        loaded_pairs = []

        # Function to convert a single relative path back to an absolute path
        def make_absolute(rel_path):
            # Decode if the path is a byte string
            if isinstance(rel_path, bytes):
                rel_path = rel_path.decode('utf-8')
            parts = rel_path.split('/')
            new_parts = []
            for part in parts:
                if part == 'sun3d_extracted' or part == '..':
                    continue
                new_parts.append(part)
            corrected_path = '/'.join(new_parts)
            absolute_path = os.path.join(hdf5_folder, corrected_path)
            return absolute_path
            # rel_path = rel_path.replace('../sun3d_extracted', '')
            # return os.path.join(hdf5_folder, rel_path)

        # Function to process paths in pairs
        def process_paths(img_paths_array):
            # Ensure each path in the tuple is absolute
            return tuple(make_absolute(path) for path in img_paths_array)

        # Load pairs
        pairs_group = hdf['pairs']
        for pair_name in pairs_group:
            pair_group = pairs_group[pair_name]
            # This will be a NumPy array
            img_paths_array = pair_group['img_paths'][()]
            # Process each path to be absolute
            img_paths = process_paths(img_paths_array)
            points1 = torch.tensor(pair_group['points1'][()])
            pos_points2 = torch.tensor(pair_group['pos_points2'][()])
            neg_points2 = torch.tensor(pair_group['neg_points2'][()])
            loaded_pairs.append({
                'img_paths': img_paths,
                'points1': points1,
                'pos_points2': pos_points2,
                'neg_points2': neg_points2
            })

    return loaded_pairs

In [6]:
# Get the current working directory
current_directory = os.getcwd()
output_path = os.path.join(current_directory, os.pardir,
                           os.pardir, 'datasets', 'sun3d_training')
# output_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training'
hdf5_file_path = os.path.join(output_path, 'pairs.hdf5')
# hdf5_file_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training/pairs.hdf5'
loaded_pairs = load_pairs_from_hdf5(hdf5_file_path, output_path)

In [7]:
# Get the current working directory
# current_directory = os.getcwd()
# # output_path = os.path.join(current_directory, os.pardir, os.pardir,'sun3d_training')
# output_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training'
# # hdf5_file_path = os.path.join(output_path, 'pairs.hdf5')
# hdf5_file_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training/pairs.hdf5'
# loaded_pairs= load_pairs_from_hdf5(hdf5_file_path,output_path)

In [8]:
class BinarizedActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.sign(input)
        # out = torch.where(out == 0.0, torch.tensor(1.0), out)  # Change only 0 to 1
        # Ensures that all zero values become one, and all other values remain as they are.
        out = out + (out == 0.0).float()  # Adds 1 to only where out == 0

        # Create a mask of elements that are neither 1.0 nor -1.0
        # mask = (out != 1.0) & (out != -1.0)
        # # Check if there are any such elements
        # if torch.any(mask):
        #     print(out[mask])  # Print only the elements that do not meet the condition
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        # print("Gradient from next layer: ", grad_output)
        # grad_input[(input < -1) | (input > 1)] = 0
        grad_input[(input.abs() > 1)] = 0
        # print("Modified gradient: ", grad_input)
        # print("BinarizedActivation backward called")
        return grad_input


class GCNv2(nn.Module):
    def __init__(self):
        super(GCNv2, self).__init__()
        # self.elu = F.elu
        # self.elu = torch.nn.ELU(inplace=True)
        self.elu = torch.nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.convF_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convF_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.convD_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convD_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        # self.binarized_activation = BinarizedActivation()
        self.pixel_shuffle = nn.PixelShuffle(16)
        # Adjust scale factor as needed
        self.upsample = nn.Upsample(
            scale_factor=16, mode='bilinear', align_corners=True)
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    nn.init.kaiming_normal_(
                        m.weight, mode='fan_in', nonlinearity='relu')
    # def reset_parameters(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             # Glorot uniform initialization, also known as Xavier uniform initialization
    #             nn.init.xavier_uniform_(m.weight)
    #             if m.bias is not None:
    #                 nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x.requires_grad_(True)
        x = self.elu(self.conv1(x))
        x = self.elu(self.conv2(x))
        x = self.elu(self.conv3_1(x))
        x = self.elu(self.conv3_2(x))
        x = self.elu(self.conv4_1(x))
        x = self.elu(self.conv4_2(x))

        # Detector
        xD = self.elu(self.convD_1(x))
        det = self.convD_2(xD).sigmoid()
        det = self.pixel_shuffle(det)

        # Descriptor
        xF = self.elu(self.convF_1(x))
        desc = self.convF_2(xF)
        desc = self.upsample(desc)
        dn = torch.norm(desc, p=2, dim=1)
        # desc = BinarizedActivation.apply(desc.div(torch.unsqueeze(dn, 1)))
        desc = desc.div(torch.unsqueeze(dn, 1))

        # desc = F.interpolate(desc, size=det.shape[2:], mode='bilinear', align_corners=False)
        # desc = self.binarized_activation.apply(desc)
        # # Check if the norm across each channel is 1 (with a small tolerance for floating point arithmetic)
        # norms = torch.norm(desc, p=2, dim=1)
        # # Check norms and print those which are not approximately 1 (considering numerical stability)
        # not_normalized = (torch.abs(norms - 1) > 1e-6)
        # if not_normalized.any():
        #     print("Descriptors not normalized to 1:", norms[not_normalized])

        # print(desc.requires_grad)
        # print("desc shape")
        # print(desc.shape)
        # print("det shape")
        # print(det.shape)

        return desc, det

In [9]:
class loss_calculator_3:
    def __init__(self, model):
        self.model = model.to(torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'))
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

    # def remove_padding(self, pts):
    #     mask = (pts != torch.tensor([float('inf'), float('inf')]))
    #     non_padding_vectors = pts[mask.all(dim=1)]
    #     return non_padding_vectors

    def get_images_for_batch(self, img_paths):
        images = []
        # img_paths = batch_loaded_pairs['img_paths']
        for img_path in img_paths:
            img1, img2 = cv2.imread(img_path[0]), cv2.imread(img_path[1])
            gray_img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
            gray_img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
            img_tensor1 = torch.from_numpy(
                gray_img1).unsqueeze(0).float().to(self.device)
            img_tensor2 = torch.from_numpy(
                gray_img2).unsqueeze(0).float().to(self.device)
            images.append((img_tensor1, img_tensor2))

        return images

    # def normalized_pts(self, batch, height=480, width=640):
    #     normalized_pts = {}
    #     scale_w = float(width) / 2.
    #     scale_h = float(height) / 2.
    #     pts_cur = batch['points1'].clone()
    #     pts_tar_pos = batch['pos_points2'].clone()
    #     pts_tar_neg = batch['neg_points2'].clone()

    #     normalized_pts['cur'] = torch.zeros_like(pts_cur)
    #     normalized_pts['tar_pos'] = torch.zeros_like(pts_tar_pos)
    #     normalized_pts['tar_neg'] = torch.zeros_like(pts_tar_neg)
    #     normalized_pts['cur'][:, :, 0] = pts_cur[:, :, 0] / scale_w - 1
    #     normalized_pts['cur'][:, :, 1] = pts_cur[:, :, 1] / scale_h - 1
    #     normalized_pts['tar_pos'][:, :, 0] = pts_tar_pos[:, :, 0] / scale_w - 1
    #     normalized_pts['tar_pos'][:, :, 1] = pts_tar_pos[:, :, 1] / scale_h - 1
    #     normalized_pts['tar_neg'][:, :, 0] = pts_tar_neg[:, :, 0] / scale_w - 1
    #     return normalized_pts

    # def get_desc_pairs(self, batch, height=480, width=640):
    #     batchsize = len(batch.get('img_paths'))
    #     normal_pts = self.normalized_pts(batch, height, width)
    #     img_paths = batch.get('img_paths')
    #     gray_images = self.get_images_for_batch(img_paths)  # get img_paths
    #     desc_dict = {}
    #     for key, pts_group in normal_pts.items():
    #         desc = []
    #         for i in range(batchsize):
    #             inp_cur, inp_tar = gray_images[i]
    #             inp_cur = inp_cur.unsqueeze(0).float()
    #             inp_tar = inp_tar.unsqueeze(0).float()
    #             desc_cur, desc_tar = self.model(
    #                 inp_cur)[0], self.model(inp_tar)[0]
    #             desc_cur.requires_grad_(True)
    #             desc_tar.requires_grad_(True)
    #             desc_tensor = desc_cur if key == 'cur' else desc_tar
    #             if torch.isinf(pts_group[i]).any():
    #                 pts = self.remove_padding(pts_group[i]).view(
    #                     1, 1, -1, 2).float().to(self.device)
    #             else:
    #                 pts = pts_group[i].view(
    #                     1, 1, -1, 2).float().to(self.device)
    #             # sample_desc = F.grid_sample(desc_tensor, pts, align_corners=False, mode='nearest').squeeze()
    #             sample_desc = F.grid_sample(
    #                 desc_tensor, pts, align_corners=False, mode='nearest').squeeze()
    #             # Check if sample_desc contains float numbers
    #             # if not torch.all((sample_desc == 1.0) | (sample_desc == -1.0)):
    #             #   raise ValueError("sample_desc contains values other than 1.0 or -1.0.")
    #             # mask = (sample_desc != 1.0) & (sample_desc != -1.0)
    #             # # Check if there are any such elements
    #             # if torch.any(mask):
    #             #     print(sample_desc[mask])  # Print only the elements that do not meet the condition

    #             desc.append(sample_desc)

    # Ensure to gather descriptors at specific keypoints
    # def gather_descriptors(self, desc_map, coords):
    #     if coords.nelement() == 0:
    #         print("no coords")
    #         return torch.empty(0, desc_map.size(1), device=desc_map.device)
    #     # Convert (x, y) coordinates to indices
    #     indices = coords.long().to(desc_map.device)
    #     # Clamp coordinates to the size of the descriptor map
    #     indices[:, 0].clamp_(0, desc_map.size(3) - 1)  # x coordinates
    #     indices[:, 1].clamp_(0, desc_map.size(2) - 1)  # y coordinates
    #     # Gather descriptors
    #     sampled_desc = desc_map[0, :, indices[:, 1], indices[:, 0]].t()
    #     return sampled_desc

    #         desc_dict[key] = desc
    #     return desc_dict
    def get_desc_pairs(self, batch, height=480, width=640):
        img_paths = batch.get('img_paths')
        batchsize = len(img_paths)
        gray_images = self.get_images_for_batch(img_paths)
        desc_dict = {'cur': [], 'tar_pos': [], 'tar_neg': []}

        # Retrieve coordinate arrays from the batch
        coord_cur = batch.get('points1')
        coord_tar_pos = batch.get('pos_points2')
        coord_tar_neg = batch.get('neg_points2')

        mask_cur = self.generate_batch_mask(
            coord_cur, height, width).to(self.device)
        mask_pos = self.generate_batch_mask(
            coord_tar_pos, height, width).to(self.device)
        mask_neg = self.generate_batch_mask(
            coord_tar_neg, height, width).to(self.device)

        # print("mask shapes:",
        #       mask_cur.shape, mask_pos.shape, mask_neg.shape)

        for i in range(batchsize):
            inp_cur, inp_tar = gray_images[i]
            inp_cur = inp_cur.unsqueeze(0).float()
            inp_tar = inp_tar.unsqueeze(0).float()

            # Get descriptor feature maps
            desc_cur, desc_tar = self.model(inp_cur)[0], self.model(inp_tar)[0]
            desc_cur.requires_grad_(True)
            desc_tar.requires_grad_(True)

            # # Sample descriptors for current, positive, and negative keypoints
            # sampled_desc_cur = self.gather_descriptors(
            #     desc_cur, batch['points1'][i])
            # sampled_desc_pos = self.gather_descriptors(
            #     desc_tar, batch['pos_points2'][i])
            # sampled_desc_neg = self.gather_descriptors(
            #     desc_tar, batch['neg_points2'][i])
            # Generate indices from the mask
            nonzero_indices_cur = mask_cur[i].nonzero(as_tuple=False)
            nonzero_indices_pos = mask_pos[i].nonzero(as_tuple=False)
            nonzero_indices_neg = mask_neg[i].nonzero(as_tuple=False)

            # Extract descriptors using the indices
            # We index with (y, x) because the first dimension is height (rows) and the second is width (columns)
            sampled_desc_cur = desc_cur[0, :,
                                        nonzero_indices_cur[:, 0], nonzero_indices_cur[:, 1]].t()
            sampled_desc_pos = desc_tar[0, :,
                                        nonzero_indices_pos[:, 0], nonzero_indices_pos[:, 1]].t()
            sampled_desc_neg = desc_tar[0, :,
                                        nonzero_indices_neg[:, 0], nonzero_indices_neg[:, 1]].t()

            # Append to lists
            desc_dict['cur'].append(sampled_desc_cur)
            desc_dict['tar_pos'].append(sampled_desc_pos)
            desc_dict['tar_neg'].append(sampled_desc_neg)

        desc_dict['cur'] = torch.cat(desc_dict['cur'])
        desc_dict['tar_pos'] = torch.cat(desc_dict['tar_pos'])
        desc_dict['tar_neg'] = torch.cat(desc_dict['tar_neg'])

        # print("desc shapes", desc_dict['cur'].shape,
        #       desc_dict['tar_pos'].shape, desc_dict['tar_neg'].shape)

        return desc_dict

    def batch_l_desc_loss(self, desc_batch, margin=1.0):
        ldesc = 0
        count_non_zero_losses = 0  # Counter for non-zero losses
        for cur_list, pos_list, neg_list in zip(desc_batch['cur'], desc_batch['tar_pos'], desc_batch['tar_neg']):
            # Calculate the norms of descriptor vectors
            # norms_cur = torch.norm(cur_list, p=2, dim=-1)
            # norms_pos = torch.norm(pos_list, p=2, dim=-1)
            # norms_neg = torch.norm(neg_list, p=2, dim=-1)

            # # Check if any norms are greater than 1 and print them
            # if (torch.abs(norms_cur - 1) > 1e-6).any():
            #     print("Current descriptor norms greater than 1:",
            #           norms_cur[norms_cur > 1])
            # if (torch.abs(norms_pos - 1) > 1e-6).any():
            #     print("Positive descriptor norms greater than 1:",
            #           norms_pos[norms_pos > 1])
            # if (torch.abs(norms_neg - 1) > 1e-6).any():
            #     print("Negative descriptor norms greater than 1:",
            #           norms_neg[norms_neg > 1])

            # print("Shapes:", cur_list.shape, pos_list.shape, neg_list.shape)

            pairwise_dist_pos = torch.sum((cur_list - pos_list) ** 2, dim=-1)
            pairwise_dist_neg = torch.sum((cur_list - neg_list) ** 2, dim=-1)
            # Compute Euclidean distance without squaring
            # pairwise_dist_pos = torch.sqrt(
            #     torch.sum((cur_list - pos_list) ** 2, dim=-1))
            # check_tensor(pairwise_dist_pos, "pos desc")
            # pairwise_dist_neg = torch.sqrt(
            #     torch.sum((cur_list - neg_list) ** 2, dim=-1))
            # check_tensor(pairwise_dist_neg, "neg desc")
            # sample_loss = torch.sum(torch.max(torch.zeros_like(
            #     pairwise_dist_pos), pairwise_dist_pos - pairwise_dist_neg + margin))
            # Compute individual losses with margin
            individual_losses = torch.max(torch.zeros_like(
                pairwise_dist_pos), pairwise_dist_pos - pairwise_dist_neg + margin)
            sample_loss = torch.sum(individual_losses)

            # Only add to the loss sum if it's non-zero
            if torch.any(individual_losses > 0):
                # Count how many are non-zero
                count_non_zero_losses += individual_losses.gt(0).sum().item()
                ldesc += sample_loss

            # ldesc += sample_loss
        # Normalize the loss by the number of non-zero losses if there are any
        if count_non_zero_losses > 0:
            # print("total num of challenging pairs:", count_non_zero_losses)
            ldesc = ldesc / count_non_zero_losses
        else:
            ldesc = torch.tensor(0.0).to(self.device)

        return ldesc.to(self.device)

    def get_det_pairs(self, batch):
        img_paths = batch.get('img_paths')
        batchsize = len(img_paths)
        gray_images = self.get_images_for_batch(img_paths)
        det_dict = {}
        det_cur_list = []
        det_tar_list = []
        for i in range(batchsize):
            inp_cur, inp_tar = gray_images[i]
            inp_cur = inp_cur.unsqueeze(0).float()
            inp_tar = inp_tar.unsqueeze(0).float()
            det_cur, det_tar = self.model(inp_cur)[1], self.model(inp_tar)[1]
            det_cur.requires_grad_(True)
            det_tar.requires_grad_(True)
            det_cur_list.append(det_cur.squeeze())
            det_tar_list.append(det_tar.squeeze())
        det_cur_list = torch.stack(det_cur_list, dim=0)
        det_tar_list = torch.stack(det_tar_list, dim=0)
        # print('det_list_shapes',det_cur_list.shape,det_tar_list.shape)
        det_dict = {'points1': det_cur_list, 'pos_points2': det_tar_list}
        return det_dict

    def generate_batch_mask(self, coords, height=480, width=640):

        mask = torch.zeros((coords.shape[0], height, width), dtype=torch.uint8)
        for batch_idx in range(coords.shape[0]):
            batch_coords = coords[batch_idx]
            for point_idx, (x, y) in enumerate(batch_coords):
                x = x.int()
                y = y.int()
                if 0 <= x < width and 0 <= y < height:
                    mask[batch_idx, y, x] = 1

        return mask

    def transform_target(self, batch, height=480, width=640):
        coord_cur = batch.get('points1')
        coord_tar = batch.get('pos_points2')
        # print("det shapes:", coord_cur.shape, coord_tar.shape)
        cur_tensor = self.generate_batch_mask(coord_cur, height, width)
        tar_tensor = self.generate_batch_mask(coord_tar, height, width)
        return cur_tensor, tar_tensor

    def l_det_loss(self, o_cur, c_cur, o_tar, c_tar, alpha1=1.0, alpha2=1.0):
        Lce_cur = self.binary_cross_entropy(o_cur, c_cur) / (480 * 640)
        Lce_tar = self.binary_cross_entropy(o_tar, c_tar) / (480 * 640)
        Ldet = alpha1 * Lce_cur + alpha2 * Lce_tar
        return Ldet

    def binary_cross_entropy(self, o, c, epsilon=1e-8):
        c = c.to(o.dtype)
        positive = c.sum()
        total = c.numel()
        negative = total - positive
        # print("total pairs positive")
        # print(positive)
        if positive.item() == 0:  # Avoid division by zero
            pos_weight = 1.0  # Default to 1 or handle it according to your context
        else:
            pos_weight = negative / positive
        weights = c * pos_weight + (1 - c) * 1.0
        # Ensuring that the target labels are within (epsilon, 1 - epsilon) to avoid log(0)
        c = torch.clamp(c, epsilon, 1 - epsilon)
        bce_loss = F.binary_cross_entropy_with_logits(
            o, c, reduction='none', weight=weights)
        total_loss = torch.sum(bce_loss)
        return total_loss

    def batch_l_det_loss(self, batch, det_batch):
        ldet = 0
        # for i in range(len(batch_loaded_pairs)):
        cur_det = det_batch['points1']
        tar_det = det_batch['pos_points2']
        trans_cur = self.transform_target(batch, height=480, width=640)[
            0].to(self.device)
        trans_tar_pos = self.transform_target(batch, height=480, width=640)[
            1].to(self.device)
        # print('tar_shapes', trans_cur.shape, trans_tar_pos.shape)
        ldet += self.l_det_loss(cur_det, trans_cur, tar_det, trans_tar_pos)
        return ldet

    def calculate_total_pairs(self, batch):
        valid_points_mask = ~torch.isinf(batch['pos_points2']).any(dim=2)
        # Sum over all true values in the mask to get the count of valid points
        total_pairs = valid_points_mask.sum().item()
        return total_pairs

    def loss(self, batch_loaded_pairs, batch_size, height=480, width=640, margin=1.0):
        batchsize = len(batch_loaded_pairs.get('img_paths'))

        det_batch = self.get_det_pairs(batch_loaded_pairs)

        loss_det = self.batch_l_det_loss(batch_loaded_pairs, det_batch)

        desc_batch = self.get_desc_pairs(batch_loaded_pairs, height, width)
        loss_desc = self.batch_l_desc_loss(desc_batch, margin)
        # Calculate the total number of pairs in the batch
        # total_pairs = self.calculate_total_pairs(batch_loaded_pairs)
        # if total_pairs == 0:
        #     total_pairs = 1  # To avoid division by zero
        # final_loss = (100.0*loss_desc / total_pairs) + (loss_det / batchsize)
        final_loss = 100.0*loss_desc + (loss_det / batchsize)
        return final_loss, loss_desc, loss_det/batchsize
        # print('batchsize')
        # print(batchsize)
        # print('total_pairs')
        # print(total_pairs)
        # return final_loss, loss_desc/total_pairs, loss_det/batchsize

In [10]:
model = GCNv2()
loss_calculator = loss_calculator_3(model)


def collate_fn(batch):
    max_len_points = max(len(sample['points1']) for sample in batch)
    for sample in batch:
        n = max_len_points - len(sample['points1'])
        inf_tensor = torch.tensor([[float('inf'), float('inf')]] * n)
        sample['points1'] = torch.cat((sample['points1'], inf_tensor), dim=0)
        sample['pos_points2'] = torch.cat(
            (sample['pos_points2'], inf_tensor), dim=0)
        sample['neg_points2'] = torch.cat(
            (sample['neg_points2'], inf_tensor), dim=0)
    img_path = [sample['img_paths'] for sample in batch]
    points1 = torch.stack([sample['points1'] for sample in batch], dim=0)
    pos_points2 = torch.stack([sample['pos_points2']
                              for sample in batch], dim=0)
    neg_points2 = torch.stack([sample['neg_points2']
                              for sample in batch], dim=0)

    return {'img_paths': img_path, 'points1': points1, 'pos_points2': pos_points2, 'neg_points2': neg_points2}


# class Pairs(Dataset):
#     def __init__(self, loaded_pairs):
#         self.loaded_pairs = loaded_pairs

#     def __len__(self):
#         return len(self.loaded_pairs)

#     def __getitem__(self, idx):
#         pair = self.loaded_pairs[idx]
#         # path_tensor = torch.tensor(pair['img_paths'])
#         pts_tensor = pair['points1']
#         pos_tensor = pair['pos_points2']
#         neg_tensor = pair['neg_points2']
#         return {'img_paths': pair['img_paths'], 'points1': pts_tensor, 'pos_points2': pos_tensor, 'neg_points2': neg_tensor}


dataset = loaded_pairs

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, collate_fn=collate_fn)
# for batch in dataloader:
#   l_total = loss_calculator.loss(batch,batch_size)
#   print(l_total)
# print(count_ones)s

In [11]:
def adjust_learning_rate(optimizer: torch.optim.Optimizer, epoch: int) -> None:
    """Halves the learning rate of the optimizer every 40 epochs.

    Args:
    optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
    epoch (int): The current epoch number.

    """
    # Every 40 epochs, halve the learning rate
    if epoch % 40 == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5

In [12]:

# Define parameter groups with different learning rates
detector_parameters = [p for n, p in model.named_parameters() if 'convD' in n]
descriptor_parameters = [
    p for n, p in model.named_parameters() if 'convF' in n]
base_parameters = [p for n, p in model.named_parameters(
) if 'convD' not in n and 'convF' not in n]

# Sometimes, you might want to ensure that every parameter is assigned only once
# Make sure no parameter is left behind, you might want to verify:
assert set(detector_parameters).isdisjoint(set(descriptor_parameters))
assert len(detector_parameters) + len(descriptor_parameters) + \
    len(base_parameters) == len(list(model.parameters()))

learning_rate = 0.0
base_learning_rate = 5e-8
det_learning_rate = 5e-5
desc_learning_rate = 1e-6
# Setting up the Adam optimizer with different learning rates for different groups
optimizer = optim.Adam([
    # Base parameters, use default or base learning rate
    {'params': base_parameters, 'lr': base_learning_rate},
    # Higher learning rate for detector
    {'params': detector_parameters, 'lr': det_learning_rate},
    # Different learning rate for descriptor
    {'params': descriptor_parameters, 'lr': desc_learning_rate}
], betas=(0.95, 0.999), eps=1e-8, weight_decay=5e-5)

In [13]:
# optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)  # 仅传递需要梯度更新的参数给优化器
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
# learning_rate = 0.00005
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(
#     0.95, 0.999), eps=1e-8, weight_decay=5e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
num_epochs = 10

In [14]:
# class ReduceLRonPlateauPerStep:
#     def __init__(self, optimizer, factor=0.5, patience_steps=1000, min_lr=1e-15, verbose=False):
#         self.optimizer = optimizer
#         self.factor = factor
#         self.patience_steps = patience_steps
#         self.min_lr = min_lr
#         self.verbose = verbose
#         self.best_loss = float('inf')
#         self.num_bad_steps = 0

#     def step(self, loss):
#         if loss < self.best_loss:
#             self.best_loss = loss
#             self.num_bad_steps = 0
#         else:
#             self.num_bad_steps += 1

#         if self.num_bad_steps >= self.patience_steps:
#             for param_group in self.optimizer.param_groups:
#                 old_lr = param_group['lr']
#                 new_lr = max(old_lr * self.factor, self.min_lr)
#                 if old_lr - new_lr > 0:
#                     param_group['lr'] = new_lr
#                     if self.verbose:
#                         print(
#                             f"Reducing learning rate from {old_lr} to {new_lr}.")
#             self.num_bad_steps = 0


# scheduler = ReduceLRonPlateauPerStep(
#     optimizer, factor=0.95, patience_steps=500, verbose=True)

In [15]:
# class TrendBasedLRScheduler:
#     def __init__(self, optimizer, window_size=200, decrease_factor=0.5, min_lr=1e-16, max_lr=1e-2, verbose=False):
#         self.optimizer = optimizer
#         self.window_size = window_size
#         self.decrease_factor = decrease_factor
#         self.min_lr = min_lr
#         self.max_lr = max_lr
#         self.verbose = verbose
#         self.losses = []
#         self.last_slope = None

#     def step(self, loss):
#         self.losses.append(loss)
#         if len(self.losses) > self.window_size:
#             self.losses.pop(0)

#         if len(self.losses) == self.window_size:
#             # Calculate the slope of the linear fit to the losses
#             times = np.arange(self.window_size)
#             coeffs = np.polyfit(times, self.losses, 1)
#             current_slope = coeffs[0]  # Coefficient of x (time)

#             if self.last_slope is not None and current_slope > 0 and self.last_slope < 0:
#                 # Loss trend changed from decreasing to increasing
#                 for param_group in self.optimizer.param_groups:
#                     new_lr = max(param_group['lr'] *
#                                  self.decrease_factor, self.min_lr)
#                     param_group['lr'] = new_lr
#                     if self.verbose:
#                         print(
#                             f"Decreased learning rate to {new_lr} due to increasing loss trend.")

#             self.last_slope = current_slope  # Update the last known slope


# scheduler = TrendBasedLRScheduler(optimizer, verbose=True)

In [16]:
import datetime

tensorboard_path = './log/'
current_time = datetime.datetime.now()

# Format the current time as a string (e.g., YYYYMMDD-HHMMSS)
formatted_time = current_time.strftime('%Y%m%d-%H%M%S')
# Append the formatted time to the tensorboard path
tensorboard_path += 'time-{}'.format(formatted_time)
tensorboard_path += '-num_epochs-{}'.format(num_epochs)
tensorboard_path += '-batch_size-{}'.format(batch_size)
if learning_rate:
    tensorboard_path += '-learning_rate-{}'.format(learning_rate)
elif base_learning_rate:
    tensorboard_path += '-base_learning_rate-{}'.format(base_learning_rate)
    tensorboard_path += '-det_learning_rate-{}'.format(det_learning_rate)
    tensorboard_path += '-desc_learning_rate-{}'.format(desc_learning_rate)
tensorboard_path += '-random_seed-{}'.format(seed)


if not os.path.isdir(tensorboard_path):
    os.makedirs(tensorboard_path)

writer = SummaryWriter(tensorboard_path)

In [17]:
def log_weights(model, writer, step):
    for name, param in model.named_parameters():
        writer.add_histogram(f"weights/{name}", param.data, step)
        if param.grad is not None:
            writer.add_histogram(f"grads/{name}", param.grad.data, step)

In [18]:
# for name, param in model.named_parameters():
#     if 'weight' in name:  # 只对权重进行梯度计算
#         param.requires_grad = True
#     else:
#         param.requires_grad = False


for epoch in range(num_epochs):
    model.train()
    for i, batch in enumerate(dataloader):
        # batch_loaded_pairs = batch
        optimizer.zero_grad()
        # desc= loss_calculator.get_desc_pairs(batch, height=480, width=640)
        loss, loss_desc, loss_det = loss_calculator.loss(batch, batch_size)
        # print(loss)
        # loss.backward()

        try:
            loss.backward()
        except RuntimeError as e:
            print(f"Error during loss_desc.backward(): {e}")

        # for name, param in model.named_parameters():
        #   if param.requires_grad:
        #       try:
        #           print(f"{name} gradient: {param.grad}")
        #       except AttributeError:
        #           print(f"{name} gradient: None")

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # nan_detected = False
        # for name, param in model.named_parameters():
        #     if param.grad is not None and torch.isnan(param.grad).any():
        #         print(f"NaN gradients detected in {name}")
        #         nan_detected = True

        # norm = torch.sqrt(sum(p.grad.data.norm()**2 for p in model.parameters() if p.grad is not None))
        # print(f"Current gradient norm: {norm.item()}")

        optimizer.step()
        if (epoch * len(dataloader) + i) % 700 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = max(param_group['lr'] * 0.3, 1e-50)
        # if epoch * len(dataloader) + i > 200:
        #     scheduler.step(loss_desc.item())
        # optimizer.zero_grad()
        writer.add_scalar(
            'Training/Loss', loss.item(), epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Descriptor", loss_desc.item(),
                          epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Detector", loss_det.item(),
                          epoch * len(dataloader) + i)
        # writer.add_scalar(
        #     'Training/Learning Rate', optimizer.param_groups[0]["lr"], epoch * len(dataloader) + i)
        for index, param_group in enumerate(optimizer.param_groups):
            writer.add_scalar(
                f'Training/Learning Rate/Group_{index}',
                param_group["lr"],
                epoch * len(dataloader) + i)

        if (i + 1) % 100 == 0:
            print(
                f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')
            log_weights(model, writer, epoch * len(dataloader) + i)

    torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')
    # Save the entire model
    # torch.save(model, f'model_epoch_{epoch+1}.pt')
writer.close()

Epoch [1/10], Step [100/18689], Loss: 99.946044921875
Epoch [1/10], Step [200/18689], Loss: 98.30538940429688
Epoch [1/10], Step [300/18689], Loss: 96.3205337524414
Epoch [1/10], Step [400/18689], Loss: 91.78264617919922
Epoch [1/10], Step [500/18689], Loss: 98.62808227539062
Epoch [1/10], Step [600/18689], Loss: 91.729736328125
Epoch [1/10], Step [700/18689], Loss: 91.33296203613281
Epoch [1/10], Step [800/18689], Loss: 93.52735137939453
Epoch [1/10], Step [900/18689], Loss: 90.42301177978516
Epoch [1/10], Step [1000/18689], Loss: 95.57615661621094
Epoch [1/10], Step [1100/18689], Loss: 85.71043395996094
Epoch [1/10], Step [1200/18689], Loss: 91.20927429199219
Epoch [1/10], Step [1300/18689], Loss: 99.64363861083984
Epoch [1/10], Step [1400/18689], Loss: 81.6548080444336
Epoch [1/10], Step [1500/18689], Loss: 88.99545288085938
Epoch [1/10], Step [1600/18689], Loss: 91.843505859375
Epoch [1/10], Step [1700/18689], Loss: 96.46175384521484
Epoch [1/10], Step [1800/18689], Loss: 90.030136

KeyboardInterrupt: 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# Adjust input dimensions as necessary
example_input = torch.rand(1, 1, 480, 640)
example_input = example_input.to(device)
traced_script_module = torch.jit.script(model, example_input)
traced_script_module.save(f'model_epoch_{epoch+1}_script.pt')

In [None]:
%load_ext tensorboard
%tensorboard --logdir tensorboard_path

In [None]:
# for name, param in model.named_parameters():
#     print(f"{name}: requires_grad = {param.requires_grad}")

In [None]:
# import h5py
# import torch

# def save_model_weights_to_hdf5(model, filepath):
#     with h5py.File(filepath, 'w') as f:
#         for name, param in model.named_parameters():
#             param_value = param.data.numpy()
#             f.create_dataset(name, data=param_value)

# def load_model_weights_from_hdf5(model, filepath):


#     with h5py.File(filepath, 'r') as f:
#         for name in f.keys():
#             param_value = torch.tensor(f[name][:])
#             model._parameters[name] = torch.nn.Parameter(param_value)


# save_model_weights_to_hdf5(model, '/content/drive/MyDrive/project_slam/model_weights.h5')

# load_model_weights_from_hdf5(model, 'model_weights.h5')