In [1]:
import torch
import sys
sys.path.append('..')
import torch
from torch.utils.data import DataLoader
from core.PC_NET import PCNet
from core.config import punct_label2id, cap_label2id, MODEL_ID


In [2]:
train_dataset_path = "../scripts/test.pt"

In [20]:
def load_sample_batch(train_dataset_path, batch_size=10):
    dataset = torch.load(train_dataset_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    sample_batch = next(iter(dataloader))
    return sample_batch

In [21]:
batch = load_sample_batch(train_dataset_path)

In [9]:
dataset = torch.load(train_dataset_path)

In [18]:
sample = dataset[1]

In [26]:
model = PCNet(
        model_name=MODEL_ID,
        learning_rate=1e-4,  # Dummy value for testing
        num_punct_classes=len(punct_label2id),
        num_cap_classes=len(cap_label2id),
        trainable_layers=2
    )


In [27]:
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
punct_labels = batch["punct_labels"]
cap_labels = batch["cap_labels"]

In [28]:
punct_logits, cap_logits = model(input_ids, attention_mask)

In [31]:
loss = model.compute_loss(punct_logits, cap_logits, punct_labels, cap_labels, attention_mask)

In [34]:
punct_logits = punct_logits.view(-1, 4)
cap_logits = cap_logits.view(-1, 2)
punct_labels = punct_labels.view(-1)
cap_labels = cap_labels.view(-1)

In [36]:
punct_logits.shape

torch.Size([1280, 4])

In [38]:
active_mask = attention_mask.view(-1) == 1  # Only consider valid tokens
active_punct_labels = torch.where(active_mask, punct_labels, torch.tensor(-100))
active_cap_labels = torch.where(active_mask, cap_labels, torch.tensor(-100))

In [40]:
active_punct_labels.shape

torch.Size([1280])

In [42]:
active_mask.shape

torch.Size([1280])

In [44]:
import torch.nn as nn
punct_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
cap_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)

In [45]:
punct_loss_fn(punct_logits, active_punct_labels)

tensor(1.8110, grad_fn=<NllLossBackward0>)

In [47]:
punct_logits.shape

torch.Size([1280, 4])

In [49]:
active_punct_labels.shape

torch.Size([1280])

In [23]:
batch['punct_labels'].shape, batch['input_ids'].shape



(torch.Size([10, 128]), torch.Size([10, 128]))

In [25]:
len(batch['subword_tokens'])

128

In [14]:
sample

{'input_ids': tensor([50281, 10002,   261,  5681,   936, 12080,  1542,   262, 21808,   609,
            85,   602, 34974,    74,  9903,   626,  9802,   266, 31984,   936,
          3529,  1439, 15160,  2915,  5092,  1257,   290,  5924,   936, 35773,
          3549,  2881, 12157,  5658, 12796,  6448,  1568,   262,   434,  1439,
         24902,   262, 12756,  1257, 40199,  2858,  1059,   434, 50282, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283, 50283,
         50283, 50283, 50283, 50283, 50

In [4]:
import torch
import sys
sys.path.append('..')
import torch
from torch.utils.data import DataLoader
from core.PC_NET import PCNet
from core.config import punct_label2id, cap_label2id, MODEL_ID

# Load a sample batch from the training dataset
def load_sample_batch(train_dataset_path, batch_size=10):
    dataset = torch.load(train_dataset_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    sample_batch = next(iter(dataloader))
    return sample_batch

def sanity_check(train_dataset_path):
    # Load a small sample batch
    batch = load_sample_batch(train_dataset_path)

    # Initialize model
    model = PCNet(
        model_name=MODEL_ID,
        learning_rate=1e-4,  # Dummy value for testing
        num_punct_classes=len(punct_label2id),
        num_cap_classes=len(cap_label2id),
        trainable_layers=2
    )

    # Move model to the device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda")

    model.to(device)
    for name, param in model.named_parameters():
        if torch.isnan(param).any():
            print(f"NaN detected in parameter: {name}")
    # Ensure all tensors in the batch are moved to the device
    for key in batch:
        if isinstance(batch[key], torch.Tensor):  # Only move tensors to the device
            batch[key] = batch[key].to(device)

    # Forward pass
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    punct_labels = batch["punct_labels"]
    cap_labels = batch["cap_labels"]

    punct_logits, cap_logits = model(input_ids, attention_mask)

    # Compute loss
    loss = model.compute_loss(punct_logits, cap_logits, punct_labels, cap_labels, attention_mask)
    print("_"* 25)
    # Print results
    print(f"Input IDs shape: {input_ids.shape}")
    print(f"Attention mask shape: {attention_mask.shape}")
    print(f"Punctuation logits shape: {punct_logits.shape}")
    print(f"Capitalization logits shape: {cap_logits.shape}")
    print(f"Loss: {loss.item()}")
    print(f"Punctuation logits (min, max): {punct_logits.min().item()}, {punct_logits.max().item()}")
    print(f"Capitalization logits (min, max): {cap_logits.min().item()}, {cap_logits.max().item()}")
    print(f"Punctuation labels: {punct_labels}")
    print(f"Capitalization labels: {cap_labels}")

# Path to your training dataset
train_dataset_path = "../scripts/test.pt"

sanity_check(train_dataset_path)


  dataset = torch.load(train_dataset_path)


Running forward pass...
Sequence output (min, max): -24.372268676757812, 45.467384338378906
Punctuation logits (min, max): -1.7480071783065796, 1.8685516119003296
Capitalization logits (min, max): -1.9483040571212769, 1.3317899703979492
tensor(1.5256, device='cuda:0', grad_fn=<NllLossBackward0>) tensor(0.8925, device='cuda:0', grad_fn=<NllLossBackward0>)
_________________________
Input IDs shape: torch.Size([10, 128])
Attention mask shape: torch.Size([10, 128])
Punctuation logits shape: torch.Size([10, 128, 4])
Capitalization logits shape: torch.Size([10, 128, 2])
Loss: 2.4180526733398438
Punctuation logits (min, max): -1.7480071783065796, 1.8685516119003296
Capitalization logits (min, max): -1.9483040571212769, 1.3317899703979492
Punctuation labels: tensor([[-100,    0,    0,  ..., -100, -100, -100],
        [-100,    0,    0,  ..., -100, -100, -100],
        [-100,    0,    0,  ..., -100, -100, -100],
        ...,
        [-100,    0,    0,  ..., -100, -100, -100],
        [-100,    

In [3]:
import torch
from transformers import AutoModel, AutoTokenizer

MODEL_ID = "answerdotai/ModernBERT-base"

def check_model_on_gpu():
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModel.from_pretrained(MODEL_ID)

    # Move model to GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Sample input text
    text = "Hello, how are you doing today?"

    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    # Forward pass
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

    # Print results
    print(f"Input IDs shape: {input_ids.shape}")
    print(f"Attention mask shape: {attention_mask.shape}")
    print(f"Sequence output shape: {sequence_output.shape}")
    print(f"Sequence output (min, max): {sequence_output.min().item()}, {sequence_output.max().item()}")

# Run the check
check_model_on_gpu()




Input IDs shape: torch.Size([1, 128])
Attention mask shape: torch.Size([1, 128])
Sequence output shape: torch.Size([1, 128, 768])
Sequence output (min, max): -24.3582763671875, 46.94162368774414


In [1]:
import torch
torch.__version__

'2.5.1+cu121'

In [2]:
print(torch.version.cuda)

12.1
