In [1]:
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer, BartModel, BartForSequenceClassification, Trainer, TrainingArguments
import accelerate

import numpy as np
import pandas as pd
import os

from sklearn.metrics import f1_score, accuracy_score, classification_report
from sklearn.cluster import KMeans

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
assert torch.cuda.is_available()
device = torch.device("cuda")

# Data pre-processing

In [3]:
df = pd.read_csv("data/bone_tumor.csv")
df = df.sample(frac=1)
df.head()

Unnamed: 0,Patient ID,Sex,Age,Grade,Histological type,MSKCC type,Site of primary STS,"Status (NED, AWD, D)",Treatment
70,STS_071,Female,64,High,epithelioid sarcoma,MFH,left biceps,D,Radiotherapy + Surgery
342,STS_343,Female,74,Intermediate,pleiomorphic leiomyosarcoma,MFH,left thigh,NED,Radiotherapy + Surgery
345,STS_346,Male,34,Intermediate,undifferentiated pleomorphic liposarcoma,Leiomyosarcoma,parascapusular,NED,Radiotherapy + Surgery
267,STS_268,Female,40,Intermediate,synovial sarcoma,Synovial sarcoma,right parascapusular,NED,Radiotherapy + Surgery
154,STS_155,Female,76,High,pleiomorphic leiomyosarcoma,MFH,left biceps,NED,Radiotherapy + Surgery


In [4]:

splits = [0.9, 0.05, 0.05]
train_len, eval_len, test_len = [int(len(df) * k) for k in splits]

train_stats_df = df[:train_len]
eval_stats_df = df[train_len : train_len + eval_len]
test_stats_df = df[train_len + eval_len : ]

In [5]:
train_stats_df.head()

Unnamed: 0,Patient ID,Sex,Age,Grade,Histological type,MSKCC type,Site of primary STS,"Status (NED, AWD, D)",Treatment
70,STS_071,Female,64,High,epithelioid sarcoma,MFH,left biceps,D,Radiotherapy + Surgery
342,STS_343,Female,74,Intermediate,pleiomorphic leiomyosarcoma,MFH,left thigh,NED,Radiotherapy + Surgery
345,STS_346,Male,34,Intermediate,undifferentiated pleomorphic liposarcoma,Leiomyosarcoma,parascapusular,NED,Radiotherapy + Surgery
267,STS_268,Female,40,Intermediate,synovial sarcoma,Synovial sarcoma,right parascapusular,NED,Radiotherapy + Surgery
154,STS_155,Female,76,High,pleiomorphic leiomyosarcoma,MFH,left biceps,NED,Radiotherapy + Surgery


In [6]:
eval_stats_df.head()

Unnamed: 0,Patient ID,Sex,Age,Grade,Histological type,MSKCC type,Site of primary STS,"Status (NED, AWD, D)",Treatment
197,STS_198,Female,17,High,poorly differentiated synovial sarcoma,Synovial sarcoma,right buttock,D,Surgery + Chemotherapy
261,STS_262,Female,25,Intermediate,synovial sarcoma,Synovial sarcoma,right thigh,NED,Radiotherapy + Surgery
332,STS_333,Male,43,High,pleiomorphic spindle cell undifferentiated,MFH,left buttock,AWD,Radiotherapy + Surgery + Chemotherapy
488,STS_489,Female,66,Intermediate,synovial sarcoma,Synovial sarcoma,right thigh,NED,Radiotherapy + Surgery
445,STS_446,Male,80,High,leiomyosarcoma,MFH,left buttock,AWD,Radiotherapy + Surgery


In [7]:
test_stats_df.head()

Unnamed: 0,Patient ID,Sex,Age,Grade,Histological type,MSKCC type,Site of primary STS,"Status (NED, AWD, D)",Treatment
276,STS_277,Male,38,High,poorly differentiated synovial sarcoma,MFH,right thigh,AWD,Radiotherapy + Surgery + Chemotherapy
2,STS_003,Male,22,Intermediate,synovial sarcoma,MFH,right buttock,D,Radiotherapy + Surgery
462,STS_463,Female,51,High,synovial sarcoma,Synovial sarcoma,right buttock,D,Surgery + Chemotherapy
381,STS_382,Female,46,Intermediate,malignant solitary fibrous tumor,MFH,parascapusular,D,Radiotherapy + Surgery
119,STS_120,Female,68,Intermediate,synovial sarcoma,MFH,right thigh,NED,Radiotherapy + Surgery


In [8]:
def grade_description(grade: str) -> str:
    grade = grade.lower()
    assert grade in ["high", "intermediate", "low"]

    if grade == "high":
        return "The tumor is fast-growing and more likely to spread"
    elif grade == "intermediate":
        return "The tumor is of medium growth rate and has an average risk of spreading"
    else:
        return "The tumor is slow-growing and less likely to spread"

In [9]:
def status_description(status: str) -> str:
    status = status.upper()
    assert status in ["NED", "AWD", "D"]

    if status == "NED":
        return "The patient is cancer-free.", 0
    elif status == "AWD":
        return "The patient has cancer but is not showing any signs of disease progression.", 1
    else:
        return "The patient has died from cancer.", 2

In [10]:
def treatment_description(treatment: str) -> str:
    lst = treatment.split(" + ")
    res = ", ".join(lst[:len(lst) - 1])
    res += f" and {lst[-1]}"
    return res.lower()

In [11]:
def construct_dataset(df):
    data = []

    for i, row in df.iterrows():
        sex = row["Sex"].lower()
        age = row["Age"]
        grade = grade_description(row["Grade"])
        histological_type = row["Histological type"] # type of tumor
        mskcc_type = row["MSKCC type"] # more specific classification of the tumor
        site_of_primary_sts = row["Site of primary STS"] # location of the tumor in the bone
        treatment = treatment_description(row["Treatment"])
        
        # used as the label (y)
        status, label = status_description(row["Status (NED, AWD, D)"])

        story = f"This patient is a {age}-year-old {sex}. The patient has a tumor of histological type {histological_type} and MSKCC type {mskcc_type} on their {site_of_primary_sts}. {grade}. The patient has received the following treatment: {treatment}."

        data.append({"story": story, "status": status, "label": label})
    
    return pd.DataFrame(data)

In [12]:
train_df = construct_dataset(train_stats_df)
eval_df = construct_dataset(eval_stats_df)
test_df = construct_dataset(test_stats_df)

In [13]:
train_df.head()

Unnamed: 0,story,status,label
0,This patient is a 64-year-old female. The pati...,The patient has died from cancer.,2
1,This patient is a 74-year-old female. The pati...,The patient is cancer-free.,0
2,This patient is a 34-year-old male. The patien...,The patient is cancer-free.,0
3,This patient is a 40-year-old female. The pati...,The patient is cancer-free.,0
4,This patient is a 76-year-old female. The pati...,The patient is cancer-free.,0


In [14]:
eval_df.head()

Unnamed: 0,story,status,label
0,This patient is a 17-year-old female. The pati...,The patient has died from cancer.,2
1,This patient is a 25-year-old female. The pati...,The patient is cancer-free.,0
2,This patient is a 43-year-old male. The patien...,The patient has cancer but is not showing any ...,1
3,This patient is a 66-year-old female. The pati...,The patient is cancer-free.,0
4,This patient is a 80-year-old male. The patien...,The patient has cancer but is not showing any ...,1


In [15]:
test_df.head()

Unnamed: 0,story,status,label
0,This patient is a 38-year-old male. The patien...,The patient has cancer but is not showing any ...,1
1,This patient is a 22-year-old male. The patien...,The patient has died from cancer.,2
2,This patient is a 51-year-old female. The pati...,The patient has died from cancer.,2
3,This patient is a 46-year-old female. The pati...,The patient has died from cancer.,2
4,This patient is a 68-year-old female. The pati...,The patient is cancer-free.,0


# BART

## Classification

In [16]:
tokenizer = AutoTokenizer.from_pretrained(
    "facebook/bart-large",
)
tokenizer.padding_side = "right"

model = BartForSequenceClassification.from_pretrained("facebook/bart-large", num_labels=3).to(device)

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-large and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Sanity check
Calculate accuracy with base model

In [17]:
predictions = []
actuals = []

for i, row in test_df.iterrows():
    inputs = tokenizer(row["story"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    
    with torch.no_grad():  # Don't compute gradients during inference
        prediction = model(**inputs.to(device))
    
    predicted_label = torch.argmax(prediction.logits, dim=-1).item()
    
    predictions.append(predicted_label)
    actuals.append(row["label"])

In [18]:
print(f"Accuracy: {accuracy_score(actuals, predictions)}")
print(classification_report(actuals, predictions))

Accuracy: 0.2
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        13
           1       0.00      0.00      0.00         7
           2       0.20      1.00      0.33         5

    accuracy                           0.20        25
   macro avg       0.07      0.33      0.11        25
weighted avg       0.04      0.20      0.07        25



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


### Training

In [19]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings['input_ids'])

In [20]:
def encode_data(tokenizer, text, labels):
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt").to(device)
    inputs["labels"] = torch.tensor(labels)
    return inputs

In [21]:
def compute_metrics(eval_pred):
    logits = eval_pred.predictions[0]
    labels = eval_pred.label_ids
    predictions = np.argmax(logits, axis=-1)
    return {"f1": f1_score(labels, predictions, average="weighted")}

In [22]:
def train_model(train_dataset, eval_dataset):
    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=20,
        per_device_train_batch_size=10,
        per_device_eval_batch_size=10,
        warmup_steps=5,
        weight_decay=0.01,
        eval_strategy="steps",
        eval_steps=5,
        save_strategy="best",
        metric_for_best_model="f1",
        greater_is_better=True,
        dataloader_pin_memory=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics
    )

    trainer.train()

    return trainer

In [23]:
train_dataset = TextDataset(encode_data(tokenizer, train_df["story"].tolist(), train_df["label"].tolist()))
eval_dataset = TextDataset(encode_data(tokenizer, eval_df["story"].tolist(), eval_df["label"].tolist()))

In [24]:
trainer = train_model(train_dataset, eval_dataset)

  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Step,Training Loss,Validation Loss,F1
5,No log,1.085821,0.268889
10,No log,1.034094,0.394286
15,No log,0.97833,0.47873
20,No log,0.964461,0.595508
25,No log,0.974731,0.44
30,No log,0.915516,0.430476
35,No log,0.954589,0.533333
40,No log,0.969573,0.556613
45,No log,0.969042,0.470065
50,No log,1.121298,0.505092


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
  return {key: t

In [25]:
best_checkpoint = sorted(os.listdir("results"), key=lambda x: int(x.split("-")[1]))[-1]
best_checkpoint = f"results/{best_checkpoint}"

In [26]:
ft_model = BartForSequenceClassification.from_pretrained(best_checkpoint, num_labels=3).to(device)

In [27]:
predictions = []
actuals = []

for i, row in test_df.iterrows():
    inputs = tokenizer(row["story"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    
    with torch.no_grad():
        prediction = ft_model(**inputs.to(device))
    
    predicted_label = torch.argmax(prediction.logits, dim=-1).item()
    
    predictions.append(predicted_label)
    actuals.append(row["label"])

In [28]:
print(f"Accuracy: {accuracy_score(actuals, predictions)}")
print(classification_report(actuals, predictions))

Accuracy: 0.8
              precision    recall  f1-score   support

           0       0.86      0.92      0.89        13
           1       1.00      0.57      0.73         7
           2       0.57      0.80      0.67         5

    accuracy                           0.80        25
   macro avg       0.81      0.76      0.76        25
weighted avg       0.84      0.80      0.80        25



In [29]:
predictions

[1, 2, 2, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 2, 2, 0, 1, 2, 0, 1, 0, 2, 0, 0, 0]

In [30]:
actuals

[1, 2, 2, 2, 0, 1, 0, 0, 1, 0, 1, 0, 0, 2, 1, 0, 1, 2, 0, 1, 0, 0, 0, 0, 0]

## Clustering
We are using the weights of the finetuned model from the previous part

In [31]:
clustering_model = BartModel.from_pretrained(best_checkpoint, num_labels=3).to(device)

In [32]:


hidden_states_list = []

for i, row in test_df.iterrows():
    inputs = tokenizer(row["story"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")

    with torch.no_grad():
        prediction = clustering_model(**inputs.to(device))

    last_hidden = prediction.last_hidden_state.mean(dim=1) # average pooling
    hidden_states_list.append(last_hidden.cpu().numpy())

hidden_states_array = np.vstack(hidden_states_list)

In [33]:
kmeans = KMeans(n_clusters=3, random_state=67)
cluster_labels = kmeans.fit_predict(hidden_states_array)

In [34]:
clustering_df = pd.DataFrame()
clustering_df["cluster"] = cluster_labels
clustering_df["golden"] = test_df["label"]

clustering_df

Unnamed: 0,cluster,golden
0,2,1
1,0,2
2,0,2
3,1,2
4,0,0
5,2,1
6,1,0
7,1,0
8,1,1
9,1,0


In [35]:
print(f"Accuracy: {accuracy_score(clustering_df["golden"].to_list(), clustering_df["cluster"].to_list())}")
print(classification_report(clustering_df["golden"].to_list(), clustering_df["cluster"].to_list()))

Accuracy: 0.2
              precision    recall  f1-score   support

           0       0.44      0.31      0.36        13
           1       0.09      0.14      0.11         7
           2       0.00      0.00      0.00         5

    accuracy                           0.20        25
   macro avg       0.18      0.15      0.16        25
weighted avg       0.26      0.20      0.22        25

