In [1]:
#Code to Apply TCGA-Trained T14 Classifier to Unseen Data
#Here, using TCGA T14 Held-Out Test Set as example 

from transformers import AutoTokenizer, AutoModel
from transformers import BigBirdForSequenceClassification
from transformers import TrainingArguments, Trainer
import torch
from sklearn.metrics import roc_auc_score
import pandas as pd
from scipy.special import softmax
import time 
start = time.time()

num_classes = 4
max_tokens = 2048
output_dir = 'output_directory/'

In [2]:
#Download Model and Instantiate Testing Object
def compute_metrics(eval_pred):
    raw_pred, labels = eval_pred
    score_pred = softmax(raw_pred, axis=1)
    #roc = roc_auc_score(labels, score_pred) 
    au_roc_macro = roc_auc_score(labels, score_pred, multi_class='ovr', average = 'macro') 
    return {"roc_auc": au_roc_macro} 

tokenizer = AutoTokenizer.from_pretrained("yikuan8/Clinical-BigBird")
model = BigBirdForSequenceClassification.from_pretrained("jkefeli/CancerStage_Classifier_T", num_labels=num_classes)

best_trainer = Trainer(model=model, compute_metrics=compute_metrics,
    args=TrainingArguments(output_dir = output_dir)) 

In [3]:
#Import Test Data
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])


test_data = pd.read_csv('T14_test.csv')
X_test = list(test_data['text'])
y_test = list(test_data['label'])
X_test_tokenized=tokenizer(X_test, padding=True, truncation=True, max_length=max_tokens)
test_dataset = Dataset(X_test_tokenized,y_test)

In [4]:
#Compute AU-ROC of Model Applied to Test Dataset
test_roc = round(best_trainer.evaluate(test_dataset)['eval_roc_auc'],4)
print(test_roc)

***** Running Evaluation *****
  Num examples = 1034
  Batch size = 8


0.9451


In [5]:
#Track Run-Time (seconds)
end = time.time()
runtime = round((end - start)/60,3)
print('\nElapsed Time: ', runtime, 'Minutes')


Elapsed Time:  120.204 Minutes
