In [1]:
import os, random, math, pickle
import pandas as pd
import numpy as np
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [2]:
name = 'book'
n_clusters = 4

## 1. Learn Embedding

### 1.1 Dataset

In [3]:
class TCKGDataset(Dataset):
    def __init__(self, triplets):
        self.triplets = triplets
    def __len__(self):
        return len(self.triplets)
    def __getitem__(self, idx):
        # Trảmovie về bộ ba (head, relation, tail)
        return self.triplets[idx]

### 1.2 TransE Model

In [4]:
class TransE(pl.LightningModule):
    def __init__(self, num_entities, num_relations, embedding_dim=64, lr=1e-3, weight_decay=1e-4, dropout_rate=0.2):
        super().__init__()
        self.save_hyperparameters()
        
        # Khởi tạo Embeddings
        self.entity_emb = nn.Embedding(num_entities + 1, embedding_dim, padding_idx=0)     # +1 because starting at 1 instead of 0
        self.relation_emb = nn.Embedding(num_relations + 1, embedding_dim, padding_idx=0)
        
        # # Xavier initialization giúp hội tụ tốt hơn
        # nn.init.xavier_uniform_(self.entity_emb.weight)
        # nn.init.xavier_uniform_(self.relation_emb.weight)

        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, h, r, t):
        h_e = self.entity_emb(h)
        r_e = self.relation_emb(r)
        t_e = self.entity_emb(t)

        # 2. Embedding Normalization (Rất quan trọng cho TransE)
        # Ép độ dài các vector về 1 (Unit Norm constraint)
        h_e = F.normalize(h_e, p=2, dim=1)
        r_e = F.normalize(r_e, p=2, dim=1)
        t_e = F.normalize(t_e, p=2, dim=1)
        
        # 3. Áp dụng Dropout
        h_e = self.dropout(h_e)
        r_e = self.dropout(r_e)
        t_e = self.dropout(t_e)
        
        # Công thức (6): Khoảng cách bình phương L2
        # g_r(h, t) = ||h + r - t||^2
        score = torch.sum((h_e + r_e - t_e)**2, dim=1)
        return score

    def training_step(self, batch, batch_idx):
        h, r, t = batch[:, 0], batch[:, 1], batch[:, 2]
        
        # Tính score cho bộ ba đúng (Positive) -> Cần giảm thiểu khoảng cách này
        pos_scores = self(h, r, t)
        
        # Negative Sampling: Thay thế tail t bằng t' ngẫu nhiên
        # t' không nhất thiết phải là không đúng thực tế (simplified), nhưng xác suất cao là không đúng.
        rand_t = torch.randint(1, self.hparams.num_entities + 1, t.shape, device=self.device)
        
        # Tính score cho bộ ba sai (Negative) -> Cần tối đa hóa khoảng cách này
        neg_scores = self(h, r, rand_t)
        
        # Công thức (7) Loss: -ln(sigmoid(g_neg - g_pos))
        # Chúng ta muốn g_neg > g_pos (khoảng cách sai lớn hơn đúng)
        # => (g_neg - g_pos) càng lớn càng tốt
        loss = -F.logsigmoid(neg_scores - pos_scores).mean()
        
        # Log loss
        self.log('train_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        h, r, t = batch[:, 0], batch[:, 1], batch[:, 2] #h,r,t shape = batch_size
        
        # 1. Tính loss trên valid set
        pos_scores = self(h, r, t)
        
        # Negative sampling (đơn giản hoá để tính loss theo dõi)
        rand_t = torch.randint(1, self.hparams.num_entities + 1, t.shape, device=self.device)

        neg_scores = self(h, r, rand_t)
        
        val_loss = -F.logsigmoid(neg_scores - pos_scores).mean()
        self.log('val_loss', val_loss, prog_bar=True)
        return val_loss

    def configure_optimizers(self):
        # 4. Thêm weight_decay (L2 regularization) vào Adam
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)

### 1.3 Load dataset

In [5]:
file_path = f'./data/{name}_TCKG.csv' 
print(f"Loading data from {file_path}...")

TCKG_df = pd.read_csv(file_path)

# Chuyển đổi dữ liệu sang index
triplets_np = np.stack([
    TCKG_df['head_id'],
    TCKG_df['relation_id'],
    TCKG_df['tail_id']
], axis=1)

# 2. Tìm Offset (Lấy ID lớn nhất của relation hiện tại)
# Ví dụ: nếu relation_id chạy từ 1 đến 10, offset sẽ là 10.
offset = TCKG_df['relation_id'].max()

# 3. Tạo Inverse Connections (Cạnh ngược)
# Đảo vị trí Tail -> Head, Head -> Tail, và cộng offset vào Relation
inverse_triplets_np = np.stack([
    TCKG_df['tail_id'],                 # Tail thành Head
    TCKG_df['relation_id'] + offset,    # Relation mới = Relation cũ + offset
    TCKG_df['head_id']                  # Head thành Tail
], axis=1)

# 4. Gộp cả 2 mảng lại với nhau
# axis=0 nghĩa là nối tiếp theo chiều dọc (thêm dòng)
all_triplets_np = np.concatenate([triplets_np, inverse_triplets_np], axis=0)

# # Lưu all_triplets_np ra file CSV
# df1 = pd.DataFrame(all_triplets_np, columns=['head_id', 'relation_id', 'tail_id'])
# df1 = df1.sort_values(by=['relation_id'])
# df1.to_csv(f'./data/{name}_TCKG_all.csv', index=False)


# Chuyển sang Tensor
triplets_tensor = torch.tensor(all_triplets_np, dtype=torch.long)
print(f'triplets_tensor.shape: {triplets_tensor.shape}')

# Tạo DataLoader
full_dataset = TCKGDataset(triplets_tensor)

# Chia 90% Train - 10% Val
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_set, val_set = random_split(full_dataset, [train_size, val_size])

# Tạo 2 Loaders
train_loader = DataLoader(train_set, batch_size=1024, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=1024, shuffle=False, num_workers=0)


Loading data from ./data/book_TCKG.csv...
triplets_tensor.shape: torch.Size([236394, 3])


### 1.3 Init and train model

In [6]:
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

num_entites = pd.concat([TCKG_df['head_id'], TCKG_df['tail_id']]).max()

num_relations = TCKG_df['relation_id'].max() * 2    #*2 to double relation for inverse connection

print(f"Total Entities: {num_entites}")
print(f"Total Relations: {num_relations}")

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',       # Theo dõi val_loss
    dirpath=f'./checkpoints/', # Thư mục lưu
    filename=f'{name}-transE-{timestamp}-{{epoch:02d}}-{{val_loss:.4f}}', 
    save_top_k=1,             # Chỉ giữ lại 1 model tốt nhất
    mode='min',               # Lưu khi val_loss nhỏ nhất
)

# 5. Early Stopping Callback
early_stop_callback = EarlyStopping(
    monitor='val_loss', # Theo dõi val_loss
    min_delta=0.001,    # Cải thiện tối thiểu cần thiết
    patience=20,         # Chờ 5 epochs nếu không cải thiện thì dừng
    verbose=True,
    mode='min'
)

model = TransE(
    num_entities=num_entites, 
    num_relations=num_relations, 
    embedding_dim=64, # Có thể chỉnh d-dimension tại đây
    lr=0.001,
    weight_decay=1e-3,  # Tăng lên nếu vẫn overfit (ví dụ: 1e-3)
    dropout_rate=0.3    # Tăng lên nếu vẫn overfit (tối đa 0.5)
)

# Trainer
trainer = pl.Trainer(
    max_epochs=500, 
    accelerator="auto", # Tự động dùng GPU nếu có
    callbacks=[checkpoint_callback, early_stop_callback],
    enable_progress_bar=True
)
# Bắt đầu huấn luyện
trainer.fit(model, train_loader, val_loader)
# Sau khi train, bạn có thể lấy embedding bằng:
# entity_embeddings = model.entity_emb.weight.detach().cpu().numpy()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name         | Type      | Params | Mode  | FLOPs
-----------------------------------------------------------
0 | entity_emb   | Embedding | 2.5 M  | train | 0    
1 | relation_emb | Embedding | 3.1 K  | train | 0    
2 | dropout      | Dropout   | 0      | train | 0    
-----------------------------------------------------------
2.5 M     Trainable params
0         Non-trainable params
2.5 M     Total params
9.967     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode
0         Total Flops


Total Entities: 38885
Total Relations: 48
Epoch 0: 100%|██████████| 208/208 [00:14<00:00, 14.41it/s, v_num=0, train_loss=0.823, val_loss=0.720]

Metric val_loss improved. New best score: 0.720


Epoch 2: 100%|██████████| 208/208 [00:11<00:00, 18.12it/s, v_num=0, train_loss=0.800, val_loss=0.718]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.718


Epoch 4: 100%|██████████| 208/208 [00:13<00:00, 15.24it/s, v_num=0, train_loss=0.808, val_loss=0.715]

Metric val_loss improved by 0.003 >= min_delta = 0.001. New best score: 0.715


Epoch 5: 100%|██████████| 208/208 [00:19<00:00, 10.47it/s, v_num=0, train_loss=0.820, val_loss=0.714]

Metric val_loss improved by 0.001 >= min_delta = 0.001. New best score: 0.714


Epoch 6: 100%|██████████| 208/208 [00:12<00:00, 17.13it/s, v_num=0, train_loss=0.847, val_loss=0.708]

Metric val_loss improved by 0.006 >= min_delta = 0.001. New best score: 0.708


Epoch 7: 100%|██████████| 208/208 [00:10<00:00, 20.24it/s, v_num=0, train_loss=0.843, val_loss=0.705]

Metric val_loss improved by 0.003 >= min_delta = 0.001. New best score: 0.705


Epoch 8: 100%|██████████| 208/208 [00:09<00:00, 22.58it/s, v_num=0, train_loss=0.840, val_loss=0.696]

Metric val_loss improved by 0.009 >= min_delta = 0.001. New best score: 0.696


Epoch 9: 100%|██████████| 208/208 [00:08<00:00, 24.01it/s, v_num=0, train_loss=0.759, val_loss=0.685]

Metric val_loss improved by 0.011 >= min_delta = 0.001. New best score: 0.685


Epoch 10: 100%|██████████| 208/208 [00:09<00:00, 21.61it/s, v_num=0, train_loss=0.812, val_loss=0.672]

Metric val_loss improved by 0.012 >= min_delta = 0.001. New best score: 0.672


Epoch 11: 100%|██████████| 208/208 [00:12<00:00, 16.77it/s, v_num=0, train_loss=0.808, val_loss=0.647]

Metric val_loss improved by 0.025 >= min_delta = 0.001. New best score: 0.647


Epoch 12: 100%|██████████| 208/208 [00:13<00:00, 15.67it/s, v_num=0, train_loss=0.792, val_loss=0.615]

Metric val_loss improved by 0.032 >= min_delta = 0.001. New best score: 0.615


Epoch 13: 100%|██████████| 208/208 [00:09<00:00, 23.00it/s, v_num=0, train_loss=0.716, val_loss=0.576]

Metric val_loss improved by 0.040 >= min_delta = 0.001. New best score: 0.576


Epoch 14: 100%|██████████| 208/208 [00:12<00:00, 16.28it/s, v_num=0, train_loss=0.709, val_loss=0.525]

Metric val_loss improved by 0.050 >= min_delta = 0.001. New best score: 0.525


Epoch 15: 100%|██████████| 208/208 [00:08<00:00, 24.00it/s, v_num=0, train_loss=0.594, val_loss=0.475]

Metric val_loss improved by 0.050 >= min_delta = 0.001. New best score: 0.475


Epoch 16: 100%|██████████| 208/208 [00:09<00:00, 20.85it/s, v_num=0, train_loss=0.508, val_loss=0.424]

Metric val_loss improved by 0.052 >= min_delta = 0.001. New best score: 0.424


Epoch 17: 100%|██████████| 208/208 [00:10<00:00, 19.55it/s, v_num=0, train_loss=0.453, val_loss=0.383]

Metric val_loss improved by 0.040 >= min_delta = 0.001. New best score: 0.383


Epoch 18: 100%|██████████| 208/208 [00:14<00:00, 14.45it/s, v_num=0, train_loss=0.430, val_loss=0.349]

Metric val_loss improved by 0.034 >= min_delta = 0.001. New best score: 0.349


Epoch 19: 100%|██████████| 208/208 [00:10<00:00, 20.36it/s, v_num=0, train_loss=0.370, val_loss=0.326]

Metric val_loss improved by 0.024 >= min_delta = 0.001. New best score: 0.326


Epoch 20: 100%|██████████| 208/208 [00:08<00:00, 23.41it/s, v_num=0, train_loss=0.346, val_loss=0.309]

Metric val_loss improved by 0.017 >= min_delta = 0.001. New best score: 0.309


Epoch 21: 100%|██████████| 208/208 [00:08<00:00, 24.38it/s, v_num=0, train_loss=0.325, val_loss=0.298]

Metric val_loss improved by 0.011 >= min_delta = 0.001. New best score: 0.298


Epoch 22: 100%|██████████| 208/208 [00:08<00:00, 23.36it/s, v_num=0, train_loss=0.310, val_loss=0.289]

Metric val_loss improved by 0.008 >= min_delta = 0.001. New best score: 0.289


Epoch 23: 100%|██████████| 208/208 [00:09<00:00, 22.62it/s, v_num=0, train_loss=0.296, val_loss=0.285]

Metric val_loss improved by 0.004 >= min_delta = 0.001. New best score: 0.285


Epoch 24: 100%|██████████| 208/208 [00:10<00:00, 19.60it/s, v_num=0, train_loss=0.298, val_loss=0.284]

Metric val_loss improved by 0.001 >= min_delta = 0.001. New best score: 0.284


Epoch 25: 100%|██████████| 208/208 [00:11<00:00, 18.28it/s, v_num=0, train_loss=0.291, val_loss=0.276]

Metric val_loss improved by 0.008 >= min_delta = 0.001. New best score: 0.276


Epoch 27: 100%|██████████| 208/208 [00:09<00:00, 22.14it/s, v_num=0, train_loss=0.277, val_loss=0.273]

Metric val_loss improved by 0.003 >= min_delta = 0.001. New best score: 0.273


Epoch 29: 100%|██████████| 208/208 [00:08<00:00, 24.37it/s, v_num=0, train_loss=0.273, val_loss=0.271]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.271


Epoch 30: 100%|██████████| 208/208 [00:08<00:00, 23.21it/s, v_num=0, train_loss=0.265, val_loss=0.266]

Metric val_loss improved by 0.005 >= min_delta = 0.001. New best score: 0.266


Epoch 33: 100%|██████████| 208/208 [00:11<00:00, 17.60it/s, v_num=0, train_loss=0.278, val_loss=0.264]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.264


Epoch 34: 100%|██████████| 208/208 [00:12<00:00, 16.89it/s, v_num=0, train_loss=0.262, val_loss=0.262]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.262


Epoch 35: 100%|██████████| 208/208 [00:10<00:00, 20.39it/s, v_num=0, train_loss=0.266, val_loss=0.259]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.259


Epoch 45: 100%|██████████| 208/208 [00:09<00:00, 22.18it/s, v_num=0, train_loss=0.274, val_loss=0.257]

Metric val_loss improved by 0.002 >= min_delta = 0.001. New best score: 0.257


Epoch 65: 100%|██████████| 208/208 [00:10<00:00, 20.21it/s, v_num=0, train_loss=0.240, val_loss=0.259]

Monitored metric val_loss did not improve in the last 20 records. Best score: 0.257. Signaling Trainer to stop.


Epoch 65: 100%|██████████| 208/208 [00:10<00:00, 20.21it/s, v_num=0, train_loss=0.240, val_loss=0.259]


### 1.5 Save trained 

In [7]:
# 1. Extract Embeddings from Model (move to CPU and convert to numpy)
entity_embeddings = model.entity_emb.weight.detach().cpu().numpy()
relation_embeddings = model.relation_emb.weight.detach().cpu().numpy()

# 2. Package everything into a dictionary
saved_data = {
    'entity_embeddings': entity_embeddings,      # (Num_Entities, dim)
    'relation_embeddings': relation_embeddings,  # (Num_Relations, dim)
}
# 3. Save to a single file
with open(f'./pickle/{name}_transE_embeddings_{timestamp}.pkl', 'wb') as f:
    pickle.dump(saved_data, f)
print("Embeddings and mappings saved successfully!")

Embeddings and mappings saved successfully!
