In [4]:
%cd /home/thiennguyen/research/pseudogroups/split_pgl

/home/thiennguyen/research/pseudogroups/split_pgl


In [5]:
import os
from train import run_epoch
from loss import LossComputer
import numpy as np
import pandas as pd
import torch
from utils import Logger, CSVBatchLogger
from data import dro_dataset
from torch.utils.data.sampler import WeightedRandomSampler
from copy import deepcopy
import wandb
from tqdm import tqdm

device = 'cuda:0'

def run_eval_data_on_model(part1_model_path, part2_data_path, log_dir):
    """
        Run data on part1_model_path and save to log_dir/part2_eval.csv 
    """
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    logger = Logger(os.path.join(log_dir, "log.txt"), 'w')
    logger.flush()
    model = torch.load(part1_model_path)
    model.to(device)
    part2_data = torch.load(data_path)["part2"]
    csv_logger = CSVBatchLogger(os.path.join(log_dir, f"part2_eval.csv"), part2_data.n_groups, mode='w')
    loader_kwargs = {  # setting for args
        "batch_size": 128,
        "num_workers": 4,
        "pin_memory": True,
    }
    part2_loader = dro_dataset.get_loader(part2_data,
                                      train=False,
                                      reweight_groups=None,
                                      **loader_kwargs)

    # then run an epoch on part2 and during that run, generate a csv containing the status of each example
    for batch_idx, batch in enumerate(tqdm(part2_loader)):
        batch = tuple(t.to(device) for t in batch)
        x, y, g, data_idx = batch[0], batch[1], batch[2], batch[3]
        outputs = model(x)
        output_df = pd.DataFrame()
        
        # Calculate stats -- get the prediction and compare with groundtruth -- save to output df
        if batch_idx == 0:
            acc_y_pred = np.argmax(outputs.detach().cpu().numpy(), axis=1)
            acc_y_true = y.cpu().numpy()
            acc_g_true = g.cpu().numpy()
            indices = data_idx.cpu().numpy()

            probs = outputs.detach().cpu().numpy()
        else:  # concatenate
            acc_y_pred = np.concatenate([
                acc_y_pred,
                np.argmax(outputs.detach().cpu().numpy(), axis=1)
            ])
            acc_y_true = np.concatenate([acc_y_true, y.cpu().numpy()])
            acc_g_true = np.concatenate([acc_g_true, g.cpu().numpy()])
            indices = np.concatenate([indices, data_idx.cpu().numpy()])
            probs = np.concatenate([probs, outputs.detach().cpu().numpy()], axis = 0)

        assert probs.shape[0] == indices.shape[0]
        output_df[f"y_pred"] = acc_y_pred
        output_df[f"y_true"] = acc_y_true
        output_df[f"indices"] = indices
        output_df[f"g_true"] = acc_g_true

        for class_ind in range(probs.shape[1]):
            output_df[f"pred_prob_{class_ind}"] = probs[:, class_ind]
    save_dir = "/".join(csv_logger.path.split("/")[:-1])
    output_df.to_csv(
        os.path.join(save_dir, 
                        f"output.csv"))
    print("Saved", os.path.join(save_dir, 
                        f"output.csv"))

def analyze_pgl(csv_path):
    n_groups = 4
    n_classes = 2

    part2_df = pd.read_csv(csv_path)
    group = part2_df['g_true']
    y_true = part2_df['y_true']
    y_pred = part2_df['y_pred']
    group_count = [len(group[group == g]) for g in range(n_groups)]

    pgl = y_true*n_classes + y_pred  # can flip y_pred to get 1-y_pred...
    pgl_count = [len(pgl[pgl == g]) for g in range(n_groups)]

    recall = []
    precision = []
    for g in range(n_groups):
        recall.append(round(sum((pgl == g) & (group == g))/group_count[g],2))
        precision.append(round(sum((pgl == g) & (group == g))/pgl_count[g],2))

    print(f"pgl_count: \t{pgl_count}")
    print(f"group_count: \t{group_count}")
    print(f"recall: \t{recall}")
    print(f"precision: \t{precision}")
    print(f"{pd.crosstab(pgl, group)}")    

In [7]:
# run = wandb.init(project=f"{args.project_name}_{args.dataset}")
# wandb.config.update(args)


model_data_root_dir = "/home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1"

for model_epoch in [10, 20, 30, 100, 300]:
    print(model_epoch)
    part1_model_path = f"{model_data_root_dir}/{model_epoch}_model.pth"
    data_path = f"{model_data_root_dir}/part1and2_data"
    log_dir = f"{model_data_root_dir}/pgl_analysis_{model_epoch}"

    run_eval_data_on_model(part1_model_path, data_path, log_dir)
    
    csv_path = f"{model_data_root_dir}/pgl_analysis/output.csv"
    analyze_pgl(csv_path)

10


100%|███████████████████████████████████████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]

Saved /home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis_10/output.csv
pgl_count: 	[1808, 39, 42, 509]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.35, 0.36, 0.94]
precision: 	[0.97, 0.82, 0.21, 0.97]
g_true     0   1   2    3
row_0                    
0       1749  59   0    0
1          7  32   0    0
2          0   0   9   33
3          0   0  16  493
20



100%|███████████████████████████████████████████████████████████████████| 19/19 [00:03<00:00,  5.87it/s]

Saved /home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis_20/output.csv
pgl_count: 	[1808, 39, 42, 509]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.35, 0.36, 0.94]
precision: 	[0.97, 0.82, 0.21, 0.97]
g_true     0   1   2    3
row_0                    
0       1749  59   0    0
1          7  32   0    0
2          0   0   9   33
3          0   0  16  493
30



100%|███████████████████████████████████████████████████████████████████| 19/19 [00:03<00:00,  5.86it/s]

Saved /home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis_30/output.csv
pgl_count: 	[1808, 39, 42, 509]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.35, 0.36, 0.94]
precision: 	[0.97, 0.82, 0.21, 0.97]
g_true     0   1   2    3
row_0                    
0       1749  59   0    0
1          7  32   0    0
2          0   0   9   33
3          0   0  16  493
100



100%|███████████████████████████████████████████████████████████████████| 19/19 [00:03<00:00,  5.86it/s]

Saved /home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis_100/output.csv
pgl_count: 	[1808, 39, 42, 509]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.35, 0.36, 0.94]
precision: 	[0.97, 0.82, 0.21, 0.97]
g_true     0   1   2    3
row_0                    
0       1749  59   0    0
1          7  32   0    0
2          0   0   9   33
3          0   0  16  493
300



100%|███████████████████████████████████████████████████████████████████| 19/19 [00:03<00:00,  5.79it/s]

Saved /home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis_300/output.csv
pgl_count: 	[1808, 39, 42, 509]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.35, 0.36, 0.94]
precision: 	[0.97, 0.82, 0.21, 0.97]
g_true     0   1   2    3
row_0                    
0       1749  59   0    0
1          7  32   0    0
2          0   0   9   33
3          0   0  16  493





In [50]:
csv_path = "/home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p0.5_wd0.0001_lr0.0001_s1/part1/pgl_analysis/output.csv"



pgl_count: 	[1801, 46, 66, 485]
group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.47, 0.76, 0.91]
precision: 	[0.97, 0.93, 0.29, 0.99]
g_true     0   1   2    3
row_0                    
0       1753  48   0    0
1          3  43   0    0
2          0   0  19   47
3          0   0   6  479


In [37]:
p = 0.5
wd = 1e-4
lr = 1e-4
seed = 1

model_data_root_dir = f"/home/thiennguyen/research/pseudogroups/CUB/splitpgl_sweep_logs/p{p}_wd{wd}_lr{lr}_s{seed}/part1"
model_epoch = 10
part1_model_path = f"{model_data_root_dir}/{model_epoch}_model.pth"
data_path = f"{model_data_root_dir}/part1and2_data"
log_dir = f"{model_data_root_dir}/pgl_analysis"

run_eval_data_on_model(part1_model_path, data_path, log_dir)

g_true,0,1,2,3
row_0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,1753,48,0,0
1,3,43,0,0
2,0,0,19,47
3,0,0,6,479


group_count: 	[1756, 91, 25, 526]
recall: 	[1.0, 0.47, 0.76, 0.91]
precision: 	[0.97, 0.93, 0.29, 0.99]


In [5]:
0.9982915717539863
0.4725274725274725
0.76
0.91

Unnamed: 0.1,Unnamed: 0,y_pred,y_true,indices,g_true,pred_prob_0,pred_prob_1
0,0,0,0,7089,0,1.591131,-1.550143
1,1,0,0,5599,0,2.486492,-2.085317
2,2,0,0,11008,0,2.010336,-1.709602
3,3,1,1,5890,3,-0.482630,0.736747
4,4,0,1,4991,2,0.787146,-0.974922
...,...,...,...,...,...,...,...
2393,2393,0,0,7007,0,2.785468,-2.244950
2394,2394,0,0,6697,0,0.654927,-0.834275
2395,2395,0,0,2101,0,2.683824,-2.309115
2396,2396,0,0,9764,0,2.117131,-1.985194
