1.optimizer  2.print gradient 3. binarized layer deactivate no need, can be called  3.print weight

In [1]:
import torch
import torch.nn.functional as F
import torch.nn.init as init
import os
import numpy as np
import cv2
import h5py
import random
import types
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from tensorboardX import SummaryWriter

In [2]:
# %pip install tensorboardX
# %pip install tensorboard

In [3]:
seed = 1
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f0b8c307990>

In [4]:
def check_tensor(x, name):
    if torch.isnan(x).any() or torch.isinf(x).any():
        print(f"NaN or Inf found in {name}")

In [5]:
def load_pairs_from_hdf5(hdf5_file_path, hdf5_folder):
    with h5py.File(hdf5_file_path, 'r') as hdf:
        loaded_pairs = []

        # Function to convert a single relative path back to an absolute path
        def make_absolute(rel_path):
            # Decode if the path is a byte string
            if isinstance(rel_path, bytes):
                rel_path = rel_path.decode('utf-8')
            parts = rel_path.split('/')
            new_parts = []
            for part in parts:
                if part == 'sun3d_extracted' or part == '..':
                    continue
                new_parts.append(part)
            corrected_path = '/'.join(new_parts)
            absolute_path = os.path.join(hdf5_folder, corrected_path)
            return absolute_path

        # Function to process paths in pairs

        def process_paths(img_paths_array):
            # Ensure each path in the tuple is absolute
            return tuple(make_absolute(path) for path in img_paths_array)

        # Load pairs
        pairs_group = hdf['pairs']
        for pair_name in pairs_group:
            pair_group = pairs_group[pair_name]
            # This will be a NumPy array
            img_paths_array = pair_group['img_paths'][()]
            # Process each path to be absolute
            img_paths = process_paths(img_paths_array)
            points1 = torch.tensor(pair_group['points1'][()])
            pos_points2 = torch.tensor(pair_group['pos_points2'][()])
            neg_points2 = torch.tensor(pair_group['neg_points2'][()])
            loaded_pairs.append({
                'img_paths': img_paths,
                'points1': points1,
                'pos_points2': pos_points2,
                'neg_points2': neg_points2
            })

    return loaded_pairs

In [6]:
# Get the current working directory
current_directory = os.getcwd()
output_path = os.path.join(current_directory, os.pardir,
                           os.pardir, 'datasets', 'sun3d_training')
# output_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training'
hdf5_file_path = os.path.join(output_path, 'pairs.hdf5')
# hdf5_file_path = '/content/drive/MyDrive/project_slam/dataset/sun3d_training/pairs.hdf5'
loaded_pairs = load_pairs_from_hdf5(hdf5_file_path, output_path)

In [7]:
class BinarizedActivation(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        out = torch.sign(input)
        # Ensures that all zero values become one, and all other values remain as they are.
        out = out + (out == 0.0).float()  # Adds 1 to only where out == 0

        # # Check if there are any such elements
        # if torch.any(mask):
        #     print(out[mask])  # Print only the elements that do not meet the condition
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[(input.abs() > 1)] = 0
        return grad_input


class GCNv2(nn.Module):
    def __init__(self):
        super(GCNv2, self).__init__()
        self.elu = torch.nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1)
        self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1)
        self.convF_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convF_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.convD_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.convD_2 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
        self.pixel_shuffle = nn.PixelShuffle(16)
        # Adjust scale factor as needed
        self.upsample = nn.Upsample(
            scale_factor=16, mode='bilinear', align_corners=True)
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                    nn.init.kaiming_normal_(
                        m.weight, mode='fan_in', nonlinearity='relu')
    # def reset_parameters(self):
    #     for m in self.modules():
    #         if isinstance(m, nn.Conv2d):
    #             # Glorot uniform initialization, also known as Xavier uniform initialization
    #             nn.init.xavier_uniform_(m.weight)
    #             if m.bias is not None:
    #                 nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x.requires_grad_(True)
        x = self.elu(self.conv1(x))
        x = self.elu(self.conv2(x))
        x = self.elu(self.conv3_1(x))
        x = self.elu(self.conv3_2(x))
        x = self.elu(self.conv4_1(x))
        x = self.elu(self.conv4_2(x))

        # Detector
        xD = self.elu(self.convD_1(x))
        det = self.convD_2(xD).sigmoid()
        det = self.pixel_shuffle(det)

        # Descriptor
        xF = self.elu(self.convF_1(x))
        desc = self.convF_2(xF)
        desc = self.upsample(desc)
        dn = torch.norm(desc, p=2, dim=1)
        desc = desc.div(torch.unsqueeze(dn, 1))

        return desc, det

In [8]:
class loss_calculator_3:
    def __init__(self, model):
        self.model = model.to(torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'))
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

    def get_images_for_batch(self, img_paths):
        images = []
        for img_path in img_paths:
            img1, img2 = cv2.imread(img_path[0]), cv2.imread(img_path[1])
            gray_img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
            gray_img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
            img_tensor1 = torch.from_numpy(
                gray_img1).unsqueeze(0).float().to(self.device)
            img_tensor2 = torch.from_numpy(
                gray_img2).unsqueeze(0).float().to(self.device)
            images.append((img_tensor1, img_tensor2))

        return images

    def get_batch_mask(self, coord_cur, coord_tar_pos, coord_tar_neg, height=480, width=640):

        mask_cur = self.generate_batch_mask(
            coord_cur, height, width).to(self.device)
        mask_pos = self.generate_batch_mask(
            coord_tar_pos, height, width).to(self.device)
        mask_neg = self.generate_batch_mask(
            coord_tar_neg, height, width).to(self.device)

        return mask_cur, mask_pos, mask_neg

    def get_loss_paris(self, gray_images, batchsize, mask_cur, mask_pos, mask_neg):

        desc_dict = {'cur': [], 'tar_pos': [], 'tar_neg': []}

        det_dict = {}
        det_cur_list = []
        det_tar_list = []

        for i in range(batchsize):
            inp_cur, inp_tar = gray_images[i]
            inp_cur = inp_cur.unsqueeze(0).float()
            inp_tar = inp_tar.unsqueeze(0).float()

            outputs_cur = self.model(inp_cur)
            outputs_tar = self.model(inp_tar)

            # Unpack the outputs
            desc_cur, det_cur = outputs_cur
            desc_tar, det_tar = outputs_tar

            det_cur.requires_grad_(True)
            det_tar.requires_grad_(True)

            desc_cur.requires_grad_(True)
            desc_tar.requires_grad_(True)

            det_cur_list.append(det_cur.squeeze())
            det_tar_list.append(det_tar.squeeze())

            # Generate indices from the mask
            nonzero_indices_cur = mask_cur[i].nonzero(as_tuple=False)
            nonzero_indices_pos = mask_pos[i].nonzero(as_tuple=False)
            nonzero_indices_neg = mask_neg[i].nonzero(as_tuple=False)

            # Extract descriptors using the indices
            # We index with (y, x) because the first dimension is height (rows) and the second is width (columns)
            sampled_desc_cur = desc_cur[0, :,
                                        nonzero_indices_cur[:, 0], nonzero_indices_cur[:, 1]].t()
            sampled_desc_pos = desc_tar[0, :,
                                        nonzero_indices_pos[:, 0], nonzero_indices_pos[:, 1]].t()
            sampled_desc_neg = desc_tar[0, :,
                                        nonzero_indices_neg[:, 0], nonzero_indices_neg[:, 1]].t()

            # Append to lists
            desc_dict['cur'].append(sampled_desc_cur)
            desc_dict['tar_pos'].append(sampled_desc_pos)
            desc_dict['tar_neg'].append(sampled_desc_neg)

        det_cur_list = torch.stack(det_cur_list, dim=0)
        det_tar_list = torch.stack(det_tar_list, dim=0)

        det_dict = {'points1': det_cur_list, 'pos_points2': det_tar_list}

        desc_dict['cur'] = torch.cat(desc_dict['cur'])
        desc_dict['tar_pos'] = torch.cat(desc_dict['tar_pos'])
        desc_dict['tar_neg'] = torch.cat(desc_dict['tar_neg'])

        return det_dict, desc_dict

    def batch_desc_loss(self, desc_batch, margin=1.0):
        ldesc = 0
        count_non_zero_losses = 0  # Counter for non-zero losses
        for cur_list, pos_list, neg_list in zip(desc_batch['cur'], desc_batch['tar_pos'], desc_batch['tar_neg']):

            pairwise_dist_pos = torch.sum((cur_list - pos_list) ** 2, dim=-1)
            pairwise_dist_neg = torch.sum((cur_list - neg_list) ** 2, dim=-1)

            individual_losses = torch.max(torch.zeros_like(
                pairwise_dist_pos), pairwise_dist_pos - pairwise_dist_neg + margin)
            sample_loss = torch.sum(individual_losses)

            # Only add to the loss sum if it's non-zero
            if torch.any(individual_losses > 0):
                # Count how many are non-zero
                count_non_zero_losses += individual_losses.gt(0).sum().item()
                ldesc += sample_loss

        # Normalize the loss by the number of non-zero losses if there are any
        if count_non_zero_losses > 0:
            # print("total num of challenging pairs:", count_non_zero_losses)
            ldesc = ldesc / count_non_zero_losses
        else:
            ldesc = torch.tensor(0.0).to(self.device)

        return ldesc.to(self.device)

    def generate_batch_mask(self, coords, height=480, width=640):

        mask = torch.zeros((coords.shape[0], height, width), dtype=torch.uint8)
        for batch_idx in range(coords.shape[0]):
            batch_coords = coords[batch_idx]
            for point_idx, (x, y) in enumerate(batch_coords):
                x = x.int()
                y = y.int()
                if 0 <= x < width and 0 <= y < height:
                    mask[batch_idx, y, x] = 1

        return mask

    def batch_det_loss(self, mask_cur, mask_pos, batch):
        cur_det = batch['points1']
        tar_det = batch['pos_points2']

        Lce_cur = self.binary_cross_entropy(
            cur_det, mask_cur) / (480 * 640)
        Lce_tar = self.binary_cross_entropy(
            tar_det, mask_pos) / (480 * 640)
        Ldet = (Lce_cur + Lce_tar)/2
        return Ldet

    def binary_cross_entropy(self, o, c, epsilon=1e-8):
        c = c.to(o.dtype)
        positive = c.sum()
        total = c.numel()
        negative = total - positive

        if positive.item() == 0:  # Avoid division by zero
            pos_weight = 1.0  # Default to 1 or handle it according to your context
        else:
            pos_weight = negative / positive
        weights = c * pos_weight + (1 - c) * 1.0
        # Ensuring that the target labels are within (epsilon, 1 - epsilon) to avoid log(0)
        c = torch.clamp(c, epsilon, 1 - epsilon)
        bce_loss = F.binary_cross_entropy_with_logits(
            o, c, reduction='none', weight=weights)
        total_loss = torch.sum(bce_loss)
        return total_loss

    def loss(self, batch_loaded_pairs, height=480, width=640, margin=1.0):
        img_paths = batch_loaded_pairs.get('img_paths')
        coord_cur = batch_loaded_pairs.get('points1')
        coord_tar_pos = batch_loaded_pairs.get('pos_points2')
        coord_tar_neg = batch_loaded_pairs.get('neg_points2')

        batchsize = len(img_paths)

        gray_images = self.get_images_for_batch(img_paths)

        mask_cur, mask_pos, mask_neg = self.get_batch_mask(
            coord_cur, coord_tar_pos, coord_tar_neg, height, width)

        det_batch, desc_batch = self.get_loss_paris(
            gray_images, batchsize, mask_cur, mask_pos, mask_neg)

        loss_det = self.batch_det_loss(
            mask_cur, mask_pos, det_batch)

        loss_desc = self.batch_desc_loss(desc_batch, margin)

        final_loss = loss_desc + (loss_det / batchsize)

        return final_loss, loss_desc, loss_det/batchsize

In [9]:
model = GCNv2()
loss_calculator = loss_calculator_3(model)


def collate_fn(batch):
    max_len_points = max(len(sample['points1']) for sample in batch)
    for sample in batch:
        n = max_len_points - len(sample['points1'])
        inf_tensor = torch.tensor([[float('inf'), float('inf')]] * n)
        sample['points1'] = torch.cat((sample['points1'], inf_tensor), dim=0)
        sample['pos_points2'] = torch.cat(
            (sample['pos_points2'], inf_tensor), dim=0)
        sample['neg_points2'] = torch.cat(
            (sample['neg_points2'], inf_tensor), dim=0)
    img_path = [sample['img_paths'] for sample in batch]
    points1 = torch.stack([sample['points1'] for sample in batch], dim=0)
    pos_points2 = torch.stack([sample['pos_points2']
                              for sample in batch], dim=0)
    neg_points2 = torch.stack([sample['neg_points2']
                              for sample in batch], dim=0)

    return {'img_paths': img_path, 'points1': points1, 'pos_points2': pos_points2, 'neg_points2': neg_points2}


dataset = loaded_pairs

batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, collate_fn=collate_fn)

In [10]:
def adjust_learning_rate(optimizer: torch.optim.Optimizer, epoch: int) -> None:
    """Halves the learning rate of the optimizer every 40 epochs.

    Args:
    optimizer (torch.optim.Optimizer): The optimizer for which to adjust the learning rate.
    epoch (int): The current epoch number.

    """
    # Every 8 epochs, halve the learning rate
    if epoch % 15 == 0 and epoch != 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5

In [11]:
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(
    0.95, 0.999), eps=1e-8, weight_decay=1e-5)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
num_epochs = 32

In [12]:
import datetime

tensorboard_path = './log/'
current_time = datetime.datetime.now()

# Format the current time as a string (e.g., YYYYMMDD-HHMMSS)
formatted_time = current_time.strftime('%Y%m%d-%H%M%S')
# Append the formatted time to the tensorboard path
tensorboard_path += 'time-{}'.format(formatted_time)
tensorboard_path += '-num_epochs-{}'.format(num_epochs)
tensorboard_path += '-batch_size-{}'.format(batch_size)
if learning_rate:
    tensorboard_path += '-learning_rate-{}'.format(learning_rate)
tensorboard_path += '-random_seed-{}'.format(seed)


if not os.path.isdir(tensorboard_path):
    os.makedirs(tensorboard_path)

writer = SummaryWriter(tensorboard_path)

In [13]:
def log_weights(model, writer, step):
    for name, param in model.named_parameters():
        writer.add_histogram(f"weights/{name}", param.data, step)
        if param.grad is not None:
            writer.add_histogram(f"grads/{name}", param.grad.data, step)

In [14]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
example_input = torch.rand(1, 1, 480, 640)
example_input = example_input.to(device)
for epoch in range(num_epochs):
    model.train()
    adjust_learning_rate(optimizer, epoch)
    for i, batch in enumerate(dataloader):

        optimizer.zero_grad()

        loss, loss_desc, loss_det = loss_calculator.loss(
            batch)

        try:
            loss.backward()
        except RuntimeError as e:
            print(f"Error during loss_desc.backward(): {e}")

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # norm = torch.sqrt(sum(p.grad.data.norm()**2 for p in model.parameters() if p.grad is not None))
        # print(f"Current gradient norm: {norm.item()}")

        optimizer.step()

        writer.add_scalar(
            'Training/Loss', loss.item(), epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Descriptor", loss_desc.item(),
                          epoch * len(dataloader) + i)
        writer.add_scalar("Loss/Detector", loss_det.item(),
                          epoch * len(dataloader) + i)

        writer.add_scalar(
            'Training/Learning Rate', optimizer.param_groups[0]["lr"], epoch * len(dataloader) + i)

        if (i + 1) % 100 == 0:
            print(
                f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')
            log_weights(model, writer, epoch * len(dataloader) + i)
            # torch.save(model, f'model_step.pt')

    # Save the entire model
    model.eval()
    # torch.save(model, f'model_epoch_{epoch+1}.pt')

    traced_script_module = torch.jit.script(model, example_input)
    traced_script_module.save(f'model_epoch_{epoch+1}_script.pt')


writer.close()

Epoch [1/32], Step [100/18689], Loss: 2.55623722076416
Epoch [1/32], Step [200/18689], Loss: 2.3670716285705566
Epoch [1/32], Step [300/18689], Loss: 2.298097848892212
Epoch [1/32], Step [400/18689], Loss: 2.4873671531677246
Epoch [1/32], Step [500/18689], Loss: 2.224130630493164
Epoch [1/32], Step [600/18689], Loss: 2.223588228225708
Epoch [1/32], Step [700/18689], Loss: 2.3517305850982666
Epoch [1/32], Step [800/18689], Loss: 2.285470962524414
Epoch [1/32], Step [900/18689], Loss: 2.4053571224212646
Epoch [1/32], Step [1000/18689], Loss: 2.2743709087371826
Epoch [1/32], Step [1100/18689], Loss: 2.318765640258789
Epoch [1/32], Step [1200/18689], Loss: 2.3776283264160156
Epoch [1/32], Step [1300/18689], Loss: 2.2468085289001465
Epoch [1/32], Step [1400/18689], Loss: 2.313753604888916
Epoch [1/32], Step [1500/18689], Loss: 2.2902376651763916
Epoch [1/32], Step [1600/18689], Loss: 2.23763370513916
Epoch [1/32], Step [1700/18689], Loss: 2.1915838718414307
Epoch [1/32], Step [1800/18689], 



Epoch [2/32], Step [100/18689], Loss: 2.1157188415527344
Epoch [2/32], Step [200/18689], Loss: 2.107551336288452
Epoch [2/32], Step [300/18689], Loss: 2.122422933578491
Epoch [2/32], Step [400/18689], Loss: 2.06986665725708
Epoch [2/32], Step [500/18689], Loss: 1.989598035812378
Epoch [2/32], Step [600/18689], Loss: 2.1741867065429688
Epoch [2/32], Step [700/18689], Loss: 2.1618239879608154
Epoch [2/32], Step [800/18689], Loss: 2.2015223503112793
Epoch [2/32], Step [900/18689], Loss: 1.9318764209747314
Epoch [2/32], Step [1000/18689], Loss: 2.323119878768921
Epoch [2/32], Step [1100/18689], Loss: 2.0877814292907715
Epoch [2/32], Step [1200/18689], Loss: 2.162752866744995
Epoch [2/32], Step [1300/18689], Loss: 2.045079469680786
Epoch [2/32], Step [1400/18689], Loss: 1.7666923999786377
Epoch [2/32], Step [1500/18689], Loss: 2.1580638885498047
Epoch [2/32], Step [1600/18689], Loss: 2.1191818714141846
Epoch [2/32], Step [1700/18689], Loss: 2.0922136306762695
Epoch [2/32], Step [1800/18689]

In [15]:
def save_checkpoint(model, optimizer, epoch, loss, scheduler=None, filename='checkpoint.pth.tar'):
    """
    Saves the training checkpoint during model training.

    Args:
    model (torch.nn.Module): The model being trained.
    optimizer (torch.optim.Optimizer): The optimizer used for training.
    epoch (int): Current training epoch.
    loss (float): The training loss at the current epoch.
    scheduler (torch.optim.lr_scheduler, optional): The learning rate scheduler.
    filename (str, optional): The path to save the checkpoint file.
    """
    checkpoint = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'loss': loss
    }

    if scheduler:
        checkpoint['scheduler'] = scheduler.state_dict()

    torch.save(checkpoint, filename)
    print(f"Checkpoint saved to {filename}")

In [17]:
save_checkpoint(model, optimizer, num_epochs, loss.item(),
                filename='checkpoint_32epoch.pth.tar')

Checkpoint saved to checkpoint_32epoch.pth.tar


In [None]:
%load_ext tensorboard
%tensorboard --logdir tensorboard_path