Think about how your brain effortlessly links attributes to things. When someone says "apple," your mind instantly pictures an apple, whether it's green, red, or something in between.
Now, in the digital world, we usually search for stuff using words or images, but it's time for a twist.
We're shaking up the way you search. Imagine blending words and pictures to find what you're looking for. That's what our project is all about. You toss in some text, and it influences the images you get back. Perfect for hunting down products online, keeping an eye on things, or finding stuff on the web.
Here's a quick example to paint the picture. Say you're shopping online and want a dress that's similar to your friend's but with some specific tweaks, like stripes and a bit more coverage. Our smart algorithm makes it happen.
We've got a nifty model, powered by autoencoders and transformers, to make sense of text and images for your search. It's all about learning from the good stuff and using it to get you the perfect matches. We even throw in a sprinkle of math to keep things in check.
Our approach is the champ here, beating the state-of-the-art methods on the MIT-States benchmark dataset.
- You'll find all the packages you need in requirements.txt.
About the Code (Inspired by ComposeAE)
We built our code based on ComposeAE's work, but we've spiced it up with some serious upgrades.
main.py
: The main event – run this script for training and testing.datasets.py
: Gets the goods, like loading images and whipping up training retrieval queries.text_model.py
: Model for text features.img_text_composition_models.py
: Fancy models for mixing up images and text.torch_function.py
: Holds our secret sauce – the soft triplet loss function and feature normalization magic.test_retrieval.py
: This one's all about retrieval tests and calculating recall performance.
Get your hands on the MITStates dataset right here. Save it in the data
folder. Just make sure it's got these files:
data/processed/mitstates/images/<adj noun>/*.jpg
First up, grab the Fashion200k dataset from this spot, and toss it into your trusty data
folder. To keep things on the level, we're using the same test queries as TIRG. You can nab those queries from here. Just make sure your dataset has these files:
data/processed/fashion200k/labels/*.txt
data/processed/fashion200k/women/<category>/<caption>/<id>/*.jpeg
data/processed/fashion200k/test_queries.txt`
Now, for the FashionIQ dataset, head over to this link, and stash it in your data
folder. This one's a bit of a mix, with three separate subsets: dress
, top-tee
, and shirt
. We're taking those two annotations and giving them a little text twist, combining them to make it look more like something a user might ask on an e-commerce platform.
What's more, we're bringing all three categories together for a beefed-up training set and training a single model on it. We're doing the same with the validation sets to keep things neat and tidy.
To train and test your models, just use the right commands. Here are some examples to get you started:
-
Training the MQIRTN Model on MITStates Dataset:
python -W ignore main.py --dataset=mitstates --dataset_path=../data/mitstates/ --model=MQIRTN --loss=soft_triplet --learning_rate_decay_frequency=50000 --num_iters=160000 --weight_decay=5e-5 --comment=mitstates_MQIRTN --log_dir ../logs/mitstates/
-
Training the MQIRTN Model on Fashion200k Dataset:
python -W ignore main.py --dataset=fashion200k --dataset_path=../data/fashion200k/ --model=MQIRTN --loss=batch_based_classification --learning_rate_decay_frequency=50000 --num_iters=160000 --use_bert True --use_complete_text_query True --weight_decay=5e-5 --comment=fashion200k_MQIRTN --log_dir ../logs/fashion200k/
This version simplifies the instructions and makes it more approachable for users.
We've got a snazzy BERT model that helps encode text queries. We use BERT-as-service with Uncased BERT-Base, and it dishes out a 512-dimensional feature vector for text queries. To get the nitty-gritty on how to use it, check out the instructions here. Just a heads-up, make sure you've got BERT-as-service up and running in the background before you dive into training your models.
To keep tabs on how your models are doing, use this command to monitor loss and retrieval performance:
bash tensorboard --logdir ./reports/fashion200k/ --port 8898
If our code has been a big help in your research, show us some love by citing it:
@InProceedings{,
authors = {Seyed Mohammad Bagher Hosseini, Soodeh Bakhshandeh},
title = {Multimodal Query Enhancement for Image Retrieval using Transformer Networks (MQIRTN)},
booktitle = {},
month = {March},
year = {2024},
pages = {}
}