In [1]:
import torch
import sys 
sys.path.append('../')
import clip
from PIL import Image
import json 
from torchvision import models, transforms
import clip
import src.clip_lingunet.fusion as fusion
import torch.nn as nn 
import torch.nn.functional as F


In [7]:
a = ["some words", "something else", "this is insane"]
tok = clip.tokenize(a)
tok.size()

torch.Size([3, 77])

In [8]:
tok_all_maps = torch.repeat_interleave(tok, 5, dim=0)

tok_all_maps.size()

torch.Size([15, 77])

In [9]:
tok_all_maps

tensor([[49406,   836,  2709,  ...,     0,     0,     0],
        [49406,   836,  2709,  ...,     0,     0,     0],
        [49406,   836,  2709,  ...,     0,     0,     0],
        ...,
        [49406,   589,   533,  ...,     0,     0,     0],
        [49406,   589,   533,  ...,     0,     0,     0],
        [49406,   589,   533,  ...,     0,     0,     0]])

In [2]:
class IdentityBlock(nn.Module):
    def __init__(self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True):
        super(IdentityBlock, self).__init__()
        self.final_relu = final_relu
        self.batchnorm = batchnorm

        filters1, filters2, filters3 = filters
        self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity()
        self.conv2 = nn.Conv2d(filters1, filters2, kernel_size=kernel_size, dilation=1,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity()
        self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += x
        if self.final_relu:
            out = F.relu(out)
        return out

class ConvBlock(nn.Module):
    def __init__(self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True):
        super(ConvBlock, self).__init__()
        self.final_relu = final_relu
        self.batchnorm = batchnorm

        filters1, filters2, filters3 = filters
        self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity()
        self.conv2 = nn.Conv2d(filters1, filters2, kernel_size=kernel_size, dilation=1,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity()
        self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity()

        self.shortcut = nn.Sequential(
            nn.Conv2d(in_planes, filters3,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity()
        )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        if self.final_relu:
            out = F.relu(out)
        return out

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),                                     # (Mohit): argh... forgot to remove this batchnorm
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),                                     # (Mohit): argh... forgot to remove this batchnorm
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [3]:
paths = {
    'floorplans': '../data/floorplans',
    'dialogs': '../data/way_splits'
}

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("RN50", device=device)

In [5]:
# img = preprocess(Image.open(paths['floorplans'] + '/floor_0/Uxmj2M2itWa_0.png'))[:3].unsqueeze(0).to(device)

dialogArray = [
            "What kind of room are you in?", 
            "I am standing just outside a bathroom near a purple chair.", 
            "Is there a red bed by you?", 
            "yes, the purple chair is between myself and the red bed."
        ]
text = ' '.join(dialogArray)
# tokens = clip.tokenize(text).to(device)
# 

In [33]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f0f18f3f280>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [28]:
def gather_all_floors(preprocess):
    image_size = [
            3,
            224,
            224,
        ]

    # preprocess = transforms.Compose(
    #         [
    #             transforms.ToTensor(),
    #             transforms.Normalize(
    #                 mean=[0.485, 0.456, 0.406, 0.555],
    #                 std=[0.229, 0.224, 0.225, 0.222],
    #             ),
    #         ]
    #     )
    all_maps = torch.zeros(
            5,
            image_size[0],
            image_size[1],
            image_size[2],
        )
    all_conversions = torch.zeros(5, 1)
    sn = 'Uxmj2M2itWa'
    mesh2meters = json.load(open(paths['floorplans'] + "/pix2meshDistance.json"))
    floors = mesh2meters[sn].keys()
    for enum, f in enumerate(floors):
        img = Image.open(
            "{}/floor_{}/{}_{}.png".format(paths['floorplans'], f, sn, f)
        )
        img = img.resize((image_size[2], image_size[1]))

        all_maps[enum, :, :, :] = preprocess(img)[:3, :, :]
        all_conversions[enum, :] = mesh2meters[sn][f]["threeMeterRadius"] / 3.0
    return all_maps, all_conversions

In [29]:
maps, convs = gather_all_floors(preprocess)

In [42]:
maps = maps.unsqueeze(0); maps.size()

torch.Size([1, 5, 3, 224, 224])

In [43]:
maps = torch.repeat_interleave(maps, 6, dim=0); maps.size()


torch.Size([6, 5, 3, 224, 224])

In [44]:
maps.to(device);

In [9]:
maps = maps.unsqueeze(0).to(device)

In [45]:
class CLIPLingUNet(nn.Module):
    """ CLIP RN50 with U-Net skip connections """
    def __init__(self, args):
        super(CLIPLingUNet, self).__init__()
        self.args = args 
        # self.output_dim = self.args.output_dim
        self.output_dim = 1
        self.input_dim = 2048  # penultimate layer channel-size of CLIP-RN50
        self.device = self.args.device 
        self.batchnorm = True
        self.lang_fusion_type = 'mult'
        self.bilinear = True
        self.up_factor = 2 if self.bilinear else 1
        self.clip_rn50, self.preprocess = clip.load("RN50", device=self.args.device)

        self._build_decoder()


    def _build_decoder(self):
        # language
        self.lang_fuser1 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 2)
        self.lang_fuser2 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 4)
        self.lang_fuser3 = fusion.names[self.lang_fusion_type](input_dim=self.input_dim // 8)

        # CLIP encoder output -> 1024
        self.proj_input_dim = 512 if 'word' in self.lang_fusion_type else 1024
        self.lang_proj1 = nn.Linear(self.proj_input_dim, 1024)
        self.lang_proj2 = nn.Linear(self.proj_input_dim, 512)
        self.lang_proj3 = nn.Linear(self.proj_input_dim, 256)

        # vision
        self.conv1 = nn.Sequential(
            nn.Conv2d(self.input_dim, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            nn.ReLU(True)
        )
        self.up1 = Up(2048, 1024 // self.up_factor, self.bilinear)

        self.up2 = Up(1024, 512 // self.up_factor, self.bilinear)

        self.up3 = Up(512, 256 // self.up_factor, self.bilinear)

        self.layer1 = nn.Sequential(
            ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer2 = nn.Sequential(
            ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.layer3 = nn.Sequential(
            ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            IdentityBlock(16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=self.batchnorm),
            nn.UpsamplingBilinear2d(scale_factor=2),
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(16, self.output_dim, kernel_size=1)
        )

    def encode_image(self, img):
        with torch.no_grad():
            img_encoding, img_im = self.clip_rn50.visual.prepool_im(img)
        return img_encoding, img_im

    def encode_text(self, x):
        with torch.no_grad():
            tokens = clip.tokenize([x]).to(self.device)
            
            text_feat = self.clip_rn50.encode_text(tokens)

        text_mask = torch.where(tokens==0, tokens, 1)  # [1, max_token_len]
        return text_feat, text_mask



    def forward(self, x, l):
        # x = self.preprocess(x)

        in_type = x.dtype
        in_shape = x.shape
        x = x[:,:3]  # select RGB
        x, im = self.encode_image(x)
        x = x.to(in_type)

        # encode text
        l_enc, l_mask = self.encode_text(l)
        l_input = l_enc
        l_input = l_input.to(dtype=x.dtype)

        # # encode image
        assert x.shape[1] == self.input_dim
        x = self.conv1(x)

        x = self.lang_fuser1(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj1)
        x = self.up1(x, im[-2])

        x = self.lang_fuser2(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj2)
        x = self.up2(x, im[-3])

        x = self.lang_fuser3(x, l_input, x2_mask=l_mask, x2_proj=self.lang_proj3)
        x = self.up3(x, im[-4])

        for layer in [self.layer1, self.layer2, self.layer3, self.conv2]:
            x = layer(x)

        # x = F.interpolate(x, size=(780, 455), mode='bilinear')
        return x

In [12]:
class args:
    device = 'cuda:0'

In [25]:
cli = CLIPLingUNet(args)
cli.to(device);

In [46]:
maps = maps.view(6*5, 3, 224, 224); maps.size()

torch.Size([30, 3, 224, 224])

In [48]:
maps = maps.to(device)

In [49]:
out = cli(maps, text)

In [50]:
out.size()

torch.Size([30, 1, 448, 448])

In [51]:
out = out.squeeze(1); out.size()

torch.Size([30, 448, 448])

In [52]:
out = out.view(6, 5, out.size()[-2], out.size()[-1])
out.size()

torch.Size([6, 5, 448, 448])

In [53]:
a = F.log_softmax(out.view(6, -1), 1).view(6, 5, 448, 448)

In [54]:
a.size()

torch.Size([6, 5, 448, 448])

In [21]:
resnet = models.resnet18(pretrained=True)
resnet.to(device)
modules = list(resnet.children())[:-4]
resnetPrePool = torch.nn.Sequential(*modules)
pass

In [22]:
x = resnetPrePool(img)

In [23]:
x.size()

torch.Size([1, 128, 28, 28])

In [12]:
img = img[:,:3]
img_encoding, img_im = model.visual.prepool_im(img)

upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


In [29]:
img_im[-3].size()

torch.Size([1, 512, 28, 28])

In [13]:
img_encoding.size()

torch.Size([1, 2048, 7, 7])

In [47]:
image_features = model.encode_image(img)
text_features = model.encode_text(tokens)

In [48]:
image_features.size(), text_features.size()

(torch.Size([1, 1024]), torch.Size([3, 1024]))