In [None]:
import MinkowskiEngine as ME
import matplotlib.pyplot as plt
import matplotlib as mpl
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
import torch
from torch import nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Includes from my libraries for this project                                                                                                                                           
from ME_dataset_libs import get_transform, DoNothing
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, cat_ME_collate_fn
from ME_analysis_libs import argmax_consistency, topk_consistency

In [None]:
from FSD_training_analysis import get_models_from_checkpoint

def calc_accuracy(input_file, nevents):

    print("Working on:", input_file)
    
    encoder, proj_head, clust_head, args = get_models_from_checkpoint(input_file)
    encoder.eval()
    proj_head.eval()
    clust_head.eval()

    encoder.to(device)
    proj_head.to(device)
    clust_head.to(device)

    aug_transform = get_transform('fsd', args.aug_type)
    data_dataset = SingleModuleImage2D_MultiHDF5_ME(args.data_dir, nom_transform=DoNothing(), aug_transform=aug_transform, max_events=nevents)

    batch_size=1024
    train_loader = torch.utils.data.DataLoader(data_dataset,
                                               collate_fn=cat_ME_collate_fn,
                                               batch_size=batch_size,
                                               shuffle=False, 
                                               num_workers=8,
                                               drop_last=True,
                                               pin_memory=False,
                                               prefetch_factor=1)

    ## Metrics
    total_acc = 0
    total_top2 = 0
    nbatches = 0
    
    ## Loop over all of the images
    for cat_bcoords, cat_bfeats, this_batch_size in train_loader:

        nbatches += 1
        cat_bcoords = cat_bcoords.to(device, non_blocking=True)
        cat_bfeats  = cat_bfeats .to(device)
        cat_batch   = ME.SparseTensor(cat_bfeats, cat_bcoords, device=device)

        ## Now do the forward pass     
        with torch.no_grad(): 
            encoded_instance_batch, encoded_cluster_batch = encoder(cat_batch, this_batch_size)
            clust_batch = clust_head(encoded_cluster_batch)
            
            total_acc += argmax_consistency(clust_batch).item()
            total_top2 += topk_consistency(clust_batch, 2)
            
    print("TOTAL ACCURACY top 1:", total_acc/nbatches, "; top 2:", total_top2/nbatches)

In [None]:
from FSD_training_analysis import get_models_from_checkpoint

## Want to extend this to do two things:
## 1 - plot accuracy as a function of max cluster index
## 2 - make a smearing matrix showing how often aug1 is in clust X, but clust 2 is in clust Y, normalized such that the sum of each column = 1
def calc_accuracy_by_cluster(input_file, nevents):

    print("Working on:", input_file)
    
    encoder, proj_head, clust_head, args = get_models_from_checkpoint(input_file)
    encoder.eval()
    proj_head.eval()
    clust_head.eval()

    encoder.to(device)
    proj_head.to(device)
    clust_head.to(device)

    print("Using augs:", args.aug_type)
    aug_transform = get_transform('fsd', args.aug_type)
    data_dataset = SingleModuleImage2D_MultiHDF5_ME(args.data_dir, nom_transform=DoNothing(), aug_transform=aug_transform, max_events=nevents)

    batch_size=1024
    train_loader = torch.utils.data.DataLoader(data_dataset,
                                               collate_fn=cat_ME_collate_fn,
                                               batch_size=batch_size,
                                               shuffle=False, 
                                               num_workers=8,
                                               drop_last=True,
                                               pin_memory=False,
                                               prefetch_factor=1)

    ## Histogram building
    N = args.nclusters
    ntotal = np.zeros(N)
    ncorrect = np.zeros(N)
    smearing = np.zeros((N, N))
    
    ## Loop over all of the images
    for cat_bcoords, cat_bfeats, this_batch_size in train_loader:

        cat_bcoords = cat_bcoords.to(device, non_blocking=True)
        cat_bfeats  = cat_bfeats .to(device)
        cat_batch   = ME.SparseTensor(cat_bfeats, cat_bcoords, device=device)

        ## Now do the forward pass     
        with torch.no_grad(): 
            encoded_instance_batch, encoded_cluster_batch = encoder(cat_batch, this_batch_size)
            clust_batch = clust_head(encoded_cluster_batch)

            ## Split batches
            clust_batch1 = clust_batch[:this_batch_size//2].detach().cpu().numpy()
            clust_batch2 = clust_batch[this_batch_size//2:].detach().cpu().numpy()
            
            ## Find the selected_cluster for each + keep a running total for the normalization
            clust_max1 = np.argmax(clust_batch1, axis=1)
            clust_max2 = np.argmax(clust_batch2, axis=1)

            counts1 = np.bincount(clust_max1, minlength=N)
            counts2 = np.bincount(clust_max2, minlength=N)
            ntotal += counts1 #+ counts2
            
            ## Add to the NxN which will become the covariance
            np.add.at(smearing, (clust_max1, clust_max2), 1)
            #print(smearing.sum())
            #np.add.at(smearing, (clust_max2, clust_max1), 1)
               
            ## Add to accuracy histogram which is being built
            same = (clust_max1 == clust_max2)

            counts_same1 =  np.bincount(clust_max1[same], minlength=N)
            #counts_same2 =  np.bincount(clust_max2[same], minlength=N)
            ncorrect += counts_same1 #+ counts_same2            

    ## Divide accuracy histogram by totals to get the average accuracy per selected cluster
    smearing_norm = smearing #/ ntotal[np.newaxis, :]
    accuracy = np.divide(ncorrect, ntotal) #, out=np.zeros_like(ncorrect, dtype=float), where=ntotal!=0)

    ## Return for plotting
    return smearing_norm, accuracy

In [None]:
file_dir = "/pscratch/sd/c/cwilk"
chk_file = "state_lat128_hid256_clust30_nchan64_5E-6_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_vbigaugbilinfixsmooth_1M_DATA1_FSDCCFIX.pth"
nevents = 100000

smearing_norm, accuracy = calc_accuracy_by_cluster(file_dir+"/"+chk_file, nevents)

plt.imshow(smearing_norm, origin='lower', cmap='viridis', aspect='auto')
plt.colorbar(label='Counts')
plt.xlabel('max_arr2')
plt.ylabel('max_arr1')
plt.title('Smearing Matrix')
plt.show()

plt.bar(np.arange(len(accuracy)), accuracy)
plt.xlabel('Index')
plt.ylabel('Count')
plt.title('1D Array as Bar Plot')
plt.show()

In [None]:
file_list = ["state_lat24_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth",
             "state_lat32_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth",
             "state_lat48_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth",
             "state_lat64_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth",
             "state_lat128_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth",
             "state_lat256_clust25_nchan64_1E-5_1024_PROJ0.5_CLUST0.5two_ent1E-1_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_bigaugbilin_1M_DATA1_FSDCCFIX.pth"]

for f in file_list:
    calc_accuracy(file_dir+"/"+f, 100000)

In [None]:
file_dir = "/pscratch/sd/c/cwilk"
aug="newbig"
for clust_temp in [0.25, 0.5, 0.75]:
    for proj_temp in [0.25, 0.5, 0.75]:
        print("CLUST_TEMP =", clust_temp, "; PROJ_TEMP =", proj_temp)
        chk_file = "state_lat128_clust30_nchan64_5E-5_1024_PROJ"+str(proj_temp)+"logits_CLUST"+str(clust_temp)+"one_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_"+aug+"_5M_DATA1_FSDCC.pth"
        calc_accuracy(file_dir+"/"+chk_file, 200000)

In [None]:
file_dir = "/pscratch/sd/c/cwilk"

for ent in ["_ENT0.01", "_ENT0.1", "_ENT0.5", ""]:
    chk_file = "state_lat128_clust30"+ent+"_nchan64_5E-5_1024_PROJ0.5logits_CLUST0.5one_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_newbig_5M_DATA1_FSDCC.pth"
    calc_accuracy(file_dir+"/"+chk_file, 100000)

In [None]:
file_dir = "/pscratch/sd/c/cwilk"
chk_file = "state_lat128_clust30_MATCH1.0_nchan64_5E-5_1024_PROJ0.5logits_CLUST0.5one_soft1.0_arch12x4_poolmax_flat1_grow1_kern7_sep1_onecycle50_newbig2_2M_DATA1_FSDCC.pth"
calc_accuracy(file_dir+"/"+chk_file, 100000)