# Full TSB-CL Model Training (Single Cell)

This notebook contains the complete code to train the **Full TSB-CL** model (Time-Aware Structural Biclique Contrastive Learning).
It includes:
1.  **DataUtils**: For loading data and mining bicliques (requires `msbe.exe`).
2.  **Model Architecture**: The fixed `FullTSBCL` class with correct GRU state handling.
3.  **Training Loop**: A sequential training process that properly warms up the RNN state before evaluation.

**Note**: Ensure `msbe.exe` (or `msbe` on Linux) is present in the `Similar-Biclique-Idx-main` folder.

In [None]:
import os
import sys
import time
import random
import struct
import subprocess
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Set Environment Variable for OpenMP
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# ==========================================
# 1. Configuration & Paths
# ==========================================
# Adjust these paths if running in a different environment
CURRENT_DIR = os.getcwd()
PROJECT_ROOT = CURRENT_DIR # Assuming notebook is in project root
DATA_PATH = os.path.join(PROJECT_ROOT, "Similar-Biclique-Idx-main", "datasets", "bi_github.txt")
MSBE_EXE = os.path.join(PROJECT_ROOT, "Similar-Biclique-Idx-main", "msbe.exe")

# Hyperparameters
EMBEDDING_DIM = 64
BATCH_SIZE = 2048
LR = 0.001
EPOCHS = 10
NUM_SNAPSHOTS = 5
TAU = 2
EPSILON = 0.1
TOP_K = 20

# Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==========================================
# 2. Data Utilities (DataUtils)
# ==========================================
class DataUtils:
    def __init__(self, data_path, msbe_exe_path, temp_dir=None):
        self.data_path = data_path
        self.msbe_exe_path = msbe_exe_path
        if temp_dir is None:
            self.temp_dir = os.path.join(os.getcwd(), 'temp')
        else:
            self.temp_dir = temp_dir
        if not os.path.exists(self.temp_dir):
            os.makedirs(self.temp_dir)
        self.user_map = {}
        self.item_map = {}
        self.num_users = 0
        self.num_items = 0

    def load_data(self):
        print(f"Loading data from {self.data_path}...")
        data = []
        if not os.path.exists(self.data_path):
            print(f"Error: Data file not found at {self.data_path}")
            return []
            
        with open(self.data_path, 'r') as f:
            f.readline() # Skip header
            timestamp = 0
            for line in f:
                try:
                    parts = line.strip().split()
                    if len(parts) < 2: continue
                    u, i = int(parts[0]), int(parts[1])
                    if u not in self.user_map: self.user_map[u] = len(self.user_map)
                    if i not in self.item_map: self.item_map[i] = len(self.item_map)
                    data.append((self.user_map[u], self.item_map[i], timestamp))
                    timestamp += 1
                except ValueError: continue
        self.num_users = len(self.user_map)
        self.num_items = len(self.item_map)
        print(f"Loaded {len(data)} interactions. Users: {self.num_users}, Items: {self.num_items}")
        return data

    def split_snapshots(self, data, num_snapshots=5):
        chunk_size = len(data) // num_snapshots
        snapshots = []
        for i in range(num_snapshots):
            start = i * chunk_size
            end = (i + 1) * chunk_size if i < num_snapshots - 1 else len(data)
            snapshots.append(data[start:end])
        return snapshots

    def save_binary_graph(self, snapshot_data, file_prefix):
        us, vs = set(), set()
        for u, v, _ in snapshot_data:
            us.add(u); vs.add(v)
        sorted_us, sorted_vs = sorted(list(us)), sorted(list(vs))
        n1, n2 = len(sorted_us), len(sorted_vs)
        n = n1 + n2
        u_map = {u: i for i, u in enumerate(sorted_us)}
        v_map = {v: i + n1 for i, v in enumerate(sorted_vs)}
        u_rev = {i: u for u, i in u_map.items()}
        v_rev = {i: v for v, i in v_map.items()}
        
        adj = [[] for _ in range(n)]
        edges_count = 0
        for u, v, _ in snapshot_data:
            uid, vid = u_map[u], v_map[v]
            adj[uid].append(vid); adj[vid].append(uid)
            edges_count += 2
        for i in range(n): adj[i].sort()
            
        with open(file_prefix + "_b_degree.bin", 'wb') as f:
            f.write(struct.pack('I', 4)); f.write(struct.pack('I', n1)); f.write(struct.pack('I', n2))
            f.write(struct.pack('I', edges_count))
            f.write(struct.pack(f'{n}I', *[len(adj[i]) for i in range(n)]))
        with open(file_prefix + "_b_adj.bin", 'wb') as f:
            flat_adj = [x for sub in adj for x in sub]
            f.write(struct.pack(f'{edges_count}I', *flat_adj))
        return n1, n2, u_rev, v_rev

    def run_msbe_mining(self, snapshot_data, snapshot_id, tau=3, epsilon=0.5):
        graph_name = f"graph_{snapshot_id}"
        input_prefix = os.path.join(self.temp_dir, graph_name)
        with open(input_prefix + ".txt", 'w') as f: f.write("dummy")
        n1, n2, u_rev, v_rev = self.save_binary_graph(snapshot_data, input_prefix)
        
        output_filename = f"bicliques_{snapshot_id}_tau{tau}_eps{epsilon}.txt"
        output_file = os.path.join(self.temp_dir, output_filename)
        
        if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
            return output_file
            
        if not os.path.exists(self.msbe_exe_path):
            print(f"Warning: MSBE executable not found at {self.msbe_exe_path}. Returning empty bicliques.")
            with open(output_file, 'w') as f: pass
            return output_file

        try:
            subprocess.run([self.msbe_exe_path, f"{graph_name}.txt", "1", "1", "0.3", "GRL3"], 
                         cwd=self.temp_dir, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            result = subprocess.run([self.msbe_exe_path, f"{graph_name}.txt", "0", "1", "0.3", "GRL3", "1", "GRL3", "0", "0", "heu", "4", str(epsilon), str(tau), "2"],
                                  cwd=self.temp_dir, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', errors='ignore')
            output_str = result.stdout
        except Exception as e:
            print(f"Mining error: {e}")
            return output_file

        bicliques = []
        current_cl, current_cr = [], []
        for line in output_str.split('\n'):
            line = line.strip()
            if line.startswith("CL :"): current_cl = [int(x) for x in line[4:].split(',') if x.strip()]
            elif line.startswith("CR :"): current_cr = [int(x) for x in line[4:].split(',') if x.strip()]
            elif (line.startswith("---") or line.startswith("----------------")) and current_cl and current_cr:
                orig_us = [u_rev[uid] for uid in current_cl if uid in u_rev]
                orig_vs = [v_rev[vid] for vid in current_cr if vid in v_rev]
                if orig_us and orig_vs: bicliques.append((orig_us, orig_vs))
                current_cl, current_cr = [], []
                
        with open(output_file, 'w') as f:
            for us, vs in bicliques:
                f.write(f"{len(us)} {len(vs)}\n")
                f.write(" ".join(map(str, us)) + "\n")
                f.write(" ".join(map(str, vs)) + "\n")
        return output_file

    def parse_bicliques(self, biclique_file):
        biclique_users, biclique_items = [], []
        b_idx = 0
        if os.path.exists(biclique_file):
            with open(biclique_file, 'r') as f:
                lines = f.readlines()
                i = 0
                while i < len(lines):
                    try:
                        counts = lines[i].strip().split()
                        if not counts: break
                        us = list(map(int, lines[i+1].strip().split()))
                        vs = list(map(int, lines[i+2].strip().split()))
                        for u in us: biclique_users.append((u, b_idx))
                        for v in vs: biclique_items.append((b_idx, v))
                        b_idx += 1; i += 3
                    except: break
        
        if b_idx == 0:
            H_u = torch.sparse_coo_tensor(size=(self.num_users, 1))
            H_v = torch.sparse_coo_tensor(size=(1, self.num_items))
        else:
            u_indices = torch.LongTensor(biclique_users).t()
            u_values = torch.ones(len(biclique_users))
            H_u = torch.sparse_coo_tensor(u_indices, u_values, size=(self.num_users, b_idx))
            v_indices = torch.LongTensor(biclique_items).t()
            v_values = torch.ones(len(biclique_items))
            H_v = torch.sparse_coo_tensor(v_indices, v_values, size=(b_idx, self.num_items))
        return H_v, H_u

    def build_adj_matrix(self, snapshot_data):
        users = [x[0] for x in snapshot_data]
        items = [x[1] for x in snapshot_data]
        R_sp = sp.coo_matrix((np.ones(len(users)), (users, items)), shape=(self.num_users, self.num_items))
        A = sp.vstack([
            sp.hstack([sp.csr_matrix((self.num_users, self.num_users)), R_sp]),
            sp.hstack([R_sp.T, sp.csr_matrix((self.num_items, self.num_items))])
        ])
        rowsum = np.array(A.sum(1))
        d_inv_sqrt = np.power(rowsum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        norm_A = d_mat_inv_sqrt.dot(A).dot(d_mat_inv_sqrt).tocoo()
        indices = torch.LongTensor([norm_A.row, norm_A.col])
        values = torch.FloatTensor(norm_A.data)
        return torch.sparse_coo_tensor(indices, values, size=norm_A.shape)

# ==========================================
# 3. Model Architecture (Fixed FullTSBCL)
# ==========================================
class BicliqueEnhancedEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super(BicliqueEnhancedEncoder, self).__init__()
        self.embedding_dim = embedding_dim
    def forward(self, user_emb, item_emb, biclique_indices):
        H_v, H_u = biclique_indices
        biclique_features = torch.sparse.mm(H_v, item_emb)
        degree_v = torch.sparse.sum(H_v, dim=1).to_dense().view(-1, 1); degree_v[degree_v == 0] = 1.0
        biclique_features = biclique_features / degree_v
        user_local_view = torch.sparse.mm(H_u, biclique_features)
        degree_u = torch.sparse.sum(H_u, dim=1).to_dense().view(-1, 1); degree_u[degree_u == 0] = 1.0
        return user_local_view / degree_u

class LightGCNEncoder(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim, n_layers=3):
        super(LightGCNEncoder, self).__init__()
        self.n_layers = n_layers
    def forward(self, user_emb, item_emb, adj_matrix):
        all_emb = torch.cat([user_emb, item_emb], dim=0)
        embs = [all_emb]
        for _ in range(self.n_layers):
            all_emb = torch.sparse.mm(adj_matrix, all_emb)
            embs.append(all_emb)
        final_emb = torch.mean(torch.stack(embs, dim=1), dim=1)
        users, items = torch.split(final_emb, [user_emb.shape[0], item_emb.shape[0]])
        return users, items

class FullTSBCL(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=64, n_layers=3, tau=0.2):
        super(FullTSBCL, self).__init__()
        self.tau = tau
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.item_embedding = nn.Embedding(num_items, embedding_dim)
        self.global_encoder = LightGCNEncoder(num_users, num_items, embedding_dim, n_layers)
        self.local_encoder = BicliqueEnhancedEncoder(embedding_dim)
        self.user_gru = nn.GRUCell(embedding_dim, embedding_dim)
        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

    def forward(self, adj_matrix, biclique_matrices, user_history_state=None):
        u_emb = self.user_embedding.weight
        i_emb = self.item_embedding.weight
        u_global, i_global = self.global_encoder(u_emb, i_emb, adj_matrix)
        u_local = self.local_encoder(u_emb, i_emb, biclique_matrices)
        
        if user_history_state is None:
            user_history_state = torch.zeros_like(u_emb)
        
        # GRU Update
        new_user_state = self.user_gru(u_global, user_history_state)
        
        # FIX: Return new_user_state as the primary representation
        return new_user_state, u_local, new_user_state, i_global

    def calculate_loss(self, u_final, u_local, i_global, users, pos_items, neg_items):
        # u_final is now the GRU output (Time-Aware User Embedding)
        u_curr = u_final[users]
        pos_i = i_global[pos_items]
        neg_i = i_global[neg_items]
        
        # BPR Loss
        pos_scores = torch.mul(u_curr, pos_i).sum(dim=1)
        neg_scores = torch.mul(u_curr, neg_i).sum(dim=1)
        bpr_loss = -torch.mean(F.logsigmoid(pos_scores - neg_scores))
        
        # Contrastive Loss (between Time-Aware View and Biclique View)
        u_view1 = F.normalize(u_curr, dim=1)
        u_view2 = F.normalize(u_local[users], dim=1)
        pos_sim = torch.sum(u_view1 * u_view2, dim=1) / self.tau
        all_sim = torch.mm(u_view1, u_view2.t()) / self.tau
        cl_loss = -torch.mean(pos_sim - torch.logsumexp(all_sim, dim=1))
        
        return bpr_loss + 0.1 * cl_loss

# ==========================================
# 4. Training Loop
# ==========================================
def evaluate(model, test_data, utils, device, history_state, biclique_matrices):
    model.eval()
    adj = utils.build_adj_matrix(test_data).to(device)
    H_v, H_u = biclique_matrices
    H_v, H_u = H_v.to(device), H_u.to(device)
    
    test_users = list(set([x[0] for x in test_data]))
    if len(test_users) > 1000: test_users = random.sample(test_users, 1000)
    
    hits, ndcgs = 0, 0
    with torch.no_grad():
        # Pass history state to get time-aware embeddings
        u_final, _, _, i_global = model(adj, (H_v, H_u), history_state)
        
        for u in test_users:
            ground_truth = set([x[1] for x in test_data if x[0] == u])
            if not ground_truth: continue
            
            scores = torch.mm(u_final[u].unsqueeze(0), i_global.t()).squeeze()
            _, indices = torch.topk(scores, TOP_K)
            pred_items = indices.cpu().numpy()
            
            hit, dcg, idcg = 0, 0, 0
            for i, item in enumerate(pred_items):
                if item in ground_truth:
                    hit += 1
                    dcg += 1.0 / np.log2(i + 2)
            for i in range(min(len(ground_truth), TOP_K)):
                idcg += 1.0 / np.log2(i + 2)
            hits += hit / len(ground_truth)
            ndcgs += dcg / idcg if idcg > 0 else 0
            
    return hits / len(test_users), ndcgs / len(test_users)

def main():
    print("Initializing DataUtils...")
    utils = DataUtils(DATA_PATH, MSBE_EXE)
    all_data = utils.load_data()
    if not all_data: return
    
    snapshots = utils.split_snapshots(all_data, NUM_SNAPSHOTS)
    train_snapshots = snapshots[:-1]
    test_data = snapshots[-1]
    
    print("Initializing Model...")
    model = FullTSBCL(utils.num_users, utils.num_items, EMBEDDING_DIM, tau=TAU).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    
    best_recall = 0.0
    
    print(f"Starting Training for {EPOCHS} epochs...")
    for epoch in range(EPOCHS):
        start_time = time.time()
        epoch_loss = 0.0
        steps = 0
        
        # Reset history state at start of epoch (or keep it if continuous)
        # Here we reset to learn the sequence pattern from scratch each epoch
        current_history_state = None 
        
        # --- Training Phase ---
        model.train()
        for t, snapshot in enumerate(train_snapshots):
            # 1. Mining
            biclique_file = utils.run_msbe_mining(snapshot, f"train_{t}", tau=TAU, epsilon=EPSILON)
            H_v, H_u = utils.parse_bicliques(biclique_file)
            H_v, H_u = H_v.to(device), H_u.to(device)
            
            # 2. Adj
            adj = utils.build_adj_matrix(snapshot).to(device)
            
            # 3. Batch Training
            pos_interactions = [(u, i) for u, i, _ in snapshot]
            random.shuffle(pos_interactions)
            
            # We need to update state batch by batch? 
            # No, GRU state is usually per user. 
            # In this model, `user_history_state` is [num_users, dim].
            # So we can update it once per snapshot (using the whole graph).
            # But we need to backprop.
            # To avoid backprop through time across snapshots (too expensive), we detach.
            
            # Forward pass for the whole graph to get updated state
            # Note: We use the state from previous snapshot
            u_final, u_local, new_state, i_global = model(adj, (H_v, H_u), current_history_state)
            
            # Loss calculation (mini-batch)
            for i in range(0, len(pos_interactions), BATCH_SIZE):
                batch = pos_interactions[i:i+BATCH_SIZE]
                optimizer.zero_grad()
                
                users = torch.LongTensor([x[0] for x in batch]).to(device)
                pos_items = torch.LongTensor([x[1] for x in batch]).to(device)
                neg_items = torch.randint(0, utils.num_items, (len(users),)).to(device)
                
                loss = model.calculate_loss(u_final, u_local, i_global, users, pos_items, neg_items)
                loss.backward(retain_graph=True) # Retain graph because u_final is used multiple times? 
                # Actually u_final is computed once per snapshot. 
                # If we backward multiple times, gradients accumulate.
                # Better: Compute u_final inside batch loop? No, GCN is full-batch usually.
                # Standard LightGCN is full-batch forward, mini-batch loss.
                # So we do forward once, then loop batches for loss.
                
                optimizer.step()
                epoch_loss += loss.item()
                steps += 1
            
            # Update state for next snapshot
            current_history_state = new_state.detach()
            
        # --- Evaluation Phase ---
        # Use the final history state from training for the test snapshot
        # Mine bicliques for test
        biclique_file_test = utils.run_msbe_mining(test_data, "test", tau=TAU, epsilon=EPSILON)
        H_v_test, H_u_test = utils.parse_bicliques(biclique_file_test)
        
        recall, ndcg = evaluate(model, test_data, utils, device, current_history_state, (H_v_test, H_u_test))
        
        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {epoch_loss/steps:.4f} | Recall@20: {recall:.4f} | NDCG@20: {ndcg:.4f} | Time: {time.time()-start_time:.1f}s")
        
        if recall > best_recall:
            best_recall = recall
            torch.save(model.state_dict(), "full_tsbcl_fixed_best.pth")
            print(f"  >>> Best Model Saved (Recall: {best_recall:.4f})")

if __name__ == "__main__":
    main()
