In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import math
import torch
import gpytorch
from matplotlib import pyplot as plt
import numpy as np
import random
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import csv
from torch.nn.utils import spectral_norm
import torch.nn as nn
from models.nn_model import nnModel_1, nnModel_0
from pathlib import Path
from models.nn_model import train_deep_kernel_gp, predict_deep_kernel_gp

from sklearn.manifold import TSNE
import networkx as nx

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--alpha', type=float, default=2.5)
args = parser.parse_args()

import json

# Read the configuration from the file
with open('experiments/config_sim_{}.json'.format(args.alpha), 'r') as file:
    config = json.load(file)

from torch.utils.data import Dataset

from models.utils import train_test_splitting


class toDataLoader(Dataset):
    def __init__(self, x_train, y_train, t_train):
        # Generate random data for input features (x) and target variable (y)
        self.x = x_train
        self.t = t_train
        self.y = y_train

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # Return a single sample as a dictionary containing input features and target variable
        inputs = torch.hstack([self.x[idx], self.t[idx]]).float()
        targets = self.y[idx].float()

        return inputs, targets


def trt_ctr(treatment):
    list1, list0 = [], []
    for index, i in enumerate(treatment):
        if i == 1:
            list1.append(index)
        elif i == 0:
            list0.append(index)
        else:
            raise TypeError('Invalid treatment value found')

    return list1, list0


def evaluation_nn(pred_1, pred_0, test_tau, query_step):
    esti_tau = torch.from_numpy(pred_1 - pred_0).float()
    pehe_test = torch.sqrt(torch.mean((esti_tau - test_tau) ** 2))

    print('\n', 'PEHE at query step: {} is {}'.format(query_step, pehe_test), '\n')

    return pehe_test


def pool_updating(idx_remaining, idx_sub_training, querying_idx):
    # Update the training and pool set for the next AL stage
    idx_sub_training = np.concatenate((idx_sub_training, querying_idx), axis=0)  # Update the training pool
    # Update the remaining pool by deleting the selected data
    mask = np.isin(idx_remaining, querying_idx,
                   invert=True)  # Create a mask that selects the elements to delete from array1
    idx_remaining = idx_remaining[mask]  # Update the remaining pool by subtracting the selected samples

    return idx_sub_training, idx_remaining


def one_side_uncertainty(combine_x_train, index, num_of_samples, model):
    model.eval()

    pred = model(combine_x_train[index])
    pred_variance = pred.variance.sqrt()

    uncertainty = pred_variance
    draw_dist = uncertainty.cpu().detach().numpy()
    # quantile_threshold = np.quantile(draw_dist, 1 - percentage)  # taking top 5% of the values

    top_k = num_of_samples
    threshold_top_k = np.partition(draw_dist, -top_k)[
        -top_k]  # Calculate the threshold for the top 5 values instead of top 5%
    print('Uncertainty threshold:', threshold_top_k)

    acquired_idx = []
    for idx, i in enumerate(draw_dist):
        # print(round(i.item(),2), round(uncertainty[idx].item(),2), round(uncertainty[idx].item()/uncertainty_mean.item(),2))
        if draw_dist[idx] >= threshold_top_k:
            acquired_idx.append(idx)

    # print('Top 5 uncertain:', draw_dist[acquired_idx])
    acquired_idx = index[acquired_idx]
    random_idx = np.random.permutation(len(acquired_idx))
    acquired_idx = acquired_idx[random_idx]

    num_elements_to_select = num_of_samples  # Selecting 5 values randomly as the step size

    return acquired_idx[:num_elements_to_select], threshold_top_k


# Function to calculate pairwise Euclidean distance in batches
def pairwise_distances_in_batches(data, batch_size=500):
    n = data.size(0)
    distances = torch.zeros(n, n, device=data.device)

    for i in range(0, n, batch_size):
        for j in range(i, n, batch_size):
            end_i = min(i + batch_size, n)
            end_j = min(j + batch_size, n)
            diff = data[i:end_i].unsqueeze(1) - data[j:end_j].unsqueeze(0)
            dist_batch = torch.sqrt(torch.sum(diff ** 2, dim=-1))
            distances[i:end_i, j:end_j] = dist_batch
            if i != j:
                distances[j:end_j, i:end_i] = dist_batch.T

    return distances


# In[ ]:


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_size = 0.1
num_trial = 1
al_step = 1
warm_up = 50
num_of_samples = 25
seed = args.seed

usage: ipykernel_launcher.py [-h] [--seed SEED] [--alpha ALPHA]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/hl506-8850/.local/share/jupyter/runtime/kernel-4022b6de-abaa-4a87-b7ac-3f62343a72de.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
if num_trial == 1:
    # for seed in range(num_trial):

    print('Trial:', seed)

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    combine_x_train, \
        combine_x_test, \
        combined_y_train, \
        combine_x_valid, \
        combined_y_valid, \
        tau_test, T_train, \
        T_valid, \
        T_test, \
        y_std = train_test_splitting(seed, device=device)

    idx_pool = np.random.permutation(len(combine_x_train))  # Global dataset index
    idx_sub_training = idx_pool[:warm_up]  # Global dataset index
    idx_remaining = idx_pool[warm_up:]  # Global dataset index

    sub_training_1, sub_training_0 = trt_ctr(T_train[idx_sub_training])
    remaining_1, remaining_0 = trt_ctr(T_train[idx_remaining])

    # Initialize the data-limited starting size as 20% of whole treated training set
    idx_sub_training_1 = idx_sub_training[sub_training_1]  # 20% as initial
    idx_remaining_1 = idx_remaining[remaining_1]  # 20% left for querying

    # Initialize the data-limited starting size as 20% of whole control training set
    idx_sub_training_0 = idx_sub_training[sub_training_0]  # 10% as initial
    idx_remaining_0 = idx_remaining[remaining_0]  # 90% left for querying

    acquired_treated, acquired_control = None, None
    error_list = []
    num_of_acquire = [len(idx_sub_training_1) + len(idx_sub_training_0)]

    for query_step in range(al_step):
        train_x_1, train_y_1 = combine_x_train[idx_sub_training_1], combined_y_train[idx_sub_training_1]
        train_x_0, train_y_0 = combine_x_train[idx_sub_training_0], combined_y_train[idx_sub_training_0]
        print("Number of data used for training in treated and control:", len(idx_sub_training_1),
              len(idx_sub_training_0))

        print(combine_x_train.shape)

        # Compute pairwise distances in batches
        batch_size = 500  # Adjust the batch size according to your memory limits
        distances = pairwise_distances_in_batches(combine_x_train, batch_size)

        # Normalize the distances
        max_distance = torch.max(distances)
        normalized_distances = distances / max_distance

        print(normalized_distances.max())

        # Step 1: Create a mask that excludes the diagonal
        mask = torch.eye(normalized_distances.size(0), dtype=torch.bool).to(device)

        # Step 2: Apply the mask to exclude diagonal elements
        masked_distances = normalized_distances.masked_fill(mask, float('inf'))

        # Step 3: Find the minimum value in the masked distance matrix
        min_value = masked_distances.min()

        print("Minimum non-diagonal distance:", min_value.item())

        # Step 1: Flatten the distance matrix
        distances_flat = normalized_distances.flatten()

        # Step 2: Exclude diagonal elements (zeros)
        # Create a mask for non-diagonal elements
        mask = ~torch.eye(normalized_distances.size(0), dtype=torch.bool).flatten().to(normalized_distances.device)
        non_diagonal_distances = distances_flat[mask]

        # Step 3: Convert to numpy for easy plotting
        non_diagonal_distances_np = non_diagonal_distances.cpu().numpy()

        # Step 4: Plot the histogram
        plt.hist(non_diagonal_distances_np, bins=50, edgecolor='black')
        plt.title("Distribution of Non-Diagonal Distances")
        plt.xlabel("Distance")
        plt.ylabel("Frequency")
        #plt.show()

        # Assume 'distances' is your distance matrix of shape [9540, 9540]

        # Step 1: Create an adjacency list or matrix
        n = normalized_distances.size(0)
        adjacency_list = {i: [] for i in range(n)}

        # Step 2: Populate the adjacency list based on the distance threshold
        threshold = 0.1

        # Step 2: Create a mask where distances are less than the threshold and not on the diagonal
        # This mask identifies the valid edges
        valid_edges = (normalized_distances < threshold) & (torch.eye(normalized_distances.size(0), device=device) == 0)

        # Step 3: Get the indices of the valid edges
        rows, cols = torch.where(valid_edges)

        # Step 4: Construct the directed graph using NetworkX
        G = nx.DiGraph()

        # Add nodes (optional, as NetworkX will add them automatically with edges)
        G.add_nodes_from(range(normalized_distances.size(0)))

        # Add edges based on the indices
        edges = list(zip(cols.cpu().numpy(), rows.cpu().numpy()))  # direction from cols to rows
        G.add_edges_from(edges)

        print('Graph construction completed')

        # Compute out-degree for each vertex
        out_degrees = dict(G.out_degree())
        #print(out_degrees)

        # Find the node with the highest out-degree
        max_out_degree_node = max(out_degrees, key=out_degrees.get)
        print(f"Node with the highest out-degree: {max_out_degree_node}")

        # Step 1: Identify the direct neighbors (nodes with an incoming edge from the selected node)
        neighbors = list(G.successors(max_out_degree_node))

        # Step 2: Identify all nodes to remove incoming edges from (include the picked node)
        nodes_to_modify = [max_out_degree_node] + neighbors

        print(len(nodes_to_modify))

        # Step 3: Remove all incoming edges to these nodes
        for node in nodes_to_modify:
            incoming_edges = list(G.in_edges(node))  # Get all incoming edges to the node
            G.remove_edges_from(incoming_edges)  # Remove the incoming edges

        # Optional: Verify by checking in-degrees after removal
        # for node in nodes_to_modify:
        #    print(f"In-degree of node {node} after removal: {G.in_degree(node)}")