# Empirical Marginal Distribution of CLIP Model

This notebook recreates Figure 3 of the manuscript, which shows the empirical marginals $P_{\theta, X}^{(k)}$, $Q_{\theta, X}^{(k)}$, $P_{\theta, Y}^{(k)}$, and $Q_{\theta, Y}^{(k)}$ as described in Sections 2 and 4 of the paper. This marginal is computed via the following steps:
1. Randomly initialize a CLIP model $(f_{\theta_I}, f_{\theta_T})$ (as in the notation of Section 2 Example 2).
2. Pass a minibatch of images-text pairs $\{(X_1, Y_1), \ldots (X_n, Y_n)\}$ through the model.
3. Compute $P_{\theta}^{(k)}$ and $Q_{\theta}^{(k)}$ using Equation 5, which are joint probability mass functions (i.e.~$n \times n$ matrices).
4. Marginalize the joint probability mass functions to compute  $P_{\theta, X}^{(k)}$, $Q_{\theta, X}^{(k)}$, $P_{\theta, Y}^{(k)}$, and $Q_{\theta, Y}^{(k)}$.

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

import sys
sys.path.append("..")
from src.multimodal_models import MiniCLIP
from src.multimodal_data import get_multimodal_dataloaders

import matplotlib as mpl

mpl.rcParams['lines.linewidth'] = 4
mpl.rcParams['xtick.labelsize'] = 20
mpl.rcParams['ytick.labelsize'] = 20
mpl.rcParams["axes.labelsize"] = 20
mpl.rcParams['legend.fontsize'] = 33
mpl.rcParams['axes.titlesize'] = 32
# mpl.rcParams['text.usetex'] = True
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

Select the device below.

In [2]:
DEVICE = "cuda:0"

Next, navigate to the location of the Imagenet-Captions dataset on your system, and set the path to the `root` variable below.

In [None]:
# load data
batch_size = 16
rank = 0
root = "/mnt/ssd/ronak/datasets/imagenet_captions_250k"
img_embed = "vit_b32_laion2b"
txt_embed = "vit_b32_datacompxl"

train_dataloader, test_dataloader = get_multimodal_dataloaders(
    batch_size, 
    rank, 
    img_embed,
    txt_embed,
    root=root, 
    quantization=None,
)