In [None]:
from numpy.lib.function_base import average
from torch.optim import optimizer
import valeodata
import laspy
import logging
import os
import numpy as np
import math
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import resnet_no_pool
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

In [None]:
def save_image(filename, im):

    dpi = 100
    fig = plt.figure(figsize=(im.shape[0]//dpi+1,im.shape[1]//dpi+1))
    ax = plt.axes([0., 0., 1., 1.], frameon=False, xticks=[],yticks=[])
    ax.imshow(im, interpolation='none')
    plt.savefig(filename)

In [None]:
class DFC2018_Im_Elevation_Dataset():

    def __init__(self, iterations, target_scale=0.5, imsize=64):

        image_scale = 0.05
        points_scale = 1.0
        target_scale = target_scale
        self.iterations = iterations
        self.imsize = imsize

        rootdir = "/datasets_local/DFC2018/test_mini_dataset"
        filename_points = os.path.join(rootdir, "lidar_C3_271460_3289689.las")
        filename_image = os.path.join(rootdir, "UH_NAD83_271460_3289689.tif")

        # load the image
        im = Image.open(filename_image)
        target_im_size = (int(im.width/(target_scale/image_scale)), int(im.height/(target_scale/image_scale)))
        im = im.resize(target_im_size)
        im = np.array(im)
        im = im.transpose(1,0,2)
        im = im[:,::-1]

        im = im.astype(np.float32) / 255
        self.im = torch.tensor(im, dtype=torch.float)

        # load the points as an elevation map
        with laspy.file.File(filename_points, mode="r") as file:
            points = np.stack([file.x, file.y, file.z], axis=1)

        image_points = points * points_scale / target_scale
        image_points_size = (math.ceil(image_points[:,0].max()-image_points[:,0].min()),
                        math.ceil(image_points[:,1].max()-image_points[:,1].min()))


        self.points = np.full(image_points_size, np.NaN)
        pts_coords = (image_points[:,:2] - image_points[:,:2].min(axis=0)).astype(np.int)

        elevation = image_points[:,2]
        elevation = elevation - elevation.mean()
        elevation_std = elevation.std()
        elevation[elevation < -3*elevation_std] = np.NaN #float("nan")
        elevation[elevation > 3*elevation_std] = np.NaN #float("nan")
        elevation = (elevation-3*elevation_std) / (6*elevation_std) # resize [0,1]

        self.points[pts_coords[:,0], pts_coords[:,1]] = elevation
        self.points[np.isnan(self.points)] = -1
        self.points = torch.tensor(self.points, dtype=torch.float)

        self.target_size = image_points_size

    def __len__(self):
        return self.iterations

    def prepare_image_patch(self, im_patch):
        im_patch = im_patch.permute(2,0,1)
        return im_patch

    def prepare_points_patch(self, points_patch):
        points_patch = points_patch - points_patch.mean() # prevent from learning an elevation bias
        return points_patch

    def __getitem__(self, index):
        
        x = torch.randint(0, self.target_size[0]-self.imsize, (1,)).item()
        y = torch.randint(0, self.target_size[1]-self.imsize, (1,)).item()

        im_patch = self.im[x:x+self.imsize,y:y+self.imsize]
        im_patch = self.prepare_image_patch(im_patch)


        points_patch = self.points[x:x+self.imsize,y:y+self.imsize]
        points_patch = points_patch.unsqueeze(0).expand_as(im_patch)

        if torch.randint(2, (1,)).item():
            im_patch = torch.flip(im_patch, dims=[1])
        if torch.randint(2, (1,)).item():
            im_patch = torch.flip(im_patch, dims=[2])
        if torch.randint(2, (1,)).item():
            points_patch = torch.flip(points_patch, dims=[1])
        if torch.randint(2, (1,)).item():
            points_patch = torch.flip(points_patch, dims=[2])

        return im_patch, points_patch

    def get_full_image(self):

        im_patch = self.im
        im_patch = self.prepare_image_patch(im_patch)
        return im_patch

    def get_full_points(self):
        points_patch = self.points
        points_patch = self.prepare_points_patch(points_patch)
        return points_patch

    def get_point_patch(self):
        
        x = torch.randint(0, self.target_size[0]-self.imsize, (1,)).item()
        y = torch.randint(0, self.target_size[1]-self.imsize, (1,)).item()
        points_patch = self.points[x:x+self.imsize,y:y+self.imsize]
        points_patch = points_patch.unsqueeze(0).expand(3, -1, -1)
        return points_patch, x, y

In [None]:
def nt_xent_loss(out_1, out_2, temperature):
    """
    Loss used in SimCLR
    """
    
    out_1 = F.normalize(out_1, dim=1)
    out_2 = F.normalize(out_2, dim=1)

    out = torch.cat([out_1, out_2], dim=0)
    n_samples = len(out)

    # Full similarity matrix
    cov = torch.mm(out, out.t().contiguous())
    sim = torch.exp(cov / temperature)

    # Negative similarity
    mask = ~torch.eye(n_samples, device=sim.device).bool()
    neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

    # Positive similarity :
    pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
    pos = torch.cat([pos, pos], dim=0)
    loss = -torch.log(pos / neg).mean()

    return loss, sim

In [None]:
in_channels = 3
out_channels = 128
batch_size = 128
num_workers = 8
num_epochs = 10
temperature = 1
num_epoch_steps = 1000
device = torch.device("cuda")

In [None]:
logging.info("Creating the dataset")
dataset = DFC2018_Im_Elevation_Dataset(iterations=num_epoch_steps*batch_size)
dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
    )

logging.info("Creating the models")
net_image = resnet_no_pool.resnet18(num_classes=out_channels)
net_image = net_image.to(device)
net_points = resnet_no_pool.resnet18(num_classes=out_channels)
net_points = net_points.to(device)
net_proj = nn.Sequential(nn.Conv2d(out_channels, out_channels,1), nn.ReLU(), nn.Conv2d(out_channels, out_channels,1)).to(device)

logging.info("Creating the optimizer")
optimizer = torch.optim.Adam(
        [
            {
                "params": net_image.parameters(),
                "lr": 0.001,
            },
            {
                "params": net_points.parameters(),
                "lr": 0.001,
            },
            {
                "params": net_proj.parameters(),
                "lr": 0.001,
            }
        ]
    )


for epoch in range(num_epochs):

    total_loss = 0
    average_accuracy_1 = 0
    average_accuracy_2 = 0

    t = tqdm(dataloader, ncols=100)
    for batch_id, batch in enumerate(t):

        images = batch[0].to(device)
        points = batch[1].to(device)

        optimizer.zero_grad()

        image_features = net_image(images)
        points_features = net_points(points)

        image_features = F.adaptive_avg_pool2d(image_features, output_size=(1, 1))
        points_features = F.adaptive_avg_pool2d(points_features, output_size=(1, 1))

        image_proj = net_proj(image_features)
        points_proj = net_proj(points_features)

        image_proj = torch.flatten(image_proj, 1)
        points_proj = torch.flatten(points_proj, 1)

        loss, sim_mat = nt_xent_loss(image_proj, points_proj, temperature)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()


        sim_mat = sim_mat[image_proj.shape[0]:, :image_proj.shape[0]]
        sim_mat = sim_mat.detach().cpu().numpy()

        labels = np.arange(image_proj.shape[0])

        predictions = np.argmax(sim_mat, axis=1)
        accuracy = (predictions==labels).sum() / labels.shape[0]
        average_accuracy_1 += accuracy

        predictions = np.argmax(sim_mat, axis=0)
        accuracy = (predictions==labels).sum() / labels.shape[0]
        average_accuracy_2 += accuracy

        t.set_description_str(f"Loss {total_loss/(batch_id+1):.5e} - Acc1 {average_accuracy_1 / (batch_id+1) * 100:.2f}% - Acc2 {average_accuracy_2 / (batch_id+1) * 100:.2f}%")

In [None]:
# validation

net_image.eval()
net_points.eval()
net_proj.eval()


images = dataset.get_full_image()
images = images.unsqueeze(0).to(device)


with torch.no_grad():

    image_features = net_image(images)
    image_proj = net_proj(image_features)
    image_proj = F.normalize(image_proj, dim=1)
    image_proj = image_proj.permute(0, 2,3,1)
    image_proj = image_proj.cpu()


In [None]:


with torch.no_grad():
    # image_proj = image_proj.squeeze(0)

    points, x, y = dataset.get_point_patch()
    points = points.unsqueeze(0).to(device)
    points_features = net_points(points)
    points_features = F.adaptive_avg_pool2d(points_features, output_size=(1, 1))
    points_proj = net_proj(points_features)
    points_proj = F.normalize(points_proj, dim=1)
    points_proj = torch.flatten(points_proj, 1)
    points_proj = points_proj.squeeze(0)
    points_proj = points_proj.cpu()
    
    # compute simalirity 
    print(image_proj.shape)
    print(points_proj.shape)

    sim = torch.matmul(image_proj, points_proj.unsqueeze(1))
    sim = sim.squeeze(3).squeeze(0)
    sim = sim.cpu().numpy()
    im = (images[0].cpu().permute(1,2,0) * 255).long().numpy()
    im_pts = points[0,0].cpu().numpy()

    im[x:x+64, y:y+64, 0] = 255
    im[x:x+64, y:y+64, 1] = 0
    im[x:x+64, y:y+64, 2] = 0

    fig = plt.figure(figsize=(12,4), dpi= 100)
    ax1 = plt.subplot(1, 3, 1)
    ax2 = plt.subplot(1, 3, 2)
    ax3 = plt.subplot(1, 3, 3)
    ax1.imshow(im_pts)
    ax2.imshow(im)
    pos = ax3.imshow(sim)
    fig.colorbar(pos, ax=ax3)
