In [8]:
import numpy as np
import tensorflow as tf

class TwoChannelVisibilityDataset(tf.keras.utils.Sequence):
    def __init__(self, points_file, encodings_file, points_per_scene, resolution, batch_size=32, caching=True, cache_factor=1):
        self.points_file = points_file
        self.encodings_file = encodings_file
        self.batch_size = batch_size
        self.points_per_scene = points_per_scene
        self.resolution = resolution
        self.caching = caching
        self.cache_factor = cache_factor
        self.indices = None
        self._load_indices()
        self.cache = [False for i in range(int(self.__len__()*self.cache_factor))]
        if self.caching:
            self.load_cache()

    def _load_indices(self):
        with open(self.points_file, 'r') as f:
            total_points = sum(1 for _ in f)
        self.num_samples = total_points
        self.indices = np.arange(self.__len__())
        np.random.shuffle(self.indices)

    def __len__(self):
        return int(np.ceil(self.num_samples / self.batch_size))
    
    def load_cache(self):
        last_scene = None
        with open(self.points_file, 'r') as pf, open(self.encodings_file, 'r') as ef:
            for start in range(int(self.__len__()*self.cache_factor)):
                points = []
                labels = []
                scenes = []

                
                read_last_scene_again = False
                
                for _ in range(self.batch_size):
                    point_line = pf.readline()
                    if not point_line.strip():
                        break

                    components = point_line.strip().split()
                    points.append([float(x) for x in components[:4]])
                    labels.append(int(components[4]))
                
                if (start*self.batch_size)%self.points_per_scene>0:
                    read_last_scene_again = True

                l = []
                if self.points_per_scene-(start*self.batch_size)%self.points_per_scene>=self.batch_size:
                    l.append(self.batch_size)
                else:
                    l.append(self.points_per_scene-(start*self.batch_size)%self.points_per_scene)
                    remaining = self.batch_size-(self.points_per_scene-(start*self.batch_size)%self.points_per_scene)
                    for _ in range(remaining//self.points_per_scene):
                        l.append(self.points_per_scene)
                    if (remaining%self.points_per_scene>0):
                        l.append(remaining%self.points_per_scene)

                not_available = False

                # print(start,(start*self.batch_size)%self.points_per_scene,read_last_scene_again,len(l),l)

                for i in range(len(l)):
                    __ = l[i]
                    if i==0 and read_last_scene_again:
                        scene = last_scene
                    else:
                        scene = []
                        for _ in range(self.resolution):
                            scene_line = ef.readline()
                            if not scene_line.strip():
                                not_available = True
                                break
                            scene.append([float(x) for x in scene_line.strip().split()])
                        if not_available:
                            break
                        ef.readline()
                    last_scene = scene
                    for _ in range(__):
                        scenes.append(scene)

                combined = []
                for i, scene in enumerate(scenes):
                    point_grid = np.zeros((self.resolution, self.resolution), dtype=np.float32)
                    p1_x, p1_y, p2_x, p2_y = points[i]
                    point_grid[min(int(p1_x * self.resolution), self.resolution-1),
                            min(int(p1_y * self.resolution), self.resolution-1)] = 1.0
                    point_grid[min(int(p2_x * self.resolution), self.resolution-1),
                            min(int(p2_y * self.resolution), self.resolution-1)] = 1.0

                    combined.append(np.stack([scene, point_grid], axis=-1))  # Shape: (resolution, resolution, 2)

                self.cache[start] = (np.array(combined), np.array(labels, dtype=np.float32))

        return


    def _load_chunk(self, start):
        points = []
        labels = []
        scenes = []

        with open(self.points_file, 'r') as pf, open(self.encodings_file, 'r') as ef:
            # Skip lines to start reading from the correct position
            for _ in range(start*self.batch_size):
                pf.readline()

            for _ in range((start*self.batch_size)//self.points_per_scene):
                for _ in range(self.resolution):
                    ef.readline()
                ef.readline()

            for _ in range(self.batch_size):
                point_line = pf.readline()
                if not point_line.strip():
                    break

                components = point_line.strip().split()
                points.append([float(x) for x in components[:4]])
                labels.append(int(components[4]))
            
            l = []
            if self.points_per_scene-(start*self.batch_size)%self.points_per_scene>=self.batch_size:
                l.append(self.batch_size)
            else:
                l.append(self.points_per_scene-(start*self.batch_size)%self.points_per_scene)
                remaining = self.batch_size-(self.points_per_scene-(start*self.batch_size)%self.points_per_scene)
                for _ in range(remaining//self.points_per_scene):
                    l.append(self.points_per_scene)
                if (remaining%self.points_per_scene>0):
                    l.append(remaining%self.points_per_scene)

            not_available = False
            for __ in l:
                scene = []
                for _ in range(self.resolution):
                    scene_line = ef.readline()
                    if not scene_line.strip():
                        not_available = True
                        break
                    scene.append([float(x) for x in scene_line.strip().split()])
                if not_available:
                    break
                ef.readline()
                for _ in range(__):
                    scenes.append(scene)

        return np.array(points, dtype=float), np.array(labels, dtype=int), np.array(scenes, dtype=float)

    def __getitem__(self, idx):
        
        start = self.indices[idx]

        if self.caching and start<int(self.__len__()*self.cache_factor) and self.cache[start]!=False:
            (combined,labels) = self.cache[start]
            return combined,labels
        
        # Load the chunk containing the required batch
        points, labels, scenes = self._load_chunk(start)

        combined = []
        for i, scene in enumerate(scenes):
            point_grid = np.zeros((self.resolution, self.resolution), dtype=np.float32)
            p1_x, p1_y, p2_x, p2_y = points[i]
            point_grid[min(int(p1_x * self.resolution), self.resolution-1),
                       min(int(p1_y * self.resolution), self.resolution-1)] = 1.0
            point_grid[min(int(p2_x * self.resolution), self.resolution-1),
                       min(int(p2_y * self.resolution), self.resolution-1)] = 1.0

            combined.append(np.stack([scene, point_grid], axis=-1))  # Shape: (resolution, resolution, 2)

        if self.caching and start<int(self.__len__()*self.cache_factor):
            self.cache[start] = (np.array(combined), np.array(labels, dtype=np.float32))

        return np.array(combined), np.array(labels, dtype=np.float32)

    def on_epoch_end(self):
        np.random.shuffle(self.indices)


# points_file = "../datasets/open_building_2D_cor_64x64.txt"
# encodings_file = "../datasets/open_building_2D_encoding_64x64.txt"
points_file = "../datasets/tt_10.txt"
encodings_file = "../datasets/scenett_10.txt"
batch_size = 2048
resolution = 32
points_per_scene = 10


dataset = TwoChannelVisibilityDataset(points_file, encodings_file, points_per_scene, resolution, batch_size,cache_factor=1)


for data, batch_labels in dataset:
    print("Batch Data Shape:", data.shape)
    print("Batch Labels Shape:", batch_labels.shape)



Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (848, 32, 32, 2)
Batch Labels Shape: (848,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch Labels Shape: (2048,)
Batch Data Shape: (2048, 32, 32, 2)
Batch 