-
Refactored GA = Gradient Ascent, gets a CLIP "opinion" (text) about an image
-
(optimizes for cosine similarity of text embeddings with image embeddings)
-
Long-CLIP ViT-L/14 (the model guiding stable diffusion) now fits in <24 GB memory!
-
Approx. 1.5 minutes / image (RTX 4090) / uses torch.cuda.amp / autocast + GradScaler
-
To use: python longclipga_AMP.py --image_path "IMG_IN/catpiz.png"
-
Likewise, longclipga_AMP_anti.py gets the cosine "DIS-similarity" ("opposite of") an image
-
There is no antonym to "cat" in real life - but in CLIP's embeddings, there is!
-
Use run_longclipga_AMP_opposites.py for both (batch) -> "What's most ALIKE to the image?" + "What's most UNLIKE the image?"
-
Saves output (all + best words) to "TOK" folder / txt files. -- Pro Tip: Use "best" to prompt SDXL. =)
-
β οΈ Highly recommended: Use "Sysmem Fallback" (NVIDIA Control Panel). It should fit in <24 GB VRAM - BUT that depends on what else is running on your box. Plus, you wouldn't want a CUDA OOM crash just because you opened your browser to a video. You can also lower the batch_size in the code, but that degrades CLIP's "opinion" quality (but try e.g. "8" if you absolutely must).
You won't win benchmarks with throwing small batch_sizes at a big model such as ViT-L/14; but using a finetune as the text encoder for e.g. Stable Diffusion SDXL, this CLIP will win some hearts! ππ€
- Uses AMP (automatic mixed precision) + AdaBelief optimizer (optional: fall back to AdamW) + OneCycleLR scheduler with warmup
- Gradually unfreeze CLIP (optional) or train whole model (default) + set Learning Rate for individual parameters (optional)
- Debug print when exploding or vanishing gradients occur + Many fancy logs and plots with live training updates
- Converts a "desc.csv" from CLIP Interrogator to dataset labels .json.
- Example: ft-X-example-my-dataset-labels.json is the expected format for my fine-tuning script; if you have a different format - e.g. single text files next to images - explain that to GPT-4, Claude 3, or any other AI assistant + "and I need to convert them to be labels in a single .json file that should look like so:" copy-paste the content of ft-X-example-my-dataset-labels.json into prompt as a one-shot example
- If you load your dataset: dataset1 = ImageTextDataset("path/to/image/folder", "path/to/my-text-labels.json", transform=preprocess), and inside the .json images are: "subpath/to/0001.jpg" -> then the script dataloader will look for the image in "path/to/image/folder/subpath/to/0001.jpg".
- Data augmentation: If your dataset is ~1000 images, consider augmenting the images by flipping them horizontally etc.
- The script example will create a copy of your images with color jitter, which prevents CLIP from overfitting on specific colors.
- Use augmented images with .json labels and randomly select from multiple labels for a given image. See code in (3) for details.
- Fine-tune CLIP. Insert dataset .json and path to images as per previous step. See code # comments for details.
- 10,000 text-image pairs can archive good fine-tuning results within 1-2 hours (RTX 4090).
- Convert the torch.save model .pt into a state_dict you can then just plug into SDXL as the text encoder.
- Easy as Pi with ComfyUI, see SeaArtLab/ComfyUI-Long-CLIP for details!
- Same random seed etc., just swapping out the original longCLIP-L model for my fine-tune. CFG scale 14 = high CLIP influence / guidance.
- Please note: The U-Net of SDXL was also trained on the same dataset, with a frozen CLIP (independent of CLIP).
- For fine-tuning the SDXL U-Net Diffusion Model to complement CLIP, please refer to kohya-ss/sd-scripts
Added run_visualization.py / 'vitvis' for LongCLIP feature activation max visualization
- Check run_visualization.py code # comments for instructions
- Based on hamidkazemi22/vit-visualization
- Added longclipga.py -> Get 'opinion' text from model about an image
- (Optimize cosine similarity of text embeddings for image embeddings)
- Check the code, I left comments.
- Original CLIP Gradient Ascent Script: Used with permission by Twitter / X: @advadnoun
- Added longclip-token-to-ID.py -> Get token <-> ID mapping
This repository is the official implementation of Long-CLIP
Long-CLIP: Unlocking the Long-Text Capability of CLIP
Beichen Zhang, Pan Zhang, Xiaoyi Dong, Yuhang Zang, Jiaqi Wang
- π₯ Long Input length Increase the maximum input length of CLIP from 77 to 248.
- π₯ Strong Performace Improve the R@5 of long-caption text-image retrieval by 20% and traditional text-image retrieval by 6%.
- π₯ Plug-in and play Can be directly applied in any work that requires long-text capability.
π [2024/4/1] The training code is released!
π [2024/3/25] The Inference code and models (LongCLIP-B and LongCLIP-L) are released!
π [2024/3/25] The paper is released!
- Training code for Long-CLIP based on OpenAI-CLIP
- Evaluation code for Long-CLIP
- evaluation code for zero-shot classification and text-image retrieval tasks.
- Usage example of Long-CLIP
- Checkpoints of Long-CLIP
Our model is based on CLIP, please prepare environment for CLIP.
Please first clone our repo from github by running the following command.
git clone https://github.com/beichenzbc/Long-CLIP.git
cd Long-CLIP
Then, download the checkpoints of our model LongCLIP-B and/or LongCLIP-L and place it under ./checkpoints
from model import longclip
import torch
from PIL import Image
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = longclip.load("./checkpoints/longclip-B.pt", device=device)
text = longclip.tokenize(["A man is crossing the street with a red car parked nearby.", "A man is driving a car in an urban scene."]).to(device)
image = preprocess(Image.open("./img/demo.png")).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints: [[0.982 0.01799]]
To run zero-shot classification on imagenet dataset, run the following command after preparing the data
cd eval/classification/imagenet
python imagenet.py
Similarly, run the following command for cifar datset
cd eval/classification/cifar
python cifar10.py #cifar10
python cifar100.py #cifar100
To run text-image retrieval on COCO2017 or Flickr30k, run the following command after preparing the data
cd eval/retrieval
python coco.py #COCO2017
python flickr30k.py #Flickr30k
Please refer to train/train.md
for training details.