In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import json 
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw

from torchvision import transforms
from tqdm import tqdm
import cv2
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
import sys 
sys.path.append('../..')
import src.utils as utils
import src.clip as clip 
import yaml
import math 
from tqdm import tqdm  
from src.clip_led.dataset import LEDDataset

import src.fusion as fusion
from src.blocks import Up, ConvBlock, IdentityBlock
%matplotlib inline 

  warn(f"Failed to load image Python extension: {e}")


In [2]:
config = {
    # Data Paths
    'train_path' : '../../data/way_splits/train_data.json',
    'valid_seen_path' : '../../data/way_splits/valSeen_data.json',
    'valid_unseen_path': '../../data/way_splits/valUnseen_data.json',
    'mesh2meters': '../../data/floorplans/pix2meshDistance.json',
    'image_dir': '../../data/floorplans/',
    'geodistance_file': '../../data/geodistance_nodes.json',

    'device': 'cpu',

    # Hyper Parameters
    'max_floors': 5,

    # Image Parameters
    'image_size': [3, 448, 448],
    # 'image_size': [3, 700, 1200],
    'original_image_size': [3, 700, 1200],
    'cropped_image_size': [3, 700, 800],
    'scaled_image_size': [3, 448, 448],


    'crop_translate_x': 200,
    'crop_translate_y': 0,
    'resize_scale_x': 448/800,
    'resize_scale_y': 448/700,
    'conversion_scale': 448/800,


    'lang_fusion_type': 'mult',
    'num_post_clip_channels': 2048, 
    'bilinear': True,
    'batch_norm': True, 
    'num_output_channels': 1
}

In [23]:
train_dataset[0]['dialogs'].size()

11.478775426276762


torch.Size([1, 77])

In [3]:
train_dataset = LEDDataset(config['valid_seen_path'], config['image_dir'], config)

In [4]:
train_dataset[0].keys()

11.478775426276762


dict_keys(['maps', 'target_maps', 'conversions', 'dialogs', 'scan_names', 'episode_ids', 'true_viewpoints'])

In [4]:
clip_rn50, preprocess = clip.load("RN50")

In [4]:
class LEDModel(nn.Module):
    """ CLIP RN50 with U-Net skip connections """
    def __init__(self, config):
        super(LEDModel, self).__init__()
        self.config = config 
        self.up_factor = 2 if self.config['bilinear'] else 1
        self.clip_rn50, self.preprocess = clip.load("RN50")

        # Freezing the CLIP model
        for param in self.clip_rn50.parameters():
            param.requires_grad = False

        self._build_decoder()


    def _build_decoder(self):
        # language
        self.lang_fuser1 = fusion.names[self.config['lang_fusion_type']](input_dim=self.config['num_post_clip_channels'] // 2)
        self.lang_fuser2 = fusion.names[self.config['lang_fusion_type']](input_dim=self.config['num_post_clip_channels'] // 4)
        self.lang_fuser3 = fusion.names[self.config['lang_fusion_type']](input_dim=self.config['num_post_clip_channels'] // 8)

        # CLIP encoder output -> 1024
        self.proj_input_dim = 512 if 'word' in self.config['lang_fusion_type'] else 1024
        self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024)
        self.lang_proj2 = nn.Linear(self.proj_input_dim, 512)
        self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)

        # vision
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.config['num_post_clip_channels'], 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True)
        )
        self.up1 = Up(2048, 1024 // self.up_factor, self.config['bilinear'])

        self.up2 = Up(1024, 512 // self.up_factor, self.config['bilinear'])

        self.up3 = Up(512, 256 // self.up_factor, self.config['bilinear'])

        self.layer1 = nn.Sequential(
            ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer2 = nn.Sequential(
            ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            IdentityBlock(32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer3 = nn.Sequential(
            ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            IdentityBlock(16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.config['batch_norm']),
            nn.UpsamplingBilinear2d(scale_factor=1),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, self.config['num_output_channels'], kernel_size=1)
        )

    def encode_image(self, img):
        with torch.no_grad():
            # The default CLIP function has been updated to be able to get intermediate prepools 
            img_encoding, img_im = self.clip_rn50.visual.prepool_im(img)
        return img_encoding, img_im

    def encode_text(self, x):
        x = x.type(torch.LongTensor)
        with torch.no_grad():
            text_feat = self.clip_rn50.encode_text(x)
            text_feat = torch.repeat_interleave(text_feat, self.config['max_floors'], 0)

        text_mask = torch.where(x==0, x, 1)  # [1, max_token_len]
        return text_feat, text_mask



    def forward(self, x, l):
        B, num_maps, C, H, W = x.size()
        x = x.view(B*num_maps, C, H, W)
        in_type = x.dtype
        in_shape = x.shape
        x = x[:,:3]  # select RGB
        x, im = self.encode_image(x)
        x = x.to(in_type)

        # encode text
        l_enc, l_mask = self.encode_text(l)
        l_input = l_enc
        l_input = l_input.to(dtype=x.dtype)

        # # encode image
        assert x.shape[1] == self.config['num_post_clip_channels']
        # print('after CLIP encoding: ', x.size())
        x = self.conv1(x)

        # print('after convolution after CLIP encoding: ', x.size())


        x = self.lang_fuser1(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj1)
        # print('after lang_fuser 1: ', x.size())
        x = self.up1(x, im[-2])
        # print('after up after lang_fuser 1: ', x.size())

        x = self.lang_fuser2(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj2)
        # print('after lang_fuser 2: ', x.size())
        x = self.up2(x, im[-3])
        # print('after up after lang_fuser 2: ', x.size())

        x = self.lang_fuser3(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj3)
        # print('after lang_fuser 3: ', x.size())
        x = self.up3(x, im[-4])
        # print('after up after lang_fuser 3: ', x.size())

        for enum, layer in enumerate([self.layer1, self.layer2, self.layer3, self.conv2]):
            x = layer(x)
            # print(f'after layer {enum} after all lang_fusions', x.size())
        
        h, w = x.size()[-2], x.size()[-1]
        x = x.squeeze(1)
        x = x.view(B, num_maps, x.size()[-2], x.size()[-1])
        x = F.log_softmax(x.view(B, -1), 1).view(B, num_maps, h, w)
        return x


In [5]:
led_clip = LEDModel(config)

In [9]:
train_dataset[0]['maps'].size()

11.478775426276762


torch.Size([5, 3, 448, 448])

In [11]:
preds = led_clip(train_dataset[0]['maps'].unsqueeze(0), train_dataset[0]['dialogs'])

11.478775426276762
11.478775426276762


In [6]:
maps, dialogs = train_dataset[0]['maps'], train_dataset[0]['dialogs']
maps = maps.unsqueeze(0)
maps.size()

11.478775426276762
11.478775426276762


torch.Size([1, 5, 3, 448, 448])

In [37]:
out.size()

torch.Size([1, 5, 448, 448])

In [None]:
# Training Parameters 

loss_fn = nn.KLDivLoss(reduction="batchmean")
optimizer = torch.optim.AdamW(led_clip.parameters(), lr=config['lr'], betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
scheduler = torch.optim.ReduceLROnPlateau(optimizer, 'min')
scaler = torch.cuda.amp.GradScaler()


In [14]:

def snap_to_grid(geodistance_nodes, node2pix, sn, pred_coord, conversion, level):
    min_dist = math.inf
    best_node = ""
    for node in node2pix[sn].keys():
        if node2pix[sn][node][2] != int(level) or node not in geodistance_nodes:
            continue
        target_coord = [node2pix[sn][node][0][1], node2pix[sn][node][0][0]]
        dist = np.sqrt(
            (target_coord[0] - pred_coord[0]) ** 2
            + (target_coord[1] - pred_coord[1]) ** 2
        ) / (conversion)
        if dist.item() < min_dist:
            best_node = node
            min_dist = dist.item()
    return best_node


def distance_from_pixels(config, preds, mesh_conversions, scan_names, true_viewpoints, episode_ids, mode):
    """Calculate distances between model predictions and targets within a batch.
    Takes the propablity map over the pixels and returns the geodesic distance"""
    node2pix = json.load(open(config['image_dir'] + "allScans_Node2pix.json"))
    geodistance_nodes = json.load(open(config['geodistance_file']))
    distances, episode_predictions = [], []
    print(scan_names)
    for pred, conversion, sn, tv, id in zip(
        preds, mesh_conversions, scan_names, true_viewpoints, episode_ids
    ):

        total_floors = len(set([v[2] for k, v in node2pix[sn].items()]))
        pred = nn.functional.interpolate(
            pred.unsqueeze(1), (700, 1200), mode="bilinear"
        ).squeeze(1)[:total_floors]
        pred_coord = np.unravel_index(pred.argmax(), pred.size())
        convers = conversion.view(config['max_floors'], 1, 1)[pred_coord[0].item()]
        pred_viewpoint = snap_to_grid(
            geodistance_nodes[sn],
            node2pix,
            sn,
            [pred_coord[1].item(), pred_coord[2].item()],
            convers,
            pred_coord[0].item(),
        )
        if mode != "test":
            dist = geodistance_nodes[sn][tv][pred_viewpoint]
            distances.append(dist)
        episode_predictions.append([id, pred_viewpoint])
    return distances, episode_predictions

In [7]:
train_loader = DataLoader(train_dataset, batch_size=2)

In [8]:
for data in train_loader:
    maps = data['maps']
    target_maps = data['target_maps']
    conversions = data['conversions']
    dialogs = data['dialogs']
    dialogs = dialogs.squeeze(1)
    print(dialogs.size())

    preds = led_clip(maps, dialogs)
    break 
    
    

11.478775426276762
16.93066725972494
torch.Size([2, 77])


In [15]:
a, b = distance_from_pixels(config, preds, data['conversions'], data['scan_names'], data['true_viewpoints'], data['episode_ids'], train_dataset.mode )

['VLzqgDo317F', 'VVfe2KiqLaN']


In [18]:
def accuracy(dists, threshold=3):
    """Calculating accuracy at 3 meters by default"""
    return np.mean((torch.tensor(dists) <= threshold).int().numpy())

In [None]:
# Training Loop 


def train(train_loader, valid_seen_loader, valid_unseen_loader, epochs, model, loss_fn, optimizer, scaler, scheduler):

    # Training 
    for enum, data in enumerate(tqdm(train_loader)):
        
        optimizer.zero_grad()

        maps = data['maps']
        target_maps = data['target_maps']
        conversions = data['conversions']
        dialogs = data['dialogs']

        with torch.autocast():
            preds = model(maps, dialogs)
            loss = loss_fn(preds, target_maps)

        scaler.scale(loss).backward()

        scaler.step(optimizer)

        scaler.update()

        le, ep = distance_from_pixels(
            args, preds.detach().cpu(), batch_conversions, info_elem, mode
        )
        return loss, accuracy(le, 0), accuracy(le, 5), ep

        


        
    
