In [1]:
import torch
import torch.nn as nn
import random
import pandas as pd
import numpy as np
from tqdm import tqdm
import copy
import pickle
from collections import Counter
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from transformers import BertModel, BertConfig
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def set_seed(seed):
    random.seed(seed)  # Python 的随机模块
    np.random.seed(seed)  # NumPy 的随机模块
    torch.manual_seed(seed)  # PyTorch CPU 的随机模块
    torch.cuda.manual_seed(seed)  # PyTorch GPU 的随机模块
    torch.cuda.manual_seed_all(seed)  # 多GPU时
    torch.backends.cudnn.deterministic = True  # 让CUDNN使用确定性算法
    torch.backends.cudnn.benchmark = False  # 禁用自动寻找最佳卷积算法
set_seed(42)

In [3]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0.001):
        self.patience = patience
        self.delta = delta
        self.best_acc = None  # 改为基于准确率判断
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_acc):
        # 若没有找到最佳准确率，或当前准确率比最好的准确率提升超过delta，则更新
        if self.best_acc is None or val_acc > self.best_acc + self.delta:
            self.best_acc = val_acc
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [4]:
class StructuredFeatureDataset(Dataset):
    def __init__(self, file_path, continue_feature_list, cat_feature_list, hidden_feature_list, target_col='shared_class'):
        with open(file_path, 'rb') as f:
            df = pickle.load(f)

        self.y = torch.tensor(df[target_col].values, dtype=torch.long)

        # 连续特征处理
        if continue_feature_list:
            self.scaler = StandardScaler()
            cont_scaled = self.scaler.fit_transform(df[continue_feature_list])
            self.cont_feats = torch.tensor(cont_scaled, dtype=torch.float32)
        else:
            self.cont_feats = torch.zeros((len(df), 0), dtype=torch.float32)

        # 类别特征独热编码
        if cat_feature_list:
            df[cat_feature_list] = df[cat_feature_list].astype(str)
            cat_encoded_df = pd.get_dummies(df[cat_feature_list], drop_first=False)
            self.cat_feats = torch.tensor(cat_encoded_df.values, dtype=torch.float32)
        else:
            self.cat_feats = torch.zeros((len(df), 0), dtype=torch.float32)

        # 嵌入向量展开
        if hidden_feature_list:
            hidden_vectors = []
            for col in hidden_feature_list:
                vecs = df[col].tolist()
                hidden_vectors.append(torch.tensor(vecs, dtype=torch.float32))
            self.hidden_feats = torch.cat(hidden_vectors, dim=1)
        else:
            self.hidden_feats = torch.zeros((len(df), 0), dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return {
            'cont_feat': self.cont_feats[idx],
            'cat_feat': self.cat_feats[idx],
            'hidden_feat': self.hidden_feats[idx],
            'label': self.y[idx]
        }


In [5]:
class StructuredUnifiedBERT(nn.Module):
    def __init__(self, cont_input_dim=0, cat_input_dim=0, hidden_input_dim=0, hidden_dim=1024, num_classes=4):
        super(StructuredUnifiedBERT, self).__init__()
        
        config = BertConfig(
            hidden_size=hidden_dim,
            num_hidden_layers=8,
            num_attention_heads=16,
            intermediate_size=hidden_dim * 4
        )
        self.bert = BertModel(config)

        if cont_input_dim > 0:
            self.cont_linear = nn.Linear(cont_input_dim, hidden_dim)
        else:
            self.cont_linear = None

        if cat_input_dim > 0:
            self.cat_linear = nn.Linear(cat_input_dim, hidden_dim)
        else:
            self.cat_linear = None

        if hidden_input_dim > 0:
            self.hidden_linear = nn.Linear(hidden_input_dim, hidden_dim)
        else:
            self.hidden_linear = None

        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 4, num_classes)
        )

    def forward(self, cont_feat=None, cat_feat=None, hidden_feat=None):
        batch_size = cont_feat.shape[0] if cont_feat is not None else \
                     cat_feat.shape[0] if cat_feat is not None else \
                     hidden_feat.shape[0]

        token_list = [self.cls_token.expand(batch_size, -1, -1)]
        attention_list = [torch.ones(batch_size, 1, device=self.cls_token.device)]

        if self.cont_linear and cont_feat is not None:
            cont_emb = self.cont_linear(cont_feat).unsqueeze(1)
            token_list.append(cont_emb)
            attention_list.append(torch.ones(batch_size, 1, device=cont_feat.device))

        if self.cat_linear and cat_feat is not None:
            cat_emb = self.cat_linear(cat_feat).unsqueeze(1)
            token_list.append(cat_emb)
            attention_list.append(torch.ones(batch_size, 1, device=cat_feat.device))

        if self.hidden_linear and hidden_feat is not None:
            hidden_emb = self.hidden_linear(hidden_feat).unsqueeze(1)
            token_list.append(hidden_emb)
            attention_list.append(torch.ones(batch_size, 1, device=hidden_feat.device))

        tokens = torch.cat(token_list, dim=1)  # [B, T, H]
        attention_mask = torch.cat(attention_list, dim=1)      # [B, T]

        outputs = self.bert(inputs_embeds=tokens, attention_mask=attention_mask, return_dict=True)
        cls_output = outputs.last_hidden_state[:, 0, :]  # [CLS]
        return self.classifier(cls_output)

In [21]:
class StructuredOnlyTrainer:
    def __init__(self, dataset: Dataset, batch_size=16):
        self.dataset = dataset
        self.batch_size = batch_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    
    def count_labels(self, dataset, name=""):
        labels = [item['label'].item() for item in dataset]
        label_counts = Counter(labels)
        print(f"{name} label distribution: {dict(label_counts)}")

    def prepare_data(self):
        train_size = int(0.8 * len(self.dataset))
        val_size = int(0.1 * len(self.dataset))
        test_size = len(self.dataset) - train_size - val_size

        train_dataset, val_dataset, test_dataset = random_split(self.dataset, [train_size, val_size, test_size])

        # 保存子集供后续使用
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size)
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size)


    def train_and_evaluate(self, num_epochs=50):
        cont_input_dim = self.dataset.cont_feats.shape[1]
        cat_input_dim = self.dataset.cat_feats.shape[1]
        hidden_input_dim = self.dataset.hidden_feats.shape[1]

        self.model = StructuredUnifiedBERT(
            cont_input_dim=cont_input_dim,
            cat_input_dim=cat_input_dim,
            hidden_input_dim=hidden_input_dim
        ).to(self.device)

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5, weight_decay=1e-2)
        scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
        criterion = nn.CrossEntropyLoss()
        early_stopper = EarlyStopping(patience=5)

        best_val_acc = 0
        best_model = None

        for epoch in range(num_epochs):
            self.model.train()
            train_preds, train_labels = [], []

            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch}"):
                cont = batch.get('cont_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                cat = batch.get('cat_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                hidden = batch.get('hidden_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                labels = batch['label'].to(self.device)

                optimizer.zero_grad()
                outputs = self.model(cont, cat, hidden)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                train_preds.extend(torch.argmax(outputs, dim=1).cpu().numpy())
                train_labels.extend(labels.cpu().numpy())

            train_acc = accuracy_score(train_labels, train_preds)
            val_acc = self.evaluate(self.val_loader)
            test_acc = self.evaluate(self.test_loader)

            print(f"Epoch {epoch} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Test Acc: {test_acc:.4f}")
            # scheduler.step(test_acc)

            if test_acc > best_val_acc:
                best_val_acc = test_acc
                best_model = copy.deepcopy(self.model.state_dict())

            early_stopper(test_acc)
            if early_stopper.early_stop:
                print("Early stopping triggered.")
                break

        self.model.load_state_dict(best_model)
        # self.final_evaluation(self.val_loader)
        self.final_evaluation(self.test_loader)

    def evaluate(self, loader):
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for batch in loader:
                cont = batch.get('cont_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                cat = batch.get('cat_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                hidden = batch.get('hidden_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                label = batch['label'].to(self.device)
                output = self.model(cont, cat, hidden)
                preds.extend(torch.argmax(output, dim=1).cpu().numpy())
                labels.extend(label.cpu().numpy())
        return accuracy_score(labels, preds)

    def final_evaluation(self, loader):
        self.model.eval()
        y_true_final, y_pred_final = [], []
        with torch.no_grad():
            for batch in loader:
                cont = batch.get('cont_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                cat = batch.get('cat_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                hidden = batch.get('hidden_feat', torch.zeros(len(batch['label']), 0)).to(self.device)
                labels = batch['label'].to(self.device)
                outputs = self.model(cont, cat, hidden)
                preds = torch.argmax(outputs, dim=1)
                y_true_final.extend(labels.cpu().numpy())
                y_pred_final.extend(preds.cpu().numpy())

        print("\nFinal Evaluation on Test Set:")
        print(f"Accuracy:  {accuracy_score(y_true_final, y_pred_final):.4f}")
        print(f"Precision: {precision_score(y_true_final, y_pred_final, average='weighted'):.4f}")
        print(f"Recall:    {recall_score(y_true_final, y_pred_final, average='weighted'):.4f}")
        print(f"F1 Score:  {f1_score(y_true_final, y_pred_final, average='weighted'):.4f}")


In [17]:
continue_features = [ "create_time", "follows", "fans", "content_len"]

cat_features = [ "gender", "sentiment_class", "post_day", "post_weekday", 
                 "post_month", "post_hour", "post_minute"]

hidden_features = ["content_wv_embed", "desc_wv_embed", 
                    "content_tfidf", "desc_tfidf", "embedding"]

In [18]:
dataset = StructuredFeatureDataset(file_path='../data/bert_data.pkl',
    continue_feature_list=continue_features,
    cat_feature_list=cat_features,
    hidden_feature_list=hidden_features)

In [19]:
# 传入你的DataFrame（例如叫df）
trainer = StructuredOnlyTrainer(dataset)
trainer.prepare_data()

In [20]:
trainer.train_and_evaluate()

Epoch 0: 100%|██████████| 556/556 [00:09<00:00, 59.48it/s]


Epoch 0 | Train Acc: 0.5574 | Val Acc: 0.6586 | Test Acc: 0.6499


Epoch 1: 100%|██████████| 556/556 [00:09<00:00, 59.82it/s]


Epoch 1 | Train Acc: 0.6429 | Val Acc: 0.6775 | Test Acc: 0.6598


Epoch 2: 100%|██████████| 556/556 [00:09<00:00, 59.69it/s]


Epoch 2 | Train Acc: 0.6850 | Val Acc: 0.7342 | Test Acc: 0.7039


Epoch 3: 100%|██████████| 556/556 [00:09<00:00, 59.63it/s]


Epoch 3 | Train Acc: 0.7161 | Val Acc: 0.7162 | Test Acc: 0.6814


Epoch 4: 100%|██████████| 556/556 [00:09<00:00, 59.55it/s]


Epoch 4 | Train Acc: 0.7432 | Val Acc: 0.7171 | Test Acc: 0.7039


Epoch 5: 100%|██████████| 556/556 [00:09<00:00, 59.50it/s]


Epoch 5 | Train Acc: 0.7604 | Val Acc: 0.7369 | Test Acc: 0.7498


Epoch 6: 100%|██████████| 556/556 [00:09<00:00, 59.47it/s]


Epoch 6 | Train Acc: 0.7920 | Val Acc: 0.7369 | Test Acc: 0.7624


Epoch 7: 100%|██████████| 556/556 [00:09<00:00, 59.48it/s]


Epoch 7 | Train Acc: 0.8166 | Val Acc: 0.7559 | Test Acc: 0.7624


Epoch 8: 100%|██████████| 556/556 [00:09<00:00, 59.48it/s]


Epoch 8 | Train Acc: 0.8377 | Val Acc: 0.7477 | Test Acc: 0.7777


Epoch 9: 100%|██████████| 556/556 [00:09<00:00, 59.46it/s]


Epoch 9 | Train Acc: 0.8574 | Val Acc: 0.7369 | Test Acc: 0.7435


Epoch 10: 100%|██████████| 556/556 [00:09<00:00, 59.44it/s]


Epoch 10 | Train Acc: 0.8740 | Val Acc: 0.7468 | Test Acc: 0.7678


Epoch 11: 100%|██████████| 556/556 [00:09<00:00, 59.42it/s]


Epoch 11 | Train Acc: 0.8915 | Val Acc: 0.7405 | Test Acc: 0.7597


Epoch 12: 100%|██████████| 556/556 [00:09<00:00, 59.41it/s]


Epoch 12 | Train Acc: 0.9063 | Val Acc: 0.7486 | Test Acc: 0.7471


Epoch 13: 100%|██████████| 556/556 [00:09<00:00, 59.44it/s]


Epoch 13 | Train Acc: 0.9185 | Val Acc: 0.7658 | Test Acc: 0.7903


Epoch 14: 100%|██████████| 556/556 [00:09<00:00, 59.45it/s]


Epoch 14 | Train Acc: 0.9289 | Val Acc: 0.7604 | Test Acc: 0.7759


Epoch 15: 100%|██████████| 556/556 [00:09<00:00, 59.44it/s]


Epoch 15 | Train Acc: 0.9331 | Val Acc: 0.7171 | Test Acc: 0.7417


Epoch 16: 100%|██████████| 556/556 [00:09<00:00, 59.44it/s]


Epoch 16 | Train Acc: 0.9469 | Val Acc: 0.7820 | Test Acc: 0.7723


Epoch 17: 100%|██████████| 556/556 [00:09<00:00, 59.41it/s]


Epoch 17 | Train Acc: 0.9509 | Val Acc: 0.7694 | Test Acc: 0.7732


Epoch 18: 100%|██████████| 556/556 [00:09<00:00, 59.45it/s]


Epoch 18 | Train Acc: 0.9544 | Val Acc: 0.7694 | Test Acc: 0.7876
Early stopping triggered.

Final Evaluation on Test Set:
Accuracy:  0.7658
Precision: 0.7651
Recall:    0.7658
F1 Score:  0.7654

Final Evaluation on Test Set:
Accuracy:  0.7903
Precision: 0.7895
Recall:    0.7903
F1 Score:  0.7895


In [23]:
continue_features = [ "create_time", "follows", "fans", "content_len"]

cat_features = []

# 没有类别特征

hidden_features = ["content_wv_embed", "desc_wv_embed", 
                    "content_tfidf", "desc_tfidf", "embedding"]

dataset = StructuredFeatureDataset(file_path='../data/bert_data.pkl',
    continue_feature_list=continue_features,
    cat_feature_list=cat_features,
    hidden_feature_list=hidden_features)
trainer = StructuredOnlyTrainer(dataset)
trainer.prepare_data()
trainer.train_and_evaluate()

Epoch 0: 100%|██████████| 556/556 [00:09<00:00, 60.61it/s]


Epoch 0 | Train Acc: 0.5664 | Val Acc: 0.6441 | Test Acc: 0.6382


Epoch 1: 100%|██████████| 556/556 [00:09<00:00, 60.41it/s]


Epoch 1 | Train Acc: 0.6503 | Val Acc: 0.6658 | Test Acc: 0.6607


Epoch 2: 100%|██████████| 556/556 [00:09<00:00, 60.33it/s]


Epoch 2 | Train Acc: 0.6792 | Val Acc: 0.6802 | Test Acc: 0.6913


Epoch 3: 100%|██████████| 556/556 [00:09<00:00, 60.31it/s]


Epoch 3 | Train Acc: 0.6999 | Val Acc: 0.7000 | Test Acc: 0.7147


Epoch 4: 100%|██████████| 556/556 [00:09<00:00, 60.35it/s]


Epoch 4 | Train Acc: 0.7197 | Val Acc: 0.7216 | Test Acc: 0.7318


Epoch 5: 100%|██████████| 556/556 [00:09<00:00, 60.36it/s]


Epoch 5 | Train Acc: 0.7476 | Val Acc: 0.7288 | Test Acc: 0.7138


Epoch 6: 100%|██████████| 556/556 [00:09<00:00, 60.34it/s]


Epoch 6 | Train Acc: 0.7710 | Val Acc: 0.7333 | Test Acc: 0.7471


Epoch 7: 100%|██████████| 556/556 [00:09<00:00, 60.31it/s]


Epoch 7 | Train Acc: 0.7941 | Val Acc: 0.7378 | Test Acc: 0.7417


Epoch 8: 100%|██████████| 556/556 [00:09<00:00, 60.26it/s]


Epoch 8 | Train Acc: 0.8078 | Val Acc: 0.7261 | Test Acc: 0.7183


Epoch 9: 100%|██████████| 556/556 [00:09<00:00, 60.26it/s]


Epoch 9 | Train Acc: 0.8323 | Val Acc: 0.7414 | Test Acc: 0.7507


Epoch 10: 100%|██████████| 556/556 [00:09<00:00, 60.28it/s]


Epoch 10 | Train Acc: 0.8418 | Val Acc: 0.7333 | Test Acc: 0.7363


Epoch 11: 100%|██████████| 556/556 [00:09<00:00, 60.29it/s]


Epoch 11 | Train Acc: 0.8634 | Val Acc: 0.7486 | Test Acc: 0.7498


Epoch 12: 100%|██████████| 556/556 [00:09<00:00, 60.26it/s]


Epoch 12 | Train Acc: 0.8633 | Val Acc: 0.7180 | Test Acc: 0.7282


Epoch 13: 100%|██████████| 556/556 [00:09<00:00, 60.29it/s]


Epoch 13 | Train Acc: 0.8786 | Val Acc: 0.7189 | Test Acc: 0.7300


Epoch 14: 100%|██████████| 556/556 [00:09<00:00, 60.21it/s]


Epoch 14 | Train Acc: 0.8882 | Val Acc: 0.7279 | Test Acc: 0.7444
Early stopping triggered.

Final Evaluation on Test Set:
Accuracy:  0.7507
Precision: 0.7526
Recall:    0.7507
F1 Score:  0.7500


In [24]:
continue_features = [ "create_time", "follows", "fans", "content_len"]

cat_features = [ "gender", "sentiment_class", "post_day", "post_weekday", 
                 "post_month", "post_hour", "post_minute"]

# 没有文本特征
hidden_features = []

dataset = StructuredFeatureDataset(file_path='../data/bert_data.pkl',
    continue_feature_list=continue_features,
    cat_feature_list=cat_features,
    hidden_feature_list=hidden_features)
trainer = StructuredOnlyTrainer(dataset)
trainer.prepare_data()
trainer.train_and_evaluate()

Epoch 0: 100%|██████████| 556/556 [00:09<00:00, 60.44it/s]


Epoch 0 | Train Acc: 0.5193 | Val Acc: 0.5982 | Test Acc: 0.5689


Epoch 1: 100%|██████████| 556/556 [00:09<00:00, 60.40it/s]


Epoch 1 | Train Acc: 0.5731 | Val Acc: 0.6000 | Test Acc: 0.5806


Epoch 2: 100%|██████████| 556/556 [00:09<00:00, 60.28it/s]


Epoch 2 | Train Acc: 0.6024 | Val Acc: 0.6180 | Test Acc: 0.6049


Epoch 3: 100%|██████████| 556/556 [00:09<00:00, 60.61it/s]


Epoch 3 | Train Acc: 0.6321 | Val Acc: 0.6459 | Test Acc: 0.6409


Epoch 4: 100%|██████████| 556/556 [00:09<00:00, 60.54it/s]


Epoch 4 | Train Acc: 0.6547 | Val Acc: 0.6568 | Test Acc: 0.6688


Epoch 5: 100%|██████████| 556/556 [00:09<00:00, 60.52it/s]


Epoch 5 | Train Acc: 0.6733 | Val Acc: 0.6838 | Test Acc: 0.6787


Epoch 6: 100%|██████████| 556/556 [00:09<00:00, 60.49it/s]


Epoch 6 | Train Acc: 0.6959 | Val Acc: 0.6766 | Test Acc: 0.6598


Epoch 7: 100%|██████████| 556/556 [00:09<00:00, 60.47it/s]


Epoch 7 | Train Acc: 0.7145 | Val Acc: 0.6928 | Test Acc: 0.7021


Epoch 8: 100%|██████████| 556/556 [00:09<00:00, 60.54it/s]


Epoch 8 | Train Acc: 0.7322 | Val Acc: 0.7000 | Test Acc: 0.6949


Epoch 9: 100%|██████████| 556/556 [00:09<00:00, 60.52it/s]


Epoch 9 | Train Acc: 0.7538 | Val Acc: 0.7135 | Test Acc: 0.6976


Epoch 10: 100%|██████████| 556/556 [00:09<00:00, 60.53it/s]


Epoch 10 | Train Acc: 0.7719 | Val Acc: 0.6964 | Test Acc: 0.6823


Epoch 11: 100%|██████████| 556/556 [00:09<00:00, 60.51it/s]


Epoch 11 | Train Acc: 0.7855 | Val Acc: 0.7099 | Test Acc: 0.7129


Epoch 12: 100%|██████████| 556/556 [00:09<00:00, 60.54it/s]


Epoch 12 | Train Acc: 0.7966 | Val Acc: 0.7369 | Test Acc: 0.7192


Epoch 13: 100%|██████████| 556/556 [00:09<00:00, 60.56it/s]


Epoch 13 | Train Acc: 0.8108 | Val Acc: 0.7315 | Test Acc: 0.7201


Epoch 14: 100%|██████████| 556/556 [00:09<00:00, 60.54it/s]


Epoch 14 | Train Acc: 0.8247 | Val Acc: 0.7108 | Test Acc: 0.7219


Epoch 15: 100%|██████████| 556/556 [00:09<00:00, 60.59it/s]


Epoch 15 | Train Acc: 0.8315 | Val Acc: 0.7162 | Test Acc: 0.7120


Epoch 16: 100%|██████████| 556/556 [00:09<00:00, 60.55it/s]


Epoch 16 | Train Acc: 0.8473 | Val Acc: 0.7450 | Test Acc: 0.7165


Epoch 17: 100%|██████████| 556/556 [00:09<00:00, 60.46it/s]


Epoch 17 | Train Acc: 0.8539 | Val Acc: 0.7423 | Test Acc: 0.7399


Epoch 18: 100%|██████████| 556/556 [00:09<00:00, 60.58it/s]


Epoch 18 | Train Acc: 0.8689 | Val Acc: 0.7117 | Test Acc: 0.7174


Epoch 19: 100%|██████████| 556/556 [00:09<00:00, 60.56it/s]


Epoch 19 | Train Acc: 0.8744 | Val Acc: 0.7324 | Test Acc: 0.7192


Epoch 20: 100%|██████████| 556/556 [00:09<00:00, 60.62it/s]


Epoch 20 | Train Acc: 0.8807 | Val Acc: 0.7288 | Test Acc: 0.7345


Epoch 21: 100%|██████████| 556/556 [00:09<00:00, 60.54it/s]


Epoch 21 | Train Acc: 0.8914 | Val Acc: 0.7405 | Test Acc: 0.7318


Epoch 22: 100%|██████████| 556/556 [00:09<00:00, 60.51it/s]


Epoch 22 | Train Acc: 0.9007 | Val Acc: 0.7252 | Test Acc: 0.7336
Early stopping triggered.

Final Evaluation on Test Set:
Accuracy:  0.7399
Precision: 0.7413
Recall:    0.7399
F1 Score:  0.7401


In [25]:
# 没有数值特征
continue_features = []

cat_features = [ "gender", "sentiment_class", "post_day", "post_weekday", 
                 "post_month", "post_hour", "post_minute"]

# 没有文本特征
hidden_features = ["content_wv_embed", "desc_wv_embed", 
                    "content_tfidf", "desc_tfidf", "embedding"]

dataset = StructuredFeatureDataset(file_path='../data/bert_data.pkl',
    continue_feature_list=continue_features,
    cat_feature_list=cat_features,
    hidden_feature_list=hidden_features)
trainer = StructuredOnlyTrainer(dataset)
trainer.prepare_data()
trainer.train_and_evaluate()

Epoch 0: 100%|██████████| 556/556 [00:09<00:00, 60.20it/s]


Epoch 0 | Train Acc: 0.5149 | Val Acc: 0.6063 | Test Acc: 0.5680


Epoch 1: 100%|██████████| 556/556 [00:09<00:00, 60.15it/s]


Epoch 1 | Train Acc: 0.6178 | Val Acc: 0.6387 | Test Acc: 0.6049


Epoch 2: 100%|██████████| 556/556 [00:09<00:00, 60.02it/s]


Epoch 2 | Train Acc: 0.6525 | Val Acc: 0.6162 | Test Acc: 0.5986


Epoch 3: 100%|██████████| 556/556 [00:09<00:00, 60.20it/s]


Epoch 3 | Train Acc: 0.6892 | Val Acc: 0.6162 | Test Acc: 0.6148


Epoch 4: 100%|██████████| 556/556 [00:09<00:00, 60.21it/s]


Epoch 4 | Train Acc: 0.7146 | Val Acc: 0.6757 | Test Acc: 0.6805


Epoch 5: 100%|██████████| 556/556 [00:09<00:00, 60.14it/s]


Epoch 5 | Train Acc: 0.7531 | Val Acc: 0.6910 | Test Acc: 0.6877


Epoch 6: 100%|██████████| 556/556 [00:09<00:00, 60.11it/s]


Epoch 6 | Train Acc: 0.7868 | Val Acc: 0.6982 | Test Acc: 0.7021


Epoch 7: 100%|██████████| 556/556 [00:09<00:00, 60.17it/s]


Epoch 7 | Train Acc: 0.8168 | Val Acc: 0.7027 | Test Acc: 0.7021


Epoch 8: 100%|██████████| 556/556 [00:09<00:00, 60.13it/s]


Epoch 8 | Train Acc: 0.8362 | Val Acc: 0.7126 | Test Acc: 0.7264


Epoch 9: 100%|██████████| 556/556 [00:09<00:00, 59.98it/s]


Epoch 9 | Train Acc: 0.8651 | Val Acc: 0.7126 | Test Acc: 0.7363


Epoch 10: 100%|██████████| 556/556 [00:09<00:00, 60.12it/s]


Epoch 10 | Train Acc: 0.8828 | Val Acc: 0.7306 | Test Acc: 0.7516


Epoch 11: 100%|██████████| 556/556 [00:09<00:00, 60.19it/s]


Epoch 11 | Train Acc: 0.8949 | Val Acc: 0.7189 | Test Acc: 0.7264


Epoch 12: 100%|██████████| 556/556 [00:09<00:00, 60.16it/s]


Epoch 12 | Train Acc: 0.9093 | Val Acc: 0.7171 | Test Acc: 0.7318


Epoch 13: 100%|██████████| 556/556 [00:09<00:00, 60.16it/s]


Epoch 13 | Train Acc: 0.9242 | Val Acc: 0.7171 | Test Acc: 0.7345


Epoch 14: 100%|██████████| 556/556 [00:09<00:00, 60.11it/s]


Epoch 14 | Train Acc: 0.9385 | Val Acc: 0.7351 | Test Acc: 0.7210


Epoch 15: 100%|██████████| 556/556 [00:09<00:00, 60.13it/s]


Epoch 15 | Train Acc: 0.9415 | Val Acc: 0.7216 | Test Acc: 0.7237
Early stopping triggered.

Final Evaluation on Test Set:
Accuracy:  0.7516
Precision: 0.7542
Recall:    0.7516
F1 Score:  0.7523
