In [None]:
import copy
import pickle
import sys
from itertools import compress
from pathlib import Path
import re
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from matplotlib.patches import Polygon
from skimage.measure import find_contours
from torchvision.ops import nms

import src.GLIP.maskrcnn_benchmark as maskrcnn_benchmark

sys.modules['maskrcnn_benchmark'] = maskrcnn_benchmark

In [None]:
phrase_grounding_results_folder = '../data/phrase_grounding_results/'
dataset_for_phrase_grounding = '../data/dataset_for_phrase_grounding/'
save_dir = '../data/phrase_grounding_selected/'
MDETR_caption = pickle.load(open(phrase_grounding_results_folder + 'MDETR_full_caption.p', 'rb'))
MDETR_title = pickle.load(open(phrase_grounding_results_folder + 'MDETR_full_title.p', 'rb'))
GLIP_caption = pickle.load(open(phrase_grounding_results_folder + 'GLIP_full_caption.p', 'rb'))
GLIP_title = pickle.load(open(phrase_grounding_results_folder + 'GLIP_full_title.p', 'rb'))
dataset = pickle.load(open(dataset_for_phrase_grounding + 'dataset.p', 'rb'))

In [None]:
def GLIP2MDETR(glip_array):
    """
    Transforms a GLIP array into a MDETR-format array
    :param glip_array: array of inference of GLIP
    :return: the array in a MDETR-format
    """
    mdetr_array = []
    for elem in glip_array:
        caption = [elem[1][k - 1] if k < len(elem[1]) else elem[1][len(elem[1]) - 1] for k in
                   elem[0].get_field('labels')]
        mdetr_array.append([elem[0].get_field('scores'), elem[0].bbox, caption])
    return mdetr_array


GLIP_caption = GLIP2MDETR(GLIP_caption)
GLIP_title = GLIP2MDETR(GLIP_title)

A few helper functions

In [None]:
def det_nms(segmentation_array):
    segmentation_array_ = []
    for i, elem in enumerate(segmentation_array):
        segmentation_array_.append(list(elem))
    seg_filtered = copy.deepcopy(segmentation_array_)
    for index_, elem in enumerate(segmentation_array_):
        unique_caption = set(elem[2])
        if len(elem[2]) != 0:
            boolean_index = [[elem_ == cap for elem_ in elem[2]] for cap in list(unique_caption)]
            idx = [[i for i, x in enumerate(bool_idx) if x] for bool_idx in boolean_index]
            idx_to_keep = [nms(boxes=torch.index_select(elem[1], 0, torch.tensor(idx_)),
                               scores=torch.index_select(elem[0], 0, torch.tensor(idx_)), iou_threshold=0.2) for idx_ in
                           idx]
            scores = []
            boxes = []
            captions = []
            for idx_, idx_tokeep, caption in zip(idx, idx_to_keep, list(unique_caption)):
                scores += (elem[0][idx_][idx_tokeep])
                boxes += (elem[1][idx_][idx_tokeep])
                captions += (
                    [elem[2][i].removeprefix(' ') for i in torch.index_select(torch.tensor(idx_), 0, idx_tokeep)])

            seg_filtered[index_][0] = torch.stack(scores, dim=0)
            seg_filtered[index_][1] = torch.stack(boxes, dim=0)
            seg_filtered[index_][2] = captions

    return seg_filtered


def global_det_nms(segmentation_array):
    segmentation_array_ = []
    for i, elem in enumerate(segmentation_array):
        segmentation_array_.append(list(elem))
    seg_filtered = copy.deepcopy(segmentation_array_)
    for index_, elem in enumerate(segmentation_array_):
        if len(elem[2]) != 0:
            idx_to_keep = nms(boxes=elem[1], scores=elem[0], iou_threshold=0.9)
            scores = []
            boxes = []
            captions = []
            scores += (elem[0][idx_to_keep])
            boxes += (elem[1][idx_to_keep])
            captions += ([elem[2][i].removeprefix(' ') for i in idx_to_keep])

            seg_filtered[index_][0] = torch.stack(scores, dim=0)
            seg_filtered[index_][1] = torch.stack(boxes, dim=0)
            seg_filtered[index_][2] = captions

    return seg_filtered

Non-max suppression local then global

In [None]:
MDETR_caption = global_det_nms(det_nms(MDETR_caption))
MDETR_title = global_det_nms(det_nms(MDETR_title))
GLIP_caption = global_det_nms(det_nms(GLIP_caption))
GLIP_title = global_det_nms(det_nms(GLIP_title))

In [None]:
matplotlib.rcParams['interactive'] == True

In [None]:
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933], [0,0,0]]

def apply_mask(image, mask, color, alpha=0.5):
    """Apply the given mask to the image.
    """
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image

def get_color(label, set_label):
    for i, elem in enumerate(set_label):
        if elem.lower() == label.lower():
            return i
    return 6

from matplotlib.transforms import Affine2D, offset_copy


def rainbow_text(x, y, strings, colors, orientation='horizontal',
                 ax=None, **kwargs):
    """
    Take a list of *strings* and *colors* and place them next to each
    other, with text strings[i] being shown in colors[i].

    Parameters
    ----------
    x, y : float
        Text position in data coordinates.
    strings : list of str
        The strings to draw.
    colors : list of color
        The colors to use.
    orientation : {'horizontal', 'vertical'}
    ax : Axes, optional
        The Axes to draw into. If None, the current axes will be used.
    **kwargs
        All other keyword arguments are passed to plt.text(), so you can
        set the font size, family, etc.
    """
    if ax is None:
        ax = plt.gca()
    t = ax.transData
    fig = ax.figure
    canvas = fig.canvas

    assert orientation in ['horizontal', 'vertical']
    if orientation == 'vertical':
        kwargs.update(rotation=90, verticalalignment='bottom')

    for s, c in zip(strings, colors):
        text = ax.text(x, y, s + " ", color=c, transform=t, **kwargs)

        # Need to draw to update the text position.
        text.draw(canvas.get_renderer())
        ex = text.get_window_extent()
        # Convert window extent from pixels to inches
        # to avoid issues displaying at different dpi
        ex = fig.dpi_scale_trans.inverted().transform_bbox(ex)

        if orientation == 'horizontal':
            t = text.get_transform() + \
                offset_copy(Affine2D(), fig=fig, x=ex.width, y=0)
        else:
            t = text.get_transform() + \
                offset_copy(Affine2D(), fig=fig, x=0, y=ex.height)

def get_title(s, set_label, ax):
    set_label = [label.replace('( ', '(').replace(' )',')') for label in set_label]

    strings = [s for s in re.split('(' + ('|').join(set_label).replace('(','\(').replace(')','\)') + ')', s, flags=re.IGNORECASE) if s.strip()]
    colors = []
    for word in strings:
        colors.append(COLORS[get_color(word.replace('( ', '(').replace(' )',')'), set_label)])


    # # Calculate the total width of the text
    # width = 0
    # for string, color in zip(strings, colors):
    #     # Create a dummy text object
    #     text = plt.text(0, 0, string + " ", color=color, size=20)
    #     # Draw the text to update its position
    #     text.draw(plt.gcf().canvas.get_renderer())
    #     # Get the bounding box of the text in inches
    #     bbox = text.get_window_extent().transformed(text.get_transform().inverted())
    #     # Add the width of the text to the total width
    #     width += bbox.width

    # # Center the text by adjusting the x coordinate
    # x = 0.5 - width / 2

    # # Generate the rainbow text at the new x coordinate
    # rainbow_text(x, 0.5, strings, colors, size=20)
    
    rainbow_text(-100, -30, strings, colors, size=12, ax=ax)



def plot_results(ax, pil_img, results, expr="", masks=None, conf=0.7):
    np_image = np.array(pil_img)
    colors = COLORS * 100
    keep = results[0] > conf
    scores = results[0][keep]
    boxes = results[1][keep]
    labels = list(compress(results[2], keep))
    set_label = set(labels)

    if masks is None:
        masks = [None for _ in range(len(scores))]
    else:
        masks = masks[results[0] > conf]

    # get_title(expr, set_label, ax)

    for s, (xmin, ymin, xmax, ymax), l, mask in zip(scores, boxes.tolist(), labels, masks):
        c = colors[get_color(l, set_label)]
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=1))
        # text = f'{l}: {s:0.2f}'
        text = f'{l}'
        ax.text(int(xmin), int(ymin), text, fontsize=10, bbox=dict(facecolor='white', alpha=0.8))

        if mask is None:
            continue
        np_image = apply_mask(np_image, mask, c)

        padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        for verts in contours:
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            p = Polygon(verts, facecolor="none", edgecolor=c)
            ax.add_patch(p)

    ax.imshow(np_image)

    return ax


## Selection of the best phrase grounding inference

In [None]:
%matplotlib qt
i = 0
j = np.random.choice(len(dataset), 20, replace=False)

def save_result(i, model, expr):
    caption = dataset[i][expr]['raw']
    if model == 'GLIP':
        if expr == 'caption':
            results = GLIP_caption[i]
        else:
            results = GLIP_title[i]
    else:
        if expr == 'caption':
            results = MDETR_caption[i]
        else:
            results = MDETR_title[i]

    pickle.dump({
        'idx': i,
        'filename': dataset[i]['filename'],
        'caption': caption,
        'results': results,
        'model': model,
        'expr': expr
    }, open(save_dir + f'{dataset[i]["filename"].removesuffix(".jpg")}.p', 'wb'))


def on_press(event):
    global i
    # QUIT
    if event.key == 'x':
        plt.close('all')
        with open('../temp/latest.txt', 'w') as f:
            f.write(str(i))
        print(i)
    # MDETR CAPTION
    elif event.key == 'q':
        plt.close('all')
        save_result(i, 'MDETR', 'caption')
        i += 1
        load_slide()
    # GLIP CAPTION
    elif event.key == 'e':
        plt.close('all')
        save_result(i, 'GLIP', 'caption')
        i += 1
        load_slide()
    # MDETR TITLE
    elif event.key == 'a':
        plt.close('all')
        save_result(i, 'MDETR', 'title')
        i += 1
        load_slide()
    # GLIP TITLE
    elif event.key == 'd':
        plt.close('all')
        save_result(i, 'GLIP', 'title')
        i += 1
        load_slide()
    # PASS
    elif event.key == 'p':
        plt.close('all')
        with open('../temp/none.txt', 'a') as file:
            file.write(str(i) + ' ')
        i += 1
        load_slide()
    elif event.key == 'k':
        plt.savefig('../temp/' + str(j[i]) + '.png')
        plt.close('all')
        with open('../temp/selected.txt', 'a') as file:
            file.write(str([i]) + ' ')
        i += 1
        load_slide()
    elif event.key == 'r':
        i -= 1
        load_slide()


def load_slide():
    global i
    elem = dataset[[i]]
    fig, ax = plt.subplots(2, 2, figsize=(35, 20))

    fig.canvas.mpl_connect('key_press_event', on_press)
    # fig.suptitle(f'{elem["caption"]["raw"]}\n{elem["title"]["raw"]}', fontsize=15)
    im = Image.open(dataset_for_phrase_grounding + 'img/' + elem["filename"]).convert('RGB')

    ax[0, 0] = plot_results(ax=ax[0, 0], pil_img=im, expr=elem["caption"]["raw"], results=MDETR_caption[i])
    # ax[0, 0].title.set_text(f'MDETR {elem["caption"]["raw"]}')

    ax[1, 0] = plot_results(ax[1, 0], im, MDETR_title[i], expr=elem["title"]["raw"],)
    # ax[1, 0].title.set_text(f'MDETR {elem["title"]["raw"]}')

    ax[0, 1] = plot_results(ax[0, 1], im, GLIP_caption[i],expr=elem["caption"]["raw"], conf=0)
    # ax[0, 1].title.set_text(f'GLIP {elem["caption"]["raw"]}')

    ax[1, 1] = plot_results(ax[1, 1], im, GLIP_title[i],expr=elem["title"]["raw"], conf=0)
    # ax[1, 1].title.set_text(f'GLIP {elem["title"]["raw"]}')

    for ax_ in ax:
        for ax__ in ax_:
            ax__.axis('off')
    # plt.savefig(f'./figs_mdetr_glip/{i}.png')
    # plt.close()
    plt.show()


load_slide()

## Best results visualization

In [None]:
%matplotlib inline
# iterate over files in that directory
files = Path(save_dir).glob('*.p')

for file in files:
    result = pickle.load(open(file, 'rb'))
    fig, ax = plt.subplots()
    im = Image.open(dataset_for_phrase_grounding + 'img/' + result["filename"]).convert('RGB')
    if result["model"] == "GLIP":
        conf = 0
    else:
        conf = 0.7
    ax = plot_results(ax=ax, pil_img=im, results=result["results"], conf=conf)
    fig.suptitle(result['caption'])
    plt.savefig(save_dir + result["filename"])

In [None]:
text = plt.annotate("Hello, World!", xy=(0.5, 0.5), xytext=(0, 0),
                    textcoords="offset points", size=20,
                    bbox=dict(boxstyle="round",
                              ec=(1., 0.5, 0.5),
                              fc=(1., 0.8, 0.8),
                              ),
                    arrowprops=dict(arrowstyle="->",
                                    connectionstyle="angle,angleA=0,angleB=90,rad=10")
                    )

# Set the horizontal alignment of the text to "center"
text.set_horizontalalignment("center")

# Show the plot
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Generate rainbow text using the rainbow_text function
strings = ["Hello", "World!"]
colors = ["red", "green"]

# Calculate the total width of the text
width = 0
for string, color in zip(strings, colors):
    # Create a dummy text object
    text = plt.text(0, 0, string + " ", color=color, size=20)
    # Draw the text to update its position
    text.draw(plt.gcf().canvas.get_renderer())
    # Get the bounding box of the text in inches
    bbox = text.get_window_extent().transformed(text.get_transform().inverted())
    # Add the width of the text to the total width
    width += bbox.width

# Center the text by adjusting the x coordinate
x = 0.5 - width / 2

# Generate the rainbow text at the new x coordinate
rainbow_text(x, 0.5, strings, colors, size=20)

# Show the plot
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Create a figure and a subplot
fig, ax_= plt.subplots(2,1)
# Add multi-colored captions
caption5 = 'This caption has words in different colors'
plt.text(0.5, 0.5, caption5[:8], color='red', fontsize=14)
plt.text(0.5, 0.5 + 0.1, caption5[8:19], color='green', fontsize=14)
plt.text(0.5, 0.5 + 0.2, caption5[19:], color='blue', fontsize=14)

# Show the plot
plt.show()

In [None]:
from PIL import Image, ImageDraw, ImageFont

# Create an image
img = Image.new('RGB', (300, 300), color = (255, 255, 255))

# Create a draw object
draw = ImageDraw.Draw(img)

# Define the font and its size

# Define the titles and their colors
titles = [('Image 1', (255, 0, 0)), ('Image 2', (0, 255, 0)), ('Image 3', (0, 0, 255)), ('Image 4', (255, 255, 0))]

# Draw the titles on the image
for i, (title, color) in enumerate(titles):
    draw.text((10, 10 + i*30), title, fill = color)

# Save the image
img.save('images.png')

In [None]:
items = [2,3,10,35,181,373,422,641,951,1017,1041,1149,1196,1247,1269,1277]

for i in items:
    elem = dataset[i]
    ax = plt.gca()

    # fig.suptitle(f'{elem["caption"]["raw"]}\n{elem["title"]["raw"]}', fontsize=15)
    im = Image.open(dataset_for_phrase_grounding + 'img/' + elem["filename"]).convert('RGB')

    ax = plot_results(ax=ax, pil_img=im, expr=elem["caption"]["raw"], results=MDETR_caption[i])
    plt.title(f'{elem["caption"]["raw"]}')
    ax.axis('off')
    plt.savefig('../temp/' + str(i) + '_MDETR_caption.png')
    plt.cla()
    ax = plt.gca()

    # ax[0, 0].title.set_text(f'MDETR {elem["caption"]["raw"]}')
    ax = plot_results(ax=ax, pil_img=im, expr=elem["title"]["raw"], results=MDETR_title[i])
    plt.title(f'{elem["title"]["raw"]}')

    ax.axis('off')
    plt.savefig('../temp/' + str(i) + '_MDETR_title.png')
    plt.cla()

    ax = plt.gca()

    ax = plot_results(ax=ax, pil_img=im, expr=elem["caption"]["raw"], results=GLIP_caption[i], conf=0)
    ax.axis('off')
    plt.title(f'{elem["caption"]["raw"]}')

    plt.savefig('../temp/' + str(i) + '_GLIP_caption.png')
    plt.cla()
    ax = plt.gca()

    ax = plot_results(ax=ax, pil_img=im, expr=elem["title"]["raw"], results=GLIP_title[i], conf=0)
    plt.title(f'{elem["title"]["raw"]}')

    ax.axis('off')
    plt.savefig('../temp/' + str(i) + '_GLIP_title.png')
    plt.cla()


In [None]:
items = [2,3,10,35,181,373,422,641,951,1017,1041,1149,1196,1247,1269,1277]

for i in items:
    caption = dataset[i]["caption"]["raw"]
    title = dataset[i]["title"]["raw"]
    # add the caption and title to a file
    with open('../temp/' + 'expression.txt', 'a') as f:
        f.write(str(i) + '\n')
    with open('../temp/' + 'expression.txt', 'a') as f:
        f.write(caption + '\n')
    with open('../temp/' + 'expression.txt', 'a') as f:
        f.write(title + '\n\n')


In [None]:
# filter dataset where caption and title are less than 7 words
filtered_items = []
items = np.arange(len(dataset))
np.random.shuffle(items)
for i in items[:100]:
    elem = dataset[i]

    # continue is no caption or title
    if elem["caption"]["raw"] is None or elem["title"]["raw"] is None:
        continue

    # continue if caption or title is longer than 10 words
    if len(elem["caption"]["raw"].split()) > 10 or len(elem["title"]["raw"].split()) > 10:
        continue
    
    filtered_items.append(i)
    fig, ax = plt.subplots(2, 2, figsize=(13, 7))

    # fig.suptitle(f'{elem["caption"]["raw"]}\n{elem["title"]["raw"]}', fontsize=15)
    im = Image.open(dataset_for_phrase_grounding + 'img/' + elem["filename"]).convert('RGB')

    ax[0, 0] = plot_results(ax=ax[0, 0], pil_img=im, expr=elem["caption"]["raw"], results=MDETR_caption[i])
    ax[0, 0].set_title(f'{elem["caption"]["raw"]}', y=1.08)

    ax[1, 0] = plot_results(ax[1, 0], im, MDETR_title[i], expr=elem["title"]["raw"],)
    ax[1, 0].set_title(f'{elem["title"]["raw"]}', y=1.08)

    ax[0, 1] = plot_results(ax[0, 1], im, GLIP_caption[i],expr=elem["caption"]["raw"], conf=0)
    ax[0, 1].set_title(f'{elem["caption"]["raw"]}', y=1.08)

    ax[1, 1] = plot_results(ax[1, 1], im, GLIP_title[i],expr=elem["title"]["raw"], conf=0)
    ax[1, 1].set_title(f'{elem["title"]["raw"]}', y=1.08)

    for ax_ in ax:
        for ax__ in ax_:
            ax__.axis('off')
    # plt.savefig(f'./figs_mdetr_glip/{i}.png')
    # plt.close()
    plt.savefig('../temp/' + str(i) + '_all.png')
    print(i)
    plt.show()

In [None]:
pickle.dump(filtered_items, open('../temp/filtered_items.pkl', 'wb'))

In [None]:
for i in filtered_items:
    pickle.dump({
        'idx': i,
        'filename': dataset[i]['filename'],
        'caption': dataset[i]['caption']['raw'],
        'results': GLIP_caption[i],
        'model': 'GLIP',
        'expr': 'caption',
    }, open(f'../glip_caption/{i}.pkl', 'wb'))

In [1]:
import os

In [18]:
# get all filenames of images in the folder using glob
from pathlib import Path

file_set = []

folders = ['../glip_caption/', '../glip_title/', '../mdetr_caption/', '../mdetr_title/']
for folder in folders:
    temp_set = set()
    path = Path(folder).glob('*.jpg')
    # get filename:
    for file in path:
        # get the filename without the extension
        filename = os.path.splitext(os.path.basename(file))[0]
        temp_set.add(filename)
    file_set.append(temp_set)

# intersection of all sets in file_set
intersection = file_set[0]
for s in file_set[1:]:
    intersection = intersection.intersection(s)

filenames = list(intersection)



In [22]:
# run command 'montage -tile 2x2 image1.png image2.png image3.png image4.png combined.png' on all files in filenames
for filename in filenames:
    # create command
    command = 'montage -tile 2x2   -geometry +2+2 ' + '../mdetr_caption/' + filename + '.jpg ' + '../glip_caption/' + filename + '.jpg ' + '../mdetr_title/' + filename + '.jpg ' + '../glip_title/' + filename + '.jpg ' + '../combined/' + filename + '.jpg'
    # run command
    os.system(command)