In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import pandas as pd
from config import DEVICE, SEED, MODEL_CONFIG, TRAINING_CONFIG, DATASET_CONFIG
from model import JointCausalModel
from utility import compute_class_weights, label_value_counts
from dataset_collator import CausalDataset, CausalDatasetCollator
from config import id2label_cls, id2label_bio, id2label_rel
from evaluate_joint_causal_model import evaluate_model, print_eval_report
from trainer import train_model
import random
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_path = "/home/rnorouzini/JointLearning/src/jointlearning/expert_bert_softmax/expert_bert_softmax_model.pt"
model = JointCausalModel(**MODEL_CONFIG)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.to(DEVICE)
model.eval()

JointCausalModel(
  (enc): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG["encoder_name"])

In [4]:
test_cases = [
    {
        "name": "Scenario 1: Simple Non-Causal",
        "text": "The sky is blue.",
        "mock_data": {
            "cls_id": 0, # non-causal
            "bio_token_ids": [6, 6, 6, 6, 6] # O O O O O (assuming 5 tokens after CLS/SEP)
        },
        "settings": {"use_heuristic": False, "override_cls_if_spans_found": False},
        "expected_causal": False,
        "expected_relations_count": 0
    },
    {
        "name": "Scenario 2: Simple Causal (C -> E) - Heuristic",
        "text": "Heavy rain caused the flood.", # Tokens: Heavy, rain, caused, the, flood, .
        "mock_data": {
            "cls_id": 1, # causal
            "bio_token_ids": [0, 1, 6, 6, 2, 6] # B-C, I-C, O, O, B-E, O
        },
        "settings": {"use_heuristic": True, "override_cls_if_spans_found": False},
        "expected_causal": True,
        "expected_relations_count": 1,
        "expected_relations_texts": [("Heavy rain", "flood")]
    },
    {
        "name": "Scenario 3: Single CE Span - Heuristic (CE Self-Pair)",
        "text": "The drought was the problem.", # Tokens: The, drought, was, the, problem, .
        "mock_data": {
            "cls_id": 1, # causal
            "bio_token_ids": [4, 5, 6, 6, 6, 6] # B-CE, I-CE, O, O, O, O
        },
        "settings": {"use_heuristic": True, "override_cls_if_spans_found": False},
        "expected_causal": True,
        "expected_relations_count": 1,
        "expected_relations_texts": [("The drought", "The drought")]
    },
    {
        "name": "Scenario 4: Single CE Span - Standard (No CE Self-Pair)",
        "text": "The drought was the problem.",
        "mock_data": { # Same mock as Scenario 3
            "cls_id": 1, "bio_token_ids": [4, 5, 6, 6, 6, 6]
        },
        "settings": {"use_heuristic": False, "override_cls_if_spans_found": False},
        "expected_causal": False, # Becomes false because no relations are formed
        "expected_relations_count": 0
    },
    {
        "name": "Scenario 5: CLS Override - Non-Causal to Causal",
        "text": "Stress leads to burnout.", # Tokens: Stress, leads, to, burnout, .
        "mock_data": {
            "cls_id": 0, # Initially non-causal
            "bio_token_ids": [0, 6, 6, 2, 6] # B-C, O, O, B-E, O
        },
        "settings": {"use_heuristic": False, "override_cls_if_spans_found": True},
        "expected_causal": True, # Should be overridden to True
        "expected_relations_count": 1, # Assumes mock rel_head predicts this
        "expected_relations_texts": [("Stress", "burnout")]
    },
    {
        "name": "Scenario 6: CLS Override - Still Non-Causal (No Spans)",
        "text": "A quiet day.",
        "mock_data": {
            "cls_id": 0, # non-causal
            "bio_token_ids": [6, 6, 6] # O O O
        },
        "settings": {"use_heuristic": False, "override_cls_if_spans_found": True},
        "expected_causal": False,
        "expected_relations_count": 0
    },
    {
        "name": "Scenario 7: Multiple Causes, One Effect - Heuristic",
        "text": "Heat and lack of water caused crops to fail.",
        # Tokens: Heat, and, lack, of, water, caused, crops, to, fail, .
        "mock_data": {
            "cls_id": 1, # causal
            "bio_token_ids": [0, 6, 0, 6, 1, 6, 2, 6, 3, 6] # B-C(Heat), O, B-C(lack), O, I-C(water), O, B-E(crops), O, I-E(fail), O
        },
        "settings": {"use_heuristic": True, "override_cls_if_spans_found": False},
        "expected_causal": True,
        "expected_relations_count": 2, # (Heat, crops to fail), (lack of water, crops to fail)
        "expected_relations_texts": [("Heat", "crops to fail"), ("lack of water", "crops to fail")]
    },
     {
        "name": "Scenario 8: Invalid I-tag correction",
        "text": "Bad food made sick.", # Bad, food, made, sick, .
        "mock_data": {
            "cls_id": 1, # causal
             # O, I-C (invalid), O, B-E, O
            "bio_token_ids": [6, 1, 6, 2, 6]
        },
        "settings": {"use_heuristic": True, "override_cls_if_spans_found": False},
        "expected_causal": True,
        "expected_relations_count": 1,
        "expected_relations_texts": [("food", "sick")] # "food" becomes B-C after correction
    }
]

In [5]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from typing import Dict, Tuple, Optional, Any, List

# --- Configuration (Simplified from config.py for this test script) ---
id2label_bio = {
    0: "B-C", 1: "I-C", 2: "B-E", 3: "I-E",
    4: "B-CE", 5: "I-CE", 6: "O"
}
label2id_bio = {v: k for k, v in id2label_bio.items()}
id2label_rel = {0: "Rel_None", 1: "Rel_CE"}
label2id_rel = {v: k for k, v in id2label_rel.items()}
id2label_cls = {0: "non-causal", 1: "causal"}
label2id_cls = {v: k for k, v in id2label_cls.items()}

MODEL_CONFIG = {
    "encoder_name": "bert-base-uncased",
    "num_cls_labels": 2,
    "num_bio_labels": 7,
    "num_rel_labels": 2,
    "dropout": 0.1,
}

In [6]:

def get_mock_forward_fn(tokenized_inputs_for_mock: Dict[str, torch.Tensor], 
                        test_case_mock_data: Dict[str, Any], 
                        tokenizer_for_mock: AutoTokenizer,
                        device="cpu"):
    """
    Returns a function that simulates the model's forward pass.
    - tokenized_inputs_for_mock: Actual output from tokenizer for the current text (batch size 1).
    - test_case_mock_data["bio_token_ids_for_words"]: List of BIO IDs for *actual word tokens only*,
      EXCLUDING CLS, SEP, and PAD.
    """
    def mock_forward(input_ids_batch, attention_mask_batch, **kwargs):
        batch_size, seq_len = input_ids_batch.shape
        cls_id = test_case_mock_data["cls_id"]
        mock_cls_logits = torch.full((batch_size, MODEL_CONFIG["num_cls_labels"]), -10.0, device=device)
        mock_cls_logits[0, cls_id] = 10.0

        mock_bio_emissions = torch.full((batch_size, seq_len, MODEL_CONFIG["num_bio_labels"]), -10.0, device=device)
        mock_bio_emissions[:, :, label2id_bio["O"]] = 5.0 

        word_bio_ids = test_case_mock_data["bio_token_ids_for_words"]
        current_word_bio_idx = 0
        for token_pos in range(seq_len):
            if attention_mask_batch[0, token_pos] == 0: continue
            current_token_id_val = input_ids_batch[0, token_pos].item()
            if current_token_id_val == tokenizer_for_mock.cls_token_id or \
               current_token_id_val == tokenizer_for_mock.sep_token_id:
                continue
            if current_word_bio_idx < len(word_bio_ids):
                bio_id_for_this_token = word_bio_ids[current_word_bio_idx]
                mock_bio_emissions[0, token_pos, bio_id_for_this_token] = 10.0
                current_word_bio_idx += 1
        
        mock_hidden_states = torch.randn(batch_size, seq_len, 768, device=device)
        return {"cls_logits": mock_cls_logits, "bio_emissions": mock_bio_emissions, "hidden_states": mock_hidden_states}
    return mock_forward

In [7]:
for tc_idx, tc in enumerate(test_cases):
        print(f"\n--- Running Test: {tc['name']} ---")
        print(f"Text: '{tc['text']}'")
        print(f"Settings: {tc['settings']}")

        texts_batch = [tc["text"]]
        tokenized_inputs = tokenizer(
            texts_batch,
            return_tensors="pt",
            padding="max_length", # Pad to max_length for consistent bio_token_ids length
            max_length=32,      # A small max_length for testing
            truncation=True,
            return_offsets_mapping=True
        )
        # Apply the mock for the model's forward pass
        # Store original forward and restore it later if needed, or make mock part of model for test mode
        original_forward = model.forward
        model.forward = get_mock_forward_fn(tc["mock_data"], device=DEVICE)

        try:
            predictions = model.predict_batch(
                texts_batch,
                tokenized_inputs,
                device=DEVICE,
                **tc["settings"]
            )
            result = predictions[0] # We are processing one sentence at a time

            print(f"  Predicted Output: {result}")

            # Basic Assertions (more detailed assertions can be added)
            assert result["causal"] == tc["expected_causal"], \
                f"Causal flag mismatch. Expected {tc['expected_causal']}, Got {result['causal']}"
            assert len(result["relations"]) == tc["expected_relations_count"], \
                f"Relations count mismatch. Expected {tc['expected_relations_count']}, Got {len(result['relations'])}"

            if "expected_relations_texts" in tc:
                extracted_rel_texts = sorted([(r["cause"], r["effect"]) for r in result["relations"]])
                expected_rel_texts = sorted(tc["expected_relations_texts"])
                assert extracted_rel_texts == expected_rel_texts, \
                    f"Relation texts mismatch. Expected {expected_rel_texts}, Got {extracted_rel_texts}"

            print(f"  Test PASSED!")

        except AssertionError as e:
            print(f"  Test FAILED: {e}")
        except Exception as e:
            print(f"  Test ERRORED: {e}")
            import traceback
            traceback.print_exc()
        finally:
            model.forward = original_forward # Restore original forward method


--- Running Test: Scenario 1: Simple Non-Causal ---
Text: 'The sky is blue.'
Settings: {'use_heuristic': False, 'override_cls_if_spans_found': False}


TypeError: get_mock_forward_fn() missing 2 required positional arguments: 'test_case_mock_data' and 'tokenizer_for_mock'