In [None]:
!pip install transformers evaluate datasets

In [None]:
!pip install pytorch-crf

# **Loading Data**

Loading data from huggingface using datasets module and load_dataset class.

In [None]:
from datasets import load_dataset

data = load_dataset("rajpurkar/squad_v2")
print(data)

In [None]:
train_data = data['train'].shuffle(seed=42).select(range(15000))
val_data = data['validation']


print(train_data)
print("size:", len(train_data))
print(val_data)
print("size:", len(val_data))

# **Pre-Prcoessing and Tokenization**
Here the inputs or the data will be proprocessed and tokenized thouroughly.

In [None]:
print(train_data[1])

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("SpanBERT/spanbert-base-cased")

In [None]:
import numpy as np

def prepare_features_optimized(examples):
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=384,
        return_offsets_mapping=True,
        padding="max_length"
    )

    # Convert token_type_ids and offsets to NumPy arrays for vectorized operations.
    token_type_ids = np.array(tokenized["token_type_ids"])
    offset_mappings = [np.array(offsets) for offsets in tokenized["offset_mapping"]]

    start_positions = []
    end_positions = []

    for i in range(len(examples["id"])):
        answer = examples["answers"][i]
        offsets = offset_mappings[i]
        token_types = token_type_ids[i]
        context_tokens = np.where(token_types == 1)[0]

        # Default positions (CLS token)
        start_pos = 0
        end_pos = 0

        if answer["text"] and answer["text"][0]:
            answer_text = answer["text"][0]
            answer_start = answer["answer_start"][0]
            answer_end = answer_start + len(answer_text)

            if len(context_tokens) > 0:
                context_start = offsets[context_tokens[0]][0]
                context_end = offsets[context_tokens[-1]][1]

                if answer_start >= context_start and answer_end <= context_end:
                    # Use vectorized search for the start position
                    start_idx = np.searchsorted(offsets[:, 0], answer_start, side="right") - 1
                    end_idx = np.searchsorted(offsets[:, 1], answer_end, side="left")

                    # Optional: Add verification step only if needed
                    # predicted_span = tokenizer.decode(tokenized["input_ids"][i][start_idx: end_idx+1])
                    # if predicted_span.strip() != answer_text.strip():
                    #     start_idx, end_idx = 0, 0

                    start_pos = int(start_idx)
                    end_pos = int(end_idx)

        start_positions.append(start_pos)
        end_positions.append(end_pos)

    tokenized["start_positions"] = start_positions
    tokenized["end_positions"] = end_positions

    # Remove the offset mapping to save space
    del tokenized["offset_mapping"]
    return tokenized

In [None]:
print(train_data_tokenized)
print(eval_data_tokenized)

In [None]:
print(len(train_data_tokenized))
print(len(eval_data_tokenized))

# **Beginning tuning the model**

In [None]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

# Load the model
model_name = "SpanBERT/spanbert-base-cased"  # Change if using SpanBERT-CRF
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Move model to GPU
model.to("cuda")

In [None]:

output_dir = "./spanbert-base"

training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="epoch",  # Evaluate after every epoch
    save_strategy="epoch",  # Save after every epoch
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=8,
    weight_decay=0.01,
    save_total_limit=6,  
    logging_dir="./logs",
    logging_steps=500,
    report_to="none",
    optim="adafactor",
    fp16=True,
    gradient_accumulation_steps=1,
    dataloader_num_workers=8,
    remove_unused_columns=False,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_data_tokenized,
    eval_dataset=eval_data_tokenized,
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
# Save as a PyTorch .pth file
import torch
torch.save(model.state_dict(), "huggingface_model.pth")
print("Model saved as huggingface_model.pth")

In [None]:
trainer.evaluate()

In [None]:
import matplotlib.pyplot as plt

# Replace with your actual logged values
train_losses = [2.112200, 1.129800, 1.144000, 0.013100, 0.005400, 0.003300, 0.001800, 0.000500]
val_losses = [1.111936, 1.102107, 1.101450, 0.000045, 0.000210, 0.000000, 0.000000, 0.000000]

epochs = range(1, 9)  # 8 epochs

plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label='Validation Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
predictions_output = trainer.predict(eval_data_tokenized)
# This returns a PredictionOutput object with:
#   predictions (tuple of (start_logits, end_logits))
#   label_ids (tuple of (start_positions, end_positions))
#   metrics (eval_loss, etc.)

In [None]:
import numpy as np
from typing import Tuple
import re

def format_text(text: str) -> str:
    """Lowercases, trims, and removes extra spaces from the text."""
    text = text.lower().strip()
    return re.sub(r'\s+', ' ', text)
    
def process_predictions(predictions_output, tokenizer, eval_data_tokenized) -> Tuple[list, list]:
    """Process model predictions to get normalized answer texts."""
    # Unpack predictions and labels
    start_logits, end_logits = predictions_output.predictions
    start_true, end_true = predictions_output.label_ids

    # Convert logits to numpy arrays if they're not already
    start_pred = np.argmax(start_logits, axis=1)
    end_pred = np.argmax(end_logits, axis=1)

    # Vectorized span validation and correction
    start_pred, end_pred = np.where(start_pred > end_pred, 
                                   np.stack([end_pred, start_pred], axis=1).T,
                                   np.stack([start_pred, end_pred], axis=1).T)

    # Get token spans using numpy advanced indexing
    input_ids = np.array(eval_data_tokenized["input_ids"])
    lengths = np.array([len(seq) for seq in input_ids])  # Get actual sequence lengths
    
    # Create masks for valid spans
    valid_pred_mask = (start_pred < lengths) & (end_pred < lengths)
    valid_true_mask = (start_true < lengths) & (end_true < lengths)

    # Process predictions
    pred_spans = [
        ids[s:e+1] if valid else []
        for ids, s, e, valid in zip(input_ids, start_pred, end_pred, valid_pred_mask)
    ]

    # Process references
    ref_spans = [
        ids[s:e+1] if valid else []
        for ids, s, e, valid in zip(input_ids, start_true, end_true, valid_true_mask)
    ]

    # Batch decode with error handling
    pred_texts = tokenizer.batch_decode(pred_spans, skip_special_tokens=True)
    ref_texts = tokenizer.batch_decode(ref_spans, skip_special_tokens=True)

    # Normalize texts in batches
    return [normalize_text(t) for t in pred_texts], [normalize_text(t) for t in ref_texts]

output = []
def exact_match_score(predictions, references):
    """Compute the exact match (EM) score between predictions and references."""
    assert len(predictions) == len(references), "Lists must have same length"
    
    # Normalize predictions and references before comparison
    formatted_preds = [format_text(p) for p in predictions]
    formatted_refs = [format_text(r) for r in references]
    
    # Count matches
    matches = sum(p == r for p, r in zip(formatted_preds, formatted_refs))
    
    # Compute percentage
    output.append((formatted_preds, formatted_refs))
    return (matches / len(references)) * 100 if references else 0.0

# Usage:
predictions_output = predictions_output
pred_texts, ref_texts = process_predictions(predictions_output, tokenizer, eval_data_tokenized)
em_score = exact_match_score(pred_texts, ref_texts)

print(f"Exact Match Score: {em_score:.2f}%")
print("\nSample Predictions vs References:")
samples = list(zip(pred_texts, ref_texts))[:5]  # Get first 5 samples
for pred, ref in samples:
    print(f"Pred: {pred}")
    print(f"Ref:  {ref}\n")

In [None]:
def process_predictions(predictions_output, tokenizer, eval_data_tokenized) -> Tuple[list, list]:
    start_logits, end_logits = predictions_output.predictions  # Model's predicted logits
    start_true, end_true = predictions_output.label_ids      # True start/end positions

    # Convert logits to predicted start/end positions
    start_pred = np.argmax(start_logits, axis=1)
    end_pred = np.argmax(end_logits, axis=1)

    # Correct cases where start > end
    start_pred, end_pred = np.where(start_pred > end_pred, 
                                   np.stack([end_pred, start_pred], axis=1).T,
                                   np.stack([start_pred, end_pred], axis=1).T)

    # Get input_ids and sequence lengths
    input_ids = np.array(eval_data_tokenized["input_ids"])
    lengths = np.array([len(seq) for seq in input_ids])

    # Validate spans
    valid_pred_mask = (start_pred < lengths) & (end_pred < lengths)
    valid_true_mask = (start_true < lengths) & (end_true < lengths)

    # Extract predicted spans
    pred_spans = [ids[s:e+1] if valid else [] 
                  for ids, s, e, valid in zip(input_ids, start_pred, end_pred, valid_pred_mask)]
    
    # Extract reference spans
    ref_spans = [ids[s:e+1] if valid else [] 
                 for ids, s, e, valid in zip(input_ids, start_true, end_true, valid_true_mask)]

    # Decode spans to text
    pred_texts = tokenizer.batch_decode(pred_spans, skip_special_tokens=True)
    ref_texts = tokenizer.batch_decode(ref_spans, skip_special_tokens=True)

    # Normalize texts (assuming normalize_text is format_text)
    return [normalize_text(t) for t in pred_texts], [normalize_text(t) for t in ref_texts]

In [None]:
def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have same length"
    
    formatted_preds = [format_text(p) for p in predictions]
    formatted_refs = [format_text(r) for r in references]
    
    matches = sum(p == r for p, r in zip(formatted_preds, formatted_refs))
    
    output.append((formatted_preds, formatted_refs))
    return (matches / len(references)) * 100 if references else 0.0

In [None]:
predictions_output = trainer.predict(eval_data_tokenized)
pred_texts, ref_texts = process_predictions(predictions_output, tokenizer, eval_data_tokenized)
em_score = exact_match_score(pred_texts, ref_texts)

print(f"Exact Match Score: {em_score:.2f}%")
print("\nSample Predictions vs References:")
samples = list(zip(pred_texts, ref_texts))[:5]
for pred, ref in samples:
    print(f"Pred: {pred}")
    print(f"Ref:  {ref}\n")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
model = trainer.model
save_path = "fine_tuned_model.pth"
torch.save(model.state_dict(), save_path)

In [None]:
import torch
import numpy as np
from torch import nn
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

# **THE Base variant of Span-BERT preparing and loading**

In [None]:
dataset = load_dataset("squad_v2")
train_data = dataset["train"].select(range(15000))
val_data = dataset["validation"]

#Will be used for both

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel

class SpanBERTForQA(nn.Module):
    def __init__(self):
        super().__init__()
        self.spanbert = AutoModel.from_pretrained("SpanBERT/spanbert-base-cased")
        self.qa_outputs = nn.Linear(self.spanbert.config.hidden_size, 2)  # Start/end logits

    def forward(self, input_ids, attention_mask, start_positions=None, end_positions=None):
        outputs = self.spanbert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        if start_positions is not None and end_positions is not None:
            loss_fct = nn.CrossEntropyLoss()
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            loss = (start_loss + end_loss) / 2
            return loss, start_logits, end_logits
        return start_logits, end_logits

In [None]:
# Data Preparation\
tokenizer = AutoTokenizer.from_pretrained("SpanBERT/spanbert-base-cased")

In [None]:
def prepare_features(examples):
    tokenized = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=384,
        stride=128,
        return_offsets_mapping=True,
        padding="max_length"
    )
    
    start_positions = []
    end_positions = []

    for i in range(len(examples["id"])):
        answer = examples["answers"][i]
        start_pos, end_pos = 0, 0
        
        if answer["text"]:
            answer_start = answer["answer_start"][0]
            answer_text = answer["text"][0]
            answer_end = answer_start + len(answer_text)

            sequence_ids = tokenized.sequence_ids(i)
            offsets = tokenized["offset_mapping"][i]
            context_start = sequence_ids.index(1) if 1 in sequence_ids else 0
            context_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1 if 1 in sequence_ids else len(sequence_ids) - 1

            for token_idx, (start, end) in enumerate(offsets[context_start:context_end + 1]):
                token_idx += context_start
                if start <= answer_start < end:
                    start_pos = token_idx
                if start < answer_end <= end:
                    end_pos = token_idx
                    break

        start_positions.append(start_pos)
        end_positions.append(end_pos)

    tokenized["start_positions"] = start_positions
    tokenized["end_positions"] = end_positions
    return tokenized

In [None]:
train_dataset_base = train_data.map(prepare_features, batched=True, batch_size=256, num_proc=4, remove_columns=train_data.column_names)
val_dataset_base = val_data.map(prepare_features, batched=True, batch_size=256, num_proc=4, remove_columns=val_data.column_names)

In [None]:
# Cell: DataLoader Setup for SpanBERT Base Model
from torch.utils.data import DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence

# Collate function for base model
def collate_fn_base(batch):
    # Pad input_ids and attention_mask to the same length
    input_ids = pad_sequence([torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch], 
                            batch_first=True, padding_value=0)
    attention_mask = pad_sequence([torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch], 
                                 batch_first=True, padding_value=0)
    # Start and end positions don’t need padding (they’re single values per example)
    start_positions = torch.tensor([ex["start_positions"] for ex in batch], dtype=torch.long)
    end_positions = torch.tensor([ex["end_positions"] for ex in batch], dtype=torch.long)
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "start_positions": start_positions,
        "end_positions": end_positions
    }

# Set up DataLoaders
train_dataloader_base = DataLoader(
    train_dataset_base,  # From your preprocessed base dataset
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn_base,
    num_workers=2,  # Kaggle usually has 4 CPUs, so 2 workers is safe
    pin_memory=True  # Speeds up data transfer to GPU
)

val_dataloader_base = DataLoader(
    val_dataset_base,
    batch_size=16,
    shuffle=False,  # No need to shuffle validation
    collate_fn=collate_fn_base,
    num_workers=2,
    pin_memory=True
)

print("DataLoaders for SpanBERT Base are set up!")

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import torch

print("Training the base SpanBERT model...")
model_base = SpanBERTForQA()
model_base.to(device)
optimizer_base = torch.optim.AdamW(model_base.parameters(), lr=2e-5)  # Learning rate
train_losses_base = []
val_losses_base = []

for epoch in range(6):  # Train for 6 epochs
    model_base.train()  # Training mode
    total_loss = 0
    for batch in tqdm(train_dataloader_base, desc=f"Epoch {epoch+1}"):
        # Move data to GPU
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)
        
        optimizer_base.zero_grad()  # Clear old gradients
        start_logits, end_logits = model_base(input_ids, attention_mask)
        # Calculate loss for start and end
        loss_start = nn.CrossEntropyLoss()(start_logits, start_positions)
        loss_end = nn.CrossEntropyLoss()(end_logits, end_positions)
        loss = (loss_start + loss_end) / 2  # Average the two losses
        loss.backward()  # Compute gradients
        optimizer_base.step()  # Update weights
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_dataloader_base)
    train_losses_base.append(avg_train_loss)
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss}")
    
    # Validation
    model_base.eval()  # Evaluation mode
    val_loss = 0
    with torch.no_grad():  # No gradients during validation
        for batch in val_dataloader_base:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)
            start_logits, end_logits = model_base(input_ids, attention_mask)
            loss_start = nn.CrossEntropyLoss()(start_logits, start_positions)
            loss_end = nn.CrossEntropyLoss()(end_logits, end_positions)
            loss = (loss_start + loss_end) / 2
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_dataloader_base)
    val_losses_base.append(avg_val_loss)
    print(f"Validation Loss: {avg_val_loss}")

In [None]:
def predict_base(model, dataloader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start_logits, end_logits = model(input_ids, attention_mask)
            start_pred = torch.argmax(start_logits, dim=1)  # Pick highest score
            end_pred = torch.argmax(end_logits, dim=1)
            for i in range(len(input_ids)):
                start = start_pred[i].item()
                end = end_pred[i].item()
                if start > end or start == 0:  # If invalid or no answer
                    pred_text = ""
                else:
                    pred_tokens = input_ids[i][start:end + 1]
                    pred_text = tokenizer.decode(pred_tokens, skip_special_tokens=True)
                predictions.append(pred_text)
    return predictions

In [None]:
def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100  # Convert to percentage

In [None]:
references = [ex["answers"]["text"][0] if ex["answers"]["text"] else "" for ex in val_data]

In [None]:
torch.save(model_base.state_dict(), 'best_model_base.pth')

In [None]:
model_base = SpanBERTForQA()
model_base.load_state_dict(torch.load("best_model_base.pth", weights_only=True))
model_base = model.to(device)

In [None]:
!pip install matplotlib

In [None]:
import matplotlib.pyplot as plt

predictions_base = predict_base(model_base, val_dataloader_base)
em_base = exact_match_score(predictions_base, references)
print("Saving loss plots...")
plt.figure()
plt.plot(train_losses_base, label="Train Loss (Base)")
plt.plot(val_losses_base, label="Validation Loss (Base)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("SpanBERT Base Model Loss")
plt.savefig("spanbert_base_loss.png")
plt.close()
print(f"Exact Match for Base Model: {em_base:.2f}%")

In [51]:
!pip install pytorch-crf



In [52]:
from torchcrf import CRF
import torch
import numpy as np
from torch import nn
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

class SpanBERTCRF(nn.Module):
    def __init__(self):
        
        super().__init__()
        self.spanbert = AutoModel.from_pretrained("SpanBERT/spanbert-base-cased")
        self.dropout = nn.Dropout(0.1)  # Prevent overfitting
        self.classifier = nn.Linear(self.spanbert.config.hidden_size, 4)  # 4 tags: B, I, O, E
        self.crf = CRF(4, batch_first=True)  # CRF layer

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.spanbert(input_ids, attention_mask=attention_mask)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)
        if labels is not None:
            # If we have labels, calculate the loss
            loss = -self.crf(logits, labels, mask=attention_mask.bool())
            return loss, logits
        return logits

    def decode(self, input_ids, attention_mask):
        # Get the best tag sequence during prediction
        logits = self.forward(input_ids, attention_mask)
        best_tags = self.crf.decode(logits, mask=attention_mask.bool())
        return best_tags

In [53]:
tokenizer = AutoTokenizer.from_pretrained("SpanBERT/spanbert-base-cased")

In [55]:
def prepare_features_crf(examples):
    # This function prepares data for the SpanBERT-CRF model with BIOE tags
    tokenized = tokenizer(
        examples["question"],  # The question part
        examples["context"],   # The context part
        truncation="only_second",  # Only shorten the context
        max_length=384,        # Max tokens allowed
        stride=128,            # Overlap for long contexts
        return_offsets_mapping=True,  # Need character positions
        padding="max_length"   # Pad to fixed length
        
    )
    
    # List to hold BIOE labels for each example
    labels = []
    
    # Go through each example
    for i in range(len(examples["id"])):
        # Get the answer for this example
        answer = examples["answers"][i]
        sequence_ids = tokenized.sequence_ids(i)  # Which tokens are from context
        offsets = tokenized["offset_mapping"][i]  # Character positions
        # Start with all tokens as O (outside, tag 2)
        label_seq = [2] * len(tokenized["input_ids"][i])
        
        # If there’s an answer, we need to tag it
        if answer["text"]:
            # Get answer start and end in characters
            answer_start = answer["answer_start"][0]
            answer_end = answer_start + len(answer["text"][0])
            
            # Find context boundaries
            context_start = 0
            for j, sid in enumerate(sequence_ids):
                if sid == 1:
                    context_start = j
                    break
            context_end = len(sequence_ids) - 1
            for j in range(len(sequence_ids) - 1, -1, -1):
                if sequence_ids[j] == 1:
                    context_end = j
                    break
            
            # Variables to track the answer span
            start_token = None
            end_token = None
            
            # Find the tokens where the answer starts and ends
            for token_idx in range(context_start, context_end + 1):
                start, end = offsets[token_idx]
                if start <= answer_start < end:
                    start_token = token_idx  # Beginning token
                if start < answer_end <= end:
                    end_token = token_idx   # End token
                    break
            
            # Label the tokens with BIOE
            if start_token is not None and end_token is not None:
                if start_token == end_token:
                    # If it’s just one token, label it B (beginning, tag 0)
                    label_seq[start_token] = 0
                else:
                    # Label the first token B
                    label_seq[start_token] = 0
                    # Label middle tokens I (inside, tag 1)
                    for idx in range(start_token + 1, end_token):
                        label_seq[idx] = 1
                    # Label the last token E (end, tag 3)
                    label_seq[end_token] = 3
        
        # Add this example’s labels to the list
        labels.append(label_seq)
      
    # Add labels to the tokenized data
    tokenized["labels"] = labels
    return tokenized

In [56]:
import torch
from torch.utils.data import DataLoader

# Tokenize datasets (unchanged, assuming prepare_features_crf is memory-efficient)
train_data_tokenized_crf = train_data.map(
    prepare_features_crf,
    batch_size=256,
    batched=True,
    num_proc=4,
    remove_columns=train_data.column_names
)
eval_data_tokenized_crf = val_data.map(
    prepare_features_crf,
    batch_size=256,
    batched=True,
    num_proc=4,
    remove_columns=val_data.column_names
)

# Optimized collate function with memory efficiency
def collate_fn(batch):
    # Convert to tensors only what’s necessary, avoid unnecessary copies
    input_ids = torch.tensor([ex["input_ids"] for ex in batch], dtype=torch.long)
    attention_mask = torch.tensor([ex["attention_mask"] for ex in batch], dtype=torch.long)
    labels = torch.tensor([ex["labels"] for ex in batch], dtype=torch.long)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# Initialize DataLoaders with memory-efficient settings
train_dataloader_crf = DataLoader(
    train_data_tokenized_crf,
    batch_size=16,  # Reduced from 64 to start conservatively
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,  # Reduced from 8 to lower CPU memory usage
    pin_memory=True,  # Keep this for faster GPU transfer
    persistent_workers=True  # Keep for efficiency with num_workers > 0
)

val_dataloader_crf = DataLoader(
    eval_data_tokenized_crf,
    batch_size=32,  # Reduced from 128 to conserve memory
    collate_fn=collate_fn,
    num_workers=4,  # Reduced from 8
    pin_memory=True
)

#train_dataloader_crf = DataLoader(train_dataset_crf, batch_size=16, shuffle=True)
#val_dataloader_crf = DataLoader(val_dataset_crf, batch_size=16)

In [57]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [58]:
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm


print("Training the SpanBERT-CRF model...")
model_crf = SpanBERTCRF()
model_crf.to(device)
optimizer_crf = torch.optim.AdamW(model_crf.parameters(), lr=2e-5)
train_losses_crf = []
val_losses_crf = []

for epoch in range(8):  # Train for 8 epochs
    model_crf.train()
    total_loss = 0
    for batch in tqdm(train_dataloader_crf, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer_crf.zero_grad()
        loss, _ = model_crf(input_ids, attention_mask, labels)
        loss.backward()
        optimizer_crf.step()
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_dataloader_crf)
    train_losses_crf.append(avg_train_loss)
    print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss}")
    
    model_crf.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader_crf:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            loss, _ = model_crf(input_ids, attention_mask, labels)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_dataloader_crf)
    val_losses_crf.append(avg_val_loss)
    print(f"Validation Loss: {avg_val_loss}")

Training the SpanBERT-CRF model...


Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 938/938 [23:46<00:00,  1.52s/it]

Epoch 1, Train Loss: 371.67641870299383





Validation Loss: 504.61521427605743


Epoch 2: 100%|██████████| 938/938 [23:48<00:00,  1.52s/it]

Epoch 2, Train Loss: 198.41657457168677





Validation Loss: 287.25989753969253


Epoch 3: 100%|██████████| 938/938 [23:48<00:00,  1.52s/it]

Epoch 3, Train Loss: 90.45798527406477





Validation Loss: 268.3063404534453


Epoch 4: 100%|██████████| 938/938 [23:49<00:00,  1.52s/it]

Epoch 4, Train Loss: 60.63223790410739





Validation Loss: 264.64242602932836


Epoch 5: 100%|██████████| 938/938 [23:49<00:00,  1.52s/it]

Epoch 5, Train Loss: 44.87432448137035





Validation Loss: 298.88885354483


Epoch 6: 100%|██████████| 938/938 [23:49<00:00,  1.52s/it]

Epoch 6, Train Loss: 34.002536326329086





Validation Loss: 299.69791551815564


Epoch 7: 100%|██████████| 938/938 [23:48<00:00,  1.52s/it]

Epoch 7, Train Loss: 27.474596296037948





Validation Loss: 332.2626731216267


Epoch 8: 100%|██████████| 938/938 [23:48<00:00,  1.52s/it]

Epoch 8, Train Loss: 23.090707579655433





Validation Loss: 298.8955161392048


In [59]:
torch.save(model_crf.state_dict(), 'best_model_crf.pth')

In [60]:
# Function to predict answers with the CRF model
def predict_crf(model, dataloader):
    model.eval()
    predictions = []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            batch_tags = model.decode(input_ids, attention_mask)
            for i, tags in enumerate(batch_tags):
                start = None
                end = None
                for idx, tag in enumerate(tags):
                    if tag == 0:  # B tag
                        start = idx
                    elif tag == 3:  # E tag
                        end = idx
                        break?
                if start is not None and end is not None:
                    pred_tokens = input_ids[i][start:end + 1]
                    pred_text = tokenizer.decode(pred_tokens, skip_special_tokens=True)
                else:
                    pred_text = ""
                predictions.append(pred_text)
    return predictions

In [61]:
predictions_crf = predict_crf(model_crf, val_dataloader_crf)
em_crf = exact_match_score(predictions_crf, references)
plt.figure()
plt.plot(train_losses_crf, label="Train Loss (CRF)")
plt.plot(val_losses_crf, label="Validation Loss (CRF)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("SpanBERT-CRF Model Loss")
plt.savefig("spanbert_crf_loss.png")
plt.close()
print(f"Exact Match for CRF Model: {em_crf:.2f}%")

Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `break` not found.
Object `brea

In [62]:
print("Saving models...")
torch.save(model_base.state_dict(), "spanbert_base_model.pt")
torch.save(model_crf.state_dict(), "spanbert_crf_model.pt")
print("All done!")

Saving models...
All done!
