In [12]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import json 
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw

from torchvision import transforms
from tqdm import tqdm
import cv2
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage.filters import gaussian_filter
import sys 
sys.path.append('../..')
import src.utils as utils
import src.clip as clip 
import yaml 
from src.clip_led.dataset import LEDDataset

import src.fusion as fusion
from src.blocks import Up, ConvBlock, IdentityBlock
%matplotlib inline 

In [2]:
config = {
    # Data Paths
    'train_path' : '../../data/way_splits/train_data.json',
    'valid_seen_path' : '../../data/way_splits/valSeen_data.json',
    'valid_unseen_path': '../../data/way_splits/valUnseen_data.json',
    'mesh2meters': '../../data/floorplans/pix2meshDistance.json',
    'image_dir': '../../data/floorplans/',

    'device': 'cpu',

    # Hyper Parameters
    'max_floors': 5,

    # Image Parameters
    'image_size': [3, 448, 448],
    # 'image_size': [3, 700, 1200],
    'original_image_size': [3, 700, 1200],
    'cropped_image_size': [3, 700, 800],
    'scaled_image_size': [3, 448, 448],


    'crop_translate_x': 200,
    'crop_translate_y': 0,
    'resize_scale_x': 448/800,
    'resize_scale_y': 448/700,
    'conversion_scale': 448/800

}

In [3]:
train_dataset = LEDDataset(config['valid_seen_path'], config['image_dir'], config)

In [15]:
model, _ = clip.load_clip("RN50")
clip_rn50 = clip.build_model(model.state_dict()).to('cpu')

RuntimeError: Method 'forward' is not defined.

In [None]:
class LEDModel(nn.Module):
    """ CLIP RN50 with U-Net skip connections """
    def __init__(self, args):
        super(LEDModel, self).__init__()
        self.args = args 
        # self.output_dim = self.args.output_dim
        self.num_maps = self.args.num_maps
        self.output_dim = self.args.output_dim
        self.input_dim = self.args.input_dim  # penultimate layer channel-size of CLIP-RN50
        self.device = self.args.device 
        self.batchnorm = self.args.batchnorm
        self.lang_fusion_type = self.args.lang_fusion_type
        self.bilinear = self.args.bilinear
        self.batch_size = self.args.batch_size
        self.up_factor = 2 if self.bilinear else 1
        self.clip_rn50, self.preprocess = clip.load("RN50", device=self.args.device)
        model, _ = clip.load_clip("RN50", device=self.device)
        self.clip_rn50 = clip.build_model(model.state_dict()).to(self.device)

        self._build_decoder()


    def _build_decoder(self):
        # language
        self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2)
        self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4)
        self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8)

        # CLIP encoder output -> 1024
        self.proj_input_dim = 512 if 'word' in self.lang_fusion_type else 1024
        self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024)
        self.lang_proj2 = nn.Linear(self.proj_input_dim, 512)
        self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)

        # vision
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_dim, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True)
        )
        self.up1 = Up(2048, 1024 // self.up_factor, self.bilinear)

        self.up2 = Up(1024, 512 // self.up_factor, self.bilinear)

        self.up3 = Up(512, 256 // self.up_factor, self.bilinear)

        self.layer1 = nn.Sequential(
            ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer2 = nn.Sequential(
            ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer3 = nn.Sequential(
            ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, self.output_dim, kernel_size=1)
        )

    def encode_image(self, img):
        with torch.no_grad():
            # The default CLIP function has been updated to be able to get intermediate prepools 
            img_encoding, img_im = self.clip_rn50.visual.prepool_im(img)
        return img_encoding, img_im

    def encode_text(self, x):
        with torch.no_grad():
            tokens = clip.tokenize(x, truncate=True).to(self.device)

            text_feat = self.clip_rn50.encode_text(tokens)
            text_feat = torch.repeat_interleave(text_feat, self.num_maps, 0)

        text_mask = torch.where(tokens==0, tokens, 1)  # [1, max_token_len]
        return text_feat, text_mask



    def forward(self, x, l):
        B, num_maps, C, H, W = x.size()
        x = x.view(B*num_maps, C, H, W)
        in_type = x.dtype
        in_shape = x.shape
        x = x[:,:3]  # select RGB
        x, im = self.encode_image(x)
        x = x.to(in_type)

        # encode text
        l_enc, l_mask = self.encode_text(l)
        l_input = l_enc
        l_input = l_input.to(dtype=x.dtype)

        # # encode image
        assert x.shape[1] == self.input_dim
        x = self.conv1(x)



        x = self.lang_fuser1(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj1)
        x = self.up1(x, im[-2])

        x = self.lang_fuser2(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj2)
        x = self.up2(x, im[-3])

        x = self.lang_fuser3(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj3)
        x = self.up3(x, im[-4])

        for layer in [self.layer1, self.layer2, self.layer3, self.conv2]:
            x = layer(x)

        # x = F.interpolate(x, size=(780, 455), mode='bilinear')
        h, w = x.size()[-2], x.size()[-1]
        x = x.squeeze(1)
        x = x.view(B, num_maps, x.size()[-2], x.size()[-1])
        x = F.log_softmax(x.view(B, -1), 1).view(B, num_maps, h, w)
        return x
