# Monitoring Dependency During Clip Training

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import sys

import numpy as np
import ot
import torch


sys.path.append('../')
from src.ind_tests import HSICTest
from src.utils import median_dist
from open_clip.src.clip.clip import _transform
from open_clip.src.clip.model import CLIP, convert_weights
from open_clip.src.training.data import get_dataset_fn

In [None]:
MODEL = 'yfcc15M_b128'  # model trained with 15M data (from Mitchell)
#MODEL = 'yfcc15M_b64'

DATA_PATH = '/mnt/hdd2/liu16/data/yfcc'
EMBED_PATH = f'/mnt/hdd2/liu16/open_clip/{MODEL}/embeddings'
CKP_PATH = f'/mnt/hdd2/liu16/open_clip/{MODEL}/checkpoints'

DEVICE = 'cuda:1' if torch.cuda.is_available() else "cpu"

## Obtain feature embeddings from the model checkpoints

In [None]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    
def integer_to_string(n, length=5):
    out = str(n)
    out = '0'*(length - len(out)) + out
    return out


def parse_args(data_id=0, init=14720):
    start = integer_to_string(init)
    end = integer_to_string(data_id + init)
    batch_size = 64
    if 'b128' in MODEL:
        batch_size = 128
    args = {
        'val_data': DATA_PATH + '/shard_{' + start + '..' + end + '}.tar',
        'batch_size': batch_size,
        'distributed': False,
        'workers': 1}
    return dotdict(args)


def load_model(epoch, model_info):
    ckp = f'{CKP_PATH}/epoch_{epoch}.pt'
    checkpoint = torch.load(ckp, map_location=DEVICE)
    sd = checkpoint["state_dict"]
    sd = {k[len('module.'):]: v for k, v in sd.items()}

    model = CLIP(**model_info)
    convert_weights(model)
    model.load_state_dict(sd)
    model.to(DEVICE)
    model.eval()
    return model


def obtain_embed(epoch, data_id, model, preprocess_val, init_id=14720):
    args = parse_args(data_id, init_id)
    val_data = get_dataset_fn('shard.tar', 'webdataset')(args, preprocess_val, is_train=False)
    data_loader = val_data.dataloader
    
    all_image_features, all_text_features = [], []
    with torch.no_grad():
        for batch in data_loader:
            images, text = batch
            images = images.cuda(DEVICE, non_blocking=True)
            text = text.cuda(DEVICE, non_blocking=True)
#             image_features = model.encode_image(images)
#             text_features = model.encode_text(text)
            image_features, text_features, _ = model(images, text)
            all_image_features.append(image_features)
            all_text_features.append(text_features)

    embed = torch.cat(all_image_features + all_text_features).cpu().detach().numpy()
    start = integer_to_string(init_id)
    end = integer_to_string(data_id + init_id)
    np.savetxt(f'{EMBED_PATH}/epoch_{epoch}/shard_{end}.txt', embed)

In [None]:
with open('../open_clip/src/training/model_configs/RN50.json', 'r') as f:
    model_info = json.load(f)

In [None]:
for ckp in os.listdir(CKP_PATH):
    if '.pt' not in ckp:
        continue
    epoch = ckp[:-3].split('_')[1]
    print(f'epoch = {epoch}')
    if f'epoch_{epoch}' not in os.listdir(EMBED_PATH):
        os.system(f'mkdir {EMBED_PATH}/epoch_{epoch}')
    model = load_model(epoch, model_info)
    preprocess_val = _transform(model.visual.input_resolution, is_train=False)
    for data_id in range(10):
        # set init_id = 0 for training set
        obtain_embed(epoch, data_id, model, preprocess_val)

In [None]:
alpha = 0.05
nperms = 500
hsic = HSICTest()

In [None]:
xdist = ot.dist(X, X)
ydist = ot.dist(Y, Y)


In [None]:
res1, res2, res3 = [], [], []
for i in range(1, 2):
    embed = np.loadtxt(f'data/image-text/cc-data/size{size}-part{i}.txt')
    X, Y, Z, W = embed[:size], embed[size:2*size], embed[2*size:3*size], embed[3*size:4*size]
    xeps, yeps = median_dist(X, Y)
    xeps, zeps = median_dist(X, Z)
    xeps, weps = median_dist(X, W)
    xgram = np.exp(-ot.dist(X, X)/xeps)
    ygram = np.exp(-ot.dist(Y, Y)/yeps)
    zgram = np.exp(-ot.dist(Z, Z)/zeps)
    wgram = np.exp(-ot.dist(W, W)/weps)
    res1.append(hsic.decision(xgram, ygram, alpha, nperms))
    res2.append(hsic.decision(xgram, wgram, alpha, nperms))
    res3.append(hsic.decision(zgram, ygram, alpha, nperms))
print(np.mean(res1))
print(np.mean(res2))
print(np.mean(res3))