In [1]:
from PIL import Image
from io import BytesIO
from transformers import CLIPProcessor, CLIPModel, SegformerImageProcessor, AutoModelForSemanticSegmentation , AutoFeatureExtractor
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
import weaviate
import numpy as np
import json
import cv2
import base64
from time import sleep

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [None]:
checkpoint = "patrickjohncyh/fashion-clip"
model = CLIPModel.from_pretrained(checkpoint)
processor = CLIPProcessor.from_pretrained(checkpoint)
seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")

def getTextEmbeddings(text):
	inputs = processor(text=text , images=Image.new('RGB' , (72 , 72)), return_tensors="pt", padding=True)
	outputs = model(**inputs , return_dict=True)
	return outputs["text_embeds"]

embeddings = getTextEmbeddings("white formal outfit")

In [None]:
embeddings.tolist()[0]

In [2]:
client = weaviate.Client(url="http://localhost:8080")

In [3]:
client.cluster.get_nodes_status()

[{'gitHash': 'ddb8a43',
  'name': 'node1',
  'shards': [{'class': 'PinterestImages',
    'name': 'T6aWnYa0mcx4',
    'objectCount': 509},
   {'class': 'PinterestTop', 'name': 'PJPTMwTnCrl9', 'objectCount': 509},
   {'class': 'PinterestBottom', 'name': 'NIu2FDBr3MkB', 'objectCount': 509},
   {'class': 'FlipkartProducts', 'name': 'oQcrqlaKza74', 'objectCount': 167}],
  'stats': {'objectCount': 1694, 'shardCount': 4},
  'status': 'HEALTHY',
  'version': '1.19.8'}]

In [None]:
client.schema.get("FlipkartProducts")

In [None]:
response = (
    client.query
    # .get("FlipkartProducts",["uRL", "brand", "category", "product", "price", "rating", "numberRatings", "colour", "row"])
    .get("PinterestImages", ["image", "top{... on PinterestTop { image, _additional {vector} }}"])
    .with_near_vector({"vector" : embeddings.tolist()[0]})
    # .with_near_vector({"vector" : top_embedding})
    .with_additional(["vector", "id", "distance"])
    .with_limit(7)
    .do()
)
print(json.dumps(response, indent=4))

In [None]:
flipkart_embedding = response["data"]["Get"]["FlipkartProducts"][0]['_additional']["vector"]

In [None]:
top_embedding = response["data"]["Get"]["PinterestImages"][0]["top"][0]['_additional']["vector"]

In [None]:
response = (
    client.query
    .get("PinterestTop",["image"])
    .with_near_vector({"vector" : top_embedding})
    .with_additional(["vector", "id"])
    .with_limit(3)
    .do()
)
print(json.dumps(response, indent=4))

In [None]:
response = (
    client.query
    .get("FlipkartProducts",["image"])
    .with_near_vector({"vector" : top_embedding})
    .with_additional(["vector", "id", "distance"])
    .with_limit(10)
    .do()
)
print(json.dumps(response, indent=4))

In [None]:
def getImageEmbeddingsFromPath(image_path):
	image = Image.open(image_path)
	inputs = processor(text=["dummy"] , images=image, return_tensors="pt", padding=True)
	outputs = model(**inputs , return_dict=True)	
	return outputs["image_embeds"]

def getImageEmbeddings(image):
	inputs = processor(text=["dummy"] , images=image, return_tensors="pt", padding=True)
	outputs = model(**inputs , return_dict=True)
	return outputs["image_embeds"]

def applyMask(image, mask):
	image = np.array(image)
	mask = np.array(mask)
	mask = np.stack((mask,)*3, axis=-1)
	resultant = image*mask
	resultant[mask == 0] = 255
	return resultant

def cropImage(image):
	temp = image[:, :, ::-1].copy() 
	temp = temp.astype('uint8')
	gray = cv2.cvtColor(temp, cv2.COLOR_BGR2GRAY)
	thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
	contours = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
	contours = sorted(contours, key=lambda x: cv2.contourArea(x), reverse=True)
	x,y,w,h = cv2.boundingRect(contours[0])
	crop = image[y:y+h, x:x+w]
	return crop

def segment(image, to_mask):
	inputs = seg_processor(images=image, return_tensors="pt")
	outputs = seg_model(**inputs)
	logits = outputs.logits.cpu()
	upsampled_logits = nn.functional.interpolate(
		logits,
		size=image.size[::-1],
		mode="bilinear",
		align_corners=False,
	)
	pred_seg = upsampled_logits.argmax(dim=1)[0]
	result = []
	for i in to_mask:
		mask = pred_seg.numpy().copy()
		mask[mask != i] = 0
		mask[mask == i] = 1
		item = applyMask(image, mask)
		result.append(item)
	return result

def segmentAndEmbed(image_path, to_mask):
	result = {}
	image = Image.open(image_path)
	fullImageEmbedding = getImageEmbeddings(image)
	buffered = BytesIO()
	image.save(buffered, format="PNG")
	fullImageBase64 = base64.b64encode(buffered.getvalue()).decode()
	result["fullImageBase64"] = fullImageBase64
	result["fullImageEmbedding"] = fullImageEmbedding
	segments = segment(image, to_mask)
	for i in range(len(to_mask)):
		segmentEmbedding = getImageEmbeddings(segments[i])
		segments[i] = Image.fromarray(np.uint8(segments[i]))
		buffered = BytesIO()
		segments[i].save(buffered, format="PNG")
		segmentBase64 = base64.b64encode(buffered.getvalue()).decode()
		result[f"segmentBase64_{i}"] = segmentBase64
		result[f"segmentEmbedding_{i}"] = segmentEmbedding
	return result

In [None]:
result = segmentAndEmbed("pinterest_image.jpg", [4])

In [None]:
top_embedding1 = result["fullImageEmbedding"].tolist()[0]

In [None]:
image = response["data"]["Get"]["FlipkartProducts"][2]['image']
# image = response["data"]["Get"]["PinterestTop"][0]['image']
# image = response["data"]["Get"]["PinterestImages"][6]['image']
image = Image.open(BytesIO(base64.b64decode(image.split(",",1)[0])))
image.show()
# image.save("pinterest_image.jpg")

In [None]:
for i in range(15):
    # image = response["data"]["Get"]["FlipkartProducts"][i]['image']
    # image = response["data"]["Get"]["PinterestTop"][0]['image']
    # image = response["data"]["Get"]["PinterestImages"][i]['image']
    image = Image.open(BytesIO(base64.b64decode(image.split(",",1)[0])))
    image.show()
    