In [1]:
!pip install -q transformers

In [61]:
import os, re, random, numpy as np, pandas as pd, torch, torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, precision_score, recall_score
from torch.utils.data import DataLoader, Dataset, RandomSampler
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
import nltk
from nltk.tokenize import sent_tokenize

TRAITS = ["O", "C", "E", "A", "N"]
RAW_COL = "text"
MAX_LEN = 256
MAX_SAMPLES = 400000
BATCH_SIZE = 48

SEED = 27
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on", DEVICE)

nltk.download('punkt_tab')
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

Running on cuda


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [3]:
class AttentionLayer(nn.Module):
    def __init__(self, hidden_size, attention_size=100):
        super(AttentionLayer, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, attention_size),
            nn.Tanh(),
            nn.Linear(attention_size, 1)
        )

    def forward(self, inputs, mask=None):
        scores = self.attention(inputs).squeeze(-1)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim=-1)
        output = torch.bmm(weights.unsqueeze(1), inputs).squeeze(1)
        return output, weights

In [4]:
class HierarchicalAttentionNetwork(nn.Module):
    def __init__(self, hidden_size, attention_size=100, dropout_rate=0.3):
        super(HierarchicalAttentionNetwork, self).__init__()
        self.word_attention = AttentionLayer(hidden_size * 2, attention_size)
        self.word_gru = nn.GRU(hidden_size, hidden_size, bidirectional=True, batch_first=True)

        self.sentence_attention = AttentionLayer(hidden_size * 2, attention_size)
        self.sentence_gru = nn.GRU(hidden_size * 2, hidden_size, bidirectional=True, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, bert_embeddings, attention_mask, sent_lengths=None):
        batch_size, seq_len, hidden_size = bert_embeddings.size()

        word_outputs, _ = self.word_gru(bert_embeddings)  # [batch_size, seq_len, hidden_size*2]
        word_outputs = self.dropout(word_outputs)

        document_representations = []

        for batch_idx in range(batch_size):
            item_sent_lengths = sent_lengths[batch_idx] if sent_lengths else None

            if item_sent_lengths is None or len(item_sent_lengths) == 0:
                max_sent_len = 32
                num_sentences = (seq_len + max_sent_len - 1) // max_sent_len
                item_sent_lengths = [max_sent_len] * (num_sentences - 1)
                item_sent_lengths.append(seq_len - (num_sentences - 1) * max_sent_len)

            sentence_representations = []
            current_idx = 0

            for sent_len in item_sent_lengths:
                if current_idx + sent_len > seq_len:
                    sent_len = seq_len - current_idx

                if sent_len <= 0:
                    break

                sentence_tokens = word_outputs[batch_idx:batch_idx+1, current_idx:current_idx+sent_len, :]
                sentence_mask = attention_mask[batch_idx:batch_idx+1, current_idx:current_idx+sent_len]

                sentence_repr, _ = self.word_attention(sentence_tokens, sentence_mask)
                sentence_representations.append(sentence_repr)

                current_idx += sent_len

            if sentence_representations:
                sentence_reprs = torch.cat(sentence_representations, dim=0).unsqueeze(0)

                sentence_outputs, _ = self.sentence_gru(sentence_reprs)
                sentence_outputs = self.dropout(sentence_outputs)

                sentence_mask = torch.ones(1, len(sentence_representations), device=bert_embeddings.device)

                document_repr, _ = self.sentence_attention(sentence_outputs, sentence_mask)
                document_representations.append(document_repr)
            else:
                document_repr = torch.zeros(1, hidden_size * 2, device=bert_embeddings.device)
                document_representations.append(document_repr)

        document_reprs = torch.cat(document_representations, dim=0)  # [batch_size, hidden_size*2]

        return document_reprs

In [5]:
class BERTHANClassificationModel(nn.Module):
    def __init__(self, bert_model_name="bert-base-uncased", dropout_rate=0.3):
        super(BERTHANClassificationModel, self).__init__()

        self.bert = BertModel.from_pretrained(bert_model_name)

        for param in list(self.bert.parameters())[:-4]:
            param.requires_grad = False

        hidden_size = self.bert.config.hidden_size

        self.han = HierarchicalAttentionNetwork(
            hidden_size=hidden_size,
            attention_size=hidden_size // 2,
            dropout_rate=dropout_rate
        )

        fused_size = hidden_size * 3

        self.trait_heads = nn.ModuleDict({
            trait: nn.Sequential(
                nn.Linear(fused_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.Dropout(dropout_rate),
                nn.GELU(),
                nn.Linear(hidden_size, 1)
            ) for trait in TRAITS
        })

    def forward(self, input_ids, attention_mask, sent_lengths=None):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        token_embeddings = outputs.last_hidden_state

        cls_embedding = token_embeddings[:, 0, :]

        han_embedding = self.han(token_embeddings, attention_mask, sent_lengths)

        fused_embedding = torch.cat([cls_embedding, han_embedding], dim=1)

        logits = torch.cat([
            self.trait_heads[trait](fused_embedding) for trait in TRAITS
        ], dim=1)

        return logits

In [6]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [7]:
def preprocess_text(text):
    text = str(text)

    url_pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
    text = re.sub(url_pattern, '[URL]', text)

    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'\n+', ' ', text)
    text = re.sub(r'([.!?])([A-Za-z])', r'\1 \2', text)

    return text.strip()

In [8]:
class PandoraBinaryHANDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=MAX_LEN, raw_col=RAW_COL):
        self.df = df.copy()
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.raw_col = raw_col

        self.df[self.raw_col] = self.df[self.raw_col].apply(preprocess_text)

        self.df['sentence_count'] = self.df[self.raw_col].apply(lambda x: len(sent_tokenize(x)))
        self.df = self.df[self.df['sentence_count'] >= 3]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text = str(self.df.iloc[idx][self.raw_col])
        labels = self.df.iloc[idx][TRAITS].values.astype(np.float32)

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        sentences = sent_tokenize(text)
        sent_lengths = []

        for sentence in sentences:
            tokens = self.tokenizer.tokenize(sentence)
            sent_lengths.append(min(len(tokens) + 1, self.max_len))

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'sent_lengths': sent_lengths,
            'labels': torch.tensor(labels, dtype=torch.float)
        }

In [9]:
def collate_han_batch(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    sent_lengths = [item.get('sent_lengths', []) for item in batch]

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'sent_lengths': sent_lengths,
        'labels': labels
    }

def find_optimal_thresholds(probs, labels):
    thresholds = []
    best_f1s = []

    for i, trait in enumerate(TRAITS):
        trait_probs = probs[:, i]
        trait_labels = labels[:, i]
        best_threshold = 0.5
        best_f1 = 0.0

        for threshold in np.arange(0.3, 0.8, 0.01):
            predictions = (trait_probs >= threshold).astype(int)
            f1 = f1_score(trait_labels, predictions)

            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold

        thresholds.append(best_threshold)
        best_f1s.append(best_f1)
        print(f" {trait}: Best threshold = {best_threshold:.2f}, F1 = {best_f1:.4f}")

    return np.array(thresholds), np.mean(best_f1s)

In [10]:
def train_berthan_model(train_df, val_df, test_df, epochs=3, patience=2):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = PandoraBinaryHANDataset(train_df, tokenizer)
    val_dataset = PandoraBinaryHANDataset(val_df, tokenizer)
    test_dataset = PandoraBinaryHANDataset(test_df, tokenizer)
    print(f"Dataset sizes after filtering: Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_han_batch
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_han_batch
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_han_batch
    )

    model = BERTHANClassificationModel().to(DEVICE)

    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': 5e-5, 'weight_decay': 0.01},
        {'params': model.han.parameters(), 'lr': 1e-4, 'weight_decay': 0.01},
        {'params': model.trait_heads.parameters(), 'lr': 1e-4, 'weight_decay': 0.01}
    ], eps=1e-8)

    total_steps = len(train_dataloader) * epochs

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    loss_fn = FocalLoss(alpha=0.25, gamma=2.0)

    # For early stopping
    best_val_f1 = 0
    best_model_state = None
    best_thresholds = None
    wait = 0

    for epoch in range(1, epochs+1):
        print(f"Epoch {epoch}/{epochs}")

        model.train()
        total_train_loss = 0

        for batch_idx, batch in enumerate(tqdm(train_dataloader)):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            sent_lengths = batch.get('sent_lengths', None)

            optimizer.zero_grad()

            outputs = model(input_ids, attention_mask, sent_lengths)

            loss = loss_fn(outputs, labels)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss:.4f}")

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels']
                sent_lengths = batch.get('sent_lengths', None)

                outputs = model(input_ids, attention_mask, sent_lengths)
                probs = torch.sigmoid(outputs).cpu().numpy()
                val_preds.append(probs)
                val_labels.append(labels.numpy())

        val_preds = np.vstack(val_preds)
        val_labels = np.vstack(val_labels)

        thresholds, val_f1 = find_optimal_thresholds(val_preds, val_labels)
        print(f"Validation F1: {val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = model.state_dict().copy()
            best_thresholds = thresholds
            wait = 0
            print(f"New best model saved! F1: {val_f1:.4f}")
        else:
            wait += 1
            print(f"No improvement. Patience: {wait}/{patience}")

            if wait >= patience:
                print(f"Early stopping triggered after epoch {epoch}")
                break

    if best_model_state:
        model.load_state_dict(best_model_state)

    model.eval()
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels']
            sent_lengths = batch.get('sent_lengths', None)

            outputs = model(input_ids, attention_mask, sent_lengths)
            probs = torch.sigmoid(outputs).cpu().numpy()
            test_preds.append(probs)
            test_labels.append(labels.numpy())

    test_preds = np.vstack(test_preds)
    test_labels = np.vstack(test_labels)

    binary_preds = (test_preds >= best_thresholds).astype(int)

    test_f1 = f1_score(test_labels, binary_preds, average='macro')

    print("Per-trait performance:")
    for i, trait in enumerate(TRAITS):
        precision = precision_score(test_labels[:, i], binary_preds[:, i])
        recall = recall_score(test_labels[:, i], binary_preds[:, i])
        f1 = f1_score(test_labels[:, i], binary_preds[:, i])
        print(f"{trait}: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, Threshold={best_thresholds[i]:.2f}")

    print("Classification:")
    print(classification_report(test_labels, binary_preds, target_names=TRAITS, digits=4))

    torch.save({
        'model_state_dict': model.state_dict(),
        'thresholds': best_thresholds,
        'val_f1': best_val_f1,
        'test_f1': test_f1
    }, 'berthan_pandora_model.pt')

    print(f"Model saved with test F1: {test_f1:.4f}")

    return model, best_thresholds, test_f1

In [12]:
def train_berthan_model(train_df, val_df, test_df, epochs=3, patience=2):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    train_dataset = PandoraBinaryHANDataset(train_df, tokenizer)
    val_dataset = PandoraBinaryHANDataset(val_df, tokenizer)
    test_dataset = PandoraBinaryHANDataset(test_df, tokenizer)

    print(f"Dataset sizes after filtering: Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_han_batch
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_han_batch
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        collate_fn=collate_han_batch
    )

    model = BERTHANClassificationModel().to(DEVICE)

    checkpoint = torch.load('./content/berthan_pandora_model.pt')

    model_state_dict = checkpoint['model_state_dict']
    thresholds = checkpoint['thresholds']
    val_f1 = checkpoint.get('val_f1', 0)
    test_f1 = checkpoint.get('test_f1', 0)

    model.load_state_dict(model_state_dict)

    print(f"Loaded model with validation F1: {val_f1:.4f}, test F1: {test_f1:.4f}")
    print(f"Classification thresholds: {thresholds}")


    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': 5e-5, 'weight_decay': 0.01},
        {'params': model.han.parameters(), 'lr': 1e-4, 'weight_decay': 0.01},
        {'params': model.trait_heads.parameters(), 'lr': 1e-4, 'weight_decay': 0.01}
    ], eps=1e-8)

    total_steps = len(train_dataloader) * epochs

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    loss_fn = FocalLoss(alpha=0.25, gamma=2.0)

    best_val_f1 = 0
    best_model_state = None
    best_thresholds = None
    wait = 0

    for epoch in range(1, epochs+1):
        print(f"\nEpoch {epoch}/{epochs}")

        model.train()
        total_train_loss = 0

        for batch_idx, batch in enumerate(tqdm(train_dataloader, desc="Training")):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            sent_lengths = batch.get('sent_lengths', None)

            optimizer.zero_grad()

            outputs = model(input_ids, attention_mask, sent_lengths)

            loss = loss_fn(outputs, labels)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()

            scheduler.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss:.4f}")

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels']
                sent_lengths = batch.get('sent_lengths', None)

                outputs = model(input_ids, attention_mask, sent_lengths)
                probs = torch.sigmoid(outputs).cpu().numpy()
                val_preds.append(probs)
                val_labels.append(labels.numpy())

        val_preds = np.vstack(val_preds)
        val_labels = np.vstack(val_labels)

        print("Finding optimal thresholds:")
        thresholds, val_f1 = find_optimal_thresholds(val_preds, val_labels)
        print(f"Validation F1: {val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = model.state_dict().copy()
            best_thresholds = thresholds
            wait = 0
            print(f"New best model saved! F1: {val_f1:.4f}")
        else:
            wait += 1
            print(f"No improvement. Patience: {wait}/{patience}")

            if wait >= patience:
                print(f"Early stopping after epoch {epoch}")
                break

    if best_model_state:
        model.load_state_dict(best_model_state)

    model.eval()
    test_preds = []
    test_labels = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels']
            sent_lengths = batch.get('sent_lengths', None)

            outputs = model(input_ids, attention_mask, sent_lengths)
            probs = torch.sigmoid(outputs).cpu().numpy()
            test_preds.append(probs)
            test_labels.append(labels.numpy())

    test_preds = np.vstack(test_preds)
    test_labels = np.vstack(test_labels)

    binary_preds = (test_preds >= best_thresholds).astype(int)

    test_f1 = f1_score(test_labels, binary_preds, average='macro')

    print("Per-trait performance:")
    for i, trait in enumerate(TRAITS):
        precision = precision_score(test_labels[:, i], binary_preds[:, i])
        recall = recall_score(test_labels[:, i], binary_preds[:, i])
        f1 = f1_score(test_labels[:, i], binary_preds[:, i])
        print(f"{trait}: Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, Threshold={best_thresholds[i]:.2f}")

    print("Classification:")
    print(classification_report(test_labels, binary_preds, target_names=TRAITS, digits=4))

    torch.save({
        'model_state_dict': model.state_dict(),
        'thresholds': best_thresholds,
        'val_f1': best_val_f1,
        'test_f1': test_f1
    }, 'berthan_pandora_model.pt')

    print(f"Model saved with test F1: {test_f1:.4f}")

    return model, best_thresholds, test_f1

In [54]:
import os

file_path = 'berthan_pandora_model.pt'
if os.path.exists(file_path):
    file_size = os.path.getsize(file_path)
    print(f"File exists with size: {file_size} bytes")
    if file_size == 0:
        print("Error: File exists but is empty (0 bytes)")
else:
    print(f"Error: File {file_path} does not exist")

File exists with size: 549058334 bytes


In [11]:
def load_pandora_binary_data():

    splits = {
        'train': 'data/train-00001-of-00002.parquet',
        'validation': 'data/validation-00000-of-00001.parquet',
        'test': 'data/test-00000-of-00001.parquet'
    }

    train_df = pd.read_parquet('hf://datasets/jingjietan/pandora-big5/' + splits['train'])
    if len(train_df) > MAX_SAMPLES:
        train_df = train_df.sample(MAX_SAMPLES, random_state=SEED)

    val_df = pd.read_parquet('hf://datasets/jingjietan/pandora-big5/' + splits['validation'])
    if len(val_df) > MAX_SAMPLES // 5:
        val_df = val_df.sample(MAX_SAMPLES // 5, random_state=SEED)

    test_df = pd.read_parquet('hf://datasets/jingjietan/pandora-big5/' + splits['test'])
    if len(test_df) > MAX_SAMPLES // 5:
        test_df = test_df.sample(MAX_SAMPLES // 5, random_state=SEED)

    print(f"Dataset sizes: Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

    thresholds = {}
    for trait in TRAITS:
        thresholds[trait] = train_df[trait].median()
        train_df[trait] = (train_df[trait] >= thresholds[trait]).astype(int)
        val_df[trait] = (val_df[trait] >= thresholds[trait]).astype(int)
        test_df[trait] = (test_df[trait] >= thresholds[trait]).astype(int)
        print(f"Median threshold for {trait}: {thresholds[trait]:.2f}")

    print("Class distribution:")
    for trait in TRAITS:
        train_pos = train_df[trait].mean() * 100
        val_pos = val_df[trait].mean() * 100
        test_pos = test_df[trait].mean() * 100
        print(f"{trait}: Train {train_pos:.1f}% positive, Val {val_pos:.1f}% positive, Test {test_pos:.1f}% positive")

    return train_df, val_df, test_df

In [60]:
def continue_training_berthan_model(model_path, train_dataloader, val_dataloader,
                                   continue_epochs=3, patience=2, accumulation_steps=4):
    print(f"Loading model from {model_path}")

    checkpoint = torch.load(model_path, weights_only=False)

    model = BERTHANClassificationModel().to(DEVICE)

    model.load_state_dict(checkpoint['model_state_dict'])

    thresholds = checkpoint['thresholds']
    best_val_f1 = checkpoint.get('val_f1', 0)

    print(f"Loaded model with validation F1: {best_val_f1:.4f}")
    print(f"Classification thresholds: {thresholds}")

    optimizer = AdamW([
        {'params': model.bert.parameters(), 'lr': 5e-5, 'weight_decay': 0.01},
        {'params': model.han.parameters(), 'lr': 1e-4, 'weight_decay': 0.01},
        {'params': model.trait_heads.parameters(), 'lr': 1e-4, 'weight_decay': 0.01}
    ], eps=1e-8)

    total_steps = len(train_dataloader) * continue_epochs // accumulation_steps

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    loss_fn = FocalLoss(alpha=0.25, gamma=2.0)

    best_model_state = model.state_dict().copy()
    wait = 0

    for epoch in range(1, continue_epochs+1):
        print(f"Epoch {epoch}/{continue_epochs}")

        model.train()
        total_train_loss = 0
        optimizer.zero_grad()

        for batch_idx, batch in enumerate(tqdm(train_dataloader)):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)
            sent_lengths = batch.get('sent_lengths', None)


            outputs = model(input_ids, attention_mask, sent_lengths)

            loss = loss_fn(outputs, labels)

            loss = loss / accumulation_steps

            loss.backward()

            total_train_loss += loss.item() * accumulation_steps

            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()

                scheduler.step()

                optimizer.zero_grad()

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average training loss: {avg_train_loss:.4f}")

        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch in tqdm(val_dataloader, desc="Validation"):
                input_ids = batch['input_ids'].to(DEVICE)
                attention_mask = batch['attention_mask'].to(DEVICE)
                labels = batch['labels']
                sent_lengths = batch.get('sent_lengths', None)

                outputs = model(input_ids, attention_mask, sent_lengths)
                probs = torch.sigmoid(outputs).cpu().numpy()
                val_preds.append(probs)
                val_labels.append(labels.numpy())

        val_preds = np.vstack(val_preds)
        val_labels = np.vstack(val_labels)

        print("Finding optimal thresholds:")
        new_thresholds, val_f1 = find_optimal_thresholds(val_preds, val_labels)
        print(f"Validation F1: {val_f1:.4f}")

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_model_state = model.state_dict().copy()
            thresholds = new_thresholds
            wait = 0
            print(f"New best model saved! F1: {val_f1:.4f}")
        else:
            wait += 1
            print(f"No improvement. Patience: {wait}/{patience}")

            if wait >= patience:
                print(f"Early stopping triggered after epoch {epoch}")
                break

    if best_model_state:
        model.load_state_dict(best_model_state)

    torch.save({
        'model_state_dict': model.state_dict(),
        'thresholds': thresholds,
        'val_f1': best_val_f1
    }, 'berthan_pandora_model_continued_2.pt')

    print(f"Improved model saved as berthan_pandora_model_continued.pt")

    return model, thresholds, best_val_f1

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_dataset = PandoraBinaryHANDataset(train_df, tokenizer)
val_dataset = PandoraBinaryHANDataset(val_df, tokenizer)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_han_batch
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_han_batch
)



In [None]:
train_df, val_df, test_df = load_pandora_binary_data()

In [None]:
model, thresholds, test_f1 = train_berthan_model(
  train_df, val_df, test_df, epochs=5, patience=2
)

In [59]:
model, thresholds, val_f1 = continue_training_berthan_model(
    model_path='berthan_pandora_model_continued.pt',
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    continue_epochs=3,
    patience=2,
)

print(f"Continued training complete. Validation F1: {val_f1:.4f}")

Loading model from berthan_pandora_model.pt
Loaded model with validation F1: 0.7080
Classification thresholds: [0.41 0.43 0.46 0.42 0.4 ]

Epoch 1/3


Training:   0%|          | 0/2767 [00:00<?, ?it/s]

Average training loss: 0.0380


Validation:   0%|          | 0/552 [00:00<?, ?it/s]

Finding optimal thresholds:
 O: Best threshold = 0.44, F1 = 0.7021
 C: Best threshold = 0.42, F1 = 0.7178
 E: Best threshold = 0.44, F1 = 0.7131
 A: Best threshold = 0.43, F1 = 0.7234
 N: Best threshold = 0.44, F1 = 0.6764
Validation F1: 0.7065
× No improvement. Patience: 1/2

Epoch 2/3


Training:   0%|          | 0/2767 [00:00<?, ?it/s]

Average training loss: 0.0373


Validation:   0%|          | 0/552 [00:00<?, ?it/s]

Finding optimal thresholds:
 O: Best threshold = 0.42, F1 = 0.7039
 C: Best threshold = 0.43, F1 = 0.7200
 E: Best threshold = 0.44, F1 = 0.7170
 A: Best threshold = 0.43, F1 = 0.7251
 N: Best threshold = 0.41, F1 = 0.6782
Validation F1: 0.7089
✓ New best model saved! F1: 0.7089

Epoch 3/3


Training:   0%|          | 0/2767 [00:00<?, ?it/s]

Average training loss: 0.0365


Validation:   0%|          | 0/552 [00:00<?, ?it/s]

Finding optimal thresholds:
 O: Best threshold = 0.43, F1 = 0.7057
 C: Best threshold = 0.43, F1 = 0.7212
 E: Best threshold = 0.45, F1 = 0.7180
 A: Best threshold = 0.44, F1 = 0.7269
 N: Best threshold = 0.42, F1 = 0.6804
Validation F1: 0.7104
✓ New best model saved! F1: 0.7104
Improved model saved as berthan_pandora_model_continued.pt
Continued training complete. Validation F1: 0.7104
