In [1]:
import rootutils
rootutils.setup_root('.', indicator=".project-root", pythonpath=True)
from src.model.model_module import LitMML
from src.model.utils import get_model_and_processor
from omegaconf import OmegaConf
import torch
from tqdm import tqdm

device = "cuda:3"

In [2]:
ckpt_path = '/home/data/bhavin/ckpts/93t3xgrr/last.ckpt' # CLIP+ITM
model, processor = get_model_and_processor(config=OmegaConf.load('../configs/model/dual_encoder.yaml'))
lit_model = LitMML.load_from_checkpoint(ckpt_path, model=model, processor=processor, map_location=device)
model, processor = lit_model.model, lit_model.processor
itm_head = lit_model.itm_head
print(itm_head)

/home/phisch/venv_py3.8/py3.8/lib/python3.8/site-packages/lightning/pytorch/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.2.3, which is newer than your current Lightning version: v2.2.0.post0


Sequential(
  (0): Linear(in_features=1024, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=2, bias=True)
)


In [3]:
from src.data.data_module import MyDataModule
data_config = OmegaConf.load('../configs/data/test.yaml')
data_config.root = '/home/data'
data_config.dataloader.test.batch_size = 16
data_config.datasets = ["Caltech101"]
data_module = MyDataModule(data_config=data_config, processor=processor)
dataloaders = data_module.get_test_dataloaders()
print(dataloaders)

{'Caltech101': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f404c45c430>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f3ee43f5700>}}


In [4]:
from src.callbacks.zeroshot_callback import (
    _create_zero_shot_classifier, 
    _evaluate_zero_shot,
    _get_itm_head_predictions
)

loader = dataloaders["Caltech101"]["test"]
class_names = loader.dataset.classnames
num_classes = len(class_names)
templates = ["a photo of a {}."]
verbose = True
dtype = torch.float32
forward_func = lambda x: model.get_image_features(pixel_values=x)
top_k_preds = 10

classifier = _create_zero_shot_classifier(
    forward_func=lambda x, y: model.get_text_features(input_ids=x, attention_mask=y),
    classnames=class_names,
    templates=templates,
    tokenizer=processor,
    batch_size=loader.batch_size,
    device=device,
    verbose=verbose,
)

Classifier weights...: 100%|██████████| 7/7 [00:02<00:00,  2.66it/s]


In [None]:
from torchmetrics.classification import (
    MulticlassAccuracy,
    # MultilabelAveragePrecision,
    # MulticlassConfusionMatrix,
)
from torchmetrics import MetricCollection

top_k = [1, 5, 10]
metric_kwargs = dict(dist_sync_on_step=False, sync_on_compute=False)
metric = {
            f'Top{k}Accuracy': MulticlassAccuracy(top_k=k, average="micro", num_classes=num_classes, **metric_kwargs)
            for k in top_k
        }
metric = MetricCollection(metric).to(device)
metric_itm = MetricCollection({k: v.clone() for k, v in metric.items()}).to(device)

with torch.no_grad():
    bar = tqdm(loader, desc=f'Predicting...', total=len(loader)) if verbose else loader
    for point in bar:
        inputs, target = point
        inputs = inputs.to(device)
        target = target.to(device)
        with torch.cuda.amp.autocast(enabled=True, dtype=dtype):
            features = forward_func(inputs)
        features /= features.norm(dim=-1, keepdim=True)
        logits_regular = features @ classifier
        step_metric = metric(logits_regular, target.squeeze().long())
        if itm_head:
            logits_itm = _get_itm_head_predictions(itm_head, logits_regular.clone(), top_k_preds, features, classifier)
            step_metric_itm = metric_itm(logits_itm, target.squeeze().long())
        if verbose:
            postfix = {f"{k}": v.item() for k, v in step_metric.items() if k != 'ConfusionMatrix'}
            if itm_head:
                postfix.update({f"itm_{k}": v.item() for k, v in step_metric_itm.items() if k != 'ConfusionMatrix'})
            bar.set_postfix(postfix)

Predicting...:   0%|          | 0/109 [00:00<?, ?it/s]ERROR:tornado.general:SEND Error: Host unreachable
