In [None]:
!export PYTHONPATH=$PWD

import sys
import json
import os
from domainbed.lib import misc, reporting
from domainbed.lib.query import Q
from domainbed import model_selection
from domainbed import datasets

import collections
from collections import OrderedDict
from domainbed.scripts.collect_results import print_table, format_mean
# from domainbed.lib.query import make_selector_fn

import scipy
latex = True

: 

In [4]:
import tqdm
def load_records(path, file_name):
    records = []
    for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))),
                               ncols=80,
                               leave=False):
        results_path = os.path.join(path, subdir, file_name)
        try:
            with open(results_path, "r") as f:
                for line in f:
                    records.append(json.loads(line[:-1]))
        except IOError:
            pass

    return Q(records)

In [5]:
def get_grouped_records(records, group_str):
    """Group records by (trial_seed, dataset, algorithm, test_env). Because
    records can have multiple test envs, a given record may appear in more than
    one group."""
    result = collections.defaultdict(lambda: [])
    for r in records:
        for test_env in r["args"]["test_envs"]:
            group = list(Q([r]).select(group_str)[0])
            group.append(test_env)
            group = tuple(group)
            result[group].append(r)
    group_key = group_str.replace(' ', '').replace('args.','').replace('hparams.', '').split(',') + ['test_env', 'records']
    
    grouped_records = []
    for v, r in result.items():
        v = list(v)
        v.append(Q(r))
        grouped_records.append(dict(zip(group_key, v)))
    return Q(grouped_records)

In [6]:
def custom_print_table(table, header_text, row_labels, col_labels, colwidth=10,
    latex=True):
    """Pretty-print a 2D array of data, optionally with row/col labels"""
    print("")

    if latex:
        num_cols = len(table[0])
        """
        print("\\begin{center}")
        print("\\adjustbox{max width=\\textwidth}{%")
        print("\\begin{tabular}{l" + "c" * num_cols + "}")
        print("\\toprule")
        """
    else:
        print("--------", header_text)

    for row, label in zip(table, row_labels):
        row.insert(0, label)
    
    """
    if latex:
        col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}"
            for col_label in col_labels]
    table.insert(0, col_labels)
    """

    for r, row in enumerate(table):
        misc.print_row(row, colwidth=colwidth, latex=latex)
        """
        if latex and r == 0:
            print("\\midrule")
        """
    print("\\midrule")
    """
    if latex:
        print("\\bottomrule")
    """


In [26]:
# algorithms = ['FrozenERM']
# algorithms = ['ERM', 'CORAL', 'DANN', 'APCLIP', 'CLIP', 'WordCLIP']
# algorithms = ['ERM', 'CORAL']
algorithms = ['ERM']
# algorithms = ['APCLIP']
backbones = ['clip_vitb16']
clip_backbones = ['ViT-B/16']
hparams = ['{"backbone": "clip", "clip_backbone": "ViT-B/16"}', '{"backbone": "clip", "clip_backbone": "ViT-B/32"}','{"backbone": "clip", "clip_backbone": "RN101"}']
# backbones = ['DeiT', 'HViT', 'ViT-B32', 'ViT-B16','resnet50', 'resnet18']
tgt_dataset_names = ['VLCS', 'PACS', 'OfficeHome', 'TerraIncognita']
tgt_adaptation_names = ['None', 'T3A-64', 'TentClf-64', 'SHOTIM-64', 'PLClf-64', 'SHOT-64', 'PseudoLabel-64']
# tgt_adaptation_names = ['None']
seeds = [0, 1, 2]
records = []

### 
### Backbone == clip
### clip_backbone == 'ViT-B/16'

In [27]:
for backbone in backbones:
    r = load_records('/groups/gcb50389/xinzhang/t3a/{}'.format(backbone), 'results.jsonl')
    r = r.filter_in('args.trial_seed', seeds)
    r = r.filter_in('args.algorithm', algorithms)
    r = r.filter_in('args.hparams', hparams)
    r = r.map(lambda r: {**r, "adapt": 'None', "selection_method": model_selection.IIDAccuracySelectionMethod}) 
    records += r._list

    for tgt_adaptation in tgt_adaptation_names:
        r = load_records('/groups/gcb50389/xinzhang/t3a/{}'.format(backbone), 'results_{}.jsonl'.format(tgt_adaptation))
        r = r.filter_in('args.trial_seed', seeds)
        r = r.filter_in('args.algorithm', algorithms)
        # r = r.filter_equals('filter_K', 100)
        r = r.map(lambda r: {**r, "adapt": r['args']['adapt_algorithm'], "selection_method": model_selection.IIDAccuracySelectionMethod}) 
        records += r._list
records = Q(records)
print(len(records))

 47%|██████████████████▏                    | 3209/6874 [01:32<02:56, 20.71it/s]

In [15]:
# perpare the.
selection_method = model_selection.IIDAccuracySelectionMethod
group_str = 'args.trial_seed, args.dataset, args.algorithm, hparams.backbone, adapt'
grouped_records = get_grouped_records(records, group_str).map(lambda group:
        { **group, "sweep_acc": group["records"][0]['selection_method'].sweep_acc(group["records"]) }
    ).filter(lambda g: g["sweep_acc"] is not None)

adaptation_names = Q(records).select("adapt").unique()
dataset_names = Q(records).select("args.dataset").unique().sorted()
dataset_names = [d for d in datasets.DATASETS if d in dataset_names]
print('adaptation_names: ', adaptation_names)
print('dataset_names: ', dataset_names)

adaptation_names:  ['None']
dataset_names:  ['VLCS', 'PACS', 'OfficeHome', 'TerraIncognita']


In [25]:
backbone_names = Q(records).select("backbone").unique()
# alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
#     [n for n in alg_names if n not in algorithms.ALGORITHMS])
print(backbone_names)
# # read dataset names and sort (lexicographic order)
dataset_names = Q(records).select("dataset").unique().sorted()
print(dataset_names)
dataset_names = [d for d in datasets.DATASETS if d in dataset_names]

for dataset in dataset_names:
    if latex:
        print()
        print("\\subsubsection{{{}}}".format(dataset))
    test_envs = range(datasets.num_environments(dataset))

    table = [[None for _ in [*test_envs, "Avg"]] for _ in backbone_names]
    for i, backbone in enumerate(backbone_names):
        means = []
        for j, test_env in enumerate(test_envs):
            trial_accs = (grouped_records
                .filter_equals(
                    "dataset, backbone, test_env",
                    (dataset, backbone, test_env)
                ).select("sweep_acc"))
            mean, err, table[i][j] = format_mean(trial_accs, latex)
            means.append(mean)
        if None in means:
            table[i][-1] = "X"
        else:
            table[i][-1] = "{:.1f}".format(sum(means) / len(means))

    col_labels = [
        "Backbone", 
        *datasets.get_dataset_class(dataset).ENVIRONMENTS,
        "Avg"
    ]
    header_text = (f"Dataset: {dataset}, "
        f"model selection method: {selection_method.name}")
    print_table(table, header_text, backbone_names, list(col_labels),
        colwidth=20, latex=latex)


['clip']
['TerraIncognita']

\subsubsection{TerraIncognita}

\begin{center}
\adjustbox{max width=\textwidth}{%
\begin{tabular}{lccccc}
\toprule
\textbf{Backbone}    & \textbf{L100}        & \textbf{L38}         & \textbf{L43}         & \textbf{L46}         & \textbf{Avg}         \\
\midrule
clip                 & 49.3 $\pm$ 8.7       & 36.2 $\pm$ 3.3       & 42.8 $\pm$ 3.2       & 36.6 $\pm$ 2.0       & 41.2                 \\
\bottomrule
\end{tabular}}
\end{center}


In [16]:
# def check_records(records):
#     for rec in records:
#         for r in rec['records']:
#             print(r)
#         # assert len(rec) == 360

# # Check the data missing.
# # if len() == 60 : backbone + 4 tta methods. 
# # elif: len() == 12 : backbone only. 
# else: ERROR.
for backbone in backbones:
    for algo in algorithms:
        for dataset in dataset_names:
            records = grouped_records.filter_equals("dataset, algorithm, backbone", (dataset, algo, backbone))
            if len(records) == 84:
                print('Finish all experiment: backbone + 6 TTA methods', algo, dataset, len(records), backbone)
                # check_records(records)
            elif len(records) == 12:
                print('Finish backbone: ', algo, dataset, len(records), backbone)
            else:
                print('Did not finish, ERROR!!!!!', len(records), algo, dataset, backbone)
        

Finish backbone:  ERM VLCS 12 clip
Finish backbone:  ERM PACS 12 clip
Finish backbone:  ERM OfficeHome 12 clip
Finish backbone:  ERM TerraIncognita 12 clip
Did not finish, ERROR!!!!! 0 CORAL VLCS clip
Did not finish, ERROR!!!!! 4 CORAL PACS clip
Finish backbone:  CORAL OfficeHome 12 clip
Finish backbone:  CORAL TerraIncognita 12 clip


In [12]:
# import shutil
# # for algo in algorithms:
# d = {}
# for dataset in dataset_names:
#     recoreds = grouped_records.filter_equals("dataset, algorithm", (dataset, 'D'))
#     # assert  == 'DANN':
#     if len(recoreds) != 60:
#         print(len(recoreds), 'DANN', dataset)
#         for rec in recoreds:
#             print(len(rec['records']))
#             for r in rec['records']:
#                 assert r['args']['algorithm'] == 'DANN'
#                 print(r['args']['output_dir'])
#                 dir = r['args']['output_dir']
#                 try:
#                     shutil.rmtree(dir)
#                     print(dir)
#                 except:
#                     pass
    # print(len(grouped_records[0]['records']))
    # print(grouped_records[0]['output_dir'])

In [17]:
for backbone in backbones:
    print(backbone)
    for algorithm in algorithms:
        table = [[None for _ in [*dataset_names, "Avg"]] for _ in adaptation_names]
        model_names = []
        for i, adapt_method in enumerate(adaptation_names):
            means = []
            if i == 0:
                model_names.append(algorithm)
            else:
                model_names.append('+'+adapt_method)
            for j, dataset in enumerate(dataset_names):
                trial_averages = (grouped_records
                    .filter_equals(
                        "dataset, adapt, backbone, algorithm",
                        (dataset, adapt_method, backbone, algorithm)
                    ).group("trial_seed")
                    .map(lambda trial_seed, group:
                        group.select("sweep_acc").mean()
                    )
                )
                mean, err, table[i][j] = format_mean(trial_averages, latex)
                means.append(mean)
            if None in means:
                table[i][-1] = "X"
            else:
                table[i][-1] = "{:.1f}".format(sum(means) / len(means))
                a = grouped_records.filter_equals("adapt, backbone, algorithm", ('None', backbone, algorithm)).filter_in('dataset', tgt_dataset_names).select('sweep_acc')
                b = grouped_records.filter_equals("adapt, backbone, algorithm", (adapt_method, backbone, algorithm)).filter_in('dataset', tgt_dataset_names).select('sweep_acc')
                if (len(a) == len(b) == 48) & (i != 0):
                    p_val = scipy.stats.ttest_rel(a, b, alternative='less')[1]
                    if p_val <= 0.01:
                        table[i][-1] += '$^{**}$'
                    elif p_val <= 0.05:
                        table[i][-1] += '$^{*}$'
                else:
                    # print(len(a), len(b))
                    pass 
        # for i, adapt_method in enumerate(adaptation_names):
        #     for j, dataset in enumerate(dataset_names):
        #         try:
        #             val = float(table[i][j].split(' ')[0])
        #             base_val = float(table[0][j].split(' ')[0])
        #             if val > base_val:
        #                 table[i][j] = '\\textbf{' + table[i][j] + '}'
        #         except:
        #             pass

        col_labels = ["Models", *dataset_names, "Avg"]
        header_text = f"Averages, backbone: {algorithm}"
        custom_print_table(table, header_text, model_names, col_labels, colwidth=25,
            latex=latex)

clip

ERM                       & 80.8 $\pm$ 0.4            & 93.0 $\pm$ 0.7            & 77.3 $\pm$ 1.6            & 42.2 $\pm$ 5.5            & 73.4                      \\
\midrule

CORAL                     & X                         & 87.4 $\pm$ 0.0            & 77.5 $\pm$ 1.0            & 40.3 $\pm$ 4.4            & X                         \\
\midrule


In [None]:
records

In [19]:
backbone_names = Q(records).select("args.hparams.backbone").unique()
# alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
#     [n for n in alg_names if n not in algorithms.ALGORITHMS])
print(backbone_names)
# read dataset names and sort (lexicographic order)
dataset_names = Q(records).select("args.dataset").unique().sorted()
dataset_names = [d for d in datasets.DATASETS if d in dataset_names]

for dataset in dataset_names:
    if latex:
        print()
        print("\\subsubsection{{{}}}".format(dataset))
    test_envs = range(datasets.num_environments(dataset))

    table = [[None for _ in [*test_envs, "Avg"]] for _ in backbone_names]
    for i, backbone in enumerate(backbone_names):
        means = []
        for j, test_env in enumerate(test_envs):
            trial_accs = (grouped_records
                .filter_equals(
                    "dataset, backbone, test_env",
                    (dataset, backbone, test_env)
                ).select("sweep_acc"))
            mean, err, table[i][j] = format_mean(trial_accs, latex)
            means.append(mean)
        if None in means:
            table[i][-1] = "X"
        else:
            table[i][-1] = "{:.1f}".format(sum(means) / len(means))

    col_labels = [
        "Backbone", 
        *datasets.get_dataset_class(dataset).ENVIRONMENTS,
        "Avg"
    ]
    header_text = (f"Dataset: {dataset}, "
        f"model selection method: {selection_method.name}")
    print_table(table, header_text, backbone_names, list(col_labels),
        colwidth=20, latex=latex)


KeyError: 'args'