In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import pickle
import json
from PIL import Image
import clip
from clip.model import AttentionPool2d
from clip.model import ModifiedResNet
from typing import Tuple, Union
from clip.model import CLIP
from clip.model import convert_weights

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dataset

In [3]:
refcocog_path =  "E:/DL_Datasets/refcocog"

In [4]:
pick = pickle.load(open(refcocog_path+"/annotations/refs(umd).p", "rb"))
jsn = json.load(open(refcocog_path+"/annotations/instances.json", "rb"))

In [5]:
# set of all images
images_set = {}
for i in jsn['images']:
  image_id = i['id']
  images_set[image_id] = i

# set of all annotations
annotations_set = {}
for a in jsn['annotations']:
  annotation_id = a['id']
  annotations_set[annotation_id] = a

# set of all categories
# categories_set = {}
# for c in jsn['categories']:
#   category_id = c['id']
#   categories_set[category_id] = c

**Build dataset splits**

In [46]:
train_data, train_label       = [], []
validate_data, validate_label = [], []
test_data, test_label         = [], []

for p in pick:
    data_image_path = f"{refcocog_path}/images/{images_set[p['image_id']]['file_name']}"
    data_sentences = p['sentences']
    data_bbox = annotations_set[p['ann_id']]['bbox']

    data = []

    for s in data_sentences:
        sentence = s['sent']
        data.append([data_image_path, sentence, data_bbox])

    if p['split'] == 'train':
        train_data.extend(data)
    elif p['split'] == 'test':
        test_data.extend(data)
    elif p['split'] == 'val':
        validate_data.extend(data)

print(f"train {len(train_data)}, validation {len(validate_data)}, test {len(test_data)}")

train 80512, validation 4896, test9602


**Display an image with a bounding box**

In [7]:
def view_image_with_bbox(image_path, prompt, bbox):
    image = Image.open(image_path).convert("RGB")
    image = np.asarray(image)

    p1 = (int(bbox[0]), int(bbox[1]))
    p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))

    cv2.rectangle(image, p1, p2, (0,255,255), 3)

    plt.imshow(image)
    plt.title(prompt)
    plt.show()

# Model

In [8]:
# load the CLIP model
clip_model, clip_preprocess = clip.load("RN50")

def linear(x, weight, bias):
    x = x.matmul(weight.t())
    x += bias
    return x

class AttentionSpatial2d(AttentionPool2d):
    """Edited attention pool layer to introduce spatial attention"""
    def __init__(self,
                 spacial_dim: int,
                 embed_dim: int,
                 num_heads: int,
                 output_dim: int = None):
        super().__init__(spacial_dim, embed_dim, num_heads, output_dim)

    def forward(self, x):
        n, c, h, w = x.shape
        x = x.reshape(n, c, h*w).permute(2, 0, 1) # NCHW -> (H*W)NC
        x = linear(x, self.v_proj.weight, self.v_proj.bias)
        x = linear(x, self.c_proj.weight, self.c_proj.bias)
        x = x.permute(1, 2, 0).reshape(n, -1, h, w) # (H*W)NC -> C(H*W)N -> (N, -1, H, W)
        return x

class ModifiedSpatialResNet(ModifiedResNet):
    """Modified resnet to include the edited attention pool layer"""
    def __init__(self,
                 layers,
                 output_dim,
                 heads,
                 input_resolution=224,
                 width=64):
        super().__init__(layers, output_dim, heads, input_resolution, width)

        self.attnpool = AttentionSpatial2d(
            input_resolution // 32, width * 32, heads, output_dim)

class CLIPSpatialResNet(CLIP):
    """Modified spatial CLIP including the spatial attention"""
    def __init__(self,
                 embed_dim: int,
                 # vision
                 image_resolution: int,
                 vision_layers: Union[Tuple[int, int, int, int], int],
                 vision_width: int,
                 vision_patch_size: int,
                 # text
                 context_length: int,
                 vocab_size: int,
                 transformer_width: int,
                 transformer_heads: int,
                 transformer_layers: int):

        super().__init__(embed_dim, image_resolution, vision_layers, vision_width,
                         vision_patch_size, context_length, vocab_size,
                         transformer_width, transformer_heads, transformer_layers)

        # Override the visual model
        vision_heads = vision_width * 32 // 64
        self.visual = ModifiedSpatialResNet(layers=vision_layers,
                                            output_dim=embed_dim,
                                            heads=vision_heads,
                                            input_resolution=image_resolution,
                                            width=vision_width)

    def forward(self, image):
        image = image.type(self.dtype)

        # pad image
        pad = 64
        pad = (pad, pad, pad, pad)
        padded_image = F.pad(image, pad, "constant", 0)

        # get features
        features = self.encode_image(padded_image)
        target_size_h, target_size_w = image.size(-2) // 32, image.size(-1) // 32

        # compute new pad size
        pad_h = (features.size(-2) - target_size_w) // 2
        pad_w = (features.size(-1) - target_size_w) // 2
        features = features[:, :, pad_h:pad_h+target_size_h, pad_w:pad_w+target_size_w]

        # interpolate back to 224*224
        features = F.upsample(features, size=(image.size(-2), image.size(-1)),
            mode="bilinear", align_corners=None) # 1*C*H*W

        return features
    

def build_feature_extractor_model(clip_model): 
    """"Instantiate the modified CLIP model and adapt weights"""
    # transfer learning: extract weights from CLIP
    clip_state_dict = clip_model.state_dict()
    # run [k for k in clip_state_dict if k.startswith("visual.layer2")] to see what's up
    counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
    vision_layers = tuple(counts)
    vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0]
    output_width = round(
        (clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)

    vision_patch_size = None
    image_resolution = output_width * 32

    embed_dim = clip_state_dict["text_projection"].shape[1]
    context_length = clip_state_dict["positional_embedding"].shape[0]
    vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
    transformer_width = clip_state_dict["ln_final.weight"].shape[0]
    transformer_heads = transformer_width // 64
    transformer_layers = len(set(
        k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))

    model = CLIPSpatialResNet(
        embed_dim,
        image_resolution, vision_layers, vision_width, vision_patch_size,
        context_length, vocab_size, transformer_width, transformer_heads, transformer_layers).to(device)

    for key in ["input_resolution", "context_length", "vocab_size"]:
        if key in clip_state_dict:
            del clip_state_dict[key]

    convert_weights(model)

    # False for the average filter layer.
    model.load_state_dict(clip_state_dict, strict=False)
    # model.eval()
    if device == 'cpu':
        model.float()
    return model

### BBox extraction

This code is borrowed from the paper. It includes three models that can be used to extract the bounding box from the heatmap.

In [9]:
class BruteForceBoxSearch():
    """Use pytorch to speed up. If matrix is too large, it may be out of memory.
    """

    def __init__(self, downsample=8):
        self.downsample = downsample

    def __call__(self, matrix, objective_cls):
        h, w = matrix.shape[:2]
        matrix = torch.from_numpy(matrix).to(device)
        # get new size
        self.h = h // self.downsample
        self.w = w // self.downsample
        # downsample matrix
        self.matrix = F.interpolate(
            matrix[None, None], (self.h, self.w), mode='bilinear')[0, 0]
        # get objective object
        self.objective = objective_cls(self.matrix)
        # get full intervals
        intervals = [[0, self.w-1], [0, self.h-1],
                     [0, self.w-1], [0, self.h-1]]
        # get coarse guess
        anchor_box = self.search(intervals)
        # rescale to original resolution
        anchor_box *= self.downsample
        x1, y1, x2, y2 = anchor_box
        # offset of adjustment
        offset_w = offset_h = self.downsample
        # back to original matrix
        self.matrix = matrix
        self.h, self.w = h, w
        # get new objective
        self.objective = objective_cls(self.matrix)
        # set intervals
        intervals = [
            [max(0, x1-offset_w), min(x1+offset_w, self.w-1)],
            [max(0, y1-offset_h), min(y1+offset_h, self.h-1)],
            [max(0, x2-offset_w), min(x2+offset_w, self.w-1)],
            [max(0, y2-offset_h), min(y2+offset_h, self.h-1)],
        ]
        # search box
        box = self.search(intervals)
        return box

    def search(self, intervals):
        # intervals is like [[x1_min, x1_max], [y1_min, y1_max], ...]
        x1 = torch.arange(intervals[0][0], intervals[0][1]+1).to(device)
        y1 = torch.arange(intervals[1][0], intervals[1][1]+1).to(device)
        x2 = torch.arange(intervals[2][0], intervals[2][1]+1).to(device)
        y2 = torch.arange(intervals[3][0], intervals[3][1]+1).to(device)
        boxes = torch.cartesian_prod(x1, y1, x2, y2)
        x1, y1, x2, y2 = boxes.transpose(0, 1)
        boxes = boxes[(x1 >= 0) & (y1 >= 0) & (x2 < self.w) &
                      (y2 < self.h) & (x2 > x1) & (y2 > y1)]
        box = boxes[self.objective.eval(boxes).argmax()]
        box = box.cpu().numpy()
        return box


class SumAreaObjective():
    """f(x) = (sum inside x) - alpha * (normalized area of x)
    """

    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, matrix):
        # precompute things
        self.matrix = matrix
        self.h, self.w = matrix.size(0), matrix.size(1)
        # for normalizing area
        self.area = float(self.h * self.w)
        # assume matrix is all positive
        self.sums = self.matrix.cumsum(1).cumsum(0)
        # pad the sums on top and left to deal with boundary cases
        self.sums = F.pad(self.sums, (1, 0, 1, 0))
        # get total sum for computing fraction
        self.total_sum = self.matrix.sum().item()
        return self

    def eval(self, boxes):
        frac = self._compute_frac(boxes)
        area = self._compute_area(boxes)
        return frac - self.alpha * area

    def _compute_frac(self, boxes):
        # boxes is Nx4, each is [x1, y1, x2, y2]
        x1, y1, x2, y2 = boxes.transpose(0, 1)
        # assume all boxes are valid
        bottom_right = self.sums[y2+1, x2+1]
        bottom_left = self.sums[y2+1, x1]
        top_right = self.sums[y1, x2+1]
        top_left = self.sums[y1, x1]
        box_sum = bottom_right - bottom_left - top_right + top_left
        return box_sum / (self.total_sum + 1e-8)

    def _compute_area(self, boxes):
        # boxes is Nx4, each is [x1, y1, x2, y2]
        x1, y1, x2, y2 = boxes.transpose(0, 1)
        # assume all boxes are valid
        return (x2 - x1 + 1).float() * (y2 - y1 + 1).float() / self.area


class FractionAreaObjective():
    """f(x) = (fraction of sum inside x) - alpha * (normalized area of x)
    """

    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, matrix):
        # precompute things
        self.matrix = matrix
        self.h, self.w = matrix.size(0), matrix.size(1)
        # for normalizing area
        self.area = float(self.h * self.w)
        # assume matrix is all positive
        self.sums = self.matrix.cumsum(1).cumsum(0)
        # pad the sums on top and left to deal with boundary cases
        self.sums = F.pad(self.sums, (1, 0, 1, 0))
        # get total sum for computing fraction
        self.total_sum = self.matrix.sum().item()
        return self

    def eval(self, boxes):
        frac = self._compute_frac(boxes)
        area = self._compute_area(boxes)
        return frac - self.alpha * area

    def _compute_frac(self, boxes):
        # boxes is Nx4, each is [x1, y1, x2, y2]
        x1, y1, x2, y2 = boxes.transpose(0, 1)
        # assume all boxes are valid
        bottom_right = self.sums[y2+1, x2+1]
        bottom_left = self.sums[y2+1, x1]
        top_right = self.sums[y1, x2+1]
        top_left = self.sums[y1, x1]
        box_sum = bottom_right - bottom_left - top_right + top_left
        return box_sum / (self.total_sum + 1e-8)

    def _compute_area(self, boxes):
        # boxes is Nx4, each is [x1, y1, x2, y2]
        x1, y1, x2, y2 = boxes.transpose(0, 1)
        # assume all boxes are valid
        return (x2 - x1 + 1).float() * (y2 - y1 + 1).float() / self.area

### Feature Extractor

In [32]:
class ResNetHighResV2(nn.Module):
    """Feature extractor that includes CLIP as its fundation model"""
    def __init__(self, clip_preprocess, tokenize, temperature=0.1, alpha=0.7, remap_heatmaps=True):
        super().__init__()
        self.spatial_model = build_feature_extractor_model(clip_model)
        self.clip_preprocess = clip_preprocess
        self.tokenize = tokenize
        self.temperature = temperature
        self.remap_heatmaps=remap_heatmaps
        self.alpha = alpha

    def get_image_features(self, images):
        images = [clip_preprocess(image) for image in images]
        images = torch.stack(images).to(device)
        image_features = self.spatial_model(images)
        return image_features

    def get_text_features(self, texts):
        tokenized_texts = self.tokenize(texts).to(device)
        text_features = self.spatial_model.encode_text(tokenized_texts)
        return text_features
    
    def box_from_heatmap(self, heatmap):
        alpha = self.alpha
        sum_map = heatmap.copy()
        sum_map /= sum_map.sum() + 1e-8
        sum_map -= alpha / sum_map.shape[0] / sum_map.shape[1]
        bf = BruteForceBoxSearch()
        objective = FractionAreaObjective(alpha=alpha)
        box = bf(heatmap, objective)
        box = box.astype(np.float32)[None]
        return box
    
    def get_heatmaps(self, image_features, text_features):
        image_features /= image_features.norm(dim=1, keepdim=True)
        text_features /= text_features.norm(dim=1, keepdim=True)
        heatmaps = (image_features * text_features[:, :, None, None]).sum(1)
        heatmaps = torch.exp(heatmaps/0.025)
        if self.remap_heatmaps:
            for i in range(len(heatmaps)):          
                min_ele = torch.min(heatmaps[i])
                heatmaps[i] -= min_ele
                heatmaps[i] /= torch.max(heatmaps[i])
        return heatmaps

    def forward(self, images, texts):
        image_features = self.get_image_features(images)
        text_features = self.get_text_features(texts)
        heatmaps = self.get_heatmaps(image_features, text_features)
        heatmaps = heatmaps.cpu().detach().float()
        return heatmaps


In [33]:
class HeatmapToBox(nn.Module):
    """Custom model to regress a bounding box from an heatmap"""
    def __init__(self):
        super().__init__()  
        self.conv = nn.Sequential(                            
            nn.Conv2d(1,1,8,stride=1), 
            nn.Conv2d(1,1,6,stride=2), 
            nn.Conv2d(1,1,4,stride=2)
            )
        self.seq = nn.Sequential(    
            nn.Flatten(),
            nn.AvgPool1d(4),

            nn.Linear(676, 256),
            nn.Dropout(p=0.05),            
            nn.Sigmoid(),            

            nn.Linear(256, 128),  
            #nn.Dropout(p=0.05),          
            nn.Sigmoid(),

            nn.Linear(128, 4),
            nn.Sigmoid(),          
        )       
        self.conv.apply(self.weights_init)
        self.seq.apply(self.weights_init)
        
    def weights_init(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, 0.0,  .75)
        if isinstance(m, torch.nn.BatchNorm1d):
            torch.nn.init.normal_(m.weight, 0.0, .75)
        if isinstance(m, torch.nn.Conv2d):
            torch.nn.init.normal_(m.weight, 0.0, .75)
            
    def forward(self, x):   
        return self.seq(self.conv(x.unsqueeze(dim=1)))

# Training

**Intersection Over Union**

In [34]:
from torchvision.ops import generalized_box_iou_loss
from torchvision.ops.boxes import box_convert

def iou(boxes1, boxes2) -> torch.Tensor:
    return generalized_box_iou_loss(box_convert(boxes1,in_fmt="xywh",out_fmt="xyxy"),box_convert(boxes2,in_fmt="xywh",out_fmt="xyxy"))

**Training parameters**

In [36]:
train_size = 4096
train_batch_size = 64
epochs = 512
mini_train_data = train_data[:train_size]

validation_size = 1024
validation_batch_size = 32
validation_module = 6
mini_val_data = validate_data[:validation_size]

test_size = 512
test_batch_size = 32
mini_test_data = test_data[:test_size]

### Routines
- Train
- Validation
- Test

In [41]:
def training_routine(model, loss_fn, feature_extractor, optimizer):
    model.train()
    epoch_loss =[]
    for i in range(0, train_size, train_batch_size):
        batch_data = mini_train_data[i:i+train_batch_size]

        images, target_boxes, prompts = [], [], []
        for image_path, prompt, box in batch_data:
            image = Image.open(image_path).convert("RGB")
            w, h = image.size
            correct_box = [
                box[0] / w,
                box[1] / h,
                (box[0] + box[2]) / w,
                (box[1] + box[3]) / h
            ]
            target_boxes.append(correct_box)
            images.append(image)            
            prompts.append(prompt)            

        target_boxes = torch.tensor(target_boxes)

        with torch.no_grad():
            heatmaps = feature_extractor(images, prompts)            
        heatmaps_tensor = torch.tensor(np.array(heatmaps))

        optimizer.zero_grad()
        prediction_boxes = model(heatmaps_tensor)
        loss = loss_fn(prediction_boxes, target_boxes)        
        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    return sum(epoch_loss) / len(epoch_loss)

def validation_routine(model, loss_fn, feature_extractor):
    model.eval()
    epoch_loss = []
    giou = []
    print("running validation")
    for i in range(0, validation_size, validation_batch_size):
        batch_data = mini_val_data[i:i+validation_batch_size]
        images, target_boxes, prompts = [], [], []
        for image_path, prompt, box in batch_data:
            image = Image.open(image_path).convert("RGB")
            w, h = image.size

            correct_box = [
                box[0] / w,
                box[1] / h,
                (box[0] + box[2]) / w,
                (box[1] + box[3]) / h
            ]
            target_boxes.append(correct_box)
            images.append(image)            
            prompts.append(prompt)            

        target_boxes = torch.tensor(target_boxes)
        with torch.no_grad():
            heatmaps = feature_extractor(images, prompts)
            heatmaps_tensor = torch.tensor(np.array(heatmaps))
            prediction_boxes = model(heatmaps_tensor)

        loss = loss_fn(prediction_boxes, target_boxes)
        epoch_loss.append(loss.item())
        giou.append(iou(prediction_boxes, target_boxes))
    return sum(epoch_loss) / len(epoch_loss), sum(giou) / len(giou)

def test_routine(model, loss_fn, feature_extractor):
    model.eval()
    epoch_loss = []
    giou = []
    print("running test")
    for i in range(0, test_size, test_batch_size):
        batch_data = mini_test_data[i:i+test_batch_size]
        images, target_boxes, prompts = [], [], []
        for image_path, prompt, box in batch_data:
            image = Image.open(image_path).convert("RGB")
            w, h = image.size

            correct_box = [
                box[0] / w,
                box[1] / h,
                (box[0] + box[2]) / w,
                (box[1] + box[3]) / h
            ]
            target_boxes.append(correct_box)
            images.append(image)            
            prompts.append(prompt)            

        target_boxes = torch.tensor(target_boxes)
        with torch.no_grad():
            heatmaps = feature_extractor(images, prompts)
            heatmaps_tensor = torch.tensor(np.array(heatmaps))
            prediction_boxes = model(heatmaps_tensor)

        loss = loss_fn(prediction_boxes, target_boxes)
        epoch_loss.append(loss.item())
        giou.append(iou(prediction_boxes, target_boxes))
    return sum(epoch_loss) / len(epoch_loss), sum(giou) / len(giou)

### Training Cycle

In [39]:
feature_extractor = ResNetHighResV2(clip_preprocess, clip.tokenize, remap_heatmaps=False)
loss_fn = nn.MSELoss()
bboxer = HeatmapToBox()
optimizer = torch.optim.Adam(params=bboxer.parameters(), lr=1e-2)

In [None]:
best = 1E3
for epoch in range(epochs):
    loss = training_routine(bboxer,loss_fn ,feature_extractor, optimizer)
    print(f"epoch {epoch}")
    print(f"training_loss: {loss}")
    if epoch != 0 and epoch % validation_module == 0:        
        val_loss, giou = validation_routine(bboxer, loss_fn, feature_extractor)
        if val_loss+giou < best:
            torch.save(bboxer.state_dict(), "checkpoint")
            print("saving checkpoint")
            best=val_loss+giou        
        print(f"validation loss: {val_loss}, giou: {giou}")                
optimizer.zero_grad(set_to_none=True)
torch.save(bboxer.state_dict(), "checkpoint")

### Testing

In [None]:
test_loss, test_giou = test_routine(bboxer, loss_fn, feature_extractor)
print(f"test loss: {val_loss}, test giou: {giou}")         