In [None]:
print('hello world')

In [None]:

from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.nn import HypergraphConv
from transformers import BertModel
import torch.nn.functional as F

import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import stanza

nlp = stanza.Pipeline(lang='en', processors='tokenize,pos,lemma,depparse')

In [None]:
import torch
import torch.nn as nn
from torch_geometric.nn import HypergraphConv
from transformers import BertModel

class HGATLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(HGATLayer, self).__init__()
        self.hgat = HypergraphConv(in_channels, out_channels)

    def forward(self, x, hyperedge_index):
        return self.hgat(x, hyperedge_index)

class HGATCausalClassifier(nn.Module):
    def __init__(self, bert_model='bert-base-uncased', hidden_dim=256, out_dim=2):
        super(HGATCausalClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        self.hgat = HGATLayer(768, hidden_dim)
        self.classifier = nn.Sequential(
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, input_ids, attention_mask, hyperedge_index):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        cls_tokens = outputs.last_hidden_state.squeeze(0) # Use [CLS] token embeddings
         # Debugging line to check shape
        
        hgat_out = self.hgat(cls_tokens, hyperedge_index)
        hgat_out = hgat_out.mean(dim=0)
        return self.classifier(hgat_out)


In [None]:
def build_hyperedge_index_from_tokens(tokens, doc, max_length=128, max_hyperedges=32):
    """
    Builds padded hyperedge_index from stanza `doc`.
    Returns:
        hyperedge_index: torch.LongTensor([2, max_hyperedges * 2])
    """
    source_nodes = []
    hyperedges = []

    if not doc.sentences:
        return torch.zeros((2, max_hyperedges * 2), dtype=torch.long)

    sentence = doc.sentences[0]

    # Align words to BERT tokens (simplified: 1 word → 1 token)
    word_to_token_idx = []
    token_pos = 1  # skip [CLS]
    for word in sentence.words:
        if token_pos >= max_length - 1:
            break
        word_to_token_idx.append(token_pos)
        token_pos += 1

    for word in sentence.words:
        head = word.head
        dep = word.id

        if head == 0:
            continue

        if head - 1 < len(word_to_token_idx) and dep - 1 < len(word_to_token_idx):
            parent_idx = word_to_token_idx[head - 1]
            child_idx = word_to_token_idx[dep - 1]

            if parent_idx < max_length and child_idx < max_length:
                edge_id = len(hyperedges) + 1
                source_nodes.extend([parent_idx, child_idx])
                hyperedges.extend([edge_id, edge_id])

            if len(hyperedges) >= max_hyperedges * 2:
                break

    # Padding
    current_len = len(hyperedges)
    if current_len < max_hyperedges * 2:
        pad_len = max_hyperedges * 2 - current_len
        source_nodes += [0] * pad_len
        hyperedges += [0] * pad_len
    elif current_len > max_hyperedges * 2:
        source_nodes = source_nodes[:max_hyperedges * 2]
        hyperedges = hyperedges[:max_hyperedges * 2]

    return torch.tensor([source_nodes, hyperedges], dtype=torch.long)


In [None]:
class CausalDataset(Dataset):
    def __init__(self, csv_path, tokenizer):
        import pandas as pd

        df = pd.read_csv(csv_path)

        self.labels = df["label"].tolist()
        self.sentences = df["input_text"].tolist()

        self.encodings = tokenizer(self.sentences, truncation=True, padding='max_length', max_length=128, return_tensors="pt")
        
        
        self.hyperedge_indices = []
        for i, text in enumerate(self.sentences):
            tokens = tokenizer.convert_ids_to_tokens(self.encodings["input_ids"][i])
            doc = nlp(text)
            hyperedge_index = build_hyperedge_index_from_tokens(tokens, doc)
            self.hyperedge_indices.append(hyperedge_index)


        # Create dummy or real hyperedge indices for now (update with actual graph logic)
        

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            "input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long),
            "attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long),
            "label": torch.tensor(self.labels[idx], dtype=torch.long),
            "hyperedge_index": self.hyperedge_indices[idx],
        }


In [None]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['label'] for item in batch]
    hyperedge_indices = [item['hyperedge_index'] for item in batch]

    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = torch.stack(labels)

    # Combine hyperedge_index tensors and shift node indices
    hyperedge_index_combined = []
    node_offset = 0
    for i, edge_index in enumerate(hyperedge_indices):
        edge_index = edge_index.clone()
        edge_index[0, :] += node_offset  # shift node indices
        hyperedge_index_combined.append(edge_index)
        node_offset += batch[i]['input_ids'].size(0)
        
  # or len(batch[i]['input_ids'])

    # Concatenate all hyperedge indices
    hyperedge_index = torch.cat(hyperedge_index_combined, dim=1)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "label": labels,
        "hyperedge_index": hyperedge_index,
    }


In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
dataset = CausalDataset("causal_classification_dataset.csv", tokenizer)




In [None]:
total_size = len(dataset)
train_size = int(0.7 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size


In [None]:
from torch.utils.data import random_split
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # for reproducibility
)


In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)
test_loader = DataLoader(test_dataset, batch_size=4)
model = HGATCausalClassifier()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

In [None]:
print(train_loader)

In [None]:
print(dataset.__getitem__(1)['hyperedge_index'])

In [None]:
def train_model(model, dataloader, optimizer, criterion, device, epochs=5):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            batch_size = batch['input_ids'].size(0)
            for i in range(batch_size):
                input_ids = batch['input_ids'][i].unsqueeze(0).to(device)
                attention_mask = batch['attention_mask'][i].unsqueeze(0).to(device)
                labels = batch['label'][i].unsqueeze(0).to(device)
                
                hyperedge_index = batch['hyperedge_index'][i].to(device) 
                 # shape: [2, N]
                

                if hyperedge_index.ndim != 2 or hyperedge_index.shape[0] != 2 :
                    print(f"Skipping bad hyperedge_index shape: {hyperedge_index.shape}")
                    continue

                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask, hyperedge_index)
                outputs=outputs.unsqueeze(0)
                loss = criterion(outputs,labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}")

    return model


In [None]:
model=train_model(model,train_loader,optimizer,criterion,device='cuda' if torch.cuda.is_available() else 'cpu',epochs=3)
