In [1]:
import pandas as pd
# baic transformer Decoder model
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
import numpy as np
import xformers.ops as xops
import math
from typing import Optional, Union
from torch import Tensor
import random

main_df = pd.read_csv('adult.csv')
main_df.head()
DEVICE = 'cuda'
# DEVICE = 'cpu'

In [2]:
def is_in_interval(interval_str, number):
    # decompose the interval string
    lower_bound = float(interval_str[1:interval_str.index(",")])
    upper_bound = float(interval_str[interval_str.index(",")+1:-1])
    inclusive_lower = interval_str[0] == "("
    inclusive_upper = interval_str[-1] == "]"

    # judge if the number is in the interval
    if inclusive_lower:
        is_in_range = number >= lower_bound
    else:
        is_in_range = number > lower_bound

    if inclusive_upper:
        is_in_range = is_in_range and number <= upper_bound
    else:
        is_in_range = is_in_range and number < upper_bound

In [3]:
# build the dicationary for the categorical features
def build_dict(df, dump = False):
    dict = {}
    for col in df.columns:
        dict_col = {}
        for i, item in enumerate(df[col].unique()):
            dict_col[str(item)] = i
        dict_col['UNSEEN'] = i + 1
        dict[str(col)] = dict_col
    # offset the item index
    offset = 0
    for column in dict.keys():
        for key in dict[column].keys():
            dict[column][key] += offset
        offset += len(dict[column].keys())
        
    if dump:
        import json
        with open('dict.json', 'w') as f:
            f.write(json.dumps(dict, indent=4))
    return dict

# build_dict(train_pool,True)

In [4]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import KBinsDiscretizer
def POOL_preprocess(df, N_BINS = 100):
    
    CAT = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']
    NUM = ['age', 'fnlwgt', 'educational-num', 'capital-gain', 'capital-loss', 'hours-per-week']
    
    num_CAT = len(CAT)
    num_NUM = len(NUM)  
    
    ct = ColumnTransformer([
        ("age", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["age"]),
        ("fnlwgt", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["fnlwgt"]),
        ("educational-num", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["educational-num"]),
        ("capital-gain", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["capital-gain"]),
        ("capital-loss", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["capital-loss"]),
        ("hours-per-week", KBinsDiscretizer(n_bins = N_BINS, encode='ordinal', strategy='uniform', subsample=None), ["hours-per-week"]),
         ],remainder = 'passthrough', verbose_feature_names_out = False) # make sure columns are unique
    ct.set_output(transform = 'pandas')
    X_trans = ct.fit_transform(df) 
    
    # make catagoy in NUM columns unique
    # each NUM column has N_BINS unique values, that is, each NUM column represents as N_BINS node
    offset = 0
    for column in NUM:
        X_trans[column] = X_trans[column].apply(lambda x: x + offset)
        offset += N_BINS
    
    # apply lable encoding on CAT columns
    from sklearn.preprocessing import LabelEncoder
    lb = LabelEncoder()
    X_trans[CAT] = X_trans[CAT].apply(lambda x: lb.fit_transform(x))
    
    # make catagoy all columns unique
    # each column has it's own number of unique values. '+1' is for unseen values
    offset = len(NUM) * N_BINS
    for column in CAT:
        X_trans[column] = X_trans[column].apply(lambda x: x + offset)
        offset += X_trans[column].nunique() + 1
    
    X_trans = X_trans.astype(int)
    return X_trans, ct, (num_NUM, num_CAT - 1) # -1 is for the income column (label)
X_trans, _, _= POOL_preprocess(main_df)
X_trans.head()

Unnamed: 0,age,fnlwgt,educational-num,capital-gain,capital-loss,hours-per-week,workclass,education,marital-status,occupation,relationship,race,gender,native-country,income
0,10,114,240,300,400,539,604,611,631,642,654,660,665,706,710
1,28,105,253,300,400,550,604,621,629,640,651,662,665,706,710
2,15,121,273,300,400,539,602,617,629,646,651,662,665,706,711
3,36,110,260,307,400,539,604,625,629,642,651,660,665,706,711
4,1,106,260,300,400,529,600,625,631,635,654,662,664,706,710


In [5]:
# def POOL_preprocess(df):
#     '''
#     input the original dataframe, output the dataframe after preprocessing,
#     change the numerical columns to categorical columns by qcut and cut
#     then apply label encoding to all columns
    
#     '''
#     df_ = df.copy()
#     CAT = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'gender', 'native-country', 'income']
#     NUM = ['age', 'fnlwgt', 'educational-num', 'capital-gain', 'capital-loss', 'hours-per-week']
#     # qcut on numerical columns
#     for column in NUM:
#         if column in ['educational-num','capital-gain','capital-loss','hours-per-week']:
#             df_[column] = pd.cut(df_[column], 100)
#         else:
#             df_[column] = pd.cut(df_[column], 100)
#     # make income column binary
#     # df['income'] = df['income'].apply(lambda x: 1 if x == '>50K' else 0)
    
#     build_dict(df_, True)

#     # convert df into unique index
#     import json
#     index_dict = open('dict.json', 'r')
#     index_dict = json.load(index_dict)
#     for column in CAT:
#         df[column] = df[column].replace(index_dict[column])
#     from tqdm import tqdm
#     for column in NUM:
        
#         # 這個時間太久了，要改
#         for row in tqdm(df[column].index):
#             for i, item in enumerate(index_dict[column]):
#                 if item == 'UNSEEN':
#                     break
#                 if is_in_interval(item, df[column][row]):
#                     df[column][row] = item
#                     break
#                 # if not in the interval, assign the unseen index
#                 df[column][row] = index_dict[column]['UNSEEN']

    
#     return df
# tmp = POOL_preprocess(main_df)
# tmp.head()

In [6]:
train_size = 4*48842//5
test_size = 48842//5
train_pool = main_df[test_size:]
test_pool = main_df[:test_size]
print('total data num:' , main_df.shape[0])
print('trian data num:' , train_pool.shape[0])
print('test data num:' , test_pool.shape[0])

total data num: 48842
trian data num: 39074
test data num: 9768


In [7]:
# notations
#   node: number of all nodes = L + S + C + F
#   L: number of lable nodes
#   S: number of sample nodes
#   C: number of catagory nodes
#   F: number of field(column) nodes
#   hidden: number of hidden representation

# data size = (node, hidden)
# mask size = (node, node - L) without lable nodes
#             for each node, real mask = cat[mask,(node,L)] = (node, node)
#             cannot see it's label node

# use nn.transformerDecoder(data,mask) to get the output
# use the above output as input of MLP to predict the lable   

In [8]:
class HGNN_DataSet():
    def __init__(self,
                 data_df : pd.DataFrame,
                 split_ratio : float ,
                 label_column : str,
                 ):
        test_size = math.ceil(data_df.shape[0] * (1-split_ratio))
        train_pool = data_df[test_size:]
        test_pool = data_df[:test_size]
        print('total data num:' , data_df.shape[0])
        print('trian data num:' , train_pool.shape[0])
        print('test data num:' , test_pool.shape[0])
        
        # to-dos:
        # train
        #   
        N_BINS = 100
        TARGET_POOL, CT, NUM_vs_CAT = POOL_preprocess(train_pool, N_BINS = N_BINS)
        # TEST_POOL = POOL_preprocess(test_pool)
        LABEL_COLUMN = label_column

        # cut feature and lable
        FEATURE_POOL = TARGET_POOL.drop(LABEL_COLUMN, axis=1)
        LABEL_POOL = TARGET_POOL[LABEL_COLUMN]

        # trasform label into one-hot
        from sklearn.preprocessing import OneHotEncoder
        enc = OneHotEncoder()
        LABEL_POOL = enc.fit_transform(LABEL_POOL.values.reshape(-1,1)).toarray()

        # L: number of lable nodes
        # the last node of Lable nodes is served as unknown lable node
        L = LABEL_POOL.shape[1] + 1

        # S: number of sample nodes
        S = FEATURE_POOL.shape[0] + 1
        # the last node of sample nodes is served as infering node
        
        # F: number of field(column) nodes
        F = FEATURE_POOL.shape[1]

        # C: number of catagory nodes
        C = FEATURE_POOL.to_numpy().max()  # total_unique_labels, which includes unseen nodes
        C_POOL = range(int(C))

        nodes_num = {'L':L, 'S':S, 'C':C, 'F':F}
        print('node_nums', nodes_num)
        print('total', L+S+C+F, 'nodes')
        
        # get true training sample
        self.labe_to_index = {}
        tmp_pool = TARGET_POOL.copy().reset_index()
        for label in tmp_pool['income'].unique():
            self.labe_to_index[label] = (tmp_pool[tmp_pool['income'] == label].index).tolist()
        
        self.TARGET_POOL = TARGET_POOL
        # self.TEST_POOL = TEST_POOL
        self.LABEL_COLUMN = LABEL_COLUMN
        self.FEATURE_POOL = FEATURE_POOL
        self.LABEL_POOL = LABEL_POOL
        self.C_POOL = C_POOL   
        self.nodes_num = nodes_num
        self.NUM_vs_CAT = NUM_vs_CAT
        self.CT = CT
        self.N_BINS = N_BINS

        
        self.make_input_tensor()
        # self.get_sample(10)        
        self.make_mask_all()
        
        # self.make_mask()
        
        
    def make_mask(self,
                  sample_indices: Optional[list] = None,
                ):
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']

        sample_size = len(sample_indices)
        masked_POOL = self.TARGET_POOL.iloc[sample_indices]
        # caculate masking
        masks = {}

        # label to sample 
        tmp = torch.zeros([math.ceil(sample_size/8) * 8, math.ceil(L/8) * 8], dtype=torch.float).to(DEVICE)
        label_ids = self.TARGET_POOL[self.LABEL_COLUMN].unique()
        for i, value_df in enumerate(masked_POOL[self.LABEL_COLUMN]):
            for j, value_label in enumerate(label_ids):
                if value_label == value_df:
                    tmp[i][j] = 1
                    break
        masks['L2S'] = Tensor.contiguous(tmp)

        # sample to catagory
        tmp = torch.zeros([math.ceil(C/8) * 8, math.ceil(sample_size/8) * 8], dtype=torch.float).to(DEVICE)
        tmp_df = masked_POOL.drop(self.LABEL_COLUMN, axis=1)
        for i, value_df in enumerate(tmp_df.values):
            for j, value in enumerate(value_df):
                tmp[value][i] = 1
        masks['S2C'] = Tensor.contiguous(tmp)

        # catagory to field
        masks['C2F'] = Tensor.contiguous(self.MASKS_FULL['C2F'])
        
        self.MASKS = masks
        self.nodes_num['K'] = sample_size
        
    def make_mask_all(self,
                  sample_indices: Optional[torch.tensor] = None,
                ):
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        # caculate masking
        masks = {}
        
        # label to sample 
        tmp = torch.zeros([math.ceil(S/8) * 8, math.ceil(L/8) * 8], dtype=torch.float)
        label_ids = self.TARGET_POOL[self.LABEL_COLUMN].unique()
        for i, value_df in enumerate(self.TARGET_POOL[self.LABEL_COLUMN]):
            for j, value_label in enumerate(label_ids):
                if value_label == value_df:
                    tmp[i][j] = 1
                    break
        masks['L2S'] = tmp

        # sample to catagory
        tmp = torch.zeros([math.ceil(C/8) * 8, math.ceil(S/8) * 8], dtype=torch.float)
        tmp_df = self.TARGET_POOL.drop(self.LABEL_COLUMN, axis=1)
        for i, value_df in enumerate(tmp_df.values):
            for j, value in enumerate(value_df):
                tmp[value][i] = 1
        masks['S2C'] = tmp

        # catagory to field
        tmp = torch.zeros([math.ceil(F/8) * 8, math.ceil(C/8) * 8], dtype=torch.float)
        unique_items = [(self.TARGET_POOL[column].unique()) for column in (self.TARGET_POOL.columns)]
        for i in range(F):
            for j in (unique_items[i]):
                tmp[i][j] = 1
        masks['C2F'] = tmp
        
        
        self.MASKS = masks
        self.MASKS_FULL = masks
    
    def make_mask_all_infer(self,
                            # index of infering nodes in test_pool
                            infering_node_index: Optional[int] = None , 
                            
                        ):
        '''
        currently, only one infering node is supported
        not implemented yet
        unseen node
        '''
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        # caculate masking
        masks = {}
        
        # label to sample 
        tmp = torch.zeros([math.ceil(S/8) * 8, math.ceil(L/8) * 8], dtype=torch.float)
        label_ids = self.TARGET_POOL[self.LABEL_COLUMN].unique()
        for i, value_df in enumerate(self.TARGET_POOL[self.LABEL_COLUMN]):
            for j, value_label in enumerate(label_ids):
                if value_label == value_df:
                    tmp[i][j] = 1
                    break
        # infering node is located at the last S 
        tmp[S - 1][label_ids + 1] = 1
        masks['L2S'] = tmp

        # sample to catagory
        tmp = torch.zeros([math.ceil(C/8) * 8, math.ceil(S/8) * 8], dtype=torch.float)
        tmp_df = self.TARGET_POOL.drop(self.LABEL_COLUMN, axis=1)
        for i, value_df in enumerate(tmp_df.values):
            for j, value in enumerate(value_df):
                tmp[value][i] = 1
        # infering node 
        # get the infering node's catagory
        tmp_df = self.TEST_POOL.drop(self.LABEL_COLUMN, axis=1)[infering_node_index]
        for i, value_df in enumerate(tmp_df.values):
            for j, value in enumerate(value_df):
                # # out of range 為解決
                # 目前問題：無法檢測到unseen catagory
                # 因為沒看過的會被lable成有看過的，只有最後一個會是unseen
                if value > C:
                    tmp[value][i] = 1
        
        masks['S2C'] = tmp

        # catagory to field
        tmp = torch.zeros([math.ceil(F/8) * 8, math.ceil(C/8) * 8], dtype=torch.float)
        unique_items = [(self.TARGET_POOL[column].unique()) for column in (self.TARGET_POOL.columns)]
        for i in range(F):
            for j in (unique_items[i]):
                tmp[i][j] = 1
        masks['C2F'] = tmp
        
        
        self.MASKS = masks
        self.MASKS_FULL = masks
    
    def make_input_tensor(self):
        # make input tensor
        L, S, C, F = self.nodes_num['L'], self.nodes_num['S'], self.nodes_num['C'], self.nodes_num['F']
        # L
        L_input = torch.tensor([range(L)]).to(DEVICE).reshape(-1,1)
        print('L_input', L_input.type(), L_input.shape)
        # S (normalized)
        normalized_features = Fun.normalize(torch.tensor(self.FEATURE_POOL.values, dtype = torch.float).to(DEVICE), p=2, dim=0)
        S_input = torch.cat([normalized_features, torch.tensor([[0]*F]).to(DEVICE)],dim = 0).to(DEVICE) # add infering node
        print('S_input', S_input.type(), S_input.shape)
        # C 
        C_input = torch.tensor([self.C_POOL]).to(DEVICE).reshape(-1,1)
        print('C_input', C_input.type(), C_input.shape)
        # F 
        F_input = torch.tensor([range(F)]).to(DEVICE).reshape(-1,1)
        print('F_input', F_input.type(), F_input.shape)
        # 
        self.INPUTS = (L_input, S_input, C_input, F_input)
        self.INPUT_DIMS = (L_input.size(1), S_input.size(1), C_input.size(1), F_input.size(1))
        
    def sample_with_distrubution(self, sample_size):
        # sample with distrubution
        """
        currently, only support binary label
        """
        indices = []
        for i in range(sample_size):
            label = random.choice(list(self.labe_to_index.keys()))
            indices.append(torch.tensor(random.choice(self.labe_to_index[label])))
            # print(label)
        return indices
            
        
    def get_sample(self, sample_size, inculde = []):
        # get K samples from S
        # return sample node mask
        S = self.nodes_num['S']
        
        # inculde specific nodes (e.g. query nodes), while remaining sample_size
        # -1 is infering node
        sample_indices = self.sample_with_distrubution(sample_size - len(inculde))
        if inculde is not []:
            while inculde in sample_indices:
                sample_indices = self.sample_with_distrubution(sample_size - len(inculde))
            # add inculde nodes into sample_indices
            for node in inculde:
                sample_indices.append(torch.tensor(node))
            sample_indices = sorted(sample_indices)
        # update mask
        self.make_mask(sample_indices)
        
        # update input tensor
        L_input, S_input, C_input, F_input = self.INPUTS
        S_input_masked = torch.index_select(S_input, 0, torch.tensor([int(x) for x in sample_indices]).to(DEVICE))
        self.MASKED_INPUTS = (L_input, S_input_masked, C_input, F_input)   
        return sample_indices
            
Train_data = HGNN_DataSet( main_df, 0.8, 'income')


total data num: 48842
trian data num: 39073
test data num: 9769


node_nums {'L': 3, 'S': 39074, 'C': 708, 'F': 14}
total 39799 nodes
L_input torch.cuda.LongTensor torch.Size([3, 1])
S_input torch.cuda.FloatTensor torch.Size([39074, 14])
C_input torch.cuda.LongTensor torch.Size([708, 1])
F_input torch.cuda.LongTensor torch.Size([14, 1])


In [9]:
Train_data.INPUTS[1]

tensor([[0.0123, 0.0050, 0.0043,  ..., 0.0051, 0.0051, 0.0051],
        [0.0031, 0.0056, 0.0058,  ..., 0.0050, 0.0051, 0.0050],
        [0.0038, 0.0051, 0.0045,  ..., 0.0051, 0.0051, 0.0051],
        ...,
        [0.0009, 0.0051, 0.0049,  ..., 0.0051, 0.0051, 0.0051],
        [0.0069, 0.0053, 0.0049,  ..., 0.0051, 0.0051, 0.0051],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')

In [10]:
Train_data.get_sample(10, inculde=[10])


[tensor(10),
 tensor(7742),
 tensor(8797),
 tensor(20439),
 tensor(23942),
 tensor(25196),
 tensor(32721),
 tensor(34924),
 tensor(37391),
 tensor(37395)]

In [11]:
from torch import Tensor
from typing import Optional, Any, Union, Callable

class CustomTransformerDecoderLayer(nn.TransformerDecoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        # remove defined modules
        delattr(self, 'self_attn')
        delattr(self, 'norm1')
        delattr(self, 'dropout1')
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        x = tgt
        if self.norm_first:
            # x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
            x = x + self._mha_block(self.norm2(x), memory, memory_mask)
            x = x + self._ff_block(self.norm3(x))
        else:
            # x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
            x = self.norm2(x + self._mha_block(x, memory, memory_mask))
            # x =  x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
            # x = self.norm3(x + self._ff_block(x))

        return x
    def _mha_block(self, x: Tensor, mem: Tensor,
                   attn_mask: Optional[Tensor],) -> Tensor:
        x = xops.memory_efficient_attention(x, mem, mem, attn_mask)
        return self.dropout2(x)


In [12]:
# baic transformer decoder model
import torch
import torch.nn as nn
import torch.nn.functional as Fun
from tqdm import trange

class TransformerDecoderModel(nn.Module):
    def __init__(self, 
                 target_dataset, 
                 num_layers, 
                 embedding_dim,  
                 subgraph_masked: Optional[bool] = False,
                 K : Optional[int] = 10,
                 ):
        super(TransformerDecoderModel, self).__init__()

        L_dim, S_dim, C_dim, F_dim = target_dataset.INPUT_DIMS
        L, S, C, F = target_dataset.nodes_num['L'], target_dataset.nodes_num['S'], target_dataset.nodes_num['C'], target_dataset.nodes_num['F']
        num_NUM , num_CAT = target_dataset.NUM_vs_CAT
        
        # check input dims
        if num_CAT + num_NUM != S_dim:
            raise ValueError('num_CAT + num_NUM != number of columns (S_dim)   {} + {} != {}'.format(num_CAT, num_NUM, S_dim))
        
        # 目前b卡在embedding的怎麼用
        # Catagory_embedding => 數值類Qcut後用linear來做embedding, 類別用nn.Embedding
        catagories = C - num_NUM * target_dataset.N_BINS

        
        # nn.Embedding( number of possible catagories, embedding_dim, )
        # nn.Linear( number of input dimantion, embedding_dim, )

        self.Lable_embedding = nn.Embedding(L, embedding_dim, dtype=torch.float)
        
        self.Sample_embedding_num = nn.Linear(num_NUM, embedding_dim, dtype=torch.float)
        # use MLP projector to project sample feature from 8 dim to 128 dim
        # self.Sample_embedding_cat = nn.Linear(num_CAT, embedding_dim, dtype=torch.float)
        self.Sample_embedding_cat = nn.Sequential(
                                        nn.Linear(num_CAT, 64),  
                                        nn.ReLU(),        
                                        nn.Linear(64, embedding_dim) 
                                    )
                                    
        
        self.Catagory_embedding_num = nn.Linear(C_dim, embedding_dim, dtype=torch.float)
        self.Catagory_embedding_cat = nn.Embedding(catagories, embedding_dim, dtype=torch.float)
        
        self.Field_embedding = nn.Embedding(F, embedding_dim, dtype=torch.float)
        
        self.transformer_decoder = nn.TransformerDecoder(
            CustomTransformerDecoderLayer(embedding_dim,  nhead = 1 ),
            num_layers
        )
        
        # downstream task
        self.MLP = nn.Sequential(
            nn.Linear(embedding_dim, 2),
            # nn.Softmax(dim=2),
            # nn.Sigmoid(),
        )
        
        self.subgraph_masked = subgraph_masked
        if subgraph_masked: 
            self.K = K
        else:
            # init mask
            target_dataset.make_mask_all()
        
        self.tmpmask_L2S = target_dataset.MASKS['L2S'].clone().to(DEVICE)


    def forward(self, 
                target_dataset: HGNN_DataSet, 
                ):
        L, S, C, F = target_dataset.nodes_num['L'], target_dataset.nodes_num['S'], target_dataset.nodes_num['C'], target_dataset.nodes_num['F']
        num_NUM, num_CAT = target_dataset.NUM_vs_CAT
        N_bins = target_dataset.N_BINS
        
        if self.subgraph_masked: 
            target_dataset.get_sample(self.K)
            L_input, S_input, C_input, F_input = target_dataset.MASKED_INPUTS
            masks = target_dataset.MASKS
            K = target_dataset.nodes_num['K']
            S_ = K
        else:
            target_dataset.make_input_tensor()
            L_input, S_input, C_input, F_input = target_dataset.INPUTS
            target_dataset.make_mask_all()
            masks = target_dataset.MASKS
            S_ = S
            
        # for S and C, we use two different embedding methods, for CAT and NUM, respectively
        # Squeeze for making batch dimantion
        L_embedded = self.Lable_embedding(L_input.long()).squeeze(1).unsqueeze(0).float()
        
        S_embedded_num = self.Sample_embedding_num(S_input[:,:num_NUM]).unsqueeze(0).float()
        S_embedded_cat = self.Sample_embedding_cat(S_input[:,num_NUM:].float()).unsqueeze(0).float()
        S_embedded = S_embedded_num + S_embedded_cat

        C_embedded_num = self.Catagory_embedding_num(C_input[:num_NUM * N_bins].float()).unsqueeze(0).float()
        C_embedded_cat = self.Catagory_embedding_cat(C_input[num_NUM * N_bins:].squeeze(1).long() - num_NUM*N_bins).unsqueeze(0).float()
        C_embedded = torch.cat([C_embedded_num, C_embedded_cat], dim = 1)
        
        F_embedded = self.Field_embedding(F_input.long()).squeeze(1).unsqueeze(0).float()
        
        # print(L_embedded.shape, S_embedded.shape, C_embedded.shape, F_embedded.shape)
        

        for mask in masks.keys():
            masks[mask] = masks[mask].to(DEVICE)
        
        # propagate steps: L→S→C→F
        #                  L←S←C←
        # more steps more menory usage
        PROPAGATE_STEPS = 2
        for i in range(PROPAGATE_STEPS):
            S_embedded = self.transformer_decoder(S_embedded,L_embedded, 
                                                memory_mask = self.tmpmask_L2S[:S_,:L]) 
            C_embedded = self.transformer_decoder(C_embedded,S_embedded,
                                                memory_mask = masks['S2C'][:C,:S_])
            F_embedded = self.transformer_decoder(F_embedded,C_embedded,
                                                memory_mask = masks['C2F'][:F,:C])
            C_embedded = self.transformer_decoder(C_embedded,F_embedded,
                                                memory_mask = Tensor.contiguous(masks['C2F'].transpose(0, 1))[:C,:F])
            S_embedded = self.transformer_decoder(S_embedded,C_embedded,
                                                memory_mask = Tensor.contiguous(masks['S2C'].transpose(0, 1))[:S_,:C])
            L_embedded = self.transformer_decoder(L_embedded,S_embedded, 
                                                memory_mask = Tensor.contiguous(self.tmpmask_L2S.transpose(0, 1))[:L,:S_])
        
        
        output = self.MLP(S_embedded)[0]
        return output

# 測試模型
num_layers = 1  # TransformerDecoder 的層數
embedding_dim = 2*128  # 嵌入維度
hidden_dim = 64  

print('input_dims', Train_data.INPUT_DIMS)
model = TransformerDecoderModel(Train_data, num_layers, embedding_dim, subgraph_masked = True, K = 10).to(DEVICE)
outputs = model(Train_data)

print("模型輸出的大小:", outputs.shape)
output_label = torch.argmax(outputs, dim=1)
output_label



input_dims (1, 14, 1, 1)


模型輸出的大小: torch.Size([10, 2])


tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [30]:
# training
from torch import autograd
from torcheval.metrics.aggregation.auc import AUC
from torcheval.metrics import BinaryAUROC


def train(model, datset):
    LABEL_POOL = datset.LABEL_POOL
    
    tmp_log = []
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.00001)
    epochs = 100
    for epoch in range(epochs):
        
        # train
        model.train()
        Original_L2S_mask = datset.MASKS['L2S'].clone()
        # logs
        loss_log = []
        AUC_metric = BinaryAUROC().to(DEVICE)
        
        for index in trange(datset.nodes_num['S']-1): # query through all sample nodes (not infering node)
            # for all query, input = sample K + 1 query smaple
            sample_indices = datset.get_sample(10, inculde = [index])
            
            
            # modify the mask to mask out the queries node's edge to it's label node
            L = datset.nodes_num['L']
            query_index = sample_indices.index(index) # query_index: index of query node in sample_indices
            datset.tmpmask_L2S = Original_L2S_mask.clone()
            datset.tmpmask_L2S[query_index] = 0
            datset.tmpmask_L2S[query_index][L-1] = 1 # make it as unseen label
            
            optimizer.zero_grad()
            # outputs = model(datset)[query_index]
            outputs = model(datset)
            # for trainning, only the query node's output is used

            # caculate loss
            if model.subgraph_masked:
                # get the real label fo query node
                LABEL_POOL_ = LABEL_POOL[index]
            else:
                LABEL_POOL_ = LABEL_POOL
            
            
            # batch_loss = criterion(outputs, torch.tensor(LABEL_POOL_).to(DEVICE))
            batch_loss = criterion(outputs, torch.tensor(LABEL_POOL[sample_indices]).to(DEVICE))
            loss_log.append(batch_loss.item())
            
            # backpropagation
            batch_loss.backward()
            optimizer.step()

            # TRUE = (torch.argmax(torch.tensor(LABEL_POOL_)))
            TRUE = (torch.argmax(torch.tensor(LABEL_POOL[sample_indices]), dim=1))
            # print(TRUE)
            
            outputs = outputs.softmax(dim=0)
            # y_pred_first = outputs[1].to(DEVICE).int()
            y_pred_first = [z[1] for z in outputs]
            
            # tmp_log.append(y_pred_first)
            # AUC_metric.update(torch.Tensor([y_pred_first]),torch.Tensor([TRUE]))
            AUC_metric.update(torch.Tensor(y_pred_first),torch.Tensor(TRUE))
            torch.cuda.empty_cache()
            
            if epoch >= 200:
                break
            
        if sum(tmp_log)==0:
            print('all zero')
        # AUC_metric.update(torch.Tensor(AUC_true_log),torch.Tensor(AUC_pred_log))
        # tmp_log = []
        epoch_loss = sum(loss_log)/(datset.nodes_num['S']-1)
        epoch_AUC = float(AUC_metric.compute())
        AUC_metric.reset()
        # break
        del loss_log, AUC_metric

        print(f"Epoch{epoch+1}/{epochs} | Loss: {epoch_loss} | AUC: {epoch_AUC}")
model = TransformerDecoderModel(Train_data, num_layers, embedding_dim, subgraph_masked = True).to(DEVICE)
train(model, Train_data)

  0%|          | 0/39073 [00:00<?, ?it/s]

  3%|▎         | 1302/39073 [00:12<06:06, 103.02it/s]


KeyboardInterrupt: 

In [None]:
# self.Sample_embedding = nn.Linear(S_dim, embedding_dim, dtype=torch.half)
# self.Catagory_embedding = nn.Linear(C_dim, embedding_dim, dtype=torch.half)
# S_embedded = self.Sample_embedding(S_input.half()).unsqueeze(0)
# C_embedded = self.Catagory_embedding(C_input.half()).unsqueeze(0)
# xops.memory_efficient_attention(C_embedded, S_embedded, S_embedded,attn_mask)[0]