In [1]:
import random
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, roc_curve, auc, precision_score
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler


# Set random seed
seed = 3407
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load downstream task dataset
file_path = 'data3.csv'
data2_df = pd.read_csv(file_path)

# Split training and test sets with 8:2 ratio
train_df, test_df = train_test_split(data2_df, test_size=0.1, random_state=seed)

# Function to generate enhanced representations using upstream model
def generate_enhanced_embeddings(texts, tokenizer, model):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(device)
    outputs = model(**inputs)
    embeddings = outputs.hidden_states[-1][:, 0, :]
    predictions = torch.sigmoid(outputs.logits)
    enhanced_embeddings = torch.cat((embeddings, predictions), dim=1)
    return enhanced_embeddings

# Custom Dataset class
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, model):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.model = model

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        embeddings = generate_enhanced_embeddings([text], self.tokenizer, self.model)
        return embeddings.squeeze(), label

# Define simple binary classifier
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim):
        super(SimpleClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return self.linear(x)

# Load pre-trained model and tokenizer (load only once)
model_path = "saved_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path)
config.output_hidden_states = True

# Initialize model and classifier
model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config).to(device)
classifier = SimpleClassifier(config.hidden_size + config.num_labels).to(device)

# Create Dataset and DataLoader
train_dataset = TextDataset(train_df['compound'].tolist(), train_df['label'].tolist(), tokenizer, model)
test_dataset = TextDataset(test_df['compound'].tolist(), test_df['label'].tolist(), tokenizer, model)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Put parameters of both model and classifier in the same optimizer, using AdamW
optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=0.000260006006712908)

# Use CosineAnnealingLR scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64)

# Define loss function
criterion = nn.BCEWithLogitsLoss()

# Use GradScaler for mixed precision training
scaler = GradScaler()

# Track the model with best AUC on test set
best_test_auc = 0.0
best_epoch = 0  # Record the best epoch

# Train model and classifier
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    classifier.train()
    epoch_loss = 0
    
    for embeddings, labels in train_loader:
        optimizer.zero_grad()
        
        with autocast(device_type='cuda'):
            embeddings, labels = embeddings.to(device), labels.to(device, dtype=torch.float32)
            outputs = classifier(embeddings).squeeze()
            loss = criterion(outputs, labels)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader)

    # Calculate loss, AUC and precision on test set
    model.eval()
    classifier.eval()
    test_outputs = []
    test_labels = []
    test_predictions = []
    test_loss = 0
    
    with torch.no_grad():
        for embeddings, labels in test_loader:
            embeddings, labels = embeddings.to(device), labels.to(device, dtype=torch.float32)
            outputs = classifier(embeddings).squeeze()
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            
            # Convert prediction probabilities to binary labels
            predictions = torch.sigmoid(outputs) > 0.5
            test_outputs.extend(outputs.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())
            test_predictions.extend(predictions.cpu().numpy())

    test_loss /= len(test_loader)

    # Calculate AUC
    fpr, tpr, _ = roc_curve(test_labels, test_outputs)
    test_auc = auc(fpr, tpr)

    # Calculate precision
    test_precision = precision_score(test_labels, test_predictions)

    # Save the best performing model on test set, and immediately save to disk
    if test_auc > best_test_auc:
        best_test_auc = test_auc
        best_epoch = epoch + 1  # Record the epoch of the best model

        # Immediately save the best model to file
        torch.save({
            'model_state_dict': model.state_dict(),
            'classifier_state_dict': classifier.state_dict(),
            'test_auc': best_test_auc,
            'epoch': best_epoch
        }, "best_model.pth")

    scheduler.step()

    # Print training and test results for each epoch, including precision
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.2f}, Test Loss: {test_loss:.2f}, Test AUC: {test_auc:.2f}, Test Precision: {test_precision:.2f}")

# Print the epoch with the best model
print(f"Best Test AUC was {best_test_auc:.2f} at epoch {best_epoch}")

# Load the best performing model on test set
checkpoint = torch.load("best_model.pth")
model.load_state_dict(checkpoint['model_state_dict'])
classifier.load_state_dict(checkpoint['classifier_state_dict'])

# Perform final evaluation after loading the best model
model.eval()
classifier.eval()
test_accuracy = 0
all_test_labels = []
all_test_predictions = []
all_test_outputs = []

with torch.no_grad():
    for embeddings, labels in test_loader:
        embeddings, labels = embeddings.to(device), labels.to(device, dtype=torch.float32)
        outputs = classifier(embeddings).squeeze()
        test_predictions = torch.sigmoid(outputs) > 0.5
        test_accuracy += accuracy_score(labels.cpu(), test_predictions.cpu())
        all_test_labels.extend(labels.cpu().numpy())
        all_test_predictions.extend(test_predictions.cpu().numpy())
        all_test_outputs.extend(outputs.cpu().numpy())

test_accuracy /= len(test_loader)

# Calculate confusion matrix and other evaluation metrics
conf_matrix = confusion_matrix(all_test_labels, all_test_predictions)
class_report = classification_report(all_test_labels, all_test_predictions, target_names=['Class 0', 'Class 1'])
accuracy = accuracy_score(all_test_labels, all_test_predictions)

# Calculate ROC curve and AUC
fpr, tpr, _ = roc_curve(all_test_labels, all_test_outputs)
roc_auc = auc(fpr, tpr)

print(f"\nBest Test Accuracy: {test_accuracy:.2f}")
print(f"Confusion Matrix:\n{conf_matrix}")
print(f"\nClassification Report:\n{class_report}")
print(f"ROC AUC: {roc_auc:.2f}")

  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/50, Train Loss: 0.60, Test Loss: 0.54, Test AUC: 0.79, Test Precision: 0.77
Epoch 2/50, Train Loss: 0.51, Test Loss: 0.48, Test AUC: 0.84, Test Precision: 0.78
Epoch 3/50, Train Loss: 0.45, Test Loss: 0.46, Test AUC: 0.86, Test Precision: 0.80
Epoch 4/50, Train Loss: 0.39, Test Loss: 0.44, Test AUC: 0.87, Test Precision: 0.84
Epoch 5/50, Train Loss: 0.34, Test Loss: 0.43, Test AUC: 0.88, Test Precision: 0.83
Epoch 6/50, Train Loss: 0.31, Test Loss: 0.43, Test AUC: 0.89, Test Precision: 0.85
Epoch 7/50, Train Loss: 0.28, Test Loss: 0.44, Test AUC: 0.89, Test Precision: 0.85
Epoch 8/50, Train Loss: 0.24, Test Loss: 0.46, Test AUC: 0.89, Test Precision: 0.84
Epoch 9/50, Train Loss: 0.23, Test Loss: 0.45, Test AUC: 0.90, Test Precision: 0.85
Epoch 10/50, Train Loss: 0.21, Test Loss: 0.48, Test AUC: 0.90, Test Precision: 0.86
Epoch 11/50, Train Loss: 0.19, Test Loss: 0.52, Test AUC: 0.90, Test Precision: 0.88
Epoch 12/50, Train Loss: 0.18, Test Loss: 0.48, Test AUC: 0.91, Test Preci

  checkpoint = torch.load("best_model.pth")



Best Test Accuracy: 0.86
Confusion Matrix:
[[349  95]
 [ 71 589]]

Classification Report:
              precision    recall  f1-score   support

     Class 0       0.83      0.79      0.81       444
     Class 1       0.86      0.89      0.88       660

    accuracy                           0.85      1104
   macro avg       0.85      0.84      0.84      1104
weighted avg       0.85      0.85      0.85      1104

ROC AUC: 0.92
