In [None]:
import os
import requests

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait

In [None]:
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as v2
from transformers import  ConditionalDetrImageProcessor, ConditionalDetrForObjectDetection
from transformers import ViTConfig, ViTImageProcessor, ViTModel
from datasets import load_from_disk

In [None]:
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
print(f'device: {device}')

In [None]:
ckpt = "../object_detection/model_ckpt"
object_detection_image_processor = ConditionalDetrImageProcessor.from_pretrained(ckpt)
object_detector = ConditionalDetrForObjectDetection.from_pretrained(ckpt).to(device).eval()

In [None]:
def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    xx1, yy1 = max(x1, x3), max(y1, y3)
    xx2, yy2 = min(x2, x4), min(y2, y4)

    intersection_area = max(0, xx2 - xx1) * max(0, yy2 - yy1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x4 - x3) * (y4 - y3)

    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def non_max_suppression(items, iou_threshold=0.75):
    sorted_items = sorted(items, key=lambda x: x[0], reverse=True)
    
    keep = []
    while sorted_items:
        current = sorted_items.pop(0)
        keep.append(current)
        
        sorted_items = [
            item for item in sorted_items
            if calculate_iou(current[2], item[2]) < iou_threshold
        ]
    
    return keep

def detect_objects(image):
    with torch.no_grad():
        inputs = object_detection_image_processor(images=[image], return_tensors="pt")
        with torch.no_grad():
            outputs = object_detector(**inputs.to(device))
        target_sizes = torch.tensor([[image.size[1], image.size[0]]])
        results = object_detection_image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[0]

    items = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [i.item() for i in box]
        items.append((score.item(), label.item(), box))
    
    # NMS 적용
    items = non_max_suppression(items)
        
    result = []
    for score, label, bbox in items:
        box = [round(i, 2) for i in bbox]
        result.append((image.crop(bbox), object_detector.config.id2label[label], score))

    return result

In [None]:
cls_labels = ['bag', 'bottom', 'dress', 'hat', 'outer', 'shoes', 'top', 'etc']
cls_id2label = {
    i: l for i, l in enumerate(cls_labels)
}

In [None]:
ckpt = 'google/vit-base-patch16-224-in21k'
config = ViTConfig.from_pretrained(ckpt)
vit_image_processor = ViTImageProcessor.from_pretrained(ckpt)

transform = v2.Compose([
    v2.Resize((config.image_size, config.image_size)),
    v2.ToTensor(),
    v2.Normalize(mean=vit_image_processor.image_mean, std=vit_image_processor.image_std),
])

class Classifier(nn.Module):
    def __init__(self, num_labels):
        super(Classifier, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(config.hidden_size, num_labels)
        
    def forward(self, x):
        logits = self.fc(self.vit(x).pooler_output)
        return logits

classifier = Classifier(num_labels=len(cls_id2label))
classifier.load_state_dict(torch.load('../image_encoder/classification_model_ckpt/classifier.pt'))
classifier = classifier.to(device).eval()

In [None]:
driver = webdriver.Chrome() 

In [None]:
headers = {
    'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)'
}
def save_images(post_id):
    url = f'https://onthelook.co.kr/post/{post_id}'
    driver.get(url)
    WebDriverWait(driver, 0.5).until(EC.presence_of_element_located((By.CLASS_NAME, 'social_image_box')))

    post_image = driver.find_element(By.ID, 'post-image')
    post_image_url = post_image.get_attribute('src')
    post_image = Image.open(requests.get(post_image_url, stream=True).raw).convert('RGB')

    product_items = driver.find_elements(By.CLASS_NAME, 'product-item')

    product_images = []
    for product_item in product_items:
        img_tag = product_item.find_element(By.TAG_NAME, 'img')
        product_image_url = img_tag.get_attribute('src').split('?')[0]
        product_image = Image.open(requests.get(product_image_url, headers=headers, stream=True).raw).convert('RGB')
        product_images.append(product_image)

    if not product_items:
        return

    result = detect_objects(post_image)

    detected_labels = []
    shoes_count = 0
    detected_label_to_image = {}
    for cropped_image, label, score in result:
        if label == 'shoes':
            shoes_count += 1
        detected_labels.append(label)
        if label in detected_label_to_image:
            if score > detected_label_to_image[label][1]:
                detected_label_to_image[label] = ((cropped_image, score))
        else:
            detected_label_to_image[label] = ((cropped_image, score))

    for label, (cropped_image, score) in detected_label_to_image.items():
        detected_label_to_image[label] = cropped_image

    if shoes_count > 2:
        return

    if len([i for i in detected_labels if i != "shoes"]) != len(set([i for i in detected_labels if i != "shoes"])):
        return

    tag_labels = []
    tag_label_to_image = {}

    for product_image in product_images:
        if not (product_image.width >= 200 and product_image.height >= 200):
            continue

        image_tensor = transform(product_image).to(device)
        with torch.no_grad():
            logits = classifier(image_tensor.unsqueeze(0)).squeeze()
        
        probs = torch.softmax(logits, dim=-1)
        indices = probs.argsort(dim=-1, descending=True).tolist()
        max_prob = probs[indices[0]].item()
        if max_prob < 0.8:
            continue

        label = cls_id2label[indices[0]]
        if label == 'etc':
            continue

        tag_labels.append(label)
        tag_label_to_image[label] = product_image

    if not tag_labels:
        return

    if not len(tag_labels) == len(set(tag_labels)):
        return

    for label in tag_labels:
        anchor_image = detected_label_to_image.get(label)
        if not anchor_image:
            return
        positive_image = tag_label_to_image[label]

        save_dir = f'./onthelook_anchor_positive_images/{label}/{post_id}'
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        anchor_image_path = f'{save_dir}/anchor.jpg'
        positive_image_path = f'{save_dir}/positive.jpg'

        anchor_image.save(anchor_image_path)
        positive_image.save(positive_image_path)

In [None]:
start_post_id = 1
prev_dirs = sorted([d for d in os.listdir('./') if d.startswith('onthelook_dataset')])
if prev_dirs:
    prev_dataset = load_from_disk(prev_dirs[-1])
    start_post_id = max(int(i) for i in prev_dataset['post_id']) + 1

print(f'start_post_id: {start_post_id}')

for post_id in range(start_post_id, 211727):
    try:
        save_images(post_id)
    except Exception:
        pass