In [16]:
import pandas as pd
# baic transformer Decoder model
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
import numpy as np
import xformers.ops as xops
import math 
from typing import Optional, Union
from torch import Tensor
import random

main_df = pd.read_csv('adult.csv')
main_df.head()
DEVICE = 'cuda'
# DEVICE = 'cpu'

In [17]:
def check_DataFrame_distribution(X_trans):
    columns_range = {}
    print('%15s' % '', '%6s' % 'min','%6s' % 'max', '%6s' % 'nunique')
    
    for column in X_trans.columns:
        print('%15s' % column, '%6s' % X_trans[column].min(),'%6s' % X_trans[column].max(), '%6s' % X_trans[column].nunique())
        columns_range[column] = {}

In [18]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
def POOL_preprocess(df, N_BINS = 100):
    '''
    Preprocess the DataFrame 
    Args:
        df: DataFrame
        N_BINS: number of bins for each numerical column (will not be the exact number of bins, differ by distribution)
    Return:
        X_trans: DataFrame after preprocessing
        ct: ColumnTransformer object, for inference and inverse transform
        NUM_vs_CAT: tuple, (number of numerical columns, number of categorical columns - 1) "in feature field, do not include label column"
        existing_values: dict, {column name: sorted list of existing values}
    '''
    
    CAT = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']
    NUM = ['age', 'fnlwgt', 'educational-num', 'capital-gain', 'capital-loss', 'hours-per-week']
    
    num_CAT = len(CAT)
    num_NUM = len(NUM)  
    
    ct = ColumnTransformer([
        ("age", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["age"]),
        ("fnlwgt", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='quantile', subsample=None), ["fnlwgt"]),
        ("educational-num", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='quantile', subsample=None), ["educational-num"]),
        ("capital-gain", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["capital-gain"]),
        ("capital-loss", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["capital-loss"]),
        ("hours-per-week", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["hours-per-week"]),
         ],remainder = 'passthrough', verbose_feature_names_out = False) # make sure columns are unique
    ct.set_output(transform = 'pandas')
    X_trans = ct.fit_transform(df) 
    
    # store the numrical columns' existing values for identifying unseen values
    existing_values = {}
    for column in NUM:
        existing_values[column] = sorted(X_trans[column].unique().astype(int))
    for column in CAT:
        existing_values[column] = sorted(X_trans[column].unique().astype(str))
    
    # apply Ordinal encoding on columns
    from sklearn.preprocessing import OrdinalEncoder
    OE_list = {}
    for column in NUM + CAT:
        OE = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value = -1)
        X_trans[column] = OE.fit_transform(X_trans[[column]])
        OE_list[column] = OE
    
    # make all columns' catagory unique
    # 7/19: each NUM column has its own number of unique values, plus 1 for unseen values
    # each column has it's own number of unique values. '+1' is for unseen values
    offset = 0
    for column in NUM + CAT:
        X_trans[column] = X_trans[column].apply(lambda x: x + offset)
        offset += (X_trans[column].max() - X_trans[column].min() + 1) + 1
    
    X_trans = X_trans.astype(int).reset_index(drop = True)
    return X_trans, (ct, OE_list, NUM, CAT, existing_values), (num_NUM, num_CAT - 1)
    # -1 is for the income column (label)
main_df_SHUFFLE = main_df.sample(frac=1).reset_index(drop=True)
X_trans, inference_package , _  = POOL_preprocess(main_df_SHUFFLE[48842//5:])
X_trans.head()



Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,workclass,education,marital-status,occupation,relationship,race,gender,native-country,income
0,28,164,184,191,214,303,367,385,390,403,413,420,425,441,470
1,38,109,183,191,214,288,364,381,387,396,415,422,425,466,470
2,26,105,185,191,214,303,364,378,389,400,411,422,425,466,471
3,33,136,188,191,214,313,364,382,389,399,411,422,425,466,471
4,35,89,183,191,214,308,364,381,389,401,411,422,425,466,470


In [19]:

def POOL_preprocess_inference(df: pd.DataFrame,
                              inference_package: tuple,
                                # ct: ColumnTransformer,
                                # OE_list: dict,
                                # NUM: list,
                                # CAT: list,
                                # existing_values: dict,
                              ):
    '''Preprocess the DataFrame when inference
    
    Args:
        `df`: DataFrame to be processed.\n
        `inference_package`: tuple, containing the following objects.
            `ct`: ColumnTransformer object required for inference, which makes sure values are in the same range as training data
            `OE_list`: dict, {column name: OrdinalEncoder object}\n
            `NUM`: list of numerical columns \n
            `CAT`: list of categorical columns\n
            `existing_values`: dict, {column name: sorted list of existing values}
    '''
    (ct, OE_list, NUM, CAT, existing_values) = inference_package
    X_trans_ori = ct.transform(df)
    
    # caculate the loaction of unseen values
    unseen_node_indexs = {}
    offset = 0
    for col in NUM + CAT:
        unseen_node_indexs[col] = (int(len(existing_values[col])) + offset )
        offset += int(len(existing_values[col])) + 1
    
    X_trans = X_trans_ori
    
    # apply Ordinal encoding on columns, and make all columns' catagory unique
    offset = 0
    for column in NUM + CAT:
        OE = OE_list[column]
        X_trans[column] = OE.transform(X_trans[[column]]) # use fitted OE to transform, the unseen values will be encoded as -1
        if -1 in X_trans[column].tolist():
            print('[preprocess]: detected unseen values in column', column)
        X_trans[column] = X_trans[column].apply(lambda x: x + offset if x != -1 else unseen_node_indexs[column])
        offset = unseen_node_indexs[column] + 1  

    
    X_trans = X_trans.astype(int).reset_index(drop = True) 
    return X_trans, unseen_node_indexs 
X_trans_ , unseen_node_indexs= POOL_preprocess_inference(main_df_SHUFFLE[:48842//5], inference_package)
X_trans_.head()

[preprocess]: detected unseen values in column capital-gain
[preprocess]: detected unseen values in column hours-per-week


Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,workclass,education,marital-status,occupation,relationship,race,gender,native-country,income
0,17,98,183,191,214,307,364,381,389,398,411,422,425,466,470
1,42,115,183,191,214,303,364,381,387,403,412,422,424,466,470
2,13,155,183,191,214,303,367,381,391,403,412,422,425,466,470
3,30,128,183,194,214,323,366,381,389,400,411,422,425,466,471
4,3,79,184,191,214,283,364,385,391,403,414,422,425,466,470


In [20]:
check_DataFrame_distribution(X_trans_)

                   min    max nunique
            age      0     73     70
         fnlwgt     75    174    100
educational-num    176    189     14
   capital-gain    191    213     20
   capital-loss    214    261     39
 hours-per-week    264    359     84
      workclass    360    368      9
      education    370    385     16
 marital-status    387    393      7
     occupation    395    409     15
   relationship    411    416      6
           race    418    422      5
         gender    424    425      2
 native-country    427    468     41
         income    470    471      2


In [21]:
check_DataFrame_distribution(X_trans)
'[74, 175, 190, 214, 264, 360, 370, 387, 395, 411, 418, 424, 427, 470, 473]'

                   min    max nunique
            age      0     73     74
         fnlwgt     75    174    100
educational-num    176    189     14
   capital-gain    191    212     22
   capital-loss    214    262     49
 hours-per-week    264    358     95
      workclass    360    368      9
      education    370    385     16
 marital-status    387    393      7
     occupation    395    409     15
   relationship    411    416      6
           race    418    422      5
         gender    424    425      2
 native-country    427    468     42
         income    470    471      2


'[74, 175, 190, 214, 264, 360, 370, 387, 395, 411, 418, 424, 427, 470, 473]'

In [22]:
# train_size = 4*48842//5
# test_size = 48842//5
# train_pool = main_df[test_size:]
# test_pool = main_df[:test_size]
# print('total data num:' , main_df.shape[0])
# print('trian data num:' , train_pool.shape[0])
# print('test data num:' , test_pool.shape[0])

In [23]:
'''Notations
  node: number of all nodes = L + S + C + F
  L: number of lable nodes + 1 (for unseen lable)
  S: number of sample nodes + 1 (for inference)
  C: number of catagory nodes + F (for each field(column)
  F: number of field(column) nodes (no unseen field is allowed)
  hidden: number of hidden representation

data size = 
mask size =
use nn.transformerDecoder(data,mask) to get the output
use the above output as input of MLP to predict the lable   
'''

'Notations\n  node: number of all nodes = L + S + C + F\n  L: number of lable nodes + 1 (for unseen lable)\n  S: number of sample nodes + 1 (for inference)\n  C: number of catagory nodes + F (for each field(column)\n  F: number of field(column) nodes (no unseen field is allowed)\n  hidden: number of hidden representation\n\ndata size = \nmask size =\nuse nn.transformerDecoder(data,mask) to get the output\nuse the above output as input of MLP to predict the lable   \n'

In [24]:
class HGNN_():
    def __init__(self,
                 data_df : pd.DataFrame,
                 split_ratio : float ,
                 label_column : str,
                 ):
        # shuffle and cut data
        data_df = data_df.sample(frac=1).reset_index(drop=True)
        test_size = math.ceil(data_df.shape[0] * (1-split_ratio))
        train_pool = data_df[test_size:]
        test_pool = data_df[:test_size]
        print('total data num:' , data_df.shape[0])
        print('trian data num:' , train_pool.shape[0])
        print('test data num:' , test_pool.shape[0])
        
        # to-dos:
        # train
        #   
        N_BINS = 100
        TARGET_POOL, self.inference_package, self.NUM_vs_CAT = POOL_preprocess(train_pool, N_BINS = N_BINS)
        TEST_POOL, self.unseen_node_indexs_C = POOL_preprocess_inference(test_pool, self.inference_package)
        LABEL_COLUMN = label_column

        # cut feature and lable
        FEATURE_POOL = TARGET_POOL.drop(LABEL_COLUMN, axis=1)
        LABEL_POOL = TARGET_POOL[LABEL_COLUMN]
        TEST_LABEL_POOL = TEST_POOL[LABEL_COLUMN]
        
        from sklearn.preprocessing import OneHotEncoder
        enc = OneHotEncoder()
        LABEL_POOL = enc.fit_transform(LABEL_POOL.values.reshape(-1,1)).toarray()
        TEST_LABEL_POOL = enc.fit_transform(TEST_LABEL_POOL.values.reshape(-1,1)).toarray()

        # L: number of lable nodes, the last node of Lable nodes is served as unknown lable node
        L = LABEL_POOL.shape[1] + 1

        # S: number of sample nodes, the last node of sample nodes is served as infering node
        S = FEATURE_POOL.shape[0] + 1
        
        # F: number of field (column) nodes
        F = FEATURE_POOL.shape[1]

        # C: number of catagory nodes, each field(column) has its own "unseen" catagory nodes
        self.nodes_of_fields = []
        for column in FEATURE_POOL.columns:
            self.nodes_of_fields.append(FEATURE_POOL[column].nunique()+1)
        C = sum(self.nodes_of_fields) # the total number of nodes equals to the sum of nodes of each field
        C_POOL = range(int(C))

        nodes_num = {'L':L, 'S':S, 'C':C, 'F':F}
        print('node_nums', nodes_num)
        print('total', L+S+C+F, 'nodes')
        
        # get samples indexs for each label
        self.labe_to_index = {}
        tmp_pool = TARGET_POOL.copy().reset_index(drop=True)
        for label in tmp_pool['income'].unique():
            self.labe_to_index[label] = (tmp_pool[tmp_pool['income'] == label].index).tolist()
        
        self.TARGET_POOL = TARGET_POOL
        self.TEST_POOL = TEST_POOL
        self.TEST_LABEL_POOL = TEST_LABEL_POOL
        self.LABEL_COLUMN = LABEL_COLUMN
        self.FEATURE_POOL = FEATURE_POOL
        self.LABEL_POOL = LABEL_POOL
        self.C_POOL = C_POOL   
        self.nodes_num = nodes_num
        self.N_BINS = N_BINS

        
        self.make_input_tensor()
        # self.get_sample(10)        
        self.make_mask_all()
        
        # self.make_mask()
        
        
    def make_mask_subgraph(self,
                  sample_indices: Optional[list] = None,
                ):
        '''Makeing masks for subgraph. Mask values are 1 if two nodes are connected, otherwise 0.
        
        Args:
            sample_indices: list of sample node indices
        
        for example, with:
            {'L': 3, 'S': 39074, 'C': 470, 'F': 14, 'K': 10}
            
        the masks will be:
            masks['L2S'] = torch.Size([16, 8]), values in torch.Size([10, 3])\\
            masks['S2C'] = torch.Size([472, 16]), values in torch.Size([470, 10])\\
            masls['C2F'] = torch.Size([16, 472]), values in torch.Size([14, 470])\\
        Notice: xformer require the mask's tensor must align on memory, and should be slice of a tensor if shape cannot be divided by 8
        '''
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']

        sample_size = len(sample_indices)
        masked_POOL = self.TARGET_POOL.iloc[sample_indices] # sample dataframe into shape (10,14)
        # caculate masking
        masks = {}

        # label to sample
        tmp = torch.zeros([math.ceil(sample_size/8) * 8, math.ceil(L/8) * 8], dtype=torch.float, device=DEVICE) 
        label_value = masked_POOL[self.LABEL_COLUMN].values
        tmp[torch.arange(sample_size), torch.tensor(label_value - min(label_value))] = 1
        masks['L2S'] = tmp

        # sample to catagory
        tmp = torch.zeros([math.ceil(C/8) * 8, math.ceil(sample_size/8) * 8], dtype=torch.float, device=DEVICE).T
        tmp_df = masked_POOL.drop(self.LABEL_COLUMN, axis=1)
        tmp[torch.arange(sample_size).unsqueeze(-1), torch.tensor(tmp_df.values)] = 1
        tmp = tmp.T.contiguous()
        
        masks['S2C'] = Tensor.contiguous(tmp)

        # catagory to field
        masks['C2F'] = self.MASKS_FULL['C2F']
        
        self.MASKS = masks
        self.nodes_num['K'] = sample_size
        
    def make_mask_all(self):
        '''Makeing masks for the entire graph. Mask values are 1 if two nodes are connected, otherwise 0.

        for example, with:
            {'L': 3, 'S': 39074, 'C': 470, 'F': 14, 'K': 10}.
            
        the masks will be:
            masks['L2S']: torch.Size([39080, 8]), values in torch.Size([39074, 3]).\\
            masks['S2C']: torch.Size([472, 39080]), values in torch.Size([470, 39074]).\\
            masls['C2F']: torch.Size([16, 472]), values in torch.Size([14, 470]).\\
            
        Notice: xformer require the mask's tensor must align on memory, and should be slice of a tensor if shape cannot be divided by 8
        '''
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        # caculate masking
        masks = {}
        
        # label to sample 
        tmp = torch.zeros([math.ceil(S/8) * 8, math.ceil(L/8) * 8], dtype=torch.float, device=DEVICE)
        label_ids = self.TARGET_POOL[self.LABEL_COLUMN].unique()
        for i, value_df in enumerate(self.TARGET_POOL[self.LABEL_COLUMN]):
            for j, value_label in enumerate(label_ids):
                if value_label == value_df:
                    tmp[i][j] = 1
                    break
        masks['L2S'] = tmp

        # sample to catagory
        tmp = torch.zeros([math.ceil(C/8) * 8, math.ceil(S/8) * 8], dtype=torch.float, device=DEVICE).T
        tmp_df = self.TARGET_POOL.drop(self.LABEL_COLUMN, axis=1)
        tmp[torch.arange(len(self.TARGET_POOL)).unsqueeze(-1), torch.tensor(tmp_df.values)] = 1
        tmp = tmp.T.contiguous()
        masks['S2C'] = tmp

        # catagory to field
        # to do : this is wrong , should connect all catagory nodes (even unseen nodes))
        tmp = torch.zeros([math.ceil(F/8) * 8, math.ceil(C/8) * 8], dtype=torch.float, device=DEVICE)
        unique_items = [sorted(self.FEATURE_POOL[column].unique()) for column in (self.FEATURE_POOL.columns)]
        for i in range(F):
            for j in (unique_items[i]):
                tmp[i][j] = 1
        masks['C2F'] = tmp
        
        self.MASKS = masks
        self.MASKS_FULL = masks
        
    def make_mask_test(self, index_in_test_pool ):
        '''Make mask tensor for the testing scenario. \n
        In testing scenario, L, S, C, F remain the same, while all INPUTs are the same (sience they are initialized fixed vlaues\n
        All we need to do is to update masks(L2S, S2C) for the new inference node
        '''
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        
        masks = {}
        # L2S shape: torch.Size([39080, 8]), values in torch.Size([39074, 3]).
        # number of sample nodes : 39073 + 1 (inference node)
        # S = 39074, -1 to convert to index of last node
        tmp = self.MASKS_FULL['L2S'].clone().detach()
        tmp[S-1, L-1] = 1 # connect inference node to unseen lable nodes
        masks['L2S'] = tmp
        
        # S2C shape: torch.Size([472, 39080]), values in torch.Size([470, 39074]).
        # self.MASKS_FULL['S2C'].T :[39080, 472], values in [39074, 470]
        # self.TEST_POOL.drop(self.LABEL_COLUMN, axis=1).values[index_in_test_pool]
        tmp = self.MASKS_FULL['S2C'].T.clone().detach()
        tmp[S-1, self.TEST_POOL.drop(self.LABEL_COLUMN, axis=1).values[index_in_test_pool]] = 1 
        masks['S2C'] = tmp.T.contiguous()
        
        # C2F remains the same
        masks['C2F'] = self.MASKS_FULL['C2F']
        
        self.MASKS = masks
        # print('masks[\'L2S\']',masks['L2S'].shape)
        # print('masks[\'S2C\']',masks['S2C'].shape)
        # print('masks[\'C2F\']',masks['C2F'].shape)
        
        
    def make_input_tensor(self):
        '''Makeing input tensor for the entire graph.
            
        for example, with:
            {'L': 3, 'S': 39074, 'C': 470, 'F': 14, 'K': 10}.
                
        the input tensor will be:
            L_input: torch.Size([3, 1]).
            S_input: torch.Size([39074, 128]).
            C_input: torch.Size([470, 1]).
            F_input: torch.Size([14, 1]).
        '''
        # make input tensor
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        # L
        L_input = torch.tensor([range(L)], device=DEVICE).reshape(-1,1)
        print('L_input', L_input.type(), L_input.shape)
        
        # S (normalized by standard scaler)
        # features = torch.tensor(self.FEATURE_POOL.values, device=DEVICE).float()
        # normalized_features = (features - torch.mean(features, dim = 0)) / torch.std(features, dim = 0)
        # S_input = torch.cat([normalized_features, torch.tensor([[0]*F], device=DEVICE)],dim = 0).float() # add infering node
        
        # S (initialize by random)
        S_input = torch.rand(128, device=DEVICE).repeat(S,1)
        
        print('S_input', S_input.type(), S_input.shape)
        # C 
        C_input = torch.tensor([self.C_POOL], device=DEVICE).reshape(-1,1)
        print('C_input', C_input.type(), C_input.shape)
        # F 
        F_input = torch.tensor([range(F)], device=DEVICE).reshape(-1,1)
        print('F_input', F_input.type(), F_input.shape)
        # 
        self.INPUTS = (L_input, S_input, C_input, F_input)
        self.INPUT_DIMS = (L_input.size(1), S_input.size(1), C_input.size(1), F_input.size(1))

    def sample_with_distrubution(self, sample_size):
        '''
        Sample equally from each label with required sample size\\
        forced to make balenced sample
        '''
        # decide each label's number of samples (fourced to be balenced if possible) 
        label_list = []
        label_unique = list(self.labe_to_index.keys())
        count = sample_size // len(label_unique)
        remainder = sample_size % len(label_unique)
        label_list = [item for item in label_unique for _ in range(count)]
        label_list.extend(random.sample(label_unique, remainder))
        # sample from indexes
        indices = [random.choice(self.labe_to_index[label]) for label in label_list]
        return indices     
        
    def get_sample(self, sample_size, inculde = []):
        '''get sample nodes indices, and update mask and input tensor
        
        Args:
            sample_size: number of sample nodes required.
            inculde (optional): list of nodes indices that must be included in the nodes indices.
        
        The inculded nodes shold not and will not be repeated, in case of the lable leakage.
        '''
        # inculde specific nodes (e.g. query nodes), while remaining sample_size
        sample_indices = self.sample_with_distrubution(sample_size - len(inculde))
        if inculde is not []:
            while inculde in sample_indices:
                sample_indices = self.sample_with_distrubution(sample_size - len(inculde))
            # add inculde nodes into sample_indices
            for node in inculde:
                sample_indices.append(node)
            sample_indices = sorted(sample_indices)
        # update mask
        sample_indices = sorted(sample_indices)
        
        # modify input tensor
        L_input, S_input, C_input, F_input = self.INPUTS
        S_input_masked = torch.index_select(S_input, 0, torch.tensor(sample_indices, device=DEVICE))
        self.MASKED_INPUTS = (L_input, S_input_masked, C_input, F_input) 
          
        return sample_indices
            
Train_data = HGNN_( main_df, 0.8, 'income')
Train_data.get_sample(10)


total data num: 48842
trian data num: 39073
test data num: 9769
[preprocess]: detected unseen values in column native-country




node_nums {'L': 3, 'S': 39074, 'C': 471, 'F': 14}
total 39562 nodes
L_input torch.cuda.LongTensor torch.Size([3, 1])
S_input torch.cuda.FloatTensor torch.Size([39074, 128])
C_input torch.cuda.LongTensor torch.Size([471, 1])
F_input torch.cuda.LongTensor torch.Size([14, 1])


[3051, 3693, 8285, 12228, 12430, 13444, 15181, 16891, 17035, 31417]

In [25]:
Train_data.INPUTS[0]

tensor([[0],
        [1],
        [2]], device='cuda:0')

In [26]:
Train_data.INPUTS[1]

tensor([[0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720],
        [0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720],
        [0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720],
        ...,
        [0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720],
        [0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720],
        [0.8498, 0.7133, 0.9065,  ..., 0.6976, 0.3817, 0.4720]],
       device='cuda:0')

In [27]:
# Train_data.get_sample(100, inculde=[10])


In [28]:
from torch import Tensor
from typing import Optional, Any, Union, Callable

class TabHyperformer_Layer(nn.TransformerDecoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        # remove defined modules
        delattr(self, 'self_attn')
        delattr(self, 'norm1')
        delattr(self, 'dropout1')
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        x = tgt
        # x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
        x = self.norm2(x + self._mha_block(x, memory, memory_mask))
        # x =  x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
        # x = self.norm3(x + self._ff_block(x))

        return x
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor],) -> Tensor:
        x = xops.memory_efficient_attention(x, mem, mem, attn_mask)
        # return self.dropout2(x)
        return (x)


In [29]:
# baic transformer decoder model
import torch
import torch.nn as nn
import torch.nn.functional as Fun
from tqdm import trange

class TransformerDecoderModel(nn.Module):
    def __init__(self, 
                 target_ : HGNN_, 
                 num_layers, 
                 embedding_dim, 
                 ):
        super(TransformerDecoderModel, self).__init__()

        L_dim, S_dim, C_dim, F_dim = target_.INPUT_DIMS
        L, S, C, F = target_.nodes_num['L'], target_.nodes_num['S'], target_.nodes_num['C'], target_.nodes_num['F']
        num_NUM , num_CAT = target_.NUM_vs_CAT
        
        # check input dims
        # if num_CAT + num_NUM != S_dim:
        #     raise ValueError('num_CAT + num_NUM != number of columns (S_dim)   {} + {} != {}'.format(num_CAT, num_NUM, S_dim))
        
        # nn.Embedding( number of possible catagories, embedding_dim, )
        # nn.Linear( number of input dimantion, embedding_dim, )

        self.Lable_embedding = nn.Embedding(L, embedding_dim, dtype=torch.float)
    
        # self.Catagory_embedding_num = nn.Linear(C_dim, embedding_dim, dtype=torch.float)
        # for every numrical filed, construct it's own Linear embedding layer
        self.Catagory_embedding_nums = []
        for i in range(num_NUM):
            self.Catagory_embedding_nums.append(
                nn.Linear(C_dim, embedding_dim, dtype=torch.float, device=DEVICE)
            )
        catagories = target_.nodes_of_fields[-num_CAT:] # number of all possible catagories nodes
        self.Catagory_embedding_cat = nn.Embedding(sum(catagories), embedding_dim, dtype=torch.float)
        
        self.Field_embedding = nn.Embedding(F, embedding_dim, dtype=torch.float)
        
        self.transformer_decoder = nn.TransformerDecoder(
            TabHyperformer_Layer(embedding_dim,  nhead = 2 ),
            num_layers
        )
        
        # downstream task
        self.MLP = nn.Sequential(
            nn.Linear(embedding_dim, 2),
        )
        
        # initialize MASK_FULL
        target_.make_mask_all()
        target_.make_input_tensor()
        
        self.tmpmask_L2S = target_.MASKS['L2S'].clone()

    def maskout_lable(self,
                      target_: HGNN_,
                      query_indices: list, # must be sorted
                      sample_indices: Optional[list] = None, 
                      ):
        if sample_indices is not None:
            for query in query_indices:
                # modify the mask to mask out the queries node's edge to it's label node
                L = target_.nodes_num['L']
                query_index = sample_indices.index(query) # query_index: index of query node in sample_indices
                self.tmpmask_L2S = target_.MASKS['L2S'].clone().detach()
                self.tmpmask_L2S[query_index] = 0
                self.tmpmask_L2S[query_index][L-1] = 1 # make it as unseen label
        else:
            for query in query_indices:
                L = target_.nodes_num['L']
                self.tmpmask_L2S = target_.MASKS['L2S'].clone().detach()
                self.tmpmask_L2S[query] = 0
                self.tmpmask_L2S[query][L-1] = 1 # make it as unseen label
    def forward(self, 
                target_: HGNN_, 
                mode : str = 'train',
                query_indices: list = None,  # must be sorted
                K : Optional[int] = 10,
                ):
        L, S, C, F = target_.nodes_num['L'], target_.nodes_num['S'], target_.nodes_num['C'], target_.nodes_num['F']
        num_NUM, num_CAT = target_.NUM_vs_CAT
        
        # decide scenario
        if mode == 'train':
            # generate subgraph with K nodes, including query_indices
            # update mask and input tensor
            self.sample_indices = Train_data.get_sample(K, inculde = query_indices) # update mask
            Train_data.make_mask_subgraph(self.sample_indices)
            # get updated masked input tensor and mask 
            L_input, S_input, C_input, F_input = target_.MASKED_INPUTS
            masks = target_.MASKS
            # mask out the queries node's edge to it's label node, prevent label leakage
            self.maskout_lable(target_, query_indices, self.sample_indices)
            
            # the query node's indexs in sample_indices
            query_indexs = [self.sample_indices.index(query) for query in query_indices]
            S_ = K # the S used in transformer decoder
            
        elif mode == 'inferring':
            # use all nodes in the graph 
            # get input tensor (no need to update)
            L_input, S_input, C_input, F_input = target_.INPUTS
            # updata mask for inference node
            target_.make_mask_test(query_indices[0]) # query node equal to inference node
            self.maskout_lable(target_, query_indices)
            
            masks = target_.MASKS
            
            # the query node's indexs in sample_indices
            query_indexs = [S-1]
            S_ = S # the S used in transformer decoder
        else:
            raise NotImplementedError

        # for S and C, we use two different embedding methods, for CAT and NUM, respectively
        # Squeeze for making batch dimantion
        L_embedded = self.Lable_embedding(L_input.long()).squeeze(1).unsqueeze(0).float()
        
        S_embedded = S_input.unsqueeze(0).float()

        # for every numrical filed, use it's own Linear embedding layer
        C_embedded_nums = []
        field = target_.nodes_of_fields
        start = 0
        for index, nodes in enumerate(field[:num_NUM]): # pick numrical fields
            end = start + nodes
            C_embedded_nums.append(self.Catagory_embedding_nums[index](C_input[start:end].float()).unsqueeze(0))
            start = end
        C_embedded_num = torch.cat(C_embedded_nums, dim = 1)
        
        catagorical_filed_nodes = sum(field[-num_CAT:]) # pick catagory fields
        C_embedded_cat = self.Catagory_embedding_cat(C_input[-catagorical_filed_nodes:].squeeze(1).long() - sum(field[:num_NUM])).unsqueeze(0).float() # - sum(field[:num_NUM] because the embedding index should start from 0
        C_embedded = torch.cat([C_embedded_num, C_embedded_cat], dim = 1)
        
        F_embedded = self.Field_embedding(F_input.long()).squeeze(1).unsqueeze(0).float()
        
        # print(L_embedded.shape, S_embedded.shape, C_embedded.shape, F_embedded.shape)
        
        
        # propagate steps: L→S→C→F
        #                  L←S←C←
        # more steps more menory usage
        PROPAGATE_STEPS = 1
        for i in range(PROPAGATE_STEPS):
            S_embedded = self.transformer_decoder(S_embedded,L_embedded, 
                                                memory_mask = self.tmpmask_L2S[:S_,:L]) 
            C_embedded = self.transformer_decoder(C_embedded,S_embedded,
                                                memory_mask = masks['S2C'][:C,:S_])
            F_embedded = self.transformer_decoder(F_embedded,C_embedded,
                                                memory_mask = masks['C2F'][:F,:C])
            C_embedded = self.transformer_decoder(C_embedded,F_embedded,
                                                memory_mask = Tensor.contiguous(masks['C2F'].transpose(0, 1))[:C,:F])
            S_embedded = self.transformer_decoder(S_embedded,C_embedded,
                                                memory_mask = Tensor.contiguous(masks['S2C'].transpose(0, 1))[:S_,:C])
            L_embedded = self.transformer_decoder(L_embedded,S_embedded, 
                                                memory_mask = Tensor.contiguous(self.tmpmask_L2S.transpose(0, 1))[:L,:S_])
        
        # print('after',S_embedded[0][0])
        output = self.MLP(S_embedded)
        # print(query_indexs)
        # print(output.shape)
        # print(output[:,query_indexs].shape)
        # print(output[:,query_indexs][0].shape)
        
        return output[:,query_indexs][0]
  

# 測試模型
num_layers = 1  # TransformerDecoder 的層數
embedding_dim = 128  # 嵌入維度
hidden_dim = 64  

print('input_dims', Train_data.INPUT_DIMS)
model = TransformerDecoderModel(Train_data, num_layers, embedding_dim).to(DEVICE)


outputs = model(Train_data, mode = 'train', query_indices = [2000,9999], K = 50)
# outputs = model(Train_data, mode = 'inferring', query_indices = [10], K = 50)
print("模型輸出的大小[q,2]:", outputs.shape)
print(outputs)
print(outputs.softmax(dim=1))
output_labels = torch.argmax(outputs.softmax(dim=1), dim=1)
output_labels



input_dims (1, 128, 1, 1)


L_input torch.cuda.LongTensor torch.Size([3, 1])
S_input torch.cuda.FloatTensor torch.Size([39074, 128])
C_input torch.cuda.LongTensor torch.Size([471, 1])
F_input torch.cuda.LongTensor torch.Size([14, 1])
模型輸出的大小[q,2]: torch.Size([2, 2])
tensor([[ 0.0463, -0.5302],
        [ 0.0637, -0.5307]], device='cuda:0', grad_fn=<SelectBackward0>)
tensor([[0.6402, 0.3598],
        [0.6444, 0.3556]], device='cuda:0', grad_fn=<SoftmaxBackward0>)


tensor([0, 0], device='cuda:0')

In [30]:
# training
from torch import autograd
from torcheval.metrics.aggregation.auc import AUC
from torcheval.metrics import BinaryAUROC
from sklearn.metrics import roc_auc_score
tmp_log = []
tmp__log = []
def train(model, datset):
    LABEL_POOL = datset.LABEL_POOL
    TEST_LABEL_POOL = datset.TEST_LABEL_POOL
    weight = torch.from_numpy(np.array([0.2, 1])).float().to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    epochs = 200
    for epoch in range(epochs):
        
        # train
        model.train()
        # logs
        loss_log = []
        AUC_metric = BinaryAUROC().to(DEVICE)
        AUC_metric_test = BinaryAUROC().to(DEVICE)
        
        iter = 0
        for index in trange(len(datset.FEATURE_POOL)): # query through all sample nodes (not infering node)
            optimizer.zero_grad()
            outputs = model(datset, mode = 'train', query_indices = [index], K = 100)
            # output shape:[q,2], example: torch.Size( 2, 2]
            # tensor([[-0.6845, -0.6323],
            #          [-0.7770, -0.4703]], device='cuda:0', grad_fn=<IndexBackward0>)
                
            # for trainning, only the query node's output is used
            # caculate loss
            LABEL_POOL_ = LABEL_POOL[[index]] # shape:[q,2] ,example [[1. 0.], [1. 0.]]
            
                        
            # print(outputs.squeeze(0),outputs.squeeze(0).shape)
            # print(torch.tensor(LABEL_POOL_,device=DEVICE).squeeze(0),torch.tensor(LABEL_POOL_,device=DEVICE).squeeze(0).shape)
            # caculate loss
            batch_loss = criterion(outputs, torch.tensor(LABEL_POOL_,device=DEVICE))
            loss_log.append(batch_loss.item())
            # break
            # backpropagation
            batch_loss.backward()
            optimizer.step()

            # TRUE = (torch.argmax(torch.(LABEL_POOL_,device=DEVICE),dim=2))
            TRUE = np.argmax(LABEL_POOL_,axis=1)
            
            outputs = outputs.softmax(dim=1)

            pred_prob_of_is_1 = [probs[1] for probs in outputs] 
            # the probability of the query node is 1 (from model output)
            
            # tmp_log.append(float(pred_prob_of_is_1))
            # tmp__log.append((TRUE))
            AUC_metric.update(torch.Tensor(pred_prob_of_is_1),torch.Tensor(TRUE))
            torch.cuda.empty_cache()
            # break
            iter += 1
            # if iter >= 100:
            #     break
        # break
        # evaluate
        model.eval()
        iter = 0
        with torch.no_grad():
            for index in trange(len(datset.TEST_POOL)):
                outputs = model(datset, mode = 'inferring', query_indices = [index], K = None)
                LABEL_POOL_ = TEST_LABEL_POOL[[index]]
                # batch_loss = criterion(outputs, torch.tensor(LABEL_POOL_,device=DEVICE))

                TRUE = np.argmax(LABEL_POOL_,axis=1)
                outputs = outputs.softmax(dim=1)
                pred_prob_of_is_1 = [probs[1] for probs in outputs] 
                AUC_metric_test.update(torch.Tensor(pred_prob_of_is_1),torch.Tensor(TRUE))
                torch.cuda.empty_cache()
                iter += 1
                # if iter >= 100:
                #     break
        
        
        # print('1 rate pre:',sum(tmp_log)/len(tmp_log),len(tmp_log))
        # print('1 rate tru:',float(sum(tmp__log)/len(tmp__log)),len(tmp__log))
        # print(tmp__log)
        # print(TRUE)
        # print(float(AUC_metric.compute()))
        # AUC_metric.reset()
        # AUC_metric.update(torch.Tensor(tmp_log),torch.Tensor(tmp__log))
        # print(float(AUC_metric.compute()))
        
        

        epoch_loss = sum(loss_log) / len(loss_log)
        epoch_AUC = float(AUC_metric.compute()) 
        epoch_AUC_test = float(AUC_metric_test.compute()) 

        AUC_metric.reset()
        AUC_metric_test.reset()
        # break
        del loss_log, AUC_metric
        tmp_log.append(float(epoch_loss))
        tmp__log.append(float(epoch_AUC))
        
        # print(f"Epoch{epoch+1}/{epochs} | Loss: {epoch_loss} | AUC: {epoch_AUC} |")
        print(f"Epoch{epoch+1}/{epochs} | Loss: {epoch_loss} | AUC: {epoch_AUC} | AUC_test: {epoch_AUC_test}")
        
        
        with open('logs/log.txt', 'a') as f:
            # f.write(f"Epoch{epoch+1}/{epochs} | Loss: {epoch_loss} | AUC: {epoch_AUC}| ")
            f.write(f"Epoch{epoch+1}/{epochs} | Loss: {epoch_loss} | AUC: {epoch_AUC}| AUC_test: {epoch_AUC_test}\n ")

model = TransformerDecoderModel(Train_data, num_layers, embedding_dim).to(DEVICE)
train(model, Train_data)

L_input torch.cuda.LongTensor torch.Size([3, 1])
S_input torch.cuda.FloatTensor torch.Size([39074, 128])
C_input torch.cuda.LongTensor torch.Size([471, 1])
F_input torch.cuda.LongTensor torch.Size([14, 1])


 89%|████████▉ | 34682/39073 [02:36<00:19, 220.26it/s]

In [None]:
from matplotlib import pyplot as plt
plt.plot(tmp_log)
plt.plot(tmp__log)
plt.show()

In [None]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.randn(3, 5).softmax(dim=1)
print(input.shape, target.shape)
print(input)
print(target)
output = loss(input, target)
print(output.item())