# QCHAT Inference Workflow

This notebook demonstrates the inference process for evaluating a fine-tuned RoBERTa model on QCHAT-based ASD vs. TD classification.

In [None]:

import torch
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, roc_curve, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import RobertaForSequenceClassification, RobertaTokenizerFast, Trainer, TrainingArguments
from torch.utils.data import DataLoader, Dataset
from scipy.special import softmax

# Specify the fine-tuned model path
fine_tuned_model_path = "./Roberta_QCHAT_Model"

# Function to load the fine-tuned model and tokenizer
def load_fine_tuned_model(model_path):
    model = RobertaForSequenceClassification.from_pretrained(model_path)
    tokenizer = RobertaTokenizerFast.from_pretrained(model_path)
    return model, tokenizer

# Load the fine-tuned model and tokenizer
model, tokenizer = load_fine_tuned_model(fine_tuned_model_path)


## Step 1: Prepare Test Data

In [None]:

# Tokenize the test dataset
def tokenize_and_pad(dataset):
    return tokenizer(
        dataset['text'].tolist(),  # Use 'text' column for tokenization
        truncation=True,
        padding=True,
        max_length=567  # Set maximum token length
    )

# Define Dataset class
class QCHATDataset(Dataset):
    def __init__(self, encodings, labels, subject_ids):
        self.encodings = encodings
        self.labels = labels
        self.subject_ids = subject_ids

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        item['SubjectId'] = self.subject_ids[idx]
        return item

    def __len__(self):
        return len(self.labels)

# Tokenize the test data and create dataset
test_encodings = tokenize_and_pad(test)
test_subject_ids = test['SubjectId'].tolist()
test_dataset = QCHATDataset(test_encodings, test['label'].tolist(), test_subject_ids)


## Step 2: Generate Predictions

In [None]:

# Initialize the trainer
trainer = Trainer(model=model)

# Generate predictions on the test dataset
predictions_output = trainer.predict(test_dataset)

# Extract predictions, probabilities, and true labels
predictions = np.argmax(predictions_output.predictions, axis=-1)
probabilities = softmax(predictions_output.predictions, axis=-1)
true_labels = predictions_output.label_ids

# Create a DataFrame for results
df_results = pd.DataFrame({
    'True_Labels': true_labels,
    'Predictions': predictions,
    'Probabilities': list(probabilities),
    'SubjectId': test_subject_ids
})

# Display the DataFrame
print(df_results.head())

# Merge results with the original test dataset
test_final = pd.merge(test, df_results, on='SubjectId')


## Step 3: Evaluate Model Performance

In [None]:

from sklearn.metrics import confusion_matrix, classification_report

# Compute confusion matrix
conf_matrix = confusion_matrix(df_results['True_Labels'], df_results['Predictions'])

# Compute classification report
class_report = classification_report(df_results['True_Labels'], df_results['Predictions'])

# Print classification report
print(class_report)

# Plot the confusion matrix
sns.heatmap(conf_matrix, annot=True, fmt='g', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()


## Step 4: Plot ROC Curve

In [None]:

from sklearn.metrics import roc_curve, auc

# Compute true positive rates and false positive rates
y_true = np.array(df_results['True_Labels'])
y_proba = np.array([proba[1] for proba in df_results['Probabilities']])  # Probabilities for class 1

fpr, tpr, _ = roc_curve(y_true, y_proba)
roc_auc = roc_auc_score(y_true, y_proba)

# Plot the ROC curve
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.show()
