In [1]:
import numpy as np
import wandb

import os
import shutil

import torch
import scipy
import numpy as np
import ot

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import tabulate
import tqdm

import itertools

from matplotlib import rc

rc('text', usetex=True)
rc('font',**{'family':'serif','serif':['Computer Modern Roman']})

%matplotlib inline
palette = sns.color_palette()

api = wandb.Api()

**Note:** this notebook uses outdated experiment config names; you may need to change experiment name regexps to run this code.

In [2]:
def query_runs(api, exp_name_regex, step):    
    runs = api.runs(
        'timgaripov/support_alignment_v2', # Use your wandb project name here
        filters={
            '$and': [
                {'config.experiment_name': {'$regex': exp_name_regex}},
                {'state': 'finished'},
                {'summary_metrics.step': {'$eq': step}},
            ]
        },
        order='config.config/training/seed.value'
    )
    return runs


In [3]:
api = wandb.Api()
# Updated regex pattern: usps_mnist_3c/lenet_2d/seed_[1-5]/s_alpha_15/dann_zero[^0_]
runs = query_runs(api, '0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann_zero[^0_]', 30000)

print(len(runs))
for run in runs:
    print(run.config['experiment_name'])

9
0_usps_mnist/features_lenetnorelu_v3/seed_1/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_2/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_2/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_3/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_3/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_4/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_4/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_5/c_3/s_alpha_15/d_2/dann_zero
0_usps_mnist/features_lenetnorelu_v3/seed_5/c_3/s_alpha_15/d_2/dann_zero


In [4]:
def compute_distances(x, y):
    x_np = x.numpy()
    y_np = y.numpy()
    
    xy_combined = np.concatenate((x, y))
    mean = np.mean(xy_combined, axis=0)    
    
    xy_combined -= mean[None, :]
    avg_norm = np.mean(np.linalg.norm(xy_combined, axis=1))
    
    x_np -= mean[None, :]
    y_np -= mean[None, :]
    
    x_np /= avg_norm
    y_np /= avg_norm
    
    
    M = ot.dist(x_np, y_np, metric='euclidean')
    w1 = ot.emd2([], [], M)
    ssd1 = 0.5 * (np.mean(np.min(M, axis=1)) + np.mean(np.min(M, axis=0)))
    h1 = np.maximum(np.max(np.min(M, axis=1)), np.max(np.min(M, axis=0)))
    return w1, ssd1, h1

def feature_distances(run, step):
    fname = f'features_step_{step:06d}.pkl'
    run.file(fname).download('./data_feature_distance', replace=True)
    feature_info = torch.load(f'./data_feature_distance/{fname}')
            
    w1, ssd1, hd1 = compute_distances(feature_info['features_src_tr'], feature_info['features_trg_tr'])
    
    os.remove(f'./data_feature_distance/{fname}')
    return w1, ssd1, hd1
    

class ResultsTable(object):
    def mean_std_fn(vals):    
        mean = np.mean(vals)
        std = np.std(vals)
        return f'{mean:.2f} ({std:.2f}) [{len(vals)}]'
    def median_iqr_fn(vals):    
        median = np.median(vals)
        iqr = scipy.stats.iqr(vals)
        return f'{median:.2f} ({iqr:.2f}) [{len(vals)}]'
    def quant_fn(vals):    
        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))
        return f'{quants[0]:05.2f} {quants[1]:05.2f} {quants[2]:05.2f} [{len(vals)}]'
    
    def quant_latex_fn(vals):    
        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))
        return f'$ {quants[1]:04.1f}_{{~{quants[0]:04.1f}}}^{{~{quants[2]:04.1f}}} $'
    
    def median_x_fn(vals):    
        median = np.median(vals)        
        return f'{median:.4f} [{len(vals)}]'
    
    def median_x_fn(vals):    
        median = np.median(vals)
        iqr = scipy.stats.iqr(vals)
        return f'{median:.4f} [{len(vals)}]'
    
    def quant_latex_x_fn(vals):    
        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))
        return f'$ {quants[1]:05.2f}_{{~{quants[0]:05.2f}}}^{{~{quants[2]:05.2f}}} $'
    
    def quant_or_fn(vals):    
        quants = np.quantile(vals, np.array([0.25, 0.5, 0.75]))
        return f'$ {quants[1]:0.2f}_{{{quants[0]:0.2f}}}^{{{quants[2]:0.2f}}} $'
    
    def __init__(self, prefix, algorithms, num_steps):
        api = wandb.Api()        
        self.algorithms = algorithms
        self.prefix = prefix
        
        os.makedirs('./data_feature_distance', exist_ok=True)
        
        
        self.summaries = [
            'eval/target_val/accuracy_class_avg',
            'eval/target_val/accuracy_class_min',
            'eval/target_val/ce_class_avg',            
            'alignment_eval_original/ot_sq',
            'alignment_eval_original/supp_dist_sq',            
            'alignment_eval_original/log_loss',
        ]
                
        
        self.results = {
            summary_name: list() for summary_name in self.summaries
        }
        self.results.update({
            dist_name: list() for dist_name in ['W1', 'SSD1', 'H1']
        })
        
        for algorithm_name in algorithms:
            regex = f'{prefix}/{algorithm_name}[^0_]'
            print(regex)
            qruns = list(query_runs(api, regex, step=num_steps))
            print(f'Q runs: {len(qruns)}')
            seeds = set()
            runs = []
            for run in qruns:
                seed = int(run.config['config/training/seed'])
                if seed not in seeds:
                    seeds.add(seed)
                    runs.append(run)
            print(f'Runs: {len(runs)}')
            for summary_name in self.summaries:                    
                self.results[summary_name].append([run.summary.get(summary_name, -1.0) for run in runs])
            
            distances_list = []
            for run in tqdm.notebook.tqdm(runs):
                distances = feature_distances(run, step=num_steps)
                distances_list.append(distances)
            w1_list, ssd1_list, h1_list = zip(*distances_list)
            self.results['W1'].append(w1_list)
            self.results['SSD1'].append(ssd1_list)
            self.results['H1'].append(h1_list)
                
            print()
        
    def print_table(self, summaries, summaries_short, agg_fn, tablefmt='pipe', sep=' '):
        print(f'{self.prefix}\n{" ".join(summaries)}')
        columns = ['algorithm'] + summaries_short
        table = [columns]
        for i, algorithm_name in enumerate(self.algorithms):
            table.append([algorithm_name])
            for j, summary_name in enumerate(summaries):            
                cell = ''
                values = self.results[summary_name][i]
                if len(values) > 0:
                    cell += agg_fn(np.array(values))                
                table[-1].append(cell)
        print(tabulate.tabulate(table, tablefmt=tablefmt, headers="firstrow"))
        

# 3 classes, dim 2, no dropout, no relu

In [5]:
prefix = ''
mid = ''
algorithms = [
    'dann_zero',
    'dann',    
    'support_abs_h0',
    'support_abs_h100',
    'support_abs_h500',
    'support_abs', 
    'support_abs_h2000',
    'support_abs_h5000',
]

# Updated regex pattern: usps_mnist_3c/lenet_2d/seed_[1-5]/s_alpha_15/
results_3c_d2 = ResultsTable('0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2', 
                             algorithms, num_steps=30000)

0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann_zero[^0_]
Q runs: 9
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

  check_result(result_code)




0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/dann[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h0[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h100[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h500[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h2000[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))



0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2/support_abs_h5000[^0_]
Q runs: 5
Runs: 5


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))





In [14]:
columns = [
    ('W1', 'W'),
    ('SSD1', 'SSD'),
]
summaries, summaries_short = map(list, zip(*columns))
results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_fn)

0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2
W1 SSD1
| algorithm         | W                     | SSD                   |
|:------------------|:----------------------|:----------------------|
| dann_zero         | 00.76 00.77 00.84 [5] | 00.10 00.10 00.10 [5] |
| dann              | 00.06 00.07 00.08 [5] | 00.02 00.02 00.02 [5] |
| support_abs_h0    | 00.22 00.23 00.47 [5] | 00.03 00.03 00.03 [5] |
| support_abs_h100  | 00.36 00.55 00.56 [5] | 00.03 00.03 00.03 [5] |
| support_abs_h500  | 00.55 00.58 00.64 [5] | 00.03 00.03 00.03 [5] |
| support_abs       | 00.56 00.59 00.62 [5] | 00.03 00.03 00.03 [5] |
| support_abs_h2000 | 00.58 00.62 00.66 [5] | 00.03 00.03 00.03 [5] |
| support_abs_h5000 | 00.63 00.64 00.67 [5] | 00.04 00.04 00.04 [5] |


In [15]:
columns = [
    ('W1', 'W'),
    ('SSD1', 'SSD'),    
]
summaries, summaries_short = map(list, zip(*columns))
results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_or_fn)

0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2
W1 SSD1
| algorithm         | W                      | SSD                    |
|:------------------|:-----------------------|:-----------------------|
| dann_zero         | $ 0.77_{0.76}^{0.84} $ | $ 0.10_{0.10}^{0.10} $ |
| dann              | $ 0.07_{0.06}^{0.08} $ | $ 0.02_{0.02}^{0.02} $ |
| support_abs_h0    | $ 0.23_{0.22}^{0.47} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h100  | $ 0.55_{0.36}^{0.56} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h500  | $ 0.58_{0.55}^{0.64} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs       | $ 0.59_{0.56}^{0.62} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h2000 | $ 0.62_{0.58}^{0.66} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h5000 | $ 0.64_{0.63}^{0.67} $ | $ 0.04_{0.04}^{0.04} $ |


| algorithm         | W                      | SSD                    |
|:------------------|:-----------------------|:-----------------------|
| dann_zero         | $ 0.77_{0.76}^{0.84} $ | $ 0.10_{0.10}^{0.10} $ |
| dann              | $ 0.07_{0.06}^{0.08} $ | $ 0.02_{0.02}^{0.02} $ |
| support_abs_h0    | $ 0.23_{0.22}^{0.47} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h100  | $ 0.55_{0.36}^{0.56} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h500  | $ 0.58_{0.55}^{0.64} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs       | $ 0.59_{0.56}^{0.62} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h2000 | $ 0.62_{0.58}^{0.66} $ | $ 0.03_{0.03}^{0.03} $ |
| support_abs_h5000 | $ 0.64_{0.63}^{0.67} $ | $ 0.04_{0.04}^{0.04} $ |


In [16]:
columns = [
    ('eval/target_val/accuracy_class_avg', 'acc (avg)'),
    ('eval/target_val/accuracy_class_min', 'acc (min)'),
    ('W1', 'W'),
    ('SSD1', 'SSD'),     
    
]
summaries, summaries_short = map(list, zip(*columns))
results_3c_d2.print_table(summaries, summaries_short, ResultsTable.quant_latex_fn, 
                          tablefmt='latex_raw', sep=' & ')

0_usps_mnist/features_lenetnorelu_v3/seed_[1-5]/c_3/s_alpha_15/d_2
eval/target_val/accuracy_class_avg eval/target_val/accuracy_class_min W1 SSD1
\begin{tabular}{lllll}
\hline
 algorithm         & acc (avg)                & acc (min)                & W                        & SSD                      \\
\hline
 dann_zero         & $ 63.0_{~62.3}^{~69.6} $ & $ 45.3_{~37.9}^{~53.6} $ & $ 00.8_{~00.8}^{~00.8} $ & $ 00.1_{~00.1}^{~00.1} $ \\
 dann              & $ 75.6_{~72.4}^{~83.7} $ & $ 54.8_{~49.6}^{~55.1} $ & $ 00.1_{~00.1}^{~00.1} $ & $ 00.0_{~00.0}^{~00.0} $ \\
 support_abs_h0    & $ 73.9_{~73.4}^{~84.1} $ & $ 61.8_{~54.6}^{~72.4} $ & $ 00.2_{~00.2}^{~00.5} $ & $ 00.0_{~00.0}^{~00.0} $ \\
 support_abs_h100  & $ 88.5_{~86.8}^{~95.1} $ & $ 71.4_{~70.6}^{~93.3} $ & $ 00.5_{~00.4}^{~00.6} $ & $ 00.0_{~00.0}^{~00.0} $ \\
 support_abs_h500  & $ 94.5_{~88.7}^{~94.7} $ & $ 89.0_{~83.1}^{~90.3} $ & $ 00.6_{~00.6}^{~00.6} $ & $ 00.0_{~00.0}^{~00.0} $ \\
 support_abs       & $ 91.1_{~91.1}^{~