In [16]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import json 
from torch.utils.data import Dataset, DataLoader
from PIL import Image 
import clip 
from torchvision import transforms

In [17]:
model, preprocess = clip.load('RN50')

In [15]:


config = {
    # Data Paths
    'train_path' : '../../data/way_splits/train_data.json',
    'valid_seen_path' : '../../data/way_splits/valSeen_data.json',
    'valid_unseen_path': '../../data/way_splits/valUnseen_data.json',
    'mesh2meters': '../../data/floorplans/pix2meshDistance.json',
    'image_dir': '../../data/floorplans/',

    # Hyper Parameters
    'max_floors': 5,

    # Image Parameters 
    'image_size': [3, 224, 224],
    'original_image_size': [3, 700, 1200]
}

In [20]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f3b42ae53a0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [4]:
class utils:
    class DotDict(dict):
        """
        a dictionary that supports dot notation 
        as well as dictionary access notation 
        usage: d = DotDict() or d = DotDict({'val1':'first'})
        set attributes: d.val2 = 'second' or d['val2'] = 'second'
        get attributes: d.val2 or d['val2']
        """
        __getattr__ = dict.__getitem__
        __setattr__ = dict.__setitem__
        __delattr__ = dict.__delitem__

In [28]:
config['valid_seen_path'].split('/')[-1][:-5].split('_')[0]

'valSeen'

In [30]:
img = Image.open('/home/saaket/embodiedAI/led_clip/data/floorplans/floor_0/1LXtFkjw3qL_0.png')
img.size

(1200, 700)

In [None]:
# Create Dataset 

class LEDDataset(Dataset):
    def __init__(self, data_path, image_dir, config):

        # Gather train_data from {train/val/test}_data.json
        self.data_path = data_path 
        self.data_file = open(self.data_path)
        self.data = json.load(self.data_file)

        # Extract the mode (train, valSeen, valUnseen) from the data_path 
        self.mode = self.data_path.split('/')[-1][:-5].split('_')[0]

        # Store access to floorplans directory 
        self.image_dir = image_dir 

        # Save the global config 
        self.config = config 

        # mesh2meters
        self.mesh2meters_path = self.config['mesh2meters']
        self.mesh2meters_file = open(self.mesh2meters_path)
        self.mesh2meters = json.loads(self.mesh2meters_file)

        # transform required for CLIP 
        def convert_image_to_rgb(image):
            return image.convert("RGB")

        self.preprocess = transforms.Compose([
            transforms.Resize(size=224, interpolation='bicubic', max_size=None, antialias=None),
            transforms.CenterCrop(size=(224, 244)),
            convert_image_to_rgb,
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])


    def gather_all_floors(self, index):
        all_maps = torch.zeros(
            self.config.max_floors,
            self.config.image_size[0],
            self.config.image_size[1],
            self.config.image_size[2],
        )
        all_conversions = torch.zeros(self.config.max_floors, 1)
        scan_name = self.data[index]['scan_name']
        floors = self.mesh2meters[scan_name].keys()
        for enum, floor in enumerate(floors):
            img = Image.open(f'{self.image_dir}floor_{floor}/{scan_name}_{floor}.png').convert('RGB')
            img.size()
            if "train" in self.mode:
                all_maps[enum, :, :, :] = self.preprocess(img)[:3, :, :]
            else:
                all_maps[enum, :, :, :] = self.preprocess(img)[:3, :, :]
            all_conversions[enum, :] = self.mesh2meters[scan_name][floor]["threeMeterRadius"] / 3.0
        return all_maps, all_conversions
    
    def scale_location(self, location):
        return {
            'x': 
        }

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        

             