Skip to content

Minimal JAX/Flax port of `lpips` supporting `vgg16`, with pre-trained weights stored in the 🤗 Hugging Face hub.

License

Notifications You must be signed in to change notification settings

pcuenca/lpips-j

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LPIPS-J

This is a minimal JAX/Flax port of lpips, as implemented in:

Only the essential features have been implemented. Our motivation is to support VQGAN training for DALL•E Mini.

It currently supports the vgg16 backend, leveraging the implementation in flaxmodels.

Pre-trained weights for the network and the linear layers are downloaded from the 🤗 Hugging Face hub.

Installation

  1. Install JAX for CUDA or TPU following the instructions at https://github.com/google/jax#installation.
  2. Install this package:
    pip install lpips-j
    

Use

Inputs must be in the range [-1, 1], and not already normalized with ImageNet stats. (They are internally converted to [0, 1] and then normalized by the underlying flax model.

x = PILToTensor()(Image.open("img8.jpg")).unsqueeze(0)
y = PILToTensor()(Image.open("img8_edited.jpg")).unsqueeze(0)

x = 2 * (x / 255.) - 1
y = 2 * (y / 255.) - 1
x = jnp.array(x).transpose(0, 2, 3, 1)
y = jnp.array(y).transpose(0, 2, 3, 1)

lpips = LPIPS()
params = lpips.init(key, x, x)
loss = lpips.apply(params, x, y)

About

Minimal JAX/Flax port of `lpips` supporting `vgg16`, with pre-trained weights stored in the 🤗 Hugging Face hub.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages