Skip to content

Pytorch conversion of LiT: Zero-Shot Transfer with Locked-image text Tuning

Notifications You must be signed in to change notification settings

samedii/pytorch-zero-lit

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-zero-lit

Converted official JAX models for LiT: Zero-Shot Transfer with Locked-image text Tuning to pytorch.

JAX -> Tensorflow -> ONNX -> Pytorch.

  • Image encoder is loaded into pytorch and supports gradients
  • Text encoder is not loaded into pytorch and runs via ONNX on cpu

Install

poetry add pytorch-zero-lit

or

pip install pytorch-zero-lit

Usage

from lit import LiT

model = LiT()

images = TF.to_tensor(
    Image.open("cat.png").convert("RGB").resize((224, 224))
)[None]
texts = [
    "a photo of a cat",
    "a photo of a dog",
    "a photo of a bird",
    "a photo of a fish",
]

image_encodings = model.encode_images(images)
text_encodings = model.encode_texts(texts)

cosine_similarity = model.cosine_similarity(image_encodings, text_encodings)

About

Pytorch conversion of LiT: Zero-Shot Transfer with Locked-image text Tuning

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages