# Run MSBEGCL on Kaggle

This notebook sets up the environment, compiles the necessary C++ mining tools, prepares the data, runs the mining algorithm to generate bicliques, and finally trains the MSBEGCL recommender system.

In [None]:
import os, sys, subprocess, time, shutil, struct

# --- Configuration ---
repo_url = 'https://github.com/yangzeha/MSBEGCL.git'
repo_dir = 'MSBEGCL'
model_name = 'MSBEGCL'
dataset_name = 'yelp2018'

# 1. Clean and Clone Repository (Silent)
if os.path.exists(repo_dir):
    try:
        shutil.rmtree(repo_dir)
    except Exception as e:
        subprocess.run(['rm', '-rf', repo_dir], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

try:
    subprocess.run(['git', 'clone', '-b', 'master', repo_url], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except subprocess.CalledProcessError as e:
    sys.exit(1)

# 2. Setup Directories (Silent)
if os.path.basename(os.getcwd()) != repo_dir:
    os.chdir(repo_dir)

# [Robustness Fix]: Auto-detect nested structure
roots = os.listdir('.')
target_structure_found = False
possible_subdirs = ['.', 'MSBEGCL', 'msbegcl', repo_dir]

for d in possible_subdirs:
    if d == '.': path_to_check = '.'
    else:
        path_to_check = d
        if not os.path.exists(d) or not os.path.isdir(d): continue
    contents = os.listdir(path_to_check)
    if 'SELFRec' in contents and 'Similar-Biclique-Idx' in contents:
        if d != '.': os.chdir(d)
        target_structure_found = True
        break

if not target_structure_found:
    found = False
    for root, dirs, files in os.walk('.'):
        if 'SELFRec' in dirs:
            os.chdir(root)
            found = True
            break
    if not found:
        print("CRITICAL ERROR: Could not locate SELFRec directory anywhere.")

selfrec_path = 'SELFRec'
msbe_path = 'Similar-Biclique-Idx'

# 3. Install Dependencies (Silent)
subprocess.run([sys.executable, '-m', 'pip', 'install', 'PyYAML==6.0.2', 'scipy==1.14.1', '-q'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
try:
    subprocess.run([sys.executable, '-m', 'pip', 'install', 'faiss-cpu', '-q'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except:
    pass

# 4. Compile C++ Mining Tools (Silent)
sparsez_dir = 'sparsehash'
if not os.path.exists(sparsez_dir):
    subprocess.run(['git', 'clone', 'https://github.com/sparsehash/sparsehash.git'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    cwd_backup = os.getcwd()
    os.chdir(sparsez_dir)
    try:
        subprocess.run(['chmod', '+x', 'configure'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) 
        subprocess.run(['./configure'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        subprocess.run(['make'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    except Exception as e:
        pass
    finally:
        os.chdir(cwd_backup)

# Compile msbe
msbe_src = os.path.join(msbe_path, 'main.cpp')
msbe_exe = './msbe'
if not os.path.exists(msbe_src):
    print(f"CRITICAL ERROR: Source file {msbe_src} not found!")
else:
    subprocess.run(['g++', '-w', '-O3', msbe_src, '-o', msbe_exe, '-I', msbe_path, '-I', 'sparsehash/src', '-D_PrintResults_', '-D_CheckResults_'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
    subprocess.run(['chmod', '+x', msbe_exe], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

# 5. Data Preprocessing
print(f'\n--- Preprocessing {dataset_name} for Mining ---')
train_file = os.path.join(selfrec_path, 'dataset', dataset_name, 'train.txt')
mining_graph_txt = 'graph.txt'

if not os.path.exists(train_file):
    print(f"CRITICAL ERROR: Data file {train_file} not found!")
else:
    users = set()
    items = set()
    edges = []
    with open(train_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                u, i = parts[0], parts[1]
                users.add(u)
                items.add(i)
                edges.append((u, i))

    try:
        sorted_users = sorted(list(users), key=lambda x: int(x))
        sorted_items = sorted(list(items), key=lambda x: int(x))
    except:
        sorted_users = sorted(list(users))
        sorted_items = sorted(list(items))

    u_map = {u: idx for idx, u in enumerate(sorted_users)}
    i_map = {i: idx for idx, i in enumerate(sorted_items)}

    n1 = len(users)
    n2 = len(items)
   
    print(f'Preprocessing graph with {n1} users, {n2} items, {len(edges)} edges.')
    
    total_nodes = n1 + n2
    adj = [[] for _ in range(total_nodes)]
    edge_count = 0
    
    for u, i in edges:
        uid = u_map[u]
        iid = i_map[i] + n1
        adj[uid].append(iid)
        adj[iid].append(uid)
        edge_count += 2
        
    for k in range(total_nodes):
        adj[k].sort()
        
    degree_file = 'graph_b_degree.bin'
    with open(degree_file, 'wb') as f:
        f.write(struct.pack('I', 4))
        f.write(struct.pack('I', n1))
        f.write(struct.pack('I', n2))
        f.write(struct.pack('I', edge_count))
        degrees = [len(adj[k]) for k in range(total_nodes)]
        f.write(struct.pack(f'{total_nodes}I', *degrees))
        
    adj_file = 'graph_b_adj.bin'
    with open(adj_file, 'wb') as f:
        flat_adj = []
        for k in range(total_nodes):
            flat_adj.extend(adj[k])
        f.write(struct.pack(f'{edge_count}I', *flat_adj))
        
    print(f"Generated binary graph files.")
    
    with open(mining_graph_txt, 'w') as f:
        f.write("dummy")

# 6. Run Mining
print('\n--- Mining Bicliques (Strict Mode) ---')
sim_threshold = 0.3    # Updated per request
size_threshold = 3     # Updated per request

if os.path.exists(msbe_exe) and os.path.exists(mining_graph_txt):
    print('Building Index...')
    subprocess.run([msbe_exe, mining_graph_txt, '1', '1', str(sim_threshold), 'GRL3'], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

    print('Enumerating...')
    raw_bicliques_file = 'bicliques_raw.txt'
    with open(raw_bicliques_file, 'w') as outfile:
        subprocess.run([
            msbe_exe, mining_graph_txt, 
            '0', '1', str(sim_threshold), 'GRL3', 
            '1', 'GRL3', 
            '0', '0', 'heu', 
            '4', str(sim_threshold), str(size_threshold), '2'
        ], stdout=outfile, stderr=subprocess.DEVNULL, check=True)
    
    if os.path.exists(raw_bicliques_file):
        size = os.path.getsize(raw_bicliques_file)
        print(f"Mining output file size: {size} bytes")
else:
    print("Skipping mining due to compliation or data failure.")

# 7. Process Bicliques -> Model Format
print('\n--- Formatting Bicliques for Model ---')
final_biclique_path = os.path.join(selfrec_path, 'dataset', dataset_name, 'bicliques.txt')
count = 0

if os.path.exists('bicliques_raw.txt'):
    with open('bicliques_raw.txt', 'r') as fr, open(final_biclique_path, 'w') as fw:
        for line in fr:
            line = line.strip()
            if not line: continue
            clean_line = line.replace('|', ' ').replace(',', ' ').replace(':', ' ')
            tokens = clean_line.split()
            current_users = []
            current_items = []
            for t in tokens:
                if not t.isdigit(): continue
                nid = int(t)
                if nid < n1:
                    if nid < len(sorted_users):
                        current_users.append(sorted_users[nid])
                else:
                    iid = nid - n1
                    if iid >= 0 and iid < len(sorted_items):
                        current_items.append(sorted_items[iid])
            if len(current_users) > 0 and len(current_items) > 0:
                fw.write(f"{' '.join(current_users)} | {' '.join(current_items)}\n")
                count += 1
    print(f"Processed {count} bicliques into {final_biclique_path}")
else:
    print("Warning: bicliques_raw.txt not found.")

# 8. Update Configuration
conf_path = os.path.join(selfrec_path, 'conf', 'MSBEGCL.yaml')

if os.path.exists(conf_path):
    with open(conf_path, 'r') as f:
        conf_content = f.read()

    new_path = f'./dataset/{dataset_name}/bicliques.txt'
    import re
    
    # Core Parameters
    conf_content = re.sub(r'biclique\.file:.*', f'biclique.file: {new_path}', conf_content)
    conf_content = re.sub(r'lambda:.*', 'lambda: 0.3', conf_content)
    conf_content = re.sub(r'gamma:.*', 'gamma: 0.8', conf_content)
    conf_content = re.sub(r'n_layer:.*', 'n_layer: 2', conf_content)
    
    # Ensure all new params exist or replace them
    updates = {
        'tau:': 'tau: 0.15',
        'eps:': 'eps: 0.1',
        'alpha:': 'alpha: 0.1',
        'local_k:': 'local_k: 3',
        'top_k_neighbors:': 'top_k_neighbors: 5',
        'sim_threshold:': 'sim_threshold: 0.3', 
        'warmup_epochs:': 'warmup_epochs: 10',
        'use_dynamic_weight:': 'use_dynamic_weight: true',
        'batch.size:': 'batch.size: 4096'
    }

    for key, val in updates.items():
        if key.strip(':') in conf_content:
            conf_content = re.sub(rf'{key.strip(":")}:.*', val, conf_content)
        else:
            conf_content += f'\n{val}'

    with open(conf_path, 'w') as f:
        f.write(conf_content)
    print("Updated MSBEGCL.yaml with Advanced Config.")

# [Patching] Write Helper Files
print('--- Patching Model Code in Notebook ---')

# 1. denoising_helper.py
denoising_code = r'''
import numpy as np
import random
from collections import defaultdict

def filter_large_bicliques(bicliques, max_size=30):
    """过滤过大的二团（可能是噪声）"""
    filtered = []
    for users, items in bicliques:
        if len(users) <= max_size and len(items) <= max_size:
            filtered.append((users, items))
    return filtered

def compute_jaccard_similarity(set1, set2):
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0

def prune_low_similarity_clusters(bicliques, interaction_mat, min_sim=0.2):
    """基于内部相似度修剪二团"""
    pruned = []
    for users, items in bicliques:
        user_sims = []
        user_list = list(users)
        for i in range(len(user_list)):
            for j in range(i+1, len(user_list)):
                # Handle sparse matrix efficiently
                try:
                    u1_idx = interaction_mat[user_list[i]].indices
                    u2_idx = interaction_mat[user_list[j]].indices
                except:
                    # Fallback for non-sparse or different structure
                    try:
                         u1_idx = np.where(interaction_mat[user_list[i]] > 0)[0]
                         u2_idx = np.where(interaction_mat[user_list[j]] > 0)[0]
                    except:
                        continue
                
                u1_interactions = set(u1_idx)
                u2_interactions = set(u2_idx)
                sim = compute_jaccard_similarity(u1_interactions, u2_interactions)
                user_sims.append(sim)
        
        avg_sim = np.mean(user_sims) if user_sims else 0
        if avg_sim >= min_sim:
            pruned.append((users, items))
    return pruned

def build_enhanced_neighbor_dict(bicliques, user_map, item_map, top_k=10):
    user_neighbors = defaultdict(list)
    item_neighbors = defaultdict(list)
    
    for users, items in bicliques:
        user_ids = [user_map[u] for u in users if u in user_map]
        item_ids = [item_map[i] for i in items if i in item_map]
        
        for u in user_ids:
            for other_u in user_ids:
                if u != other_u: user_neighbors[u].append(other_u)
        
        for i in item_ids:
            for other_i in item_ids:
                if i != other_i: item_neighbors[i].append(other_i)
    
    enhanced_user = {}
    enhanced_item = {}
    
    for u, neighbors in user_neighbors.items():
        unique_neighbors = list(set(neighbors))
        if len(unique_neighbors) > top_k:
            enhanced_user[u] = random.sample(unique_neighbors, top_k)
        else:
            enhanced_user[u] = unique_neighbors
    
    for i, neighbors in item_neighbors.items():
        unique_neighbors = list(set(neighbors))
        if len(unique_neighbors) > top_k:
            enhanced_item[i] = random.sample(unique_neighbors, top_k)
        else:
            enhanced_item[i] = unique_neighbors
    
    return enhanced_user, enhanced_item
'''
denoise_path = os.path.join(selfrec_path, 'util', 'denoising_helper.py')
with open(denoise_path, 'w') as f:
    f.write(denoising_code)
print("Created util/denoising_helper.py")

# 2. msbe_helper.py
msbe_helper_code = r'''
import os
import random
import numpy as np
from util.denoising_helper import filter_large_bicliques, prune_low_similarity_clusters, build_enhanced_neighbor_dict

def load_msbe_neighbors(file_path, user_map, item_map, interaction_mat=None, sim_threshold=0.3):
    """加载并增强二团邻居"""
    user_neighbors = {}
    item_neighbors = {}
    
    print(f"Loading bicliques from {file_path}, sim_threshold={sim_threshold}")
    
    if os.path.exists(file_path):
        bicliques = []
        with open(file_path, 'r') as f:
            for line in f:
                parts = line.strip().split('|')
                if len(parts) < 2: continue
                
                users = [u.strip() for u in parts[0].split() if u.strip() in user_map]
                items = [i.strip() for i in parts[1].split() if i.strip() in item_map]
                
                if len(users) > 0 and len(items) > 0:
                    bicliques.append((users, items))
        
        print(f"Original bicliques: {len(bicliques)}")
        
        # Step 1: Filter
        filtered = filter_large_bicliques(bicliques, max_size=30)
        print(f"Filtered bicliques: {len(filtered)}")
        
        # Step 2: Prune
        if interaction_mat is not None:
            # Note: This step can be slow. If interaction_mat is sparse CSR, logic in helper handles it.
            try:
                pruned = prune_low_similarity_clusters(filtered, interaction_mat, min_sim=sim_threshold)
            except Exception as e:
                print(f"Warning: Pruning failed ({e}), using filtered only.")
                pruned = filtered
            print(f"Pruned bicliques: {len(pruned)}")
        else:
            pruned = filtered
        
        # Step 3: Global Neighbor Dict
        user_neighbors, item_neighbors = build_enhanced_neighbor_dict(
            pruned, user_map, item_map, top_k=15
        )
        
        u_c = sum(1 for v in user_neighbors.values() if len(v) > 0)
        i_c = sum(1 for v in item_neighbors.values() if len(v) > 0)
        print(f"Enhanced Neighbors: {u_c} users, {i_c} items.")
    
    return user_neighbors, item_neighbors
'''
helper_path = os.path.join(selfrec_path, 'util', 'msbe_helper.py')
with open(helper_path, 'w') as f:
    f.write(msbe_helper_code)
print("Patched util/msbe_helper.py")

# 3. MSBEGCL.py
model_code = r'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from base.graph_recommender import GraphRecommender
from util.sampler import next_batch_pairwise
from base.torch_interface import TorchGraphInterface
from util.loss_torch import bpr_loss, l2_reg_loss, InfoNCE
import random
import numpy as np
from util.msbe_helper import load_msbe_neighbors

class MSBEGCL(GraphRecommender):
    def __init__(self, conf, training_set, test_set):
        super(MSBEGCL, self).__init__(conf, training_set, test_set)
        args = self.config['MSBEGCL']
        
        # Core Params
        self.n_layers = int(args['n_layer'])
        self.eps = float(args.get('eps', 0.1))
        self.tau = float(args.get('tau', 0.2))
        
        # Loss Weights
        self.base_cl_rate = float(args['lambda'])  # Uniformity
        self.base_msb_rate = float(args['gamma'])  # Structure
        self.local_rate = float(args.get('alpha', 0.1)) # Local Sim
        self.top_k = int(args.get('top_k_neighbors', 5))
        self.local_k = int(args.get('local_k', 3))
        
        # Dynamic Schedule
        self.warmup_epochs = int(args.get('warmup_epochs', 10))
        self.use_dynamic_weight = args.get('use_dynamic_weight', True)
        
        print(f"MSBEGCL Config: lambda={self.base_cl_rate}, gamma={self.base_msb_rate}, tau={self.tau}")
        print(f"Local Sim: alpha={self.local_rate}, local_k={self.local_k}")
        
        self.model = MSBEGCL_Encoder(self.data, self.emb_size, self.eps, self.n_layers)
        
        # Load Biclique Neighbors
        self.biclique_file = args['biclique.file']
        self.sim_threshold = float(args.get('sim_threshold', 0.3))
        self.user_msb_neighbors, self.item_msb_neighbors = self.load_denoised_neighbors()
        
        # Cache for Local Similarity
        self.user_sim_cache = {}
        self.item_sim_cache = {}
        
    def load_denoised_neighbors(self):
        # We pass interaction_mat (typically SciPy sparse CSR) to the helper
        user_neighbors, item_neighbors = load_msbe_neighbors(
            self.biclique_file,
            self.data.user,
            self.data.item,
            self.data.interaction_mat, 
            self.sim_threshold
        )
        return user_neighbors, item_neighbors

    def compute_local_similarity(self, embeddings, indices, cache_dict):
        """Batch-wise Local Similarity Calculation"""
        batch_embeddings = embeddings[indices]
        sim_matrix = torch.matmul(batch_embeddings, batch_embeddings.T)
        
        # Mask self
        mask = torch.eye(len(indices), device=embeddings.device).bool()
        sim_matrix = sim_matrix.masked_fill(mask, -1e9)
        return sim_matrix
    
    def local_sim_loss(self, view1, view2, indices, neighbor_dict, cache_key='user'):
        if len(indices) < 2: return torch.tensor(0.0, device=view1.device)
        
        # Compute similarities within views
        sim_matrix_v1 = self.compute_local_similarity(view1, indices, self.user_sim_cache)
        sim_matrix_v2 = self.compute_local_similarity(view2, indices, self.user_sim_cache)
        
        # Select top-k similar nodes (mining hard positives/local context)
        k_val = min(self.local_k, len(indices)-1)
        topk_v1 = torch.topk(sim_matrix_v1, k=k_val, dim=1)
        topk_v2 = torch.topk(sim_matrix_v2, k=k_val, dim=1)
        
        loss = 0.0
        batch_size = len(indices)
        
        # 1. Structural Anchor Contrast (if available)
        for i in range(batch_size):
            node_idx = indices[i].item()
            if cache_key == 'user' and node_idx in self.user_msb_neighbors:
                msb_neighbors = self.user_msb_neighbors[node_idx]
                if len(msb_neighbors) > 0:
                    sampled = random.sample(msb_neighbors, min(self.top_k, len(msb_neighbors)))
                    # Structure view from view2
                    struct_emb = torch.mean(view2[sampled], dim=0)
                    anchor = view1[node_idx]
                    
                    # Contrastive Term
                    pos_score = torch.exp(torch.dot(anchor, struct_emb) / self.tau)
                    self_score = torch.exp(torch.dot(anchor, view2[node_idx]) / self.tau)
                    loss += -torch.log(pos_score / (self_score + 1e-8))

        # 2. Local Neighborhood Alignment
        for k in range(k_val):
            pos_v1 = view1[topk_v1.indices[:, k]]
            pos_v2 = view2[topk_v2.indices[:, k]]
            anchors_v1 = view1[indices]
            anchors_v2 = view2[indices]
            
            loss += 0.5 * InfoNCE(anchors_v1, pos_v2, self.tau)
            loss += 0.5 * InfoNCE(anchors_v2, pos_v1, self.tau)
        
        return loss / (batch_size * (self.local_k + 1))
    
    def cal_enhanced_msbe_loss(self, independent_indices, neighbor_dict, view1, view2):
        loss = 0.0
        nodes_cpu = independent_indices.cpu().tolist()
        anchors, positives = [], []
        
        for idx in nodes_cpu:
            if idx in neighbor_dict and len(neighbor_dict[idx]) > 0:
                neighbors = neighbor_dict[idx]
                k = min(self.top_k, len(neighbors))
                sampled_neighbors = random.sample(neighbors, k)
                
                # Aggregate Neighbor Representations
                neighbor_embs = view2[sampled_neighbors] # Use other view
                pos_emb = torch.mean(neighbor_embs, dim=0)
                
                # Robustness: Add noise (Stochastic Structure)
                if self.training and random.random() > 0.5:
                    noise = torch.randn_like(pos_emb) * 0.01
                    pos_emb = F.normalize(pos_emb + noise, dim=-1)
                    
                anchors.append(view1[idx])
                positives.append(pos_emb)
        
        if len(anchors) > 0:
            anchors = torch.stack(anchors)
            positives = torch.stack(positives)
            loss = InfoNCE(anchors, positives, self.tau)
            
            # Additional Hard Negative Penalty
            if len(anchors) > 1:
                sim_matrix = torch.matmul(anchors, positives.T)
                mask = torch.eye(len(anchors), device=anchors.device).bool()
                sim_matrix = sim_matrix.masked_fill(mask, -1e9)
                hardest_neg = torch.max(sim_matrix, dim=1)[0]
                loss += 0.1 * torch.mean(hardest_neg)
                
        return loss

    def train(self):
        model = self.model.cuda()
        optimizer = torch.optim.Adam(model.parameters(), lr=self.lRate)
        
        for epoch in range(self.maxEpoch):
            # Dynamic Weights
            if self.use_dynamic_weight:
                warmup_ratio = min(1.0, (epoch + 1) / self.warmup_epochs)
                cl_rate = self.base_cl_rate * warmup_ratio
                msb_rate = self.base_msb_rate * warmup_ratio
                local_rate = self.local_rate * warmup_ratio
            else:
                cl_rate, msb_rate, local_rate = self.base_cl_rate, self.base_msb_rate, self.local_rate
                
            epoch_rec_loss = epoch_cl_loss = epoch_msb_loss = epoch_local_loss = 0
            batch_count = 0
            
            for n, batch in enumerate(next_batch_pairwise(self.data, self.batch_size)):
                user_idx, pos_idx, neg_idx = batch
                
                # 1. Main Rec (Clean)
                rec_user_emb, rec_item_emb = model(perturbed=False)
                user_emb = rec_user_emb[user_idx]
                pos_item_emb = rec_item_emb[pos_idx]
                neg_item_emb = rec_item_emb[neg_idx]
                l_rec = bpr_loss(user_emb, pos_item_emb, neg_item_emb)
                
                # 2. Augmentation Views
                user_v1, item_v1 = model(perturbed=True)
                user_v2, item_v2 = model(perturbed=True)
                
                u_uniq = torch.unique(torch.tensor(user_idx).cuda())
                i_uniq = torch.unique(torch.tensor(pos_idx).cuda())
                
                # 3. Uniformity (SimGCL)
                l_uniform = InfoNCE(user_v1[u_uniq], user_v2[u_uniq], self.tau) + \
                            InfoNCE(item_v1[i_uniq], item_v2[i_uniq], self.tau)
                            
                # 4. Structural Loss (MSBE Enhanced)
                l_struct = self.cal_enhanced_msbe_loss(u_uniq, self.user_msb_neighbors, user_v1, user_v2) + \
                           self.cal_enhanced_msbe_loss(i_uniq, self.item_msb_neighbors, item_v1, item_v2)
                           
                # 5. Local Similarity (GLSCL Style)
                l_local = self.local_sim_loss(user_v1, user_v2, u_uniq, self.user_msb_neighbors, 'user') + \
                          self.local_sim_loss(item_v1, item_v2, i_uniq, self.item_msb_neighbors, 'item')
                
                # Total Loss
                total_loss = l_rec + \
                             cl_rate * l_uniform + \
                             msb_rate * l_struct + \
                             local_rate * l_local + \
                             l2_reg_loss(self.reg, user_emb, pos_item_emb, neg_item_emb)
                
                optimizer.zero_grad()
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                epoch_rec_loss += l_rec.item()
                epoch_cl_loss += l_uniform.item()
                epoch_msb_loss += l_struct.item()
                epoch_local_loss += l_local.item()
                batch_count += 1
                
                if n % 100 == 0 and n > 0:
                    print(f'Epoch {epoch} Batch {n}: Rec={l_rec.item():.4f} Uni={l_uniform.item():.4f} Str={l_struct.item():.4f} Loc={l_local.item():.4f} (G={msb_rate:.3f})')
                    
            # Evaluation
            print(f'\nEpoch {epoch} Avg: Rec={epoch_rec_loss/batch_count:.4f} Uni={epoch_cl_loss/batch_count:.4f} Str={epoch_msb_loss/batch_count:.4f} Loc={epoch_local_loss/batch_count:.4f}')
            with torch.no_grad():
                self.user_emb, self.item_emb = self.model(perturbed=False)
            self.fast_evaluation(epoch)
        self.user_emb, self.item_emb = self.best_user_emb, self.best_item_emb

    def save(self):
        with torch.no_grad():
            self.best_user_emb, self.best_item_emb = self.model(perturbed=False)

    def predict(self, u):
        u = self.data.get_user_id(u)
        score = torch.matmul(self.user_emb[u], self.item_emb.transpose(0, 1))
        return score.cpu().numpy()

class MSBEGCL_Encoder(nn.Module):
    def __init__(self, data, emb_size, eps, n_layers):
        super(MSBEGCL_Encoder, self).__init__()
        self.data = data
        self.eps = eps
        self.emb_size = emb_size
        self.n_layers = n_layers
        self.sparse_norm_adj = TorchGraphInterface.convert_sparse_mat_to_tensor(data.norm_adj).cuda()
        self.embedding_dict = self._init_model()

    def _init_model(self):
        initializer = nn.init.xavier_uniform_
        embedding_dict = nn.ParameterDict({
            'user_emb': nn.Parameter(initializer(torch.empty(self.data.user_num, self.emb_size))),
            'item_emb': nn.Parameter(initializer(torch.empty(self.data.item_num, self.emb_size))),
        })
        return embedding_dict

    def forward(self, perturbed=False):
        ego_embeddings = torch.cat([self.embedding_dict['user_emb'], self.embedding_dict['item_emb']], 0)
        all_embeddings = []
        for k in range(self.n_layers):
            ego_embeddings = torch.sparse.mm(self.sparse_norm_adj, ego_embeddings)
            if perturbed:
                random_noise = torch.rand_like(ego_embeddings).cuda()
                ego_embeddings += torch.sign(ego_embeddings) * F.normalize(random_noise, dim=-1) * self.eps
            all_embeddings.append(ego_embeddings)
        all_embeddings = torch.stack(all_embeddings, dim=1)
        all_embeddings = torch.mean(all_embeddings, dim=1)
        user_all_embeddings, item_all_embeddings = torch.split(all_embeddings, [self.data.user_num, self.data.item_num])
        return user_all_embeddings, item_all_embeddings
'''
model_path = os.path.join(selfrec_path, 'model', 'graph', 'MSBEGCL.py')
with open(model_path, 'w') as f:
    f.write(model_code)
print("Patched model/graph/MSBEGCL.py")

# 9. Run MSBEGCL
print('\n--- Starting Training ---')

main_py_path = os.path.join(selfrec_path, 'main.py')
if not os.path.exists(main_py_path):
    print(f"CRITICAL: {main_py_path} not found.")

os.chdir(selfrec_path)
print(f"Changed directory to {os.getcwd()} for training.")

process = subprocess.Popen(
    [sys.executable, '-u', 'main.py'],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT, 
    text=True,
    bufsize=1
)

try:
    process.stdin.write(f'{model_name}\n')
    process.stdin.flush()
    process.stdin.close()
except Exception as e:
    print(f"Error writing to stdin: {e}")

while True:
    line = process.stdout.readline()
    if not line and process.poll() is not None:
        break
    if line:
        l = line.strip()
        keywords = ["training:", "Ranking Performance", "*Current Performance*", "*Best Performance*", "Epoch:", "Hit Ratio", "Traceback", "Error", "File \""]
        if any(k in l for k in keywords):
            print(l)

if process.poll() != 0:
    print("Training failed.")
else:
    print("Training finished successfully.")