This cell initializes spaCy for NLP tasks and loads the RoBERTa model and tokenizer for embeddings. Additionally, it loads and samples 90,000 sentences from a Yelp dataset for pre-training.

In [None]:
import spacy
import torch
import pandas as pd
import pickle
import os
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from transformers import AutoTokenizer, AutoModel
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv
import random
import numpy as np
import warnings

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

warnings.filterwarnings("ignore", category=UserWarning, module="torch_geometric.data.collate")
warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.modules.loss")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    print("Random seeds set successfully.")
except Exception as e:
    print(f"Error setting random seeds: {e}")
    exit()
torch.set_float32_matmul_precision('medium')

try:
    torch.cuda.empty_cache()
    print("GPU memory cleared successfully.")
except Exception as e:
    print(f"Error clearing GPU memory: {e}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(0)
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    print(f"PyTorch CUDA version: {torch.version.cuda}")
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
else:
    print("CUDA not available. Exiting.")
    exit()
os.makedirs('data', exist_ok=True)
print("Ensured data/ directory exists")

print("PyTorch Lightning version:", pl.__version__)

try:
    nlp = spacy.load('en_core_web_sm', disable=['ner', 'lemmatizer'])
    tokenizer = AutoTokenizer.from_pretrained('roberta-base')
    roberta_model = AutoModel.from_pretrained('roberta-base').to(device)
    print(f"RoBERTa model loaded successfully on {device}.")
except Exception as e:
    print(f"Failed to load spaCy or RoBERTa: {e}")
    exit()

try:
    sentences_df = pd.read_csv('data/yelp_restaurant_sentences.csv')
    sentences_df = sentences_df[sentences_df['sentence'].notna() & (sentences_df['sentence'].str.strip() != '')]
    sentences_df = sentences_df.sample(n=90000, random_state=42)
    print(f"Sampled {len(sentences_df)} sentences for pre-training")
    if sentences_df.empty:
        raise ValueError("No valid sentences found after cleaning.")
except FileNotFoundError:
    print("Dataset file not found. Ensure 'data/yelp_restaurant_sentences.csv' exists.")
    exit()

This cell processes the sampled sentences using spaCy to extract unique dependency types (e.g., nsubj, dobj) and creates a mapping of dependency types to indices. The results are saved to a pickle file for later use in graph construction.

In [None]:

dep_types = set()
for doc in nlp.pipe(sentences_df['sentence'].tolist(), batch_size=32):
    for token in doc:
        if token.head != token:
            dep_types.add(token.dep_)
dep_type_to_idx = {dep: idx for idx, dep in enumerate(sorted(dep_types))}
print(f"Found {len(dep_types)} dependency types")

with open('data/dep_types.pkl', 'wb') as f:
    pickle.dump({'dep_types': dep_types, 'dep_type_to_idx': dep_type_to_idx}, f)

Random seeds set successfully.
GPU memory cleared successfully.
Using GPU: NVIDIA L40S
PyTorch CUDA version: 11.7
Number of CUDA devices: 1
Ensured data/ directory exists
PyTorch Lightning version: 2.0.0




tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


pytorch_model.bin:   0%|          | 0.00/501M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


RoBERTa model loaded successfully on cuda.
Sampled 90000 sentences for pre-training


This cell defines a function create_dependency_graph that constructs dependency graphs for the ARGCN model. It processes sentences to create graph structures with nodes (tokens), edges (dependency relations), and edge attributes (dependency types). It also identifies noun and adjective indices for later use in pre-training tasks

In [4]:
def create_dependency_graph(sentences, num_nodes_list):
    graphs = []
    for doc, sent, num_nodes in zip(nlp.pipe(sentences, batch_size=32), sentences, num_nodes_list):
        nodes = [token.text.lower() for token in doc][:num_nodes]
        edges = []
        edge_types = []
        for token in doc:
            if token.head != token and token.i < num_nodes and token.head.i < num_nodes:
                edges.append([token.i, token.head.i])
                edge_types.append(dep_type_to_idx.get(token.dep_, 0))
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)
        edge_types = torch.tensor(edge_types, dtype=torch.long) if edge_types else torch.empty((0,), dtype=torch.long)
        noun_indices = [i for i, token in enumerate(doc) if token.pos_ == 'NOUN' and i < num_nodes]
        adj_indices = [i for i, token in enumerate(doc) if token.pos_ == 'ADJ' and i < num_nodes]
        if edges and edge_index.max().item() >= num_nodes:
            print(f"Warning: Invalid edge_index in dependency graph for sentence '{sent}': max index {edge_index.max().item()}, num_nodes {num_nodes}")
            continue
        graph = Data(edge_index=edge_index, edge_attr=edge_types, sentence=sent, nodes=nodes, noun_indices=noun_indices, adj_indices=adj_indices)
        graphs.append(graph)
    return graphs

This cell defines a function create_contextual_graph that constructs contextual graphs for the T-GCN model based on co-occurrence within a window of tokens. It creates graph structures with nodes and edges representing token proximity, and identifies noun and adjective indices.

In [None]:
def create_contextual_graph(sentences, num_nodes_list):
    graphs = []
    for doc, sent, num_nodes in zip(nlp.pipe(sentences, batch_size=32), sentences, num_nodes_list):
        nodes = [token.text.lower() for token in doc][:num_nodes]
        edges = []
        window = 2
        for i in range(num_nodes):
            for j in range(max(0, i - window), min(num_nodes, i + window + 1)):
                if i != j:
                    edges.append([i, j])
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)
        noun_indices = [i for i, token in enumerate(doc) if token.pos_ == 'NOUN' and i < num_nodes]
        adj_indices = [i for i, token in enumerate(doc) if token.pos_ == 'ADJ' and i < num_nodes]
        if edges and edge_index.max().item() >= num_nodes:
            print(f"Warning: Invalid edge_index in contextual graph for sentence '{sent}': max index {edge_index.max().item()}, num_nodes {num_nodes}")
            continue
        graph = Data(edge_index=edge_index, sentence=sent, nodes=nodes, noun_indices=noun_indices, adj_indices=adj_indices)
        graphs.append(graph)
    return graphs


This cell defines a function get_roberta_embeddings to generate RoBERTa embeddings for sentences. It processes sentences in batches, creates dependency and contextual graphs, assigns embeddings to graph nodes, and saves the graphs with metadata (user_id, date). It validates the graphs and prints progress.

In [None]:
def get_roberta_embeddings(sentences, nodes_list):
    valid_sentences = []
    valid_nodes = []
    valid_indices = []
    for i, (sent, nodes) in enumerate(zip(sentences, nodes_list)):
        if isinstance(sent, str) and sent.strip() and nodes:
            valid_sentences.append(sent)
            valid_nodes.append(nodes)
            valid_indices.append(i)
    print(f"Batch - Total sentences: {len(sentences)}, Valid sentences: {len(valid_sentences)}")
    if not valid_sentences:
        return []
    try:
        inputs = tokenizer(valid_sentences, return_tensors='pt', padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = roberta_model(**inputs)
        embeddings = outputs.last_hidden_state.cpu()
        batched_embeddings = []
        for i, nodes in enumerate(valid_nodes):
            num_nodes = min(len(nodes), embeddings[i].shape[0] - 2)  
            if num_nodes == 0:
                print(f"Warning: No valid nodes for sentence {valid_sentences[i]}")
                continue
            emb = embeddings[i, 1:num_nodes+1].clone()
            batched_embeddings.append((valid_indices[i], emb, nodes[:num_nodes], num_nodes))
        return batched_embeddings
    except RuntimeError as e:
        print(f"Error in RoBERTa embeddings: {e}")
        return []


batch_size = 32
dep_graphs = [] 
context_graphs = [] 
for start_idx in range(0, len(sentences_df), batch_size):
    batch_df = sentences_df.iloc[start_idx:min(start_idx + batch_size, len(sentences_df))]
    batch_sentences = batch_df['sentence'].tolist()

    batch_nodes = [[token.text.lower() for token in nlp(sent)] for sent in batch_sentences]
    batch_embeddings = get_roberta_embeddings(batch_sentences, batch_nodes)
    embedding_dict = {idx: (emb, nodes, num_nodes) for idx, emb, nodes, num_nodes in batch_embeddings}
    print(f"Batch {start_idx//batch_size + 1} - Embedding dict size: {len(embedding_dict)}")
   
    batch_num_nodes = [embedding_dict.get(i, (None, None, 0))[2] for i in range(len(batch_sentences))]
    batch_dep_graphs = create_dependency_graph(batch_sentences, batch_num_nodes)
    batch_context_graphs = create_contextual_graph(batch_sentences, batch_num_nodes)
    

    for i, (row, dep_graph, context_graph) in enumerate(zip(batch_df.itertuples(), batch_dep_graphs, batch_context_graphs)):
        idx = i
        if idx in embedding_dict:
            emb, adjusted_nodes, num_nodes = embedding_dict[idx]
            dep_graph.x = emb
            context_graph.x = emb
            dep_graph.nodes = adjusted_nodes
            context_graph.nodes = adjusted_nodes
            dep_graph.user_id = row.user_id
            dep_graph.date = row.date
            context_graph.user_id = row.user_id
            context_graph.date = row.date
            dep_graphs.append(dep_graph)
            context_graphs.append(context_graph)
    print(f"Batch {start_idx//batch_size + 1} - Total graphs: {len(dep_graphs)}")

if not dep_graphs or not context_graphs:
    raise ValueError("No graphs created. Check dataset and embedding generation.")
print(f"Total dependency graphs created: {len(dep_graphs)}")
print(f"Total contextual graphs created: {len(context_graphs)}")

This cell defines a function is_valid_graph to check if a graph is valid for training. It ensures node features, edge indices, edge attributes, and labels (if present) are free of NaN, infinite values, or index errors.

In [None]:
def is_valid_graph(graph):
    """Check if a graph is valid for training (no NaN, Inf, or index issues)."""
    try:

        if graph.x is not None:
            if not torch.isfinite(graph.x).all():
                return False
            if torch.isnan(graph.x).any():
                return False
        
        if graph.edge_index is not None:
            if not torch.isfinite(graph.edge_index).all():
                return False
            if torch.isnan(graph.edge_index).any():
                return False

            if graph.edge_index.max() >= graph.num_nodes:
                return False
        

        if hasattr(graph, 'edge_attr') and graph.edge_attr is not None:
            if not torch.isfinite(graph.edge_attr).all():
                return False
            if torch.isnan(graph.edge_attr).any():
                return False
        

        if hasattr(graph, 'y') and graph.y is not None:
            if not torch.isfinite(graph.y).all():
                return False
            if torch.isnan(graph.y).any():
                return False
        
        return True
    except Exception:
        return False


This cell loads previously saved dependency and contextual graphs, validates them using is_valid_graph, and saves the valid graphs to new pickle files. It ensures no mismatch in graph counts and handles errors.

In [None]:
try:
    
    with open('data/yelp_restaurant_dep_graphs_embedded.pkl', 'rb') as f:
        dep_graphs = pickle.load(f)
    
   
    with open('data/yelp_restaurant_context_graphs_embedded.pkl', 'rb') as f:
        context_graphs = pickle.load(f)
    
    total_loaded = len(dep_graphs)
    print(f"Loaded {len(dep_graphs)} dependency graphs and {len(context_graphs)} context graphs.")
    
   
    if len(dep_graphs) != len(context_graphs):
        raise ValueError("Mismatch between number of dependency and context graphs.")
    
 
    valid_dep_graphs = []
    valid_context_graphs = []
    for dep_g, ctx_g in zip(dep_graphs, context_graphs):
        if is_valid_graph(dep_g) and is_valid_graph(ctx_g):
            valid_dep_graphs.append(dep_g)
            valid_context_graphs.append(ctx_g)
    
    num_valid = len(valid_dep_graphs)
    print(f"Retained {num_valid} valid dependency graphs and {len(valid_context_graphs)} valid context graphs after validation.")
    print(f"Valid graphs: {num_valid} out of {total_loaded} loaded ({num_valid/total_loaded*100:.2f}%).")
    

    if len(valid_dep_graphs) != len(valid_context_graphs):
        raise ValueError("Mismatch in valid graphs after filtering.")
    
    if len(valid_dep_graphs) == 0:
        raise ValueError("No valid graphs found after filtering.")
    
    with open('data/yelp_restaurant_dep_graphs_valid.pkl', 'wb') as f:
        pickle.dump(valid_dep_graphs, f)
    with open('data/yelp_restaurant_context_graphs_valid.pkl', 'wb') as f:
        pickle.dump(valid_context_graphs, f)
    print("Valid graphs saved to 'data/yelp_restaurant_dep_graphs_valid.pkl' and 'data/yelp_restaurant_context_graphs_valid.pkl'.")

except FileNotFoundError as e:
    print(f"Graph files not found: {e}. Ensure 'data/yelp_restaurant_dep_graphs_embedded.pkl' and 'data/yelp_restaurant_context_graphs_embedded.pkl' exist.")
    exit()
except Exception as e:
    print(f"An error occurred: {e}")
    exit()

Loaded 90000 dependency graphs and 90000 context graphs.
Retained 89997 valid dependency graphs and 89997 valid context graphs after validation.
Valid graphs: 89997 out of 90000 loaded (100.00%).
Valid graphs saved to 'data/yelp_restaurant_dep_graphs_valid.pkl' and 'data/yelp_restaurant_context_graphs_valid.pkl'.


In [1]:
import pickle
import torch
import numpy as np
from torch_geometric.data import Data

This cell loads the validated dependency and contextual graphs from pickle files and checks for consistency in graph counts. It handles errors if files are missing or loading fails.

In [None]:
try:
   
    with open('data/yelp_restaurant_dep_graphs_valid.pkl', 'rb') as f:
        dep_graphs = pickle.load(f)
    
    with open('data/yelp_restaurant_context_graphs_valid.pkl', 'rb') as f:
        context_graphs = pickle.load(f)
    
    print(f"Loaded {len(dep_graphs)} valid dependency graphs and {len(context_graphs)} valid context graphs.")
    
    if len(dep_graphs) != len(context_graphs):
        raise ValueError("Mismatch between number of valid dependency and context graphs.")

except FileNotFoundError as e:
    print(f"Valid graph files not found: {e}. Ensure 'data/yelp_restaurant_dep_graphs_valid.pkl' and 'data/yelp_restaurant_context_graphs_valid.pkl' exist.")
    exit()
except Exception as e:
    print(f"An error occurred while loading valid graphs: {e}")
    exit()

Loaded 89997 valid dependency graphs and 89997 valid context graphs.


This cell sets up the environment, loads dependency types and all valid graph pairs from pickle files, and trains a `HybridGCN` model combining T-GCN (contextual) and ARGCN (dependency) for pre-training. It performs two tasks: predicting opinion word embeddings and dependency types, logs losses, and saves the trained model checkpoint.

In [None]:
try:
    with open('data/yelp_restaurant_dep_graphs_valid.pkl', 'rb') as f:
        dep_graphs = pickle.load(f)
    with open('data/yelp_restaurant_context_graphs_valid.pkl', 'rb') as f:
        context_graphs = pickle.load(f)
    
    print(f"Loaded {len(dep_graphs)} dependency graphs and {len(context_graphs)} context graphs.")

    if len(dep_graphs) != len(context_graphs):
        raise ValueError("Mismatch between number of dependency and context graphs.")
   
    valid_dep_graphs = []
    valid_context_graphs = []
    for idx in range(len(dep_graphs)):
        if is_valid_graph(dep_graphs[idx]) and is_valid_graph(context_graphs[idx]):
            valid_dep_graphs.append(dep_graphs[idx])
            valid_context_graphs.append(context_graphs[idx])
            print(f"Graph pair {idx}: Valid")
        else:
            print(f"Skipping graph pair {idx} due to invalid data.")
    
    num_valid = len(valid_dep_graphs)
    print(f"Selected {num_valid} valid dependency graphs and {len(valid_context_graphs)} valid context graphs.")
    
    if num_valid == 0:
        raise ValueError("No valid graph pairs found.")
    
    with open('data/yelp_restaurant_dep_graph_valid.pkl', 'wb') as f:
        pickle.dump(valid_dep_graphs, f)
    with open('data/yelp_restaurant_context_graph_valid.pkl', 'wb') as f:
        pickle.dump(valid_context_graphs, f)
    print("Saved all valid graph pairs to 'data/yelp_restaurant_dep_graph_valid.pkl' and 'data/yelp_restaurant_context_graph_valid.pkl'.")

except FileNotFoundError as e:
    print(f"Graph files not found: {e}. Ensure 'data/yelp_restaurant_dep_graphs_valid.pkl' and 'data/yelp_restaurant_context_graphs_valid.pkl' exist.")
    exit()
except Exception as e:
    print(f"An error occurred during selection: {e}")
    exit()

Loaded 89997 dependency graphs and 89997 context graphs.
Selected 200 valid dependency graphs and 200 valid context graphs.


This cell defines a function to validate graphs for training, checking node features, edge indices, edge attributes, and noun/adjective indices. It loads dependency and contextual graphs, validates all pairs, and saves the valid ones to pickle files.

In [None]:
import pickle
import torch
from torch_geometric.data import Data

def is_valid_graph(graph, graph_idx, graph_type):
    """Validate a graph for training, ensuring flat integer noun_indices and adj_indices."""
    try:

        if graph.x is None or not torch.isfinite(graph.x).all() or torch.isnan(graph.x).any():
            print(f"{graph_type} graph {graph_idx}: Invalid x (NaN/Inf or None)")
            return False
        
        if graph.edge_index is not None and graph.edge_index.numel() > 0:
            if not torch.isfinite(graph.edge_index).all() or torch.isnan(graph.edge_index).any():
                print(f"{graph_type} graph {graph_idx}: Invalid edge_index (NaN/Inf)")
                return False
            num_nodes = graph.x.shape[0]
            max_idx = graph.edge_index.max().item()
            if max_idx >= num_nodes:
                print(f"{graph_type} graph {graph_idx}: Invalid edge_index (max index {max_idx}, num_nodes {num_nodes})")
                return False
        

        if graph_type == "Dependency" and hasattr(graph, 'edge_attr') and graph.edge_attr is not None:
            if not torch.isfinite(graph.edge_attr).all() or torch.isnan(graph.edge_attr).any():
                print(f"{graph_type} graph {graph_idx}: Invalid edge_attr (NaN/Inf)")
                return False
            if graph.edge_attr.numel() > 0:
                valid_mask = (graph.edge_attr >= 0) & (graph.edge_attr < 50) | (graph.edge_attr == -1)
                if not valid_mask.all():
                    print(f"{graph_type} graph {graph_idx}: Invalid edge_attr values: {graph.edge_attr[~valid_mask].unique()}")
                    return False
            if graph.edge_attr.numel() == 0:
                print(f"{graph_type} graph {graph_idx}: Empty edge_attr")
        
       
        if graph_type == "Dependency":
            if not (hasattr(graph, 'noun_indices') and hasattr(graph, 'adj_indices')):
                print(f"{graph_type} graph {graph_idx}: Missing noun_indices or adj_indices")
                return False
            num_nodes = graph.x.shape[0]
        
            try:
                noun_indices = [idx for idx in graph.noun_indices if isinstance(idx, int) and idx < num_nodes and idx >= 0]
                adj_indices = [idx for idx in graph.adj_indices if isinstance(idx, int) and idx < num_nodes and idx >= 0]
            except TypeError:
                print(f"{graph_type} graph {graph_idx}: Invalid noun_indices or adj_indices structure (not flat integers)")
                return False
            if not noun_indices:
                print(f"{graph_type} graph {graph_idx}: Empty or invalid noun_indices")
            if not adj_indices:
                print(f"{graph_type} graph {graph_idx}: Empty or invalid adj_indices")
            return bool(noun_indices and adj_indices)
        
        return True
    except Exception as e:
        print(f"{graph_type} graph {graph_idx}: Error during validation: {e}")
        return False

try:
   
    with open('data/yelp_restaurant_dep_graphs_valid.pkl', 'rb') as f:
        dep_graphs = pickle.load(f)
    with open('data/yelp_restaurant_context_graphs_valid.pkl', 'rb') as f:
        context_graphs = pickle.load(f)
    
    print(f"Loaded {len(dep_graphs)} dependency graphs and {len(context_graphs)} context graphs.")
    

    if len(dep_graphs) != len(context_graphs):
        raise ValueError("Mismatch between number of dependency and context graphs.")
    

    valid_dep_graphs = []
    valid_context_graphs = []
    for idx in range(len(dep_graphs)):
        if is_valid_graph(dep_graphs[idx], idx, "Dependency") and is_valid_graph(context_graphs[idx], idx, "Context"):
            valid_dep_graphs.append(dep_graphs[idx])
            valid_context_graphs.append(context_graphs[idx])
            print(f"Graph pair {idx}: Valid with non-empty integer noun_indices and adj_indices")
        else:
            print(f"Skipping graph pair {idx} due to invalid data.")
    
    num_valid = len(valid_dep_graphs)
    print(f"Selected {num_valid} valid dependency graphs and {len(valid_context_graphs)} valid context graphs.")
    
    if num_valid == 0:
        raise ValueError("No valid graph pairs found.")
    
   
    with open('data/yelp_restaurant_dep_graph_valid.pkl', 'wb') as f:
        pickle.dump(valid_dep_graphs, f)
    with open('data/yelp_restaurant_context_graph_valid.pkl', 'wb') as f:
        pickle.dump(valid_context_graphs, f)
    print("Saved all valid graph pairs to 'data/yelp_restaurant_dep_graph_valid.pkl' and 'data/yelp_restaurant_context_graph_valid.pkl'.")

except FileNotFoundError as e:
    print(f"Graph files not found: {e}. Ensure 'data/yelp_restaurant_dep_graphs_valid.pkl' and 'data/yelp_restaurant_context_graphs_valid.pkl' exist.")
    exit()
except Exception as e:
    print(f"An error occurred during selection: {e}")
    exit()

Loaded 89997 dependency graphs and 89997 context graphs.
Graph pair 0: Valid with non-empty integer noun_indices and adj_indices
Graph pair 1: Valid with non-empty integer noun_indices and adj_indices
Graph pair 2: Valid with non-empty integer noun_indices and adj_indices
Graph pair 3: Valid with non-empty integer noun_indices and adj_indices
Graph pair 4: Valid with non-empty integer noun_indices and adj_indices
Graph pair 5: Valid with non-empty integer noun_indices and adj_indices
Graph pair 6: Valid with non-empty integer noun_indices and adj_indices
Graph pair 7: Valid with non-empty integer noun_indices and adj_indices
Graph pair 8: Valid with non-empty integer noun_indices and adj_indices
Graph pair 9: Valid with non-empty integer noun_indices and adj_indices
Graph pair 10: Valid with non-empty integer noun_indices and adj_indices
Graph pair 11: Valid with non-empty integer noun_indices and adj_indices
Graph pair 12: Valid with non-empty integer noun_indices and adj_indices
Grap

This cell sets up the environment, loads dependency types and all valid graph pairs from pickle files, and trains a `HybridGCN` model combining T-GCN (contextual) and ARGCN (dependency) for pre-training. It performs two tasks: predicting opinion word embeddings and dependency types, logs losses, and saves the trained model checkpoint.

In [None]:
import pickle
import torch
import numpy as np
import random
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


try:
    with open('data/dep_types.pkl', 'rb') as f:
        dep_data = pickle.load(f)
    dep_types = dep_data['dep_types']
    dep_type_to_idx = dep_data['dep_type_to_idx']
    print(f"Loaded dep_types and dep_type_to_idx (num_dep_types: {len(dep_types)})")
except FileNotFoundError:
    print("Dependency types file not found. Ensure 'data/dep_types.pkl' exists.")
    exit()
except Exception as e:
    print(f"Error loading dependency types: {e}")
    exit()

try:

    with open('data/yelp_restaurant_dep_graph_valid.pkl', 'rb') as f:
        dep_graphs = pickle.load(f)
    with open('data/yelp_restaurant_context_graph_valid.pkl', 'rb') as f:
        context_graphs = pickle.load(f)
    
    print(f"Loaded {len(dep_graphs)} valid dependency graphs and {len(context_graphs)} valid context graphs.")
    
   
    if len(dep_graphs) != len(context_graphs):
        raise ValueError("Mismatch between number of valid dependency and context graphs.")

except FileNotFoundError as e:
    print(f"Graph files not found: {e}. Ensure 'data/yelp_restaurant_dep_graph_valid.pkl' and 'data/yelp_restaurant_context_graph_valid.pkl' exist.")
    exit()
except Exception as e:
    print(f"An error occurred while loading valid graphs: {e}")
    exit()

class HybridGCN(pl.LightningModule):
    def __init__(self, input_dim=768, hidden_dim=256, num_dep_types=50):
        super().__init__()
        self.save_hyperparameters()
        self.tgcn_conv1 = GCNConv(input_dim, hidden_dim)
        self.tgcn_conv2 = GCNConv(hidden_dim, hidden_dim)
        self.argcn_conv1 = GCNConv(input_dim, hidden_dim)
        self.argcn_conv2 = GCNConv(hidden_dim, hidden_dim)
        self.type_emb = torch.nn.Embedding(num_dep_types, hidden_dim)
        self.fc_node = torch.nn.Linear(hidden_dim * 3, input_dim)
        self.fc_edge = torch.nn.Linear(hidden_dim, num_dep_types)
        self.dep_type_to_idx = dep_type_to_idx

    def forward(self, context_data, dep_data, target_idx=None, opinion_idx=None, task='node'):
        x_tgcn = context_data.x.to(device)
        edge_index_tgcn = context_data.edge_index.to(device)
        x_argcn = dep_data.x.to(device)
        edge_index_argcn = dep_data.edge_index.to(device)
        edge_attr = dep_data.edge_attr.to(device) if dep_data.edge_attr is not None else None
        num_nodes_tgcn = x_tgcn.shape[0]
        num_nodes_argcn = x_argcn.shape[0]
        if edge_index_tgcn.numel() > 0 and edge_index_tgcn.max().item() >= num_nodes_tgcn:
            raise ValueError(f"Invalid edge_index_tgcn: max index {edge_index_tgcn.max().item()}, num_nodes {num_nodes_tgcn}")
        if edge_index_argcn.numel() > 0 and edge_index_argcn.max().item() >= num_nodes_argcn:
            raise ValueError(f"Invalid edge_index_argcn: max index {edge_index_argcn.max().item()}, num_nodes {num_nodes_argcn}")
        if edge_attr is not None and edge_attr.numel() > 0:
            valid_mask = (edge_attr >= 0) & (edge_attr < len(dep_types)) | (edge_attr == -1)
            if not valid_mask.all():
                raise ValueError(f"Invalid edge_attr values: {edge_attr[~valid_mask].unique()}. Must be in [0, {len(dep_types)-1}] or -1.")

        x_tgcn1 = torch.relu(self.tgcn_conv1(x_tgcn, edge_index_tgcn))
        x_tgcn2 = torch.relu(self.tgcn_conv2(x_tgcn1, edge_index_tgcn))
        x_tgcn = (x_tgcn1 + x_tgcn2) / 2

        edge_weights = torch.ones(edge_index_argcn.size(1), device=device)
        if edge_attr is not None:
            valid_mask = edge_attr != -1
            valid_edge_attr = edge_attr[valid_mask]
            if valid_edge_attr.numel() > 0:
                edge_weights[valid_mask] = self.type_emb(valid_edge_attr).mean(dim=1)
        x_argcn1 = torch.relu(self.argcn_conv1(x_argcn, edge_index_argcn, edge_weights))
        x_argcn2 = torch.relu(self.argcn_conv2(x_argcn1, edge_index_argcn, edge_weights))
        x_argcn = (x_argcn1 + x_argcn2) / 2

        if task == 'node' and target_idx is not None and opinion_idx is not None:
            if target_idx >= num_nodes_tgcn or opinion_idx >= num_nodes_tgcn or opinion_idx >= num_nodes_argcn:
                raise ValueError(f"Invalid node indices: target_idx={target_idx}, opinion_idx={opinion_idx}, num_nodes_tgcn={num_nodes_tgcn}, num_nodes_argcn={num_nodes_argcn}")
            x_tgcn_target = x_tgcn[target_idx].unsqueeze(0)
            x_tgcn_opinion = x_tgcn[opinion_idx].unsqueeze(0)
            x_argcn_opinion = x_argcn[opinion_idx].unsqueeze(0)
            x_combined = torch.cat([x_tgcn_target, x_tgcn_opinion, x_argcn_opinion], dim=-1)
            return self.fc_node(x_combined)
        elif task == 'edge' and opinion_idx is not None:
            if opinion_idx >= num_nodes_argcn:
                raise ValueError(f"Invalid opinion_idx for edge task: {opinion_idx}, num_nodes {num_nodes_argcn}")
            return self.fc_edge(x_argcn[opinion_idx].unsqueeze(0))
        else:
            raise ValueError("Invalid task or indices.")

    def training_step(self, batch, batch_idx):
        context_data, dep_data = batch
        context_data = context_data.to(device)
        dep_data = dep_data.to(device)
        loss = torch.tensor(0.0, device=device, requires_grad=True)

        try:
    
            if context_data.x.shape[0] == 0 or dep_data.x.shape[0] == 0:
                print(f"Skipping graph {batch_idx}: Empty graph (context nodes: {context_data.x.shape[0]}, dep nodes: {dep_data.x.shape[0]})")
                return None
            
            if context_data.edge_index.numel() > 0:
                max_idx_context = context_data.edge_index.max().item()
                if max_idx_context >= context_data.x.shape[0]:
                    raise ValueError(f"Invalid edge_index_context in graph {batch_idx}: max index {max_idx_context}, num_nodes {context_data.x.shape[0]}")
            if dep_data.edge_index.numel() > 0:
                max_idx_dep = dep_data.edge_index.max().item()
                if max_idx_dep >= dep_data.x.shape[0]:
                    raise ValueError(f"Invalid edge_index_dep in graph {batch_idx}: max index {max_idx_dep}, num_nodes {dep_data.x.shape[0]}")

            def flatten_indices(indices):
                flat_indices = []
                for item in indices:
                    if isinstance(item, (list, tuple)):
                        flat_indices.extend([idx for idx in item if isinstance(idx, int)])
                    elif isinstance(item, int):
                        flat_indices.append(item)
                return flat_indices

            num_nodes = dep_data.x.shape[0]
            noun_indices = [idx for idx in flatten_indices(dep_data.noun_indices) if idx < num_nodes and idx >= 0]
            adj_indices = [idx for idx in flatten_indices(dep_data.adj_indices) if idx < num_nodes and idx >= 0]
            edge_attr_len = dep_data.edge_attr.numel() if dep_data.edge_attr is not None else 0

            node_loss = torch.tensor(0.0, device=device, requires_grad=True)
            if noun_indices and adj_indices:
                target_idx = random.choice(noun_indices)
                opinion_idx = random.choice(adj_indices)
                true_embedding = dep_data.x[opinion_idx].clone()
                dep_data.x[opinion_idx] = torch.zeros_like(true_embedding)
                context_data.x[opinion_idx] = torch.zeros_like(true_embedding)
                pred = self(context_data, dep_data, target_idx=target_idx, opinion_idx=opinion_idx, task='node')
                node_loss = torch.nn.MSELoss()(pred.squeeze(0), true_embedding)
                loss = loss + node_loss
                self.log('node_loss', node_loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=1, sync_dist=False, reduce_fx="mean")
            else:
                print(f"Graph {batch_idx}: Skipping node loss (empty noun_indices or adj_indices)")

            edge_loss = torch.tensor(0.0, device=device, requires_grad=True)
            if dep_data.edge_attr is not None and dep_data.edge_attr.numel() > 0 and adj_indices:
                edge_idx = random.randint(0, len(dep_data.edge_attr) - 1)
                true_type_idx = dep_data.edge_attr[edge_idx].item()
                new_edge_attr = dep_data.edge_attr.clone()
                new_edge_attr[edge_idx] = -1
                dep_data.edge_attr = new_edge_attr
                opinion_idx = random.choice(adj_indices)
                if true_type_idx in self.dep_type_to_idx.values():
                    edge_pred = self(context_data, dep_data, target_idx=None, opinion_idx=opinion_idx, task='edge')
                    edge_loss = torch.nn.CrossEntropyLoss()(edge_pred, torch.tensor([true_type_idx], device=device))
                    loss = loss + edge_loss
                    self.log('edge_loss', edge_loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=1, sync_dist=False, reduce_fx="mean")
                else:
                    print(f"Graph {batch_idx}: Skipping edge loss (invalid true_type_idx: {true_type_idx})")
            else:
                print(f"Graph {batch_idx}: Skipping edge loss (empty edge_attr or adj_indices)")

            self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=1, sync_dist=False, reduce_fx="mean")
            return loss

        except Exception as e:
            print(f"Skipping graph {batch_idx} due to error: {e}")
            return None

    def train_dataloader(self):

        dataset = list(zip(context_graphs, dep_graphs))
        return DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

    def on_train_epoch_end(self):
        train_loss = self.trainer.logged_metrics.get('train_loss_epoch', 'N/A')
        node_loss = self.trainer.logged_metrics.get('node_loss_epoch', 'N/A')
        edge_loss = self.trainer.logged_metrics.get('edge_loss_epoch', 'N/A')
        print(f"Epoch {self.current_epoch} - Train Loss: {train_loss}, Node Loss: {node_loss}, Edge Loss: {edge_loss}")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

try:
    model = HybridGCN(num_dep_types=len(dep_types))
    model.dep_type_to_idx = dep_type_to_idx
    trainer = pl.Trainer(
        max_epochs=10,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        log_every_n_steps=5,
        enable_checkpointing=True,
        default_root_dir='checkpoints/'
    )
    trainer.fit(model)
    trainer.save_checkpoint('data/pretrained_gcn_1.ckpt')
    print("Pre-training completed. Checkpoint saved at: data/pretrained_final.ckpt")
except Exception as e:
    print(f"Training failed: {e}")
    raise e

Loaded dep_types and dep_type_to_idx (num_dep_types: 44)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Loaded 200 valid dependency graphs and 200 valid context graphs.


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type      | Params
------------------------------------------
0 | tgcn_conv1  | GCNConv   | 196 K 
1 | tgcn_conv2  | GCNConv   | 65.8 K
2 | argcn_conv1 | GCNConv   | 196 K 
3 | argcn_conv2 | GCNConv   | 65.8 K
4 | type_emb    | Embedding | 11.3 K
5 | fc_node     | Linear    | 590 K 
6 | fc_edge     | Linear    | 11.3 K
------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.554     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Graph 0: context nodes=76, edge_index_context.max()=75, dep nodes=76, edge_index_dep.max()=75, noun_indices=9, adj_indices=5, edge_attr_len=74
Graph 1: context nodes=36, edge_index_context.max()=35, dep nodes=36, edge_index_dep.max()=35, noun_indices=7, adj_indices=5, edge_attr_len=35
Graph 2: context nodes=99, edge_index_context.max()=98, dep nodes=99, edge_index_dep.max()=98, noun_indices=19, adj_indices=10, edge_attr_len=98
Graph 3: context nodes=66, edge_index_context.max()=65, dep nodes=66, edge_index_dep.max()=65, noun_indices=10, adj_indices=10, edge_attr_len=64
Graph 4: context nodes=175, edge_index_context.max()=174, dep nodes=175, edge_index_dep.max()=174, noun_indices=26, adj_indices=18, edge_attr_len=173
Graph 5: context nodes=36, edge_index_context.max()=35, dep nodes=36, edge_index_dep.max()=35, noun_indices=9, adj_indices=3, edge_attr_len=35
Graph 6: context nodes=97, edge_index_context.max()=96, dep nodes=97, edge_index_dep.max()=96, noun_indices=17, adj_indices=5, edge

`Trainer.fit` stopped: `max_epochs=10` reached.


Graph 195: context nodes=198, edge_index_context.max()=197, dep nodes=198, edge_index_dep.max()=197, noun_indices=44, adj_indices=20, edge_attr_len=193
Graph 196: context nodes=103, edge_index_context.max()=102, dep nodes=103, edge_index_dep.max()=102, noun_indices=27, adj_indices=13, edge_attr_len=101
Graph 197: context nodes=94, edge_index_context.max()=93, dep nodes=94, edge_index_dep.max()=93, noun_indices=14, adj_indices=9, edge_attr_len=92
Graph 198: context nodes=61, edge_index_context.max()=60, dep nodes=61, edge_index_dep.max()=60, noun_indices=15, adj_indices=7, edge_attr_len=60
Graph 199: context nodes=64, edge_index_context.max()=63, dep nodes=64, edge_index_dep.max()=63, noun_indices=9, adj_indices=7, edge_attr_len=61
Epoch 9 - Train Loss: 3.2521653175354004, Node Loss: 0.04769982770085335, Edge Loss: 3.2044663429260254
Pre-training completed. Checkpoint saved at: data/pretrained_final.ckpt
