In [None]:
import os
import csv
import shutil

dataset_folder = "/home/sckim/Dataset/background/metadata.csv"
image_folder = os.path.join(dataset_folder, "image")
text_folder = os.path.join(dataset_folder, "text")
meta_path = os.path.join(dataset_folder, "tag.csv")

text_names = sorted(os.listdir(text_folder))
image_names = sorted(os.listdir(image_folder))

data = []

for text_file_name, image_file_name in zip(text_names, image_names):
    text_file_path = os.path.join(text_folder, text_file_name)
    image_file_path = os.path.join(image_folder, image_file_name)
    caption = ""

    with open(text_file_path, "r") as f:
        lines = f.readlines()

        for line in lines:
            caption = caption + line
        f.close()

    if "no human" in caption:
        file_path = os.path.join(dataset_folder, image_file_name)
        data.append({"file_path": file_path, "text": caption})
        shutil.move(image_file_path, file_path)

with open(meta_path, mode="w", newline="") as file:
    writer = csv.DictWriter(file, fieldnames=["file_path", "text"])

    writer.writeheader()

    for row in data:
        writer.writerow(row)

print(f"{meta_path} 파일이 생성되었습니다.")

In [None]:
import csv
from tqdm import tqdm
import os
from PIL import Image
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from diffusers.blip.models.blip import blip_decoder


def load_demo_image(image_path, image_size, device):
    raw_image = Image.open(image_path).convert("RGB")

    w, h = raw_image.size
    # display(raw_image.resize((w//5,h//5)))

    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ]
    )
    image = transform(raw_image).unsqueeze(0).to(device)
    return image


image_size = 512
image_dir = "/home/sckim/Dataset/background"
model_path = "/home/sckim/Dataset/model_large_caption.pth"
device = "cuda:0"
torch.cuda.set_device(torch.device(device))

model = blip_decoder(pretrained=model_path, image_size=image_size, vit="large")
model.eval()
model = model.to(device)

file_name = "/home/sckim/Dataset/background/metadata.csv"
data = []

for img_name in tqdm(os.listdir(image_dir)):
    if os.path.splitext(img_name)[-1] not in [".jpg", ".png", ".jpeg"]:
        continue

    image_path = os.path.join(image_dir, img_name)

    image = load_demo_image(image_path=image_path, image_size=image_size, device=device)

    with torch.no_grad():
        # beam search
        # caption = model.generate(image, sample=False, num_beams=3, max_length=40, min_length=5)
        # nucleus sampling
        caption = model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)

        data.append({"file_path": image_path, "text": caption[0]})

with open(file_name, mode="w", newline="") as file:
    writer = csv.DictWriter(file, fieldnames=["file_path", "text"])

    writer.writeheader()

    for row in data:
        writer.writerow(row)

print(f"{file_name} 파일이 생성되었습니다.")

In [None]:
from transformers import AutoTokenizer
import transformers
import torch
import csv
from tqdm import tqdm

device = "cuda:0"
torch.cuda.set_device(device)


def read_csv(csv_path):
    anno = []

    with open(csv_path, "r", newline="") as file:
        csv_reader = csv.reader(file)
        header = next(csv_reader)

        for row in csv_reader:
            anno.append(row)

    return anno


model = "/home/sckim/Dataset/llama2_7b_chat_hf"
csv_path = "/home/sckim/Dataset/background/tag.csv"
annotations = read_csv(csv_path)

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

meta_list = [[], [], []]  # 0:weather, 1:place, 2:time
pre_state = 0

for ann in tqdm(annotations):
    image_path, caption = ann[0], ann[1]

    sequences = pipeline(
        "{}\In the sentence above, tell me the words related to the weather, time, and place in each category.".format(
            caption
        ),
        do_sample=False,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_length=200,
        do_sample=False,
    )

    print(sequences[0]["generated_text"])

    # if "*" in sequences[0]['generated_text']:
    #     for word in sequences[0]['generated_text'].split('\n'):
    #         if 'Weather' in word:
    #             pre_state=0
    #         elif 'Place' in word:
    #             pre_state=1
    #         elif 'Time' in word:
    #             pre_state=2
    #         elif '* ' in word:
    #             text = word.split('* ')[-1]
    #             meta_list[pre_state].append(text)
    #         else:
    #             continue
    # else:
    #     for word in sequences[0]['generated_text'].split('\n'):
    #         if 'Weather' in word:
    #             pre_state=0
    #             text = word.split(': ')[-1]
    #             text = text.split(', ')
    #             meta_list[pre_state].extend(text)
    #         elif 'Place' in word:
    #             pre_state=1
    #             text = word.split(': ')[-1]
    #             text = text.split(', ')
    #             meta_list[pre_state].extend(text)
    #         elif 'Time' in word:
    #             pre_state=2
    #             text = word.split(': ')[-1]
    #             text = text.split(', ')
    #             meta_list[pre_state].extend(text)
    #         else:
    #             continue

In [None]:
import csv

meta_list = [list(set(meta)) for meta in meta_list]
meta_weather_path = "/home/sckim/Dataset/background/tag_weather.csv"
meta_place_path = "/home/sckim/Dataset/background/tag_place.csv"
meta_time_path = "/home/sckim/Dataset/background/tag_time.csv"


def write_csv(csv_path, data):
    with open(csv_path, mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=["meta"])

        writer.writeheader()

        for row in data:
            writer.writerow({"meta": row})

    print(f"{csv_path} 파일이 생성되었습니다.")


write_csv(meta_weather_path, meta_list[0])
write_csv(meta_place_path, meta_list[1])
write_csv(meta_time_path, meta_list[2])