In [2]:
from packaging import version

import pandas as pd
from matplotlib import pyplot as plt
# import seaborn as sns
from scipy import stats
import tensorboard as t
import os
import numpy as np
from tensorboard.backend.event_processing import event_accumulator
import glob

import csv
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms


%load_ext autoreload
%autoreload 2
from aggregate_results import parse_results

from IPython.display import Image

def get_seeds_from_result_strings(_result_strings):
	result_seeds = []
	for result_string in results_strings:
		splitseed = result_string.split('seed')[1][:4]
		if splitseed[-1] == '_':
			splitseed = splitseed[:-1]
		result_seeds.append(int(splitseed))
	return result_seeds
		


In [3]:
import copy
def parse_results(glob_string, lambda_, avg_only=False, no_print=False):

    result_strings = glob.glob(glob_string)
    print("FOUND {} RESULTS".format(len(result_strings)))
    
    if not no_print:
        print("USED" , glob_string)
        # print(result_strings)
    print("GOT THIS MANY", len(result_strings))
    
    try:
        result_strings = sorted(result_strings, key=lambda x: int(x[-4]))
    except:
        if not no_print:
            print("No noise ordering.")

    empty_metric_dict = {
        'pred_loss':[],
        'acc' : [],
        'info_loss' : [],
        'total_loss' : []
    }

    agg_list = {'val' : copy.deepcopy(empty_metric_dict), 'test' : copy.deepcopy(empty_metric_dict)}
    
    print(glob_string)
    if 'evaluations' in glob_string or "GENNURD" in glob_string:
        agg_list['unbal_val'] = copy.deepcopy(empty_metric_dict)

    for result_string in result_strings:
        
        result = torch.load(result_string)
        
        try:
            if len(result["weight_model"]) > 0:
                weight_val_loss_dict_list = result['weight_model']
                # print(weight_val_loss_dict_list)
            else:
                weight_val_loss_dict_list = None
        except:
            weight_val_loss_dict_list = None

        pred_result_dict = result['final_results']
        

        for (key, metric_dict) in pred_result_dict.items():
            if key not in ['val', 'test', 'unbal_val']:
                continue
            for metric_key, metric in metric_dict.items():
                if metric_key in empty_metric_dict.keys():
                    agg_list[key][metric_key].append(metric)
            agg_list[key]['total_loss'].append(metric_dict['pred_loss'] + lambda_*metric_dict['info_loss'])
        
        # if not avg_only and not no_print:
        # print(" | ".join(["{:20.4f}".format(pred_result_dict[key]) for key in print_list]))
    # print(agg_list)
    metric_dict_of_dicts = {
        'pred_loss':{},
        'acc' : {},
        'info_loss' : {},
        'total_loss' : {}
    }

    consolidated_list = {'val' : copy.deepcopy(metric_dict_of_dicts), 'test' : copy.deepcopy(metric_dict_of_dicts)}
    if 'evaluations' in glob_string  or "GENNURD" in glob_string:
        consolidated_list['unbal_val'] = copy.deepcopy(metric_dict_of_dicts)
        
    for (key, metric_dict) in agg_list.items():
        for metric_key, val_list in metric_dict.items():
            # print(val_list)
            if len(val_list) < 1:
                continue
            else:
                consolidated_list[key][metric_key]['mean'] = np.mean(val_list)
                consolidated_list[key][metric_key]['sem'] = np.std(val_list)/np.sqrt(len(val_list))
                consolidated_list[key][metric_key]['median'] = np.median(val_list)
                # print(val_list)
                print(" {:5s} {:10s} MIN/MAX         = {:.2f}/{:.2f}".format(key, metric_key, np.min(val_list), np.max(val_list)))
                print(" {:5s} {:10s} MEAN/SEM/MEDIAN = {:.3f}/{:.3f}/{:.3f}".format(
                                                                                        "", 
                                                                                        "",
                                                                                        consolidated_list[key][metric_key]['mean'],
                                                                                        consolidated_list[key][metric_key]['sem'],
                                                                                        consolidated_list[key][metric_key]['median']
                )) 
                # print("")

    return agg_list, consolidated_list, result_strings

In [13]:
LAM=1

##### PAPER RESULTS
glob_string = "./LOGS/results_RWNURD_weight_waterbirds_BS300_seed*_RHOTEST09_BORDER7_allwd001_glr0005_FIXEDSPLIT_FINAL_LAM1_FRAC2_RR1*.pt".format(LAM) # WATERBIRDS RESULTS, ACTUAL BORDER, reported result

agg_list, consolidated_list, results_strings = parse_results(glob_string, lambda_=1)

print(np.round(agg_list['test']['acc'], 2))
print(np.round(agg_list['val']['acc'], 2))
print(np.round(agg_list['val']['info_loss'], 2))
print(np.round(agg_list['val']['total_loss'], 2))


result_seeds = np.array(get_seeds_from_result_strings(results_strings))
print(result_seeds)

FOUND 0 RESULTS
USED ./LOGS/results_GENNURD_weight_joint_BALdownsample_BS1000_seed*_RHOTEST09_BORDER6_noBN_LONGGEN_DEFLR_distWD0_LAM1_FRAC2_RR1_LONG.pt
GOT THIS MANY 0
./LOGS/results_GENNURD_weight_joint_BALdownsample_BS1000_seed*_RHOTEST09_BORDER6_noBN_LONGGEN_DEFLR_distWD0_LAM1_FRAC2_RR1_LONG.pt
[]
[]
[]
[]
[]
(0,) nan
[]
