Here we collect and analyze the individual results of each setup per left-out test domains. These results can be compared to those in Appendix B of [DomainBed paper](https://arxiv.org/pdf/2007.01434.pdf).

In [1]:
import numpy as np
import json
import os
import glob
import sys
from pprint import pprint
from matplotlib import pyplot as plt
import pandas as pd

from domainbed.lib import misc, reporting
from domainbed import datasets
from domainbed import algorithms
from domainbed.lib.query import Q
from domainbed.model_selection import OracleSelectionMethod

# Arguments

In [2]:
base_output_dir="./checkpoints"
dataset_all = ['PACS', 'VLCS', 'OfficeHome', 'TerraIncognita', 'DomainNet']
algorithm_all = {'CLIPPretrained': 'CLIPPretrained',
                 'CLIPBase': 'SupCLIPBottleneckBase',
                 'CLIPCondCAD': 'SupCLIPBottleneckCondCAD'}

select_method = OracleSelectionMethod

# Helper Functions

In [3]:
def pretty(d, indent=0):
    for key, value in d.items():
        print('\t' * indent + str(key))
        if isinstance(value, dict):
             pretty(value, indent+1)
        else:
             print('\t' * (indent+1) + str(value))

def get_record(output_dir):
    print("Loading records from:", output_dir)
    records = reporting.load_records(output_dir)
    print("Total records:", len(records))

    return records

def get_results_per_domain(out_dir, selection_method, num_envs, env_names=None):
    """Given all records, get averaged results of each setup for each test domain."""
    records = get_record(out_dir)
    
    grouped_records = reporting.get_grouped_records(records,
                                                    group_test_envs=True).map(
        lambda group:
        {**group, "sweep_acc": selection_method.sweep_acc(group["records"], return_extra=True)}
        ).filter(lambda g: g["sweep_acc"] is not None)

    # read algorithm names and sort (predefined order)
    alg_names = Q(records).select("args.algorithm").unique()
    assert len(alg_names) == 1
    algorithm = alg_names[0]

    # read dataset names and sort (lexicographic order)
    dataset_names = Q(records).select("args.dataset").unique().sorted()
    assert len(dataset_names) == 1
    dataset = dataset_names[0]
    
    results = {}
    for test_env in range(num_envs):
        trial_averages = (grouped_records
                              .filter_equals("algorithm, dataset, test_env", (algorithm, dataset, test_env))
                              .group("trial_seed")
                              .map(lambda trial_seed, group:
        #                            group.select("sweep_acc").mean()
                                   tuple(map(lambda y: sum(y) / float(len(y)), zip(*group.select("sweep_acc"))))
                                   )
                              )

        tgt_all, src_all, tgt_in_all = zip(*trial_averages) 
        tgt_mean, src_mean, tgt_in_mean = 100 * np.mean(list(tgt_all)), 100 * np.mean(list(src_all)), 100 * np.mean(list(tgt_in_all))
        tgt_std, src_std, tgt_in_std = 100 * np.std(list(tgt_all)), 100 * np.std(list(src_all)), 100 * np.std(list(tgt_in_all))

        if env_names is not None:
            result_key = env_names[test_env]
        else:
            result_key = f'env_{test_env}'
        
        results.update({result_key: '{:.1f} +/- {:.1f}'.format(tgt_mean, tgt_std)})
    
    return results

# Collect Results

Collect the result dict for each setup:

In [4]:
def get_result(setup):
    result_dict = {}
    for dataset in dataset_all:
        result_dict[dataset] = {}
        sub_result_dict = result_dict[dataset]
        basedir = f'{base_output_dir}/{dataset}/{setup}'
        
        env_names = datasets.get_dataset_class(dataset).ENVIRONMENTS
        num_envs = len(env_names)
        for alg_name, alg_name_long in algorithm_all.items():
            if alg_name in ['CLIPPretrained', 'CLIPBase']:
                sub_result_dict[alg_name] = {}
                subsub_result_dict = sub_result_dict[alg_name]

                output_dir = os.path.join(basedir, f'{alg_name_long}/base')
                subsub_result_dict.update(get_results_per_domain(output_dir, select_method, num_envs, env_names))
            else:
                lambda_str_array = list(map(lambda s: s.split('_')[-1], glob.glob(os.path.join(basedir, f'{alg_name_long}/*'))))
                lambda_str_array = sorted(lambda_str_array, key=lambda r: float(r))
                lambda_val_array = np.array(list(map(lambda s: float(s), lambda_str_array)))

                for lambda_str in lambda_str_array:
                    sub_result_dict[alg_name + '_lambda_{}'.format(lambda_str)] = {}
                    subsub_result_dict = sub_result_dict[alg_name + '_lambda_{}'.format(lambda_str)]
                    output_dir = os.path.join(basedir, f'{alg_name_long}/lambda_{lambda_str}')

                    subsub_result_dict.update(get_results_per_domain(output_dir, select_method, num_envs, env_names))
    return result_dict


Result with CLIP-RN50 (CLIP S in our paper):

In [5]:
resnet_result_dict = get_result('clip_resnet')

Loading records from: ./checkpoints/PACS/clip_resnet/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/PACS/clip_resnet/SupCLIPBottleneckBase/base


                                                                                

Total records: 2000
Loading records from: ./checkpoints/PACS/clip_resnet/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2000
Loading records from: ./checkpoints/VLCS/clip_resnet/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/VLCS/clip_resnet/SupCLIPBottleneckBase/base


                                                                                

Total records: 2000
Loading records from: ./checkpoints/VLCS/clip_resnet/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2000
Loading records from: ./checkpoints/OfficeHome/clip_resnet/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/OfficeHome/clip_resnet/SupCLIPBottleneckBase/base


                                                                                

Total records: 200
Loading records from: ./checkpoints/OfficeHome/clip_resnet/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 200
Loading records from: ./checkpoints/TerraIncognita/clip_resnet/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/TerraIncognita/clip_resnet/SupCLIPBottleneckBase/base


                                                                                

Total records: 2181
Loading records from: ./checkpoints/TerraIncognita/clip_resnet/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2122
Loading records from: ./checkpoints/DomainNet/clip_resnet/CLIPPretrained/base


                                                                                

Total records: 30
Loading records from: ./checkpoints/DomainNet/clip_resnet/SupCLIPBottleneckBase/base


                                                                                

Total records: 300
Loading records from: ./checkpoints/DomainNet/clip_resnet/SupCLIPBottleneckCondCAD/lambda_1


                                                                                

Total records: 300




Result with CLIP-ViT-B/32 (CLIP L in our paper):

In [6]:
vit_result_dict = get_result('clip_vit')

Loading records from: ./checkpoints/PACS/clip_vit/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/PACS/clip_vit/SupCLIPBottleneckBase/base


                                                                                

Total records: 2000
Loading records from: ./checkpoints/PACS/clip_vit/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2000
Loading records from: ./checkpoints/VLCS/clip_vit/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/VLCS/clip_vit/SupCLIPBottleneckBase/base


                                                                                

Total records: 2000
Loading records from: ./checkpoints/VLCS/clip_vit/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2000
Loading records from: ./checkpoints/OfficeHome/clip_vit/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/OfficeHome/clip_vit/SupCLIPBottleneckBase/base


                                                                                

Total records: 200
Loading records from: ./checkpoints/OfficeHome/clip_vit/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 200
Loading records from: ./checkpoints/TerraIncognita/clip_vit/CLIPPretrained/base


                                                                                

Total records: 20
Loading records from: ./checkpoints/TerraIncognita/clip_vit/SupCLIPBottleneckBase/base


                                                                                

Total records: 2015
Loading records from: ./checkpoints/TerraIncognita/clip_vit/SupCLIPBottleneckCondCAD/lambda_1e-2


                                                                                

Total records: 2008
Loading records from: ./checkpoints/DomainNet/clip_vit/CLIPPretrained/base


                                                                                

Total records: 30
Loading records from: ./checkpoints/DomainNet/clip_vit/SupCLIPBottleneckBase/base


                                                                                

Total records: 300
Loading records from: ./checkpoints/DomainNet/clip_vit/SupCLIPBottleneckCondCAD/lambda_1e-1


                                                                                

Total records: 300




# Plot Results

We plot more fine-grained results broken up to each left-out test domain, which can be compared to those in Appendix B of [DomainBed paper](https://arxiv.org/pdf/2007.01434.pdf).

The improvement of CLIP methods over DomainBed baselines varies depending on specific dataset and left-out test domain. Finetuning with bottlenecks (CLIPCondCAD) leads to consistent improvement over the other two CLIP baselines for nearly all setups.

# DomainNet

CAD leads to significant improvments on DomainNet.

In [7]:
pd.DataFrame(resnet_result_dict['DomainNet']).T

Unnamed: 0,clip,info,paint,quick,real,sketch
CLIPPretrained,60.4 +/- 0.1,38.5 +/- 0.3,54.7 +/- 0.2,8.2 +/- 0.1,72.8 +/- 0.1,51.1 +/- 0.2
CLIPBase,60.7 +/- 0.4,35.3 +/- 0.3,53.2 +/- 0.3,9.6 +/- 0.2,71.4 +/- 0.2,50.1 +/- 0.2
CLIPCondCAD_lambda_1,61.8 +/- 0.5,38.3 +/- 0.4,56.0 +/- 0.2,9.5 +/- 0.2,73.8 +/- 0.1,52.8 +/- 0.1


In [8]:
pd.DataFrame(vit_result_dict['DomainNet']).T

Unnamed: 0,clip,info,paint,quick,real,sketch
CLIPPretrained,70.6 +/- 0.2,38.1 +/- 0.3,60.6 +/- 0.3,12.8 +/- 0.2,76.5 +/- 0.2,58.2 +/- 0.1
CLIPBase,70.7 +/- 0.1,36.9 +/- 0.3,59.8 +/- 0.2,14.3 +/- 0.1,75.4 +/- 0.1,57.4 +/- 0.3
CLIPCondCAD_lambda_1e-1,71.3 +/- 0.2,38.3 +/- 0.6,61.5 +/- 0.2,14.3 +/- 0.1,77.5 +/- 0.1,59.3 +/- 0.2


## VLCS

In [9]:
pd.DataFrame(resnet_result_dict['VLCS']).T

Unnamed: 0,C,L,S,V
CLIPPretrained,97.4 +/- 0.9,63.4 +/- 0.6,79.4 +/- 0.2,84.1 +/- 1.6
CLIPBase,98.1 +/- 0.9,64.6 +/- 1.4,80.3 +/- 0.2,83.4 +/- 0.9
CLIPCondCAD_lambda_1e-2,98.3 +/- 0.3,66.1 +/- 1.2,80.6 +/- 0.3,84.3 +/- 0.4


In [10]:
pd.DataFrame(vit_result_dict['VLCS']).T

Unnamed: 0,C,L,S,V
CLIPPretrained,98.4 +/- 0.6,63.5 +/- 0.7,78.8 +/- 0.5,81.4 +/- 0.9
CLIPBase,99.1 +/- 0.6,64.7 +/- 1.2,78.5 +/- 1.1,82.2 +/- 1.9
CLIPCondCAD_lambda_1e-2,99.0 +/- 0.3,65.3 +/- 1.3,78.8 +/- 0.5,83.4 +/- 1.6


## PACS

In [11]:
pd.DataFrame(resnet_result_dict['PACS']).T

Unnamed: 0,A,C,P,S
CLIPPretrained,87.2 +/- 0.6,90.4 +/- 0.5,98.5 +/- 0.1,85.1 +/- 0.4
CLIPBase,90.1 +/- 0.6,91.3 +/- 1.0,99.1 +/- 0.1,84.8 +/- 0.4
CLIPCondCAD_lambda_1e-2,91.4 +/- 0.5,92.1 +/- 0.8,98.9 +/- 0.4,85.9 +/- 0.5


In [12]:
pd.DataFrame(vit_result_dict['PACS']).T

Unnamed: 0,A,C,P,S
CLIPPretrained,94.7 +/- 1.1,96.2 +/- 1.4,99.2 +/- 0.4,84.8 +/- 0.9
CLIPBase,94.5 +/- 0.9,97.2 +/- 0.6,99.5 +/- 0.2,86.8 +/- 1.1
CLIPCondCAD_lambda_1e-2,95.8 +/- 0.4,97.7 +/- 0.3,99.4 +/- 0.2,86.7 +/- 0.7


## OfficeHome

In [13]:
pd.DataFrame(resnet_result_dict['OfficeHome']).T

Unnamed: 0,A,C,P,R
CLIPPretrained,68.1 +/- 0.4,50.1 +/- 0.7,82.3 +/- 0.2,82.0 +/- 0.3
CLIPBase,68.4 +/- 1.6,50.6 +/- 0.4,81.7 +/- 0.3,81.8 +/- 0.9
CLIPCondCAD_lambda_1e-2,69.6 +/- 1.6,51.2 +/- 1.1,82.2 +/- 0.3,82.2 +/- 0.6


In [14]:
pd.DataFrame(vit_result_dict['OfficeHome']).T

Unnamed: 0,A,C,P,R
CLIPPretrained,76.8 +/- 0.5,67.2 +/- 0.5,87.6 +/- 0.3,87.8 +/- 0.3
CLIPBase,76.2 +/- 0.9,67.3 +/- 0.4,87.3 +/- 0.3,87.4 +/- 0.4
CLIPCondCAD_lambda_1e-2,76.2 +/- 0.7,68.2 +/- 0.7,88.1 +/- 0.2,87.6 +/- 0.3


## TerraIncognita

In [15]:
pd.DataFrame(resnet_result_dict['TerraIncognita']).T

Unnamed: 0,L100,L38,L43,L46
CLIPPretrained,27.1 +/- 0.9,22.5 +/- 1.2,36.5 +/- 2.1,30.6 +/- 2.5
CLIPBase,33.2 +/- 2.5,36.0 +/- 2.4,40.7 +/- 1.9,34.8 +/- 0.7
CLIPCondCAD_lambda_1e-2,35.0 +/- 2.5,34.3 +/- 3.7,40.4 +/- 1.4,34.5 +/- 1.5


In [16]:
pd.DataFrame(vit_result_dict['TerraIncognita']).T

Unnamed: 0,L100,L38,L43,L46
CLIPPretrained,42.5 +/- 1.5,36.7 +/- 0.4,38.9 +/- 1.0,29.3 +/- 0.8
CLIPBase,46.6 +/- 0.7,46.4 +/- 2.2,39.5 +/- 0.4,29.5 +/- 1.8
CLIPCondCAD_lambda_1e-2,45.4 +/- 1.2,46.8 +/- 1.7,40.7 +/- 1.1,29.3 +/- 1.0
