## import

In [3]:
import os
import sys
import time
import re
import numpy as np
import random
from tqdm import tqdm
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import math
import zipfile
import sklearn
import shutil
from glob import glob
from PIL import Image
import tempfile

In [2]:
plt.switch_backend('agg')
%matplotlib inline
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils import data
from torchvision import datasets, models, transforms
import torchvision.utils as vutils

## Test Dataset

In [123]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import zipfile
import tempfile
import re
from tqdm import tqdm


def read_image(image_path):
    """
    reads image from path and shape it into (c,h,w)
    """
    img = cv2.imread(image_path, 0)  # Assuming the image is grayscale
    img_array = np.array(img)
    img_array = np.expand_dims(img_array, axis=0)  # Add channel dimension
    return img_array


def format_name(zip_file_name):
    """
    reads zip file name and format it into lower case
    """
    formatted_name = zip_file_name.replace(".zip", "").replace("_", " ").lower()
    return formatted_name


class BldgDataset(Dataset):
    def __init__(
        self,
        data_path="./data/experiments/casestudy.zip",
        mode="train",
        transform=None,
        seq_len=5,
        num_seq=6,
        num_frame=30,
    ):
        super(BldgDataset, self).__init__()

        self.data_path = data_path
        self.mode = mode
        self.transform = transform
        self.seq_len = seq_len
        self.num_seq = num_seq
        self.building_names = []
        self.num_frame = num_frame  # 30 frames each path

        # Load dataset
        self.data = []
        self.mean = 0.0
        self.std = 0.0
        self.load_dataset()

    def load_dataset(self):
        sum_pixels = np.float64(0)
        sum_pixels_squared = np.float64(0)
        pixel_count = np.float64(0)
        # Extract the main zip file if needed
        with zipfile.ZipFile(self.data_path, "r") as main_zip:
            with tempfile.TemporaryDirectory() as temp_dir:
                main_zip.extractall(temp_dir)

                # Iterate over each building sequence label zip file inside the main directory
                for bldg_zip_name in main_zip.namelist():
                    bldg_route_label = format_name(bldg_zip_name)
                    route = bldg_route_label[
                        -1
                    ]  # The last character is the route label a,b,c,d
                    bldg = bldg_route_label[:-2]
                    if bldg not in self.building_names:
                        self.building_names.append(bldg)

                    # Extract the building sequence label zip file
                    bldg_zip_path = os.path.join(temp_dir, bldg_zip_name)
                    with zipfile.ZipFile(bldg_zip_path, "r") as bldg_zip:
                        bldg_temp_dir = os.path.join(temp_dir, bldg_route_label)
                        bldg_zip.extractall(bldg_temp_dir)

                        # Iterate over path folders
                        for path_folder in tqdm(
                            sorted(os.listdir(bldg_temp_dir)),
                            desc=f"Loading {bldg}",
                            unit="path",
                        ):
                            if path_folder.startswith("path"):
                                path_images = []
                                path_folder_full = os.path.join(
                                    bldg_temp_dir, path_folder
                                )

                                for frame in range(
                                    self.num_frame
                                ):  # Assuming 30 frames per path
                                    img_filename = f"panoramic_{frame:02d}.png"
                                    img_path = os.path.join(
                                        path_folder_full, img_filename
                                    )
                                    if os.path.exists(img_path):
                                        img_array = read_image(img_path).astype(
                                            np.float64
                                        )
                                        path_images.append(img_array)
                                        # Update the sums for mean and std calculation
                                        sum_pixels += img_array.sum()
                                        sum_pixels_squared += (img_array**2).sum()
                                        pixel_count += img_array.size

                                # Only consider complete sequences with 30 frames
                                if len(path_images) != self.num_frame:
                                    print(
                                        f"Error: {bldg}, Route {route}, Path {path_folder} does not have 30 images."
                                    )
                                    return
                                else:
                                    self.data.append(
                                        {
                                            "images": np.concatenate(
                                                path_images, axis=0
                                            ),
                                            "path": int(
                                                re.search(r"\d+", path_folder).group()
                                            ),
                                            "route": route,
                                            "bldg": bldg,
                                        }
                                    )
        self.mean = sum_pixels / pixel_count
        variance = (sum_pixels_squared / pixel_count) - (self.mean**2)
        if variance < 0:
            if np.isclose(variance, 0):
                self.std = 0
            else:
                raise ValueError(f"Calculated negative variance: {variance}")
        else:
            self.std = np.sqrt(variance)

    def __getitem__(self, index):
        item = self.data[index]
        imgs = item["images"]
        t_imgs = imgs.reshape(-1, 1, 30, 60)  # num_frame, C, H, W
        # Apply transform if provided
        if self.transform:
            t_imgs = self.transform(t_imgs)
        t_imgs = np.stack(t_imgs, axis=0)
        t_imgs = torch.from_numpy(t_imgs).float()
        # normalize
        t_imgs = (t_imgs - self.mean) / self.std
        (C, H, W) = t_imgs[0].size()
        t_imgs = t_imgs.view(self.num_seq, self.seq_len, C, H, W).transpose(
            1, 2
        )  # num_seq,C,seq_len,H,W

        # Return data as a dictionary
        return {
            "t_imgs": t_imgs,
            "imgs": imgs,
            "path": item["path"],
            "route": item["route"],
            "bldg": item["bldg"],
        }

    def __len__(self):
        return len(self.data)

    def find_indices(self, bldg_name, route, path_number):
        """
        Finds the indices of the data items that match the given building, route, and path number.

        :param bldg_name: The name of the building (formatted as 'caracalla baths', for example).
        :param route: The route label (a single character like 'a', 'b', etc.).
        :param path_number: The path number (an integer).
        :return: A list of indices that match the criteria.
        """
        indices = []
        for idx, item in enumerate(self.data):
            if (
                item["bldg"].lower() == bldg_name.lower()
                and item["route"].lower() == route.lower()
                and item["path"] == path_number
            ):
                indices.append(idx)
        return indices

    def plot_one_sequence(self, idx):
        """take idx of the data , plot the sequence"""
        path = self[idx]["path"]
        imgs = self[idx]["imgs"]
        bldg = self[idx]["bldg"]
        route = self[idx]["route"]
        for img in imgs:
            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)
            ax.set_title(f"{bldg} route {route} path {path}")
            plt.imshow(img.reshape(30, 60), cmap="gray", vmin=0, vmax=255)
            plt.show()

In [106]:
# Example of how to use the dataset
dataset = BldgDataset()
print(dataset.building_names)
print(dataset.mean)
print(dataset.std)

Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 262.28path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 252.62path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 267.88path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 258.91path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 249.92path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 269.39path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 252.80path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 244.84path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 262.34path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 243.83path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 269.62path/s]
Loading trajans market: 100%|██████████| 100/100 [00:00<00:00, 251.69path/s]
Loading trajans market: 100%|█

['caracalla baths', 'india institute of management', 'pantheon', 'trajans market', 'trenton bath house']
99.77257469135803
92.27697774383148


In [107]:
print(dataset[0]["imgs"].shape)
print(dataset[0]["t_imgs"].shape)
print(dataset[0]["bldg"])

(30, 30, 60)
torch.Size([6, 1, 5, 30, 60])
caracalla baths


### Examine the dataset

In [None]:
dataset.plot_one_sequence(0)

# Test Image Augmentation

In [65]:
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F
import numbers
from PIL import Image

In [66]:
class ToTensor:
    def __call__(self, imgmap):
        return [torch.from_numpy(img.copy()).float() for img in imgmap]

In [108]:
class BrightnessJitter(object):  # 0.5 to 5 is a good range
    def __init__(self, brightness=0, consistent=True, p=0.5):
        self.brightness = self._check_input(brightness, "brightness")
        self.consistent = consistent
        self.threshold = p

    def _check_input(
        self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
    ):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError(
                    "If {} is a single number, it must be non negative.".format(name)
                )
            value = [center - value, center + value]
            if clip_first_on_zero:
                value[0] = max(value[0], 0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError(
                "{} should be a single number or a list/tuple with lenght 2.".format(
                    name
                )
            )

        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    def get_params(brightness):
        transforms = []

        if brightness is not None:
            brightness_factor = random.uniform(brightness[0], brightness[1])
            transforms.append(lambda img: img * brightness_factor)

        transform = torchvision.transforms.Compose(transforms)

        return transform

    def __call__(self, imgmap):
        if random.random() < self.threshold:  # do BrightnessJitter
            if self.consistent:
                transform = self.get_params(self.brightness)
                return [transform(i) for i in imgmap]
            else:
                result = []
                for img in imgmap:
                    transform = self.get_params(self.brightness)
                    result.append(transform(img))
                return result
        else:  # don't do BrightnessJitter, do nothing
            return imgmap

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        format_string += "brightness={0}".format(self.brightness)
        format_string += ")"
        return format_string

In [109]:
# need to test
class RandomHorizontalShift:
    def __init__(self, max_shift=30, p=0.5):
        """
        Args:
            max_shift (int): the maximum number of pixels for the horizontal shift.
            p (float): probability of applying the shift. Default is 0.5.
        """
        self.max_shift = max_shift
        self.p = p

    def __call__(self, imgmap):
        return [self.horizontal_shift(img) for img in imgmap]

    def horizontal_shift(self, img):
        """
        Shift the image horizontally by a random number of pixels and wrap around.
        Args:
            img (ndarray): the input image as a numpy array.
        Returns:
            img (ndarray): the transformed image as a numpy array.
        """
        # Check if we should apply the shift based on the probability p
        if random.random() < self.p:
            shift = random.randint(0, self.max_shift)
            shifted_np_img = np.roll(img, shift, axis=2)  # roll along width dimension
            return shifted_np_img
        return img  # return original image if not shifted

In [117]:
class RandomHorizontalFlip:  # choose consistent to be false
    def __init__(self, consistent=True, p=0.5):
        self.consistent = consistent
        self.threshold = p

    def __call__(self, imgmap):
        if self.consistent:
            if random.random() > self.threshold:
                return [np.flip(i, axis=[0, 2]) for i in imgmap]
            else:
                return imgmap
        else:
            result = []
            for i in imgmap:
                if random.random() > self.threshold:
                    result.append(np.flip(i, axis=[0, 2]))
                else:
                    result.append(i)
            assert len(result) == len(imgmap)
            return result

## Test Of Image Augmentation

In [None]:
test_dataset = dataset[0]["imgs"][:10]
brightness_jitter_transform = BrightnessJitter(
    brightness=[0.5, 5], consistent=False, p=1
)
brightness_jittered_dataset = brightness_jitter_transform(test_dataset)
print(len(brightness_jittered_dataset[0]))

for i in range(len(test_dataset)):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    img = brightness_jittered_dataset[i]
    ax.set_title(f"img {i}")
    ax.imshow(img.reshape(30, 60), cmap="gray", vmin=0, vmax=255)
    plt.show()

In [None]:
test_dataset = test_dataset.reshape(-1, 1, 30, 60)
horizontal_shift_transform = RandomHorizontalShift(max_shift=60)
shifted_dataset = horizontal_shift_transform(test_dataset)

for i in range(len(test_dataset)):
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    img = shifted_dataset[i]
    ax.set_title(f"Img {i}")
    ax.imshow(img.reshape(30, 60), cmap="gray", vmin=0, vmax=255)
    plt.show()

In [None]:
horizontal_flip_transform = RandomHorizontalFlip(consistent=os.truncate, p=0.5)
flipped_dataset = horizontal_flip_transform(test_dataset)

for i in range(len(test_dataset)):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    origin = test_dataset[i]
    flipped = flipped_dataset[i]

    axes[0].set_title("Original")
    axes[0].imshow(origin.reshape(30, 60), cmap="gray", vmin=0, vmax=255)

    axes[1].set_title("Flipped")
    axes[1].imshow(flipped.reshape(30, 60), cmap="gray", vmin=0, vmax=255)

    plt.show()

In [124]:
train_transform = transforms.Compose(
    [
        RandomHorizontalFlip(consistent=True, p=0.5),
        BrightnessJitter(brightness=[0.5, 5], consistent=True, p=0.5),
        RandomHorizontalShift(max_shift=60, p=0.5),
    ]
)
test_transform = BldgDataset(transform=train_transform)

Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 254.94path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 254.73path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 255.97path/s]
Loading caracalla baths: 100%|██████████| 100/100 [00:00<00:00, 256.23path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 255.18path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 261.80path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 251.77path/s]
Loading india institute of management: 100%|██████████| 100/100 [00:00<00:00, 256.07path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 254.27path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 239.45path/s]
Loading pantheon: 100%|██████████| 100/100 [00:00<00:00, 260.91path/s]
Loading trajans market: 100%|██████████| 100/100 [00:00<00:00, 250.24path/s]
Loading trajans market: 100%|█

In [128]:
test_transform[5]["t_imgs"].mean()

tensor(-0.1890)

# Model