In [1]:
import os
import json
import torch
import numpy as np
import pandas as pd
import pickle as pkl
from torch import nn
from tqdm import tqdm
from typing import List
import multiprocessing as mp
import torch.nn.functional as F
from collections import defaultdict
from random import sample, random, seed
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [2]:
set_seed = 42
torch.manual_seed(set_seed)
torch.cuda.manual_seed(set_seed)
torch.cuda.manual_seed_all(set_seed)
np.random.seed(set_seed)
seed(set_seed)

In [3]:
DIM = 50
MARGIN = 1
BATCH_SIZE = 512
TOP_K = 100
CORRUPTED_NUM = 100
NORM = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# with open('./data/darpa_datasets/data/all_event_5d_4.4_transE.pkl','rb') as f:
#     ls = pkl.load(f)

In [5]:
# temp = []
# for i,(subject,relation,object_) in tqdm(enumerate(ls),total=len(ls)): 
#     if subject['n_attrbiute']['cmdline'] != 'None' and relation['relation'] != 'None':
#         for key,value in object_['n_attrbiute'].items(): 
#             if value == 'None':
#                 break
#         else:
#             '''
#             object_'s key should not be in ['type','pid','localAddress', 'localPort', 'ipProtocol'], 英仁說的
#             values of 'remoteAddress' and 'remotePort' should be concatenated, 英仁說的
#             '''
#             filtered_object = list(filter(lambda x: x[0] not in ['type','pid','localAddress', 'localPort', 'ipProtocol'],object_['n_attrbiute'].items()))
#             object_data = ':'.join([str(y) for x,y in filtered_object])
#             temp.append([subject['n_attrbiute']['cmdline'], relation['relation'], object_data, relation['label']])

In [6]:
# df = pd.DataFrame(temp,columns=['Subject','relation','Object','malicious_or_benign'])
# train_df, valid_df = train_test_split(df, test_size=0.2,random_state=42,shuffle=True)
# valid_df, test_df = train_test_split(valid_df, test_size=0.5,random_state=42,shuffle=True)

In [7]:
# train_df.to_csv('./data/darpa_datasets/train_df.csv',index=False)
# valid_df.to_csv('./data/darpa_datasets/valid_df.csv',index=False)
# test_df.to_csv('./data/darpa_datasets/test_df.csv',index=False)

In [8]:
train_df = pd.read_csv('./data/darpa_datasets/train_df.csv')
valid_df = pd.read_csv('./data/darpa_datasets/valid_df.csv')
test_df = pd.read_csv('./data/darpa_datasets/test_df.csv')

In [9]:
data = sorted(list(set(train_df['Subject']) | set(train_df['Object'])))
subj_obj_dic = dict(zip(data,range(len(data))))
data = sorted(list(set(train_df['relation'])))
relation_dic = dict(zip(data,range(len(data))))

In [10]:
print(len(subj_obj_dic))

63285


In [11]:
train_df['Subject'] = train_df['Subject'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
train_df['Object'] = train_df['Object'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
train_df['relation'] = train_df['relation'].map(lambda x: relation_dic.get(x,len(relation_dic)))

valid_df['Subject'] = valid_df['Subject'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
valid_df['Object'] = valid_df['Object'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
valid_df['relation'] = valid_df['relation'].map(lambda x: relation_dic.get(x,len(relation_dic)))

test_df['Subject'] = test_df['Subject'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
test_df['Object'] = test_df['Object'].map(lambda x: subj_obj_dic.get(x,len(subj_obj_dic)))
test_df['relation'] = test_df['relation'].map(lambda x: relation_dic.get(x,len(relation_dic)))

In [12]:
class TransE_Model(nn.Module):
    def __init__(self,entity_num:int,relation_num:int,embedding_dim:int,NORM:int=1):
        super().__init__()
        self.NORM = NORM
        self.entity_num = entity_num
        self.relation_num = relation_num
        self.embedding_dim = embedding_dim
        self.entity_embedding = self.generate_embedding(self.entity_num)
        self.relation_embedding = self.generate_embedding(self.relation_num)
        self.relation_embedding = self.normalize_embedding(self.relation_embedding)
    
    def generate_embedding(self,num):
        emb = nn.Embedding(num_embeddings=num+1,embedding_dim=self.embedding_dim,padding_idx=num)
        emb.weight.data.uniform_(-6/(self.embedding_dim**0.5),6/(self.embedding_dim**0.5))
        return emb
    
    def normalize_embedding(self,emb):
        emb.weight.data /= emb.weight.data.norm(p=self.NORM,dim=1,keepdim=True)
        return emb
    
    def forward(self,subject,relation,object_):
        x = self.entity_embedding(subject)
        x += self.relation_embedding(relation)
        x -= self.entity_embedding(object_)
        return x.norm(p=self.NORM,dim=1)

In [13]:
class Custom_Dataset(Dataset):
    def __init__(self,df:pd.DataFrame):
        super().__init__()
        self.data = [tuple(i) for i in df.values]
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        return self.data[idx]

class Generate_Corrupted_Triple():
    def __init__(self,df:pd.DataFrame):
        self.obj = set(df['Object'])
        self.subj = set(df['Subject'])
        self.s_2_o_dic, self.o_2_s_dic = self.generate_dic(df)
    def generate_dic(self,df):
        s_2_o_dic = defaultdict(set)
        o_2_s_dic = defaultdict(set)
        for s,r,o,b_or_m in df.values:
            s_2_o_dic[s].add(o)
            o_2_s_dic[o].add(s)
        return s_2_o_dic,o_2_s_dic
    def __call__(self,subject,relation,object_):
        new_subject = torch.tensor([sample(self.subj - self.o_2_s_dic[obj],k=1)[0] for obj in object_.cpu().numpy()])
        new_object = torch.tensor([sample(self.obj - self.s_2_o_dic[subj],k=1)[0] for subj in subject.cpu().numpy()])
        return [new_subject,relation,object_],[subject,relation,new_object]
    
class Generate_Corrupted_Triple():
    def __init__(self,df:pd.DataFrame):
        self.obj_dic, self.subj_dic = self.generate_dic(df)
    def generate_dic(self,df):
        obj = set(df['Object'])
        subj = set(df['Subject'])
        subj_dic = defaultdict(set)
        obj_dic = defaultdict(set)
        s_2_o_dic = defaultdict(set)
        o_2_s_dic = defaultdict(set)
        for s,r,o,b_or_m in df.values:
            s_2_o_dic[s].add(o)
            o_2_s_dic[o].add(s)
        for key,value in o_2_s_dic.items():
            obj_dic[key] = subj - value
        for key,value in s_2_o_dic.items():
            subj_dic[key] = obj - value
        return obj_dic,subj_dic
    def __call__(self,subject,relation,object_):
        new_subject = torch.tensor([sample(self.obj_dic[obj], k=1)[0] for obj in object_.cpu().numpy()])
        new_object = torch.tensor([sample(self.subj_dic[subj], k=1)[0] for subj in subject.cpu().numpy()])
        return [new_subject,relation,object_],[subject,relation,new_object]

In [14]:
def To_Device(batch,device):
    return [b.to(device) for b in batch]

In [15]:
model = TransE_Model(
    entity_num=len(set(train_df['Subject']) | set(train_df['Object'])),
    relation_num=len(set(train_df['relation'])),
    embedding_dim=DIM,NORM=NORM
).to(device)
optim_fn = torch.optim.Adam(model.parameters())
loss_fn = nn.MarginRankingLoss(margin=MARGIN).to(device)
GCT = Generate_Corrupted_Triple(train_df)

In [16]:
# train, test = train_test_split(df, test_size=0.2)
train_dataset = Custom_Dataset(train_df)
train_dataloader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
valid_dataset = Custom_Dataset(valid_df)
valid_dataloader = DataLoader(valid_dataset,batch_size=16,shuffle=False)
test_dataset = Custom_Dataset(test_df)
test_dataloader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)

In [17]:
def train(model,train_dataloader,loss_fn,optim_fn,device):
    losses = []
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))
    target = torch.tensor([-1], dtype=torch.long, device=device)
    for idx,batch in bar: 
        model.entity_embedding = model.normalize_embedding(model.entity_embedding)
        corrupted_batch_1,corrupted_batch_2 = GCT(*batch[:3])
        loss = 0.5*loss_fn(
            model(*To_Device(batch[:3],device)),model(*To_Device(corrupted_batch_1,device)),target
        ) + 0.5*loss_fn(
            model(*To_Device(batch[:3],device)),model(*To_Device(corrupted_batch_2,device)),target
        )
        # if random() > 0.5:
        #     loss = loss_fn(
        #         model(*To_Device(batch,device)),model(*To_Device(corrupted_batch_1,device)),target
        #     )
        # else:
        #     loss = loss_fn(
        #         model(*To_Device(batch,device)),model(*To_Device(corrupted_batch_2,device)),target
        #     )
        optim_fn.zero_grad()
        loss.backward()
        optim_fn.step()
        losses.append(loss.item())
        bar.set_description(f'loss: {np.mean(losses):8.7f}')
    return model

In [18]:
def valid(model,dataloader,device,TOP_K):
    top_k = []
    bar = tqdm(dataloader)
    for batch in bar:
        subject,relation,object_ = To_Device(batch[:3],device)
        temp = model.entity_embedding(subject)
        temp += model.relation_embedding(relation)
        # temp -= model.entity_embedding(object_)

        shape = temp.shape
        temp = temp.unsqueeze(1)
        temp = temp.expand(shape[0],len(model.entity_embedding.weight),shape[1])

        entity_weight = model.entity_embedding.weight.unsqueeze(0)
        entity_weight = entity_weight.expand(shape[0],len(model.entity_embedding.weight),shape[1]).to(device)

        values,indices = torch.topk((temp - entity_weight).norm(p=NORM,dim=2),k=TOP_K,largest=False)
        top_k.append(torch.eq(indices , batch[2].unsqueeze(1).to(device)).float().sum().item()/len(batch[0]))
        bar.set_description(f'top_{TOP_K}: {np.mean(top_k):8.7f}')
    return np.mean(top_k)

In [19]:
for i in range(5):
    model = train(model,train_dataloader,loss_fn,optim_fn,device)
    top_k = valid(model,valid_dataloader,device,TOP_K)
    torch.save(model,f'./models/darpa/epoch_{i}_top_{TOP_K}_{top_k}.pt')
    print()

loss: 0.0607305: 100%|██████████| 10444/10444 [1:13:24<00:00,  2.37it/s]
top_100: 0.7355063: 100%|██████████| 41774/41774 [03:07<00:00, 222.47it/s]





loss: 0.0252982: 100%|██████████| 10444/10444 [1:14:08<00:00,  2.35it/s]
top_100: 0.7135489: 100%|██████████| 41774/41774 [03:04<00:00, 226.83it/s]





loss: 0.0237640:   1%|          | 112/10444 [00:52<1:21:06,  2.12it/s]


KeyboardInterrupt: 