In [None]:

import json
import numpy as np
from pycocotools.coco import COCO
from sentence_transformers import SentenceTransformer, util
import random
from sklearn.cluster import KMeans
import torch
from io import StringIO
import pandas as pd
import subprocess
import os
import qianfan
from tqdm import tqdm
import spacy
from collections import Counter
os.environ["QIANFAN_AK"] = ""
os.environ["QIANFAN_SK"] = ""

coco = COCO('coco2014annotations/captions_val2014.json')

captions = []
for img_id in coco.imgs.keys():
    img_captions = coco.imgToAnns[img_id]
    caption = random.choice(img_captions)
    captions.append(caption['caption'])
    
nlp = spacy.load("en_core_web_sm")
# 1. 统计名词频率
def get_noun_frequencies(sentences):
    noun_counter = Counter()
    for sentence in tqdm(sentences):
        doc = nlp(sentence)
        nouns = [token.text.lower() for token in doc if token.pos_ == "NOUN"]
        noun_counter.update(nouns)
    return noun_counter

# 2. 进行分层采样
def balanced_sampling(sentences, num_samples):
    noun_frequencies = get_noun_frequencies(sentences)
    
    # 创建一个字典，存储每个名词对应的句子
    noun_to_sentences = {noun: [] for noun in noun_frequencies}
    for sentence in tqdm(sentences):
        doc = nlp(sentence)
        nouns = set(token.text.lower() for token in doc if token.pos_ == "NOUN")
        for noun in nouns:
            noun_to_sentences[noun].append(sentence)

    # 计算每个名词的样本数量
    total_nouns = len(noun_frequencies)
    samples_per_noun = num_samples // total_nouns

    sampled_sentences = set()
    for noun, sentence_list in noun_to_sentences.items():
        # 从每个名词对应的句子中采样
        selected_sentences = random.sample(sentence_list, min(samples_per_noun, len(sentence_list)))
        sampled_sentences.update(selected_sentences)

    return list(sampled_sentences)

# 3. 采样句子
num_samples = 20000  # 你想采样的句子数量
captions = balanced_sampling(captions, num_samples)
captions = list(set(captions))

model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2').to('mps') 
captions_embeddings = model.encode(captions, convert_to_tensor=True)
captions_embeddings = torch.nn.functional.normalize(captions_embeddings, dim=1).cpu()

loading annotations into memory...
Done (t=0.13s)
creating index...
index created!


100%|██████████| 40504/40504 [02:32<00:00, 266.43it/s]
100%|██████████| 40504/40504 [02:28<00:00, 272.16it/s]


In [2]:
len(captions)

9489

In [3]:
num_clusters = 400 
kmeans = KMeans(n_clusters=num_clusters)
kmeans.fit(captions_embeddings)
clusters = kmeans.labels_

# captions_class = {i:[] for i in range(num_clusters)}
captions_class = {i:[] for i in range(len(set(clusters)))}
for i, cluster in enumerate(clusters):
    captions_class[cluster].append((captions[i],captions_embeddings[i]))
captions_selected = [random.choice(captions_class[i]) for i in captions_class]



In [4]:
num_reference = 6
images_selected = {}
captions_reference_selected = {}
for idx, caption_selected in enumerate(captions_selected):
    embeddings_cluster = np.array([caption_class[1] for caption_class in captions_class[idx]])
    sims = util.pytorch_cos_sim(caption_selected[1],embeddings_cluster)
    if len(sims.reshape(-1)) < num_reference:
        continue
    captions_reference_idx = sims.reshape(-1).topk(k=num_reference).indices.tolist()
    captions_reference = [caption_reference[0] for idx_, caption_reference in enumerate(captions_class[idx]) if idx_ in captions_reference_idx]
    images_reference = []
    for caption_reference in captions_reference:
        for img_id in coco.imgs.keys():
            img_captions = coco.imgToAnns[img_id]
            captions_temp = [img_caption['caption'] for img_caption in img_captions]
            if caption_reference in captions_temp:
                images_reference.append(coco.loadImgs(img_id)[0]['file_name'])
                images_selected[caption_selected[0]] = images_reference
                captions_reference_selected[caption_selected[0]] = captions_reference
                break

In [5]:
images_selected_llm = {}
for caption in images_selected:
    names_image = images_selected[caption]
    captions_total = captions_reference_selected[caption]
    # for name_image in names_image:
    #     img_id = int(name_image.split('.')[0][-6:])
    #     ann_ids = coco.getAnnIds(img_id)
    #     captions_ = coco.loadAnns(ann_ids)
    #     captions_ = [i['caption'] for i in captions_]
    #     captions_total += captions_
        
    captions_combine = ''
    idx = 0
    for caption_total in captions_total:
        captions_combine+=f'Caption{idx}: {caption_total}\n'
        idx += 1
    input_llm = f'''Here are some captions.
{captions_combine}
Please find what these captions have in common, don't have to describe the difference between them, DO NOT use generalisations such as various, different and so on and write it in one caption. Please only answer the caption without anything else.'''
#     input_llm = f'''Here are some captions.
# {captions_combine}
# Please find what these captions have in common, don't have to describe the difference between them, DO NOT use generalisations such as various, different and so on and write it in one caption, DO NOT use the word 'or'. Please only answer the caption without anything else.'''

    resp = qianfan.ChatCompletion().do(model="ERNIE-4.0-8K-0613", messages=[{"role":"user","content":input_llm}])
    images_selected_llm[resp.body['result']] = names_image

[INFO][2025-01-09 15:45:32.052] oauth.py:228 [t:8533180992]: trying to refresh access_token for ak `KRwV4I***`
[INFO][2025-01-09 15:45:32.655] oauth.py:243 [t:8533180992]: sucessfully refresh access_token


In [6]:
images_selected_llm_jsonl = [{'id': idx, 'prompt': line, 'reference': images_selected_llm[line]} for idx, line in enumerate(images_selected_llm)]
with open('data.jsonl', "w") as f:
    for obj in images_selected_llm_jsonl:
        json.dump(obj, f)
        f.write('\n')

In [7]:
import pandas as pd
pd.DataFrame(images_selected_llm_jsonl)[['id','prompt']].to_csv('only_prompt.csv',index=False)