In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate
import json
import time

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

damselpath = '/storage/home/adz6/group/project/damselfly'
"""
Date: 6/25/2021
Description: template
"""

def CalculateAccuracy(output, labels):

    output_prob = torch.nn.functional.softmax(output, dim=1)

    most_likely_class = torch.argmax(output_prob, dim=1)

    most_likely_class_matches_label = torch.as_tensor(most_likely_class == labels, dtype=torch.float)

    return torch.mean(most_likely_class_matches_label)

def AddNoiseToBatch(batch, var):
    
    rng = np.random.default_rng()
    
    noise = rng.multivariate_normal([0, 0], np.eye(2) * var / 2, batch.shape[0] * batch.shape[2])
    noise = noise[:, 0] + 1j * noise[:, 1]
    #print(noise.shape)
    #noise = noise.reshape((batch.shape[0], batch.shape[1], batch.shape[2]))
    
    batch[:, 0, :] += torch.tensor(noise.real.reshape(batch.shape[0], batch.shape[2]), dtype=torch.float)
    batch[:, 1, :] += torch.tensor(noise.imag.reshape(batch.shape[0], batch.shape[2]), dtype=torch.float)
    
    return batch
    
def NormBatch(batch):
    
    #print(torch.max(batch[:, 0, :], -1, keepdim=True)[0])
    
    batch[:, 0, :] *= 1 / torch.max(abs(batch[:, 0, :]), -1, keepdim=True)[0]
    batch[:, 1, :] *= 1 / torch.max(abs(batch[:, 1, :]), -1, keepdim=True)[0]
    
    return batch
    

def LoadDataArrays(datafilepath):
    
    file = h5py.File(datafilepath, 'r')
    
    test_data_no_noise = file['test']['data'][:]
    test_label = file['test']['label'][:]

    ninput_ch = test_data_no_noise.shape[1]
    nfeatures = test_data_no_noise.shape[2]
    
    Ntest_signals_with_noise = int(test_data_no_noise.shape[0] * (1 + 1)) # need to fix these

    test_data = np.concatenate(
        (
            test_data_no_noise, 
            np.zeros((Ntest_signals_with_noise - test_data_no_noise.shape[0], ninput_ch, nfeatures),dtype=np.float32)
        ),axis = 0, dtype=np.float32)
    
    test_label = np.int32(np.concatenate(
        (
            test_label, 
            np.zeros(Ntest_signals_with_noise - test_data_no_noise.shape[0])
        ),axis = 0))

    file.close()
    
    return (torch.tensor(test_data, dtype = torch.float), torch.tensor(test_label, dtype = torch.long))

def EvalModel(datafilepath, model, noise_var, device, batchsize):
    


    if device == torch.device("cuda:0"):
        #print('Model moved to GPU')
        model.to(device)
        
    model.eval()
        
    #print('Loading data')
    test_data = LoadDataArrays(datafilepath)
  
    #print(train_data[0].shape, train_data[1].shape, val_data[0].shape, val_data[1].shape,)
    
    test_dataloader = torch.utils.data.DataLoader(
                                                    torch.utils.data.TensorDataset(test_data[0], test_data[1]),
                                                    batchsize,
                                                    shuffle=False, 
                                                    )


    output_labels = []
    test_labels = []
    with torch.no_grad():
        for batch, labels in test_dataloader:

            batch = AddNoiseToBatch(batch, noise_var)
            batch = NormBatch(batch)

            if device == torch.device("cuda:0"):
                batch = batch.to(device)
                labels = labels.to(device)


            output = model(batch)

            output_labels.append(output.cpu().numpy())
            test_labels.append(labels.cpu().numpy())

        
        
    return (np.array(output_labels).flatten(), np.array(test_labels).flatten())


def BatchAccuracy(labels):
    
    shape = np.asarray(labels).shape
    
    batch_acc = np.zeros((shape[0], shape[1], shape[-1]))
    
    for i in range(shape[0]):
        for j in range(shape[1]):
            
            loop_batch_acc = np.zeros(shape[-1])
            
            for k in range(shape[-1]):
                
                predicted_label = torch.argmax(torch.nn.functional.softmax(torch.as_tensor(labels[i][j][0][k]), dim=-1), -1)
                ground_truth = torch.as_tensor(labels[i][j][1][k])
                
                comparison = torch.eq(predicted_label, ground_truth)
                
                loop_batch_acc[k] = torch.sum(comparison) / comparison.shape[0]
                
            batch_acc[i, j] = loop_batch_acc
            
    
    
    return batch_acc


def ROC(labels):
    
    shape = np.asarray(labels).shape
    t = torch.from_numpy(np.linspace(0, 1, 201))
    
    roc_curves = []
    for i in range(shape[0]): # eval temperatures
        for j in range(shape[1]): # train temperatures
            
            tp = torch.zeros(201)
            fp = torch.zeros(201)
            
            ground_truth_positive = 0
            ground_truth_negative = 0
            for k in range(shape[-1]): # batch
                
                batch_out = torch.nn.functional.softmax(torch.as_tensor(labels[i][j][0][k]), dim=-1)
                
                #print(batch_out[0:10, :])
                #input()
                batch_out_threshold = batch_out.reshape([batch_out.shape[0], batch_out.shape[1], 1]) >= t
                #print(batch_out_threshold[0:10, :])
                #input()
                ground_truth = torch.as_tensor(labels[i][j][1][k])
                #print(ground_truth[0:10])
                #input()
                
                ## True positives ##
                tp += batch_out_threshold[torch.where(ground_truth == 1)[0], 1, :].sum(axis=0)
                ground_truth_positive += torch.where(ground_truth == 1)[0].shape[0]
                ## False positives ##
                fp += batch_out_threshold[torch.where(ground_truth == 0)[0], 1, :].sum(axis=0)
                ground_truth_negative += torch.where(ground_truth == 0)[0].shape[0]
                
            mean_tpr = tp / ground_truth_positive
            mean_fpr = fp / ground_truth_negative
            #print(tp, fp, ground_truth_positive, ground_truth_negative)
            roc_curves.append((mean_fpr, mean_tpr))
    return roc_curves
            



In [None]:
os.listdir(os.path.join(PATH, 'results', 'damselfly', 'dl', 'train'))

In [None]:
temps = np.arange(1,15,1)
print(temps)

In [None]:
datafilepath = os.path.join(DATAPATH, 'dl', '211215_dl_classification_84_25_2cm_slice1_sample1x8192.h5')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batchsize = 2000
nch = 2
nslice = 1

input_shape_1d = 8192

conv_list_1d = [
        [
            [nch * nslice, 32, 32],
            [32, 32, 32],
            [8, 8, 8],
            [1, 1, 1], # dilation
            8
        ],
        [
            [32, 64, 64],
            [64, 64, 64],
            [4, 4, 4],
            [1, 1, 1],
            4
        ],
        [
            [64, 128, 128],
            [128, 128, 128],
            [4, 4, 4],
            [1, 1, 1],
            4
        ],
    ]

model_config_1d_cnn = {
'nclass': 2,
'nch': nch * nslice,
'conv': conv_list_1d
}

linear_list_1d = [
        [df.models.GetConv1DOutputSize(model_config_1d_cnn['conv'], model_config_1d_cnn['nch'], input_shape_1d), 1024, 512],
        [1024, 512, 256],
        [0.0, 0.0, 0.0]
    ]

model = df.models.DFCNN(
        model_config_1d_cnn['nclass'], 
        model_config_1d_cnn['nch'], 
        model_config_1d_cnn['conv'], 
        linear_list_1d
    )

labels = []

for i, eval_temp in enumerate(temps):
    noise_var = 60  * 1.38e-23 * 200e6 * eval_temp * 50 / (1 * 8192)
    labels.append([])
    print(f'Temp: {eval_temp}')
    for j, train_temp in enumerate(temps):
        
        if eval_temp == train_temp:
        
            model.load_state_dict(torch.load(os.path.join(PATH, 'results', 'damselfly', 'dl', 'train', f'220106_84_25_2cm_slice1_sample1x8192_{train_temp}K', 'model.pth')))

            labels[i].append(EvalModel(datafilepath, model, noise_var, device, batchsize))
        
        

In [None]:
print(np.asarray(labels).shape)

In [None]:
output = labels[9][0][0]
target = labels[9][0][0]

In [None]:
roc_curves = ROC(labels)

In [None]:
plt.plot(roc_curves[0][0], roc_curves[0][1], '.')


In [None]:
sns.set_theme(context='poster')
fig = plt.figure(figsize=(13, 8))

ax = fig.add_subplot(1,1,1)

for i in [0, 2, 4, 9, 12]:
    ax.plot(roc_curves[i][0], roc_curves[i][1], '-', label=f'{i+1} K')
    
    
#plt.plot(np.linspace(0, 1, 10), np.linspace(0, 1, 10), '--')
plt.legend(title='Noise Temp.', loc=4)
ax.set_xscale('log')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('DNN ROC Curves')

#plt.savefig(os.path.join(PATH, 'plots/damselfly', '220107_dnn_roc_curves'))

In [None]:


batch_acc = BatchAccuracy(labels)

In [None]:
name = os.path.join(PATH, 'results', 'damselfly', 'dl', '211217_84_25_2cm_slice1_sample1x8192_1-16K_temp_sweep_accuracies')



np.save(name, batch_acc)

In [None]:
sns.set_theme(context='talk', style='ticks')

fig = plt.figure(figsize=(13,8))
ax = fig.add_subplot(1,1,1)

img = ax.imshow(batch_acc.mean(axis=-1))
fig.colorbar(img)

In [None]:
len((1,2,3))