In [1]:
import lightning as L

from fusion_bench import (
    CLIPVisionModelPool,
    CLIPVisionModelTaskPool,
    get_default_config_path,
    initialize_hydra_config,
    instantiate,
)
from fusion_bench.models.hf_clip import HFCLIPClassifier
from fusion_bench.tasks.clip_classification import (
    get_classnames_and_templates,
    get_num_classes,
)

In [2]:
fabric = L.Fabric(accelerator="auto", devices=1)
fabric.launch()

In [3]:
config = initialize_hydra_config(
    config_name="fabric_model_fusion",
    config_path=get_default_config_path(),
    overrides=[
        "method=emr_merging/emr_merging",
        "modelpool=CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only",
        "taskpool=CLIPVisionModelTaskPool/clip-vit-classification_TA8.yaml",
    ],
)

In [4]:
algorithm = instantiate(config.method)
modelpool: CLIPVisionModelPool = instantiate(config.modelpool)
taskpool: CLIPVisionModelTaskPool = instantiate(config.taskpool)
taskpool.fabric = fabric

Unused argument: base_model=openai/clip-vit-base-patch32


In [9]:
emr_model = algorithm.run(modelpool)

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

In [10]:
emr_model

EMRModulatedModel(
  (backbone): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
        (position_embedding): Embedding(50, 768)
      )
      (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-11): 12 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): QuickGELUActivation()
            

In [None]:
if not taskpool._is_setup:
    taskpool.setup()

classifier = HFCLIPClassifier(
    taskpool.clip_model,
    processor=taskpool.processor,
)
# do not use the classifier.vision_model here due to PyTorch limitation (see note in hf_clip.py)
classifier.clip_model.vision_model = emr_model
classifier = fabric.to_device(classifier)

In [19]:
id(emr_model), id(classifier.clip_model.vision_model)

(15046195680, 15046195680)

In [20]:
results = {}
for task_name in taskpool._test_datasets:
    emr_model.set_task(task_name)
    classnames, templates = get_classnames_and_templates(task_name)
    classifier.set_classification_task(
        classnames=classnames,
        templates=templates,
    )
    result = taskpool._evaluate(
        classifier,
        test_loader=taskpool.test_dataloaders[task_name],
        task_name=task_name,
        num_classes=get_num_classes(task_name),
    )
    print(f"Results for task {task_name}:\n{result}")
    results[task_name] = result

print("Final results:", results)

Evaluating sun397:   0%|          | 0/156 [00:00<?, ?it/s]



Results for task sun397:
{'accuracy': 0.7096221446990967, 'loss': 1.107221007347107}


Evaluating stanford-cars:   0%|          | 0/63 [00:00<?, ?it/s]

Results for task stanford-cars:
{'accuracy': 0.7560005187988281, 'loss': 0.8067706227302551}


Evaluating resisc45:   0%|          | 0/50 [00:00<?, ?it/s]

Results for task resisc45:
{'accuracy': 0.918571412563324, 'loss': 0.3021736443042755}


Evaluating eurosat:   0%|          | 0/22 [00:00<?, ?it/s]

Results for task eurosat:
{'accuracy': 0.9762963056564331, 'loss': 0.06937872618436813}


Evaluating svhn:   0%|          | 0/204 [00:00<?, ?it/s]

Results for task svhn:
{'accuracy': 0.9651198387145996, 'loss': 0.13929736614227295}


Evaluating gtsrb:   0%|          | 0/99 [00:00<?, ?it/s]

Results for task gtsrb:
{'accuracy': 0.9775930047035217, 'loss': 0.0880066528916359}


Evaluating mnist:   0%|          | 0/79 [00:00<?, ?it/s]

Results for task mnist:
{'accuracy': 0.9950000047683716, 'loss': 0.023963619023561478}


Evaluating dtd:   0%|          | 0/15 [00:00<?, ?it/s]

Results for task dtd:
{'accuracy': 0.7218084931373596, 'loss': 1.063814401626587}
Final results: {'sun397': {'accuracy': 0.7096221446990967, 'loss': 1.107221007347107}, 'stanford-cars': {'accuracy': 0.7560005187988281, 'loss': 0.8067706227302551}, 'resisc45': {'accuracy': 0.918571412563324, 'loss': 0.3021736443042755}, 'eurosat': {'accuracy': 0.9762963056564331, 'loss': 0.06937872618436813}, 'svhn': {'accuracy': 0.9651198387145996, 'loss': 0.13929736614227295}, 'gtsrb': {'accuracy': 0.9775930047035217, 'loss': 0.0880066528916359}, 'mnist': {'accuracy': 0.9950000047683716, 'loss': 0.023963619023561478}, 'dtd': {'accuracy': 0.7218084931373596, 'loss': 1.063814401626587}}
