In [1]:
import json
import sys
import random
from tqdm import tqdm
from collections import defaultdict
import pyarrow as pa
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer, util
import torch

# Import data

In [2]:
%%time
root = '.'

train_data = list(
    map(json.loads, open(f"{root}/cosmos/train_data.json").readlines())
)
test_data = list(
    map(json.loads, open(f"{root}/cosmos/test_data.json").readlines())
)
# train_data = list(map(json.loads, open(f"{root}/cosmos/val_data.json").readlines()))

CPU times: user 9.91 s, sys: 610 ms, total: 10.5 s
Wall time: 10.6 s


# Remove duplicate captions

In [3]:
def remove_duplicate(dataset):
    for data in tqdm(dataset):
        seen_caption = set()
        new_list = []
        for caption in data['articles']:
            if caption['caption_modified'] not in seen_caption:
                new_list.append(caption)
                seen_caption.add(caption['caption_modified'])
        data['articles'] = new_list

remove_duplicate(train_data)


100%|██████████| 161754/161754 [00:00<00:00, 493288.23it/s]


# SBERT Paraphrase Mining

In [4]:
model = SentenceTransformer('all-MiniLM-L6-v2')

In [5]:
import notebook_util
notebook_util.pick_gpu_lowest_memory()


1

In [6]:
torch.cuda.set_device(2)

In [7]:
def neg_cos_sim(a, b):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    return torch.neg(torch.mm(a_norm, b_norm.transpose(0, 1)))

In [8]:
# Single list of sentences - Possible tens of thousands of sentences
sentences = []

for data in tqdm(train_data):
    for caption in data['articles']:
        sentences.append(caption['caption_modified'])

paraphrases = util.paraphrase_mining(model, sentences, top_k=1, score_function=neg_cos_sim)



100%|██████████| 161754/161754 [00:00<00:00, 489280.36it/s]


In [25]:
print(len(sentences))
print(len(paraphrases))

264768
435678


In [9]:
furthest_list = np.empty((len(sentences),),dtype=int)
for entry in paraphrases:
    furthest_list[entry[1]]=entry[2]

In [10]:
sentences_dict = dict()
for x in range(len(sentences)):
    sentences_dict[sentences[x]] = x

# Gen case functions

In [11]:
# Take 2 random correct caption
def gen_positive_case(data):
    caption_1_idx, caption_2_idx = random.sample(range(0,len(data['articles'])),2)
    cap1 = data['articles'][caption_1_idx]['caption_modified']
    cap2 = data['articles'][caption_2_idx]['caption_modified']  
    return [data['img_local_path'],[cap1],[cap2],[False]]

In [12]:
# Take 2 random false caption
def gen_negative_case_1(data, train_data):
    data_random1 = train_data[np.random.randint(len(train_data))]
    while data['img_local_path'] == data_random1['img_local_path']:
        data_random1 = train_data[np.random.randint(len(train_data))]
    cap1 = data_random1['articles'][np.random.randint(len(data_random1['articles']))]['caption_modified']
    data_random2 = train_data[np.random.randint(len(train_data))]
    while data['img_local_path'] == data_random2['img_local_path']:
        data_random2 = train_data[np.random.randint(len(train_data))]
    cap2 = data_random2['articles'][np.random.randint(len(data_random2['articles']))]['caption_modified']
    return [data['img_local_path'],[cap1],[cap2],[True]]

In [13]:
# Take 1 random correct caption, 1 random false caption
def gen_negative_case_2(data, train_data):
     # Pick 1st correct caption
    if np.random.rand()>0.5:
        cap1 = data['articles'][np.random.randint(len(data['articles']))]['caption_modified']
        # Pick a random article then pick its first caption
        data_random = train_data[np.random.randint(len(train_data))]
        while data['img_local_path'] == data_random['img_local_path']:
            data_random = train_data[np.random.randint(len(train_data))]
        cap2 = data_random['articles'][np.random.randint(len(data_random['articles']))]['caption_modified']
    else:
        cap2 = data['articles'][np.random.randint(len(data['articles']))]['caption_modified']
        # Pick a random article then pick its first caption
        data_random = train_data[np.random.randint(len(train_data))]
        while data['img_local_path'] == data_random['img_local_path']:
            data_random = train_data[np.random.randint(len(train_data))]
        cap1 = data_random['articles'][np.random.randint(len(data_random['articles']))]['caption_modified']
    return [data['img_local_path'],[cap1],[cap2],[True]]

In [14]:
def gen_negative_case_3(data, train_data, furthest_list, sentences_dict, sentences, count):
    # pick 1 random correct caption
    cap1 = data['articles'][np.random.randint(len(data['articles']))]['caption_modified']
    # then find its furthest distance caption
    cap1_index = sentences_dict[cap1]
    try:
        cap2 = sentences[furthest_list[cap1_index]]
    except:
        data_random = train_data[np.random.randint(len(train_data))]
        while data['img_local_path'] == data_random['img_local_path']:
            data_random = train_data[np.random.randint(len(train_data))]
        cap2 = data_random['articles'][np.random.randint(len(data_random['articles']))]['caption_modified']
        count.append(cap1_index)
    return [data['img_local_path'],[cap1],[cap2],[True]]

In [15]:
def gen_negative_case_4(data,train_data, furthest_list, sentences_dict, sentences, count):
    # pick 1 random false caption
    data_random = train_data[np.random.randint(len(train_data))]
    while data['img_local_path'] == data_random['img_local_path']:
        data_random = train_data[np.random.randint(len(train_data))]
    cap1 = data_random['articles'][np.random.randint(len(data_random['articles']))]['caption_modified']
    # then find its furthest distance caption
    cap1_index = sentences_dict[cap1]
    try:
        cap2 = sentences[furthest_list[cap1_index]]
    except:
        data_random = train_data[np.random.randint(len(train_data))]
        while data['img_local_path'] == data_random['img_local_path']:
            data_random = train_data[np.random.randint(len(train_data))]
        cap2 = data_random['articles'][np.random.randint(len(data_random['articles']))]['caption_modified']
        count.append(cap1_index)
    return [data['img_local_path'],[cap1],[cap2],[True]]

In [None]:
def gen_negative_case_5(data, train_data):
      # choose random image
    result = gen_positive_case(data)
    result[0] = train_data[np.random.randint(len(train_data))]['img_local_path']
    return result


In [None]:
def gen_negative_case_6(data, train_data):
      # choose random image
    result = gen_positive_case(data)
    result[0] = train_data[np.random.randint(len(train_data))]['img_local_path']
    return result

In [27]:
np.random.seed(42)
# train_data_sample = np.random.choice(train_data, size=int(len(train_data)*50/100))
# Not OOC cases
l = []
count = []
count_true = 0
for data in tqdm(train_data):
    if len(data['articles']) > 1:
        l.append(gen_positive_case(data))
        count_true += 1       
    if np.random.rand() > 0.5:
        # 2 random false cap
        l.append(gen_negative_case_1(data, train_data))
    # 1 true, 1 false
    l.append(gen_negative_case_2(data, train_data))
    # 1 true, 1 false furthest
    l.append(gen_negative_case_3(data, train_data, furthest_list, sentences_dict, sentences, count))
    # 1 false, 1 furthest
    if np.random.rand() > 0.5:
        l.append(gen_negative_case_4(data, train_data, furthest_list, sentences_dict, sentences, count))
    # 2 false from 1 other image
    # 2 false from 1 furthest image



100%|██████████| 161754/161754 [00:11<00:00, 13959.11it/s]


In [32]:
(count_true/len(l))/(1-count_true/len(l))

0.149960479840602

In [18]:
dataframe = pd.DataFrame(
    l, columns=["image", "caption_1", "caption_2", "label"],
)

In [19]:
dataframe[0:20]

Unnamed: 0,image,caption_1,caption_2,label
0,train/1.jpg,"[The technical infrastructure of ORG, PRODUCT ...","[This photo taken DATE, shows apps for ORG, OR...",[False]
1,train/1.jpg,[ORG said it has nearly doubled server capacit...,[A Rohingya refugee man holds his child up as ...,[True]
2,train/1.jpg,"[The technical infrastructure of ORG, PRODUCT ...",[PERSON refused to undergo a court-ordered men...,[True]
3,train/1.jpg,[Premier PERSON gave a thumbs up to the crowd ...,[Children ride a manually operated PERSON whee...,[True]
4,train/2.jpg,"[Neel PERSON, the president of ORG of GPE, sai...",[ORG College Football Playoff Semifinal-Oklaho...,[True]
5,train/2.jpg,[Smashed CARDINAL monitors are seen inside the...,[Mr. PERSON wants to increase the utility of t...,[True]
6,train/2.jpg,[Mr. PERSON wants to increase the utility of t...,[The mother was slaughtered and the newborn ba...,[True]
7,train/2.jpg,[PERSON conducts CARDINAL of his healing cerem...,[Ending restrictive zoning doesn't have to lea...,[True]
8,train/0.jpg,[GPE had fiercely objected to its neighbor's u...,"[A statue depicting PERSON stands in GPE, GPE,...",[False]
9,train/0.jpg,"[PERSON, a ORDINAL-generation dairy farmer who...",[Bloggers who quit jobs to travel the world en...,[True]


In [20]:
def load_image(path):
   try:
      with open(path, "rb") as fp:
        return fp.read()
   except:
      return None

In [21]:
tqdm.pandas()

dataframe['image'] = dataframe['image'].progress_apply(lambda x: load_image(x))

  from pandas import Panel
100%|██████████| 558683/558683 [01:55<00:00, 4843.83it/s]


In [23]:
dataframe = dataframe[dataframe.image.notnull()]

# PyArrow

In [24]:
table = pa.Table.from_pandas(dataframe)


In [26]:
split = 'train'
with pa.OSFile(f"dataset/data_1/cosmos_{split}.arrow", "wb") as sink:
    with pa.RecordBatchFileWriter(sink, table.schema) as writer:
        writer.write_table(table)

# Val

# Test

In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")

In [None]:
def modify_caption_replace_entities(caption_text):
    """
        Utility function to replace named entities in the caption with their corresponding hypernyms
        Args:
            caption_text (str): Original caption with named entities
        Returns:
            caption_modified (str): Modified caption after replacing named entities
    """
    doc = nlp(caption_text)
    caption_modified = caption_text
    caption_entity_list = []
    for ent in doc.ents:
        caption_entity_list.append((ent.text, ent.label_))
        caption_modified = caption_modified.replace(ent.text, ent.label_, 1)
    return caption_modified

In [72]:
l_test  = []
for data in tqdm(test_data):
    cap1 = modify_caption_replace_entities(data['caption1'])
    cap2 = modify_caption_replace_entities(data['caption2'])
    l_test.append([data['img_local_path'],[cap1],[cap2],[data['context_label']==True]])

100%|██████████| 1700/1700 [00:19<00:00, 88.33it/s]


In [73]:
dataframe_test = pd.DataFrame(
    l_test, columns=["image", "caption_1", "caption_2", "label"],
)

In [74]:
dataframe_test

Unnamed: 0,image,caption_1,caption_2,label
0,test/0.jpg,"[PERSON at his announcement in GPE, GPE, on DA...","[PERSON at his announcement in GPE, GPE, on DA...",[False]
1,test/1.jpg,[Supporters of GPE's ruling ORG party come out...,[A person sits on a truck as supporters of the...,[False]
2,test/2.jpg,[CARDINAL dead people turned up on the state’s...,[These social media posts did not link to a re...,[True]
3,test/3.jpg,"[Actor, musician, director and devoted followe...",[A shocking report about the former child acto...,[True]
4,test/4.jpg,[Men from the LOC tribe perform a traditional ...,"[And on DATE in GPE's Narok county, young PERS...",[False]
...,...,...,...,...
1695,test/1695.jpg,[President PERSON trademarked the name 'WORK_O...,[There was no truth that PERSON family MONEY w...,[True]
1696,test/1696.jpg,[A photograph shows a soldier carrying a donke...,[Coronavirus meme featuring “EVENT donkey” is ...,[True]
1697,test/1697.jpg,[Homeless people living on streets in GPE],[ORG in GPE],[False]
1698,test/1698.jpg,[The castle's esplanade was a perfect spot for...,[Picture shows an ORG skier],[False]


In [75]:
tqdm.pandas()

dataframe_test['image'] = dataframe_test['image'].progress_apply(lambda x: load_image(x))

  from pandas import Panel
100%|██████████| 1700/1700 [00:00<00:00, 9703.16it/s]


In [76]:
dataframe_test

Unnamed: 0,image,caption_1,caption_2,label
0,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,"[PERSON at his announcement in GPE, GPE, on DA...","[PERSON at his announcement in GPE, GPE, on DA...",[False]
1,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[Supporters of GPE's ruling ORG party come out...,[A person sits on a truck as supporters of the...,[False]
2,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[CARDINAL dead people turned up on the state’s...,[These social media posts did not link to a re...,[True]
3,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,"[Actor, musician, director and devoted followe...",[A shocking report about the former child acto...,[True]
4,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[Men from the LOC tribe perform a traditional ...,"[And on DATE in GPE's Narok county, young PERS...",[False]
...,...,...,...,...
1695,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[President PERSON trademarked the name 'WORK_O...,[There was no truth that PERSON family MONEY w...,[True]
1696,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[A photograph shows a soldier carrying a donke...,[Coronavirus meme featuring “EVENT donkey” is ...,[True]
1697,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[Homeless people living on streets in GPE],[ORG in GPE],[False]
1698,b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00...,[The castle's esplanade was a perfect spot for...,[Picture shows an ORG skier],[False]


# PyArrow

In [77]:
table = pa.Table.from_pandas(dataframe_test)
split = 'test'
with pa.OSFile(f"dataset_50/cosmos_{split}.arrow", "wb") as sink:
    with pa.RecordBatchFileWriter(sink, table.schema) as writer:
        writer.write_table(table)