In [26]:
%load_ext autoreload
%autoreload 2
import torch
import json
import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from torch_geometric.loader import NeighborLoader
from torch_geometric.data import Data
from ffm_graph import *
from data import MINDDataset
from data_utils import *
from transformers import BertConfig
from gnn import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
train_data_path = 'data/MINDlarge_train/'

mind_dataset = MINDDataset(train_data_path + 'behaviors.pkl')
train_dataloader = DataLoader(mind_dataset, batch_size=64, collate_fn=collate_fn, num_workers=4)

FileNotFoundError: [Errno 2] No such file or directory: 'data/MINDsmall_train/behaviors.pkl'

In [None]:
news_encoder_config = BertConfig.from_json_file('news_encoder.json')
bert_features_path = train_data_path + 'bert_features.pt'
news_encoder = NewsEncoder(news_encoder_config, bert_features_path)

user_encoder_config = BertConfig.from_json_file('user_encoder.json')
user_encoder = UserEncoder(user_encoder_config)

gnn = create_sage(nfeat=256, nhid=256, dropout=0.1, nlayer=3)

In [None]:
device = torch.device('cuda')
model = Fastformer_Graph(news_encoder, user_encoder, gnn).to(device)

In [31]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {total_params}')

Total number of parameters: 1848418


In [None]:
def train_model(model, train_loader, global_graph_data, device, epochs=5):
	optimizer = optim.AdamW(model.parameters(), lr=1e-4)
	criterion = torch.nn.CrossEntropyLoss()
	
	model.to(device)

	for epoch in range(epochs):
		model.train()
		total_loss = 0
		pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
		for batch in pbar:
			batch = {k: v.to(device) for k, v in batch.items()}

			sub_loader = NeighborLoader(
				global_graph_data,
				num_neighbors=[3, 3, 3],
				input_nodes=batch['seed_nodes'].cpu(),
				batch_size=len(batch['seed_nodes']),
				shuffle=False
			)
			sub_graph = next(iter(sub_loader)).to(device)

			optimizer.zero_grad()

			scores = model(batch, sub_graph, device)
			loss = criterion(scores, batch['label'])

			loss.backward()
			optimizer.step()

			total_loss += loss.item()
			pbar.set_postfix({"loss": f"{loss.item():.4f}"})
		torch.save(model.state_dict(), "ffmg_large.pt")
		print(f'Epoch {epoch} done! AVG loss: {total_loss / len(train_loader):.4f}')

In [None]:
news_tokens = np.load('data/MINDsmall_train/news_token.npy')
x = torch.from_numpy(news_tokens).long()

edge_index = torch.load('edge_index.pt').contiguous()
global_graph_data = Data(x=x, edge_index=edge_index)

In [None]:
# batch = next(iter(train_dataloader))
# batch = {k: v.to(device) for k, v in batch.items()}
# sub_loader = NeighborLoader(
#     global_graph_data,
#     num_neighbors=[3, 3, 3], 
#     input_nodes=batch['seed_nodes'].cpu(),
#     batch_size=len(batch['seed_nodes']),
#     shuffle=False
# )
# sub_graph = next(iter(sub_loader)).to(device)

In [None]:
# model.eval() # Chế độ eval để không update dropout
# with torch.no_grad():
#     try:
#         output = model(batch, sub_graph, device)
#         print(f"✅ Thành công!")
#         print(f"Shape đầu ra (Scores): {output.shape}") # Kỳ vọng: [Batch_size, 5]
#         print(f"Dữ liệu mẫu 2 hàng đầu:\n{output[:2]}")
#     except Exception as e:
#         print(f"❌ Có lỗi rồi đại vương ơi!")
#         print(f"Lỗi: {e}")

✅ Thành công!
Shape đầu ra (Scores): torch.Size([2, 5])
Dữ liệu mẫu 2 hàng đầu:
tensor([[ 84.6329,  93.0527,   0.0000,   0.0000,   0.0000],
        [105.3233, 101.0506, 113.3435,  91.9212,  99.6274]], device='cuda:0')


In [None]:
train_model(model, train_dataloader, global_graph_data, device, 1)

Epoch 1:   0%|          | 0/118172 [00:43<?, ?it/s]


KeyboardInterrupt: 