In [None]:
import os
import torch
import numpy as np
from models.pointnet2_utils import PointNetSetAbstraction, PointNetFeaturePropagation
import torch.nn.functional as F
import torch.nn as nn
from scipy.spatial.distance import directed_hausdorff  # Importing the directed_hausdorff function
import torch.optim as optim
import open3d as o3d
import random  # Make sure to include this at the start of your script


from scipy.stats import wasserstein_distance



# Define the classes
classes = ['ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 'bookcase',
           'board', 'clutter']



class TransformerFeatureAbstraction(nn.Module):
    def __init__(self, feature_dim, num_heads):
        super(TransformerFeatureAbstraction, self).__init__()
        self.transformer = nn.TransformerEncoderLayer(
            d_model=feature_dim,
            nhead=num_heads,
            dim_feedforward=2048,
            batch_first=True
        )
        self.positional_embeddings = nn.Parameter(torch.randn(1, 256, feature_dim))

    def forward(self, x):
        # Expecting x of shape [batch_size, num_points, feature_dim]
        # Add positional encodings to the input features
        x = x + self.positional_embeddings

        return self.transformer(x)


class Validator:
    def __init__(self, validator_folder):
        self.validator_folder = validator_folder
        self.ground_truth_files = self.load_ground_truth_files()

    
    def preprocess_point_cloud(self, points):
        """
        Preprocesses a point cloud by centering it around its centroid.

        Args:
        points (np.array): Array containing point cloud data

        Returns:
        o3d.geometry.PointCloud: Preprocessed point cloud object
        """
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        centroid = np.mean(np.asarray(pcd.points), axis=0)
        pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) - centroid)
        return pcd

    def read_point_cloud(self, file_path):
        """
        Read point cloud data from a .txt file.

        Args:
        file_path (str): Path to the .txt file

        Returns:
        np.array: Array containing point cloud data
        """
        points = []
        with open(file_path, 'r') as file:
            for line in file:
                values = line.strip().split()
                if len(values) >= 4:  # Assuming XYZRGB values are present
                    point = [float(val) for val in values[:4]]  # Consider first 4 values
                    points.append(point)
        return np.array(points)

    def calculate_hausdorff_distance(self, source_pcd, target_pcd):
        source_points = np.asarray(source_pcd.points)
        target_points = np.asarray(target_pcd.points)
        return max(directed_hausdorff(source_points, target_points)[0],
                   directed_hausdorff(target_points, source_points)[0])

   
    def calculate_emd(self, source_pcd, target_pcd):
        source_points = np.asarray(source_pcd.points).flatten()
        target_points = np.asarray(target_pcd.points).flatten()
        return wasserstein_distance(source_points, target_points)

    def calculate_rmse(self, source_pcd, target_pcd):
        source_points = np.asarray(source_pcd.points)
        target_points = np.asarray(target_pcd.points)
        squared_errors = np.sum(np.square(source_points - target_points), axis=1)
        return np.sqrt(np.mean(squared_errors))


    

    def load_ground_truth_files(self):
        ground_truth_files = {}
        for class_id in range(1, 14):
            file_path = os.path.join(self.validator_folder, f"{class_id}.txt")
            print("Checking file path:", file_path)  # Debug statement
            if os.path.exists(file_path):
                print("File exists:", file_path)  # Debug statement
                points = []
                with open(file_path, 'r') as file:
                    for line in file:
                        values = line.strip().split()
                        if len(values) >= 3:
                            point = [float(val) for val in values[:3]]
                            points.append(point)
                ground_truth_files[class_id] = np.array(points)
            else:
                print("File does not exist:", file_path)  # Debug statement
        return ground_truth_files

    def calculate_min_loss_class(self, validation_points):
        min_loss_class = None
        min_loss_value = float('inf')
    
        for class_id, ground_truth_points in self.ground_truth_files.items():
            total_loss = 0
            if len(validation_points) >= 100:  # Check if there are enough validation points
                # Calculate loss using validation_points and the corresponding ground truth file
                source_pcd = self.preprocess_point_cloud(np.array(validation_points)[:, :3])
                target_pcd = self.preprocess_point_cloud(np.array(ground_truth_points)[:, :3])
    
                reg_result = o3d.pipelines.registration.registration_icp(
                    source_pcd, target_pcd, 0.02, np.eye(4),
                    o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                    o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
    
                aligned_pcd = source_pcd.transform(reg_result.transformation)
                emd_dist = self.calculate_emd(source_pcd, aligned_pcd)
                rmse_dist = self.calculate_rmse(source_pcd, aligned_pcd)

                hausdorff_dist = directed_hausdorff(np.array(validation_points)[:, :3], ground_truth_points[:, :3])[0]
                total_loss -= hausdorff_dist  + emd_dist +rmse_dist
            else:
                total_loss += 0.001  # Add 10 to the total loss as a penalty for less than 100 validation points
    
            if total_loss < min_loss_value:
                min_loss_value = total_loss
                min_loss_class = class_id
    
        return min_loss_class, min_loss_value

    def calculate_validation_loss(self, xyzc_data_chunk):
    
        total_loss = 0
        min_loss_value=100
        for chunk_id, xyzc_data in enumerate(xyzc_data_chunk):
            current_class = None
            validation_points = []
    
            for point in xyzc_data:
                x, y, z, c = point
    
                if current_class is None:
                    current_class = c
                    validation_points.append(point)
                elif c == current_class:
                    prev_point = validation_points[-1]
                    # Check if the x-distance between consecutive points is less than 2
                    if abs(prev_point[0] - x) < 2:
                        validation_points.append(point)
                    else:
                        if len(validation_points) >= 100:  # Check if there are enough validation points
                            # Calculate loss using validation_points and the corresponding ground truth file
                            ground_truth_points = self.ground_truth_files.get(current_class, [])
                            if len(ground_truth_points) > 0:
                                source_pcd = self.preprocess_point_cloud(np.array(validation_points)[:,:3])
                                target_pcd = self.preprocess_point_cloud(np.array(ground_truth_points)[:,:3])
    
                                reg_result = o3d.pipelines.registration.registration_icp(
                                    source_pcd, target_pcd, 0.02, np.eye(4),
                                    o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                                    o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
                                min_loss_class,min_loss_value= self.calculate_min_loss_class(np.array(validation_points)[:,:3])
                                aligned_pcd = source_pcd.transform(reg_result.transformation)
                                emd_dist = self.calculate_emd(source_pcd, aligned_pcd)
                                rmse_dist = self.calculate_rmse(source_pcd, aligned_pcd)

                                hausdorff_dist = directed_hausdorff(np.array(validation_points)[:, :3], ground_truth_points[:, :3])[0]
                                total_loss = hausdorff_dist  +emd_dist +rmse_dist
                        else:
                            total_loss += 0.001 # Add 10 to the total loss as a penalty for less than 100 validation points
                        # Reset validation_points
                        validation_points = [point]
                else:
                    if len(validation_points) >= 100:  # Check if there are enough validation points
                        # Calculate loss using validation_points and the corresponding ground truth file
                        ground_truth_points = self.ground_truth_files.get(current_class, [])
                        if len(ground_truth_points) > 0:
                            source_pcd = self.preprocess_point_cloud(np.array(validation_points)[:,:3])
                            target_pcd = self.preprocess_point_cloud(np.array(ground_truth_points)[:,:3])
    
                            reg_result = o3d.pipelines.registration.registration_icp(
                                source_pcd, target_pcd, 0.02, np.eye(4),
                                o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                                o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
                            min_loss_class,min_loss_value= self.calculate_min_loss_class(np.array(validation_points)[:,:3])
                            aligned_pcd = source_pcd.transform(reg_result.transformation)
                            emd_dist = self.calculate_emd(source_pcd, aligned_pcd)
                            rmse_dist = self.calculate_rmse(source_pcd, aligned_pcd)
                            hausdorff_dist = directed_hausdorff(np.array(validation_points)[:, :3], ground_truth_points[:, :3])[0]
                            total_loss = hausdorff_dist +emd_dist+ rmse_dist
                    else:
                        total_loss += 0.001  # Add 10 to the total loss as a penalty for less than 100 validation points
                    # Reset current_class and validation_points
                    current_class = c
                    validation_points = [point]
    
            # After the loop, calculate loss for the last set of validation_points
            if len(validation_points) >= 100:  # Check if there are enough validation points
                # Calculate loss using validation_points and the corresponding ground truth file
                ground_truth_points = self.ground_truth_files.get(current_class, [])
                if len(ground_truth_points) > 0:
                    source_pcd = self.preprocess_point_cloud(np.array(validation_points)[:,:3])
                    target_pcd = self.preprocess_point_cloud(np.array(ground_truth_points)[:,:3])
    
                    reg_result = o3d.pipelines.registration.registration_icp(
                        source_pcd, target_pcd, 0.02, np.eye(4),
                        o3d.pipelines.registration.TransformationEstimationPointToPoint(),
                        o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
                    min_loss_class,min_loss_value= self.calculate_min_loss_class(np.array(validation_points)[:,:3])
                    aligned_pcd = source_pcd.transform(reg_result.transformation)
                    emd_dist = self.calculate_emd(source_pcd, aligned_pcd)
                    rmse_dist = self.calculate_rmse(source_pcd, aligned_pcd)
                    hausdorff_dist = directed_hausdorff(np.array(validation_points)[:, :3], ground_truth_points[:, :3])[0]
                    total_loss = hausdorff_dist +rmse_dist +emd_dist
                    
            else:
                total_loss += 0.001  # Add 10 to the total loss as a penalty for less than 100 validation points
    
        return total_loss,min_loss_value


    






class PointNetClassifier(nn.Module):
    def __init__(self, num_classes):
        super(PointNetClassifier, self).__init__()
        self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
        self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
        self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
        self.transformer_layer = TransformerFeatureAbstraction(feature_dim=64, num_heads=8)
        self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
        self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz):
        l0_points = xyz
        l0_xyz = xyz[:, :3, :]

        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        # Apply Transformer to enhance features
        l3_points = self.transformer_layer(l3_points)
        
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)

        x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x, l4_points

def read_point_cloud(file_path, chunk_size=40000):
    points = []
    with open(file_path, 'r') as file:
        for line in file:
            values = line.strip().split()
            if len(values) >= 6:  # Assuming XYZRGB values are present
                point = [float(val) for val in values[:6]]  # Consider first 6 values
                points.append(point)
    
    points = np.array(points)
    points = points[points[:, 0].argsort()]  # Sort based on Y coordinate
    
    for i in range(0, len(points), chunk_size):
        yield points[i:i+chunk_size]

def read_point_cloud1(point_cloud_data, chunk_size=200000):
        """
        Read point cloud data from a numpy array in chunks and sorts the entire array based on Y coordinate.
        Assumes each row contains XYZRGB values.
    
        Args:
        point_cloud_data (np.array): Numpy array containing point cloud data
        chunk_size (int): Number of rows to read per chunk
    
        Yields:
        np.array: Array containing point cloud data for each chunk
        """
        # Sort the entire array based on the second dimension (Y coordinate)
        points = point_cloud_data[point_cloud_data[:, 1].argsort()]  # Sort based on Y coordinate
        
        # Yield chunks of sorted points
        for i in range(0, len(points), chunk_size):
            yield points[i:i+chunk_size]


def pad_channels(data, target_channels=9):
    """
    Pad the channels of the data tensor to match the target_channels.
    
    Args:
    data (torch.Tensor): Input tensor with shape [batch_size, current_channels, ...]
    target_channels (int): Desired number of channels
    
    Returns:
    torch.Tensor: Tensor with padded channels if needed
    """
    current_channels = data.shape[1]
    if current_channels < target_channels:
        padding_channels = target_channels - current_channels
        padding = torch.zeros((data.shape[0], padding_channels, *data.shape[2:]), device=data.device)
        data = torch.cat([data, padding], dim=1)
    return data

def main():
    # Manually specify the input and output paths

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    dataset_dir = '/home/puoza/Pointnet_Pointnet2_pytorch/stanford_indoor3d'
    chunk_size = 500000
    num_epochs = 10  # Specify the number of epochs

    # Initialize the model
    model = PointNetClassifier(
        num_classes=len(classes), 
    
    ).to(device)
    weights_path = '/home/puoza/p2/Pointnet_Pointnet2_pytorch/log/sem_seg/validator/weights_epoch_1.pth'  # Provide the correct path
    checkpoint = torch.load(weights_path, map_location=device)

    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    else:
        state_dict = checkpoint
    state_dict = {k: v for k, v in state_dict.items() if not k.startswith('sa3.')}

    model.load_state_dict(state_dict, strict=False)
    
    model.train()
    # Define the optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.0001)  # You can adjust the learning rate as needed
    
    # Initialize the validator
    validator = Validator(validator_folder='/home/puoza/Pointnet_Pointnet2_pytorch/validate')  # Replace 'path_to_validator_folder' with the actual folder path

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        scene_files = [file for file in os.listdir(dataset_dir) if file.endswith('.npy')]
        random.shuffle(scene_files)

        # Randomly shuffle the list of scene files
        random.shuffle(scene_files)
        for scene_file in scene_files:
                scene_path = os.path.join(dataset_dir, scene_file)
                scene_data = np.load(scene_path)  # Load the numpy file
                xyzrgb_data = scene_data[:, :6]  # Extract XYZRGB data
                
                total_loss1 = 0.0  # Accumulate validation losses for the entire scene
                xyzc_data_chunk=[]
                for chunk_id, chunk_points in enumerate(read_point_cloud1(xyzrgb_data, chunk_size)):
                    input_pc_tensor = torch.tensor(chunk_points, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1).to(device)
                    input_pc_tensor = pad_channels(input_pc_tensor, 9).to(device)

                    
                    optimizer.zero_grad()
                    output_cls,po = model(input_pc_tensor)
                    class_predictions = output_cls.argmax(dim=2).squeeze().cpu().numpy()
                    # Calculate validation loss
                    xyzc_data = []  # Prepare xyzc_data
            
                    for i, point in enumerate(chunk_points):
                            x, y, z, _, _, _ = point
                            c = class_predictions[i]
                            xyzc_data.append([x, y, z, c])
                    xyzc_data_chunk.append(xyzc_data)
                   
                    # Write the predictions to a separate output file for each chunk
                    output_chunk_path = f'/home/puoza/Pointnet_Pointnet2_pytorch/pointcloud/output_chunk_{chunk_id}.txt'
                    if chunk_id<10:
                        with open(output_chunk_path, 'w') as f:
                            # Sort the points first by class and then by x-coordinate before writing
                            sorted_indices = np.lexsort((chunk_points[:, 3], chunk_points[:, 0]))
                            for i in sorted_indices:
                                x, y, z, _, _, _ = chunk_points[i]
                                c = class_predictions[i]
                                f.write(f"{x} {y} {z} {c}\n")
            
                    validation_loss,min_loss_value=validator.calculate_validation_loss(xyzc_data_chunk)
                    total_loss1 = validation_loss
                    combined_loss = total_loss1 + min_loss_value
                    print(f"Scene: {scene_file}, Chunk: {chunk_id}, Validation Loss: {validation_loss},min Loss: {min_loss_value}")
                    combined_loss_tensor = torch.tensor(combined_loss, dtype=torch.float32, requires_grad=True).to(device)
                    combined_loss_tensor.backward()
                    optimizer.step()
                    weights_save_path = os.path.join('/home/puoza/p2/Pointnet_Pointnet2_pytorch/log/sem_seg/validator', f'weights_epoch_{epoch+1}.pth')
                    torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': combined_loss_tensor.item(),
                    }, weights_save_path)
                    print(f"Weights saved successfully after scene: {scene_file}")
                # Backpropagate using accumulated validation loss
                
                print(f"Scene: {scene_file}, Total Validation Loss: {combined_loss_tensor.item()}")

# Call the main function
main()


      
