In [25]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import XLNetTokenizer, XLNetForSequenceClassification, AdamW, AutoTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd

#!pip install datasets
from datasets import load_dataset

Model and tokenizer definitions.

In [26]:
model = "xlnet-base-cased"
tokenizer = XLNetTokenizer.from_pretrained(model)
model = XLNetForSequenceClassification.from_pretrained(model, num_labels=2)  # Adjust num_labels based on your ground truth

Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.bias', 'sequence_summary.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

XLNetForSequenceClassification(
  (transformer): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0-11): 12 x XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation_function): GELUActivation()
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sequence_summary): SequenceSummary(
    (summary): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
    (first_dropout): Identity()
    (last

Dataset prep


In [28]:
class DementiaBankDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # print(f"Accessing index: {idx}, Length of texts: {len(self.texts)}")
        if idx >= len(self.texts):
            raise IndexError(f"Index {idx} is out of bounds for texts with length {len(self.texts)}")
         # Convert the text to a list to handle potential non-standard indexing
        text = self.texts.tolist()[idx]
        label = self.labels.iloc[idx]  # Use iloc for integer-location based indexing

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

Load data

In [29]:
# load dataset
ds = load_dataset("Ak000/Dementia_Bank_Train")

# remove the useless 'instruction' column
ds = ds.remove_columns('instruction')

# rename the 'input' column to 'text'
ds = ds.rename_column('input', 'text')

# map 'dementia' to 1 and 'control' to 0 so we make it a pandas df
df = ds['train'].to_pandas()
df['ground_truth'] = (df['output'] == 'dementia')
df['ground_truth'] = df['ground_truth'].astype(int)
df.head()


Unnamed: 0,output,text,ground_truth
0,dementia,this boy is getting cookies outof this jar ....,1
1,control,well for one thing this boy's on the stool ge...,0
2,control,oh yes . well the mother is washing the dish...,0
3,dementia,action ? what's going on in the picture ? the...,1
4,control,the little boy is getting a cookie out of the...,0


Preprocess data

In [30]:
# texts and labels are both lists
texts = df['text']
labels = df['ground_truth']

# splitting data into train, val, and test (80, 15, 5% respectively)
# First split: separate out the test set
train_val_texts, test_texts, train_val_labels, test_labels = train_test_split(
    texts, labels, test_size=0.05, random_state=42)

# Second split: divide the remaining data into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    train_val_texts, train_val_labels, test_size=0.1875, random_state=42)  # 0.1875 x 0.8 = 0.15

train_dataset = DementiaBankDataset(train_texts, train_labels, tokenizer, max_length=128)
val_dataset = DementiaBankDataset(val_texts, val_labels, tokenizer, max_length=128)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

Training setup:
  Optimizer and Linear Scheduler

In [31]:
from transformers import get_linear_schedule_with_warmup

# set up optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Adjust me!
# 5 -> .78, 10 -> .87+,
num_epochs = 8

# set up linear scheduler
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)



Setup early stopping

In [32]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')

            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

early_stopping = EarlyStopping(patience=3)

Training loop

In [33]:
from tqdm import tqdm
import numpy as np


def train():
  best_val_loss = float('inf')
  for epoch in range(num_epochs):
      model.train()
      total_loss = 0

      for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          labels = batch['labels'].to(device)

          optimizer.zero_grad()
          outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
          loss = outputs.loss
          total_loss += loss.item()

          loss.backward()
          optimizer.step()

      avg_train_loss = total_loss / len(train_loader)
      print(f"Average training loss: {avg_train_loss:.4f}")

      # Validation from here on!
      model.eval()
      val_loss = 0
      val_preds = []
      val_true = []

      with torch.no_grad():
          for batch in val_loader:
              input_ids = batch['input_ids'].to(device)
              attention_mask = batch['attention_mask'].to(device)
              labels = batch['labels'].to(device)

              outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
              loss = outputs.loss
              val_loss += loss.item()

              logits = outputs.logits
              preds = torch.argmax(logits, dim=1).cpu().numpy()
              val_preds.extend(preds)
              val_true.extend(labels.cpu().numpy())

      avg_val_loss = val_loss / len(val_loader)
      val_accuracy = np.mean(np.array(val_preds) == np.array(val_true))
      print(f"Validation loss: {avg_val_loss:.4f}")
      print(f"Validation accuracy: {val_accuracy:.4f}")

      # Save the best model
      if avg_val_loss < best_val_loss:
          best_val_loss = avg_val_loss
          best_val_accuracy = val_accuracy
          torch.save(model.state_dict(), 'best_xlnet_model.pth')
          print("Saved best model!")

      # Early stopping check (if wanted)
      early_stopping(avg_val_loss)
      if early_stopping.early_stop:
          print("Early stopping!!!")
          print(f"Best validation loss: {best_val_loss:.4f}")
          print(f"Best Validation accuracy: {best_val_accuracy:.4f}")
          break

In [34]:
train()


Epoch 1/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 1/8:   5%|▌         | 1/20 [00:00<00:03,  4.99it/s][A
Epoch 1/8:  10%|█         | 2/20 [00:00<00:03,  5.05it/s][A
Epoch 1/8:  15%|█▌        | 3/20 [00:00<00:03,  5.05it/s][A
Epoch 1/8:  20%|██        | 4/20 [00:00<00:03,  4.99it/s][A
Epoch 1/8:  25%|██▌       | 5/20 [00:01<00:03,  4.97it/s][A
Epoch 1/8:  30%|███       | 6/20 [00:01<00:02,  4.95it/s][A
Epoch 1/8:  35%|███▌      | 7/20 [00:01<00:02,  4.95it/s][A
Epoch 1/8:  40%|████      | 8/20 [00:01<00:02,  4.97it/s][A
Epoch 1/8:  45%|████▌     | 9/20 [00:01<00:02,  4.98it/s][A
Epoch 1/8:  50%|█████     | 10/20 [00:02<00:02,  4.99it/s][A
Epoch 1/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.99it/s][A
Epoch 1/8:  60%|██████    | 12/20 [00:02<00:01,  4.98it/s][A
Epoch 1/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.98it/s][A
Epoch 1/8:  70%|███████   | 14/20 [00:02<00:01,  4.99it/s][A
Epoch 1/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.99it/s][A
Epoch 1/8:  80%|████████ 

Average training loss: 0.6921
Validation loss: 0.6062
Validation accuracy: 0.7183
Saved best model!



Epoch 2/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 2/8:   5%|▌         | 1/20 [00:00<00:03,  5.21it/s][A
Epoch 2/8:  10%|█         | 2/20 [00:00<00:03,  5.10it/s][A
Epoch 2/8:  15%|█▌        | 3/20 [00:00<00:03,  5.08it/s][A
Epoch 2/8:  20%|██        | 4/20 [00:00<00:03,  5.03it/s][A
Epoch 2/8:  25%|██▌       | 5/20 [00:00<00:03,  4.97it/s][A
Epoch 2/8:  30%|███       | 6/20 [00:01<00:02,  4.96it/s][A
Epoch 2/8:  35%|███▌      | 7/20 [00:01<00:02,  4.99it/s][A
Epoch 2/8:  40%|████      | 8/20 [00:01<00:02,  4.99it/s][A
Epoch 2/8:  45%|████▌     | 9/20 [00:01<00:02,  5.02it/s][A
Epoch 2/8:  50%|█████     | 10/20 [00:01<00:01,  5.02it/s][A
Epoch 2/8:  55%|█████▌    | 11/20 [00:02<00:01,  5.01it/s][A
Epoch 2/8:  60%|██████    | 12/20 [00:02<00:01,  4.97it/s][A
Epoch 2/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.95it/s][A
Epoch 2/8:  70%|███████   | 14/20 [00:02<00:01,  4.95it/s][A
Epoch 2/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.95it/s][A
Epoch 2/8:  80%|████████ 

Average training loss: 0.5343
Validation loss: 0.4137
Validation accuracy: 0.7746
Saved best model!



Epoch 3/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 3/8:   5%|▌         | 1/20 [00:00<00:03,  5.16it/s][A
Epoch 3/8:  10%|█         | 2/20 [00:00<00:03,  5.11it/s][A
Epoch 3/8:  15%|█▌        | 3/20 [00:00<00:03,  5.06it/s][A
Epoch 3/8:  20%|██        | 4/20 [00:00<00:03,  5.03it/s][A
Epoch 3/8:  25%|██▌       | 5/20 [00:00<00:03,  4.98it/s][A
Epoch 3/8:  30%|███       | 6/20 [00:01<00:02,  4.96it/s][A
Epoch 3/8:  35%|███▌      | 7/20 [00:01<00:02,  4.92it/s][A
Epoch 3/8:  40%|████      | 8/20 [00:01<00:02,  4.92it/s][A
Epoch 3/8:  45%|████▌     | 9/20 [00:01<00:02,  4.93it/s][A
Epoch 3/8:  50%|█████     | 10/20 [00:02<00:02,  4.89it/s][A
Epoch 3/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.85it/s][A
Epoch 3/8:  60%|██████    | 12/20 [00:02<00:01,  4.87it/s][A
Epoch 3/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.89it/s][A
Epoch 3/8:  70%|███████   | 14/20 [00:02<00:01,  4.89it/s][A
Epoch 3/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.92it/s][A
Epoch 3/8:  80%|████████ 

Average training loss: 0.4248
Validation loss: 0.7476
Validation accuracy: 0.6620
EarlyStopping counter: 1 out of 3



Epoch 4/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 4/8:   5%|▌         | 1/20 [00:00<00:03,  5.11it/s][A
Epoch 4/8:  10%|█         | 2/20 [00:00<00:03,  5.04it/s][A
Epoch 4/8:  15%|█▌        | 3/20 [00:00<00:03,  4.99it/s][A
Epoch 4/8:  20%|██        | 4/20 [00:00<00:03,  4.94it/s][A
Epoch 4/8:  25%|██▌       | 5/20 [00:01<00:03,  4.92it/s][A
Epoch 4/8:  30%|███       | 6/20 [00:01<00:02,  4.92it/s][A
Epoch 4/8:  35%|███▌      | 7/20 [00:01<00:02,  4.93it/s][A
Epoch 4/8:  40%|████      | 8/20 [00:01<00:02,  4.93it/s][A
Epoch 4/8:  45%|████▌     | 9/20 [00:01<00:02,  4.98it/s][A
Epoch 4/8:  50%|█████     | 10/20 [00:02<00:02,  4.97it/s][A
Epoch 4/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.99it/s][A
Epoch 4/8:  60%|██████    | 12/20 [00:02<00:01,  4.96it/s][A
Epoch 4/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.93it/s][A
Epoch 4/8:  70%|███████   | 14/20 [00:02<00:01,  4.94it/s][A
Epoch 4/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.91it/s][A
Epoch 4/8:  80%|████████ 

Average training loss: 0.3757
Validation loss: 0.3331
Validation accuracy: 0.8310
Saved best model!



Epoch 5/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 5/8:   5%|▌         | 1/20 [00:00<00:04,  4.44it/s][A
Epoch 5/8:  10%|█         | 2/20 [00:00<00:03,  4.79it/s][A
Epoch 5/8:  15%|█▌        | 3/20 [00:00<00:03,  4.89it/s][A
Epoch 5/8:  20%|██        | 4/20 [00:00<00:03,  4.92it/s][A
Epoch 5/8:  25%|██▌       | 5/20 [00:01<00:03,  4.93it/s][A
Epoch 5/8:  30%|███       | 6/20 [00:01<00:02,  4.94it/s][A
Epoch 5/8:  35%|███▌      | 7/20 [00:01<00:02,  4.93it/s][A
Epoch 5/8:  40%|████      | 8/20 [00:01<00:02,  4.93it/s][A
Epoch 5/8:  45%|████▌     | 9/20 [00:01<00:02,  4.91it/s][A
Epoch 5/8:  50%|█████     | 10/20 [00:02<00:02,  4.92it/s][A
Epoch 5/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.94it/s][A
Epoch 5/8:  60%|██████    | 12/20 [00:02<00:01,  4.93it/s][A
Epoch 5/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.94it/s][A
Epoch 5/8:  70%|███████   | 14/20 [00:02<00:01,  4.97it/s][A
Epoch 5/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.98it/s][A
Epoch 5/8:  80%|████████ 

Average training loss: 0.1837
Validation loss: 0.7540
Validation accuracy: 0.8169
EarlyStopping counter: 1 out of 3



Epoch 6/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 6/8:   5%|▌         | 1/20 [00:00<00:03,  5.08it/s][A
Epoch 6/8:  10%|█         | 2/20 [00:00<00:03,  5.00it/s][A
Epoch 6/8:  15%|█▌        | 3/20 [00:00<00:03,  4.98it/s][A
Epoch 6/8:  20%|██        | 4/20 [00:00<00:03,  4.97it/s][A
Epoch 6/8:  25%|██▌       | 5/20 [00:01<00:03,  4.92it/s][A
Epoch 6/8:  30%|███       | 6/20 [00:01<00:02,  4.92it/s][A
Epoch 6/8:  35%|███▌      | 7/20 [00:01<00:02,  4.90it/s][A
Epoch 6/8:  40%|████      | 8/20 [00:01<00:02,  4.91it/s][A
Epoch 6/8:  45%|████▌     | 9/20 [00:01<00:02,  4.92it/s][A
Epoch 6/8:  50%|█████     | 10/20 [00:02<00:02,  4.92it/s][A
Epoch 6/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.89it/s][A
Epoch 6/8:  60%|██████    | 12/20 [00:02<00:01,  4.89it/s][A
Epoch 6/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.91it/s][A
Epoch 6/8:  70%|███████   | 14/20 [00:02<00:01,  4.91it/s][A
Epoch 6/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.89it/s][A
Epoch 6/8:  80%|████████ 

Average training loss: 0.1347
Validation loss: 0.9006
Validation accuracy: 0.8028
EarlyStopping counter: 2 out of 3



Epoch 7/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 7/8:   5%|▌         | 1/20 [00:00<00:03,  5.07it/s][A
Epoch 7/8:  10%|█         | 2/20 [00:00<00:03,  4.98it/s][A
Epoch 7/8:  15%|█▌        | 3/20 [00:00<00:03,  5.01it/s][A
Epoch 7/8:  20%|██        | 4/20 [00:00<00:03,  4.99it/s][A
Epoch 7/8:  25%|██▌       | 5/20 [00:01<00:03,  4.96it/s][A
Epoch 7/8:  30%|███       | 6/20 [00:01<00:02,  4.92it/s][A
Epoch 7/8:  35%|███▌      | 7/20 [00:01<00:02,  4.93it/s][A
Epoch 7/8:  40%|████      | 8/20 [00:01<00:02,  4.94it/s][A
Epoch 7/8:  45%|████▌     | 9/20 [00:01<00:02,  4.92it/s][A
Epoch 7/8:  50%|█████     | 10/20 [00:02<00:02,  4.93it/s][A
Epoch 7/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.92it/s][A
Epoch 7/8:  60%|██████    | 12/20 [00:02<00:01,  4.91it/s][A
Epoch 7/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.92it/s][A
Epoch 7/8:  70%|███████   | 14/20 [00:02<00:01,  4.95it/s][A
Epoch 7/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.94it/s][A
Epoch 7/8:  80%|████████ 

Average training loss: 0.2461
Validation loss: 0.3144
Validation accuracy: 0.8592
Saved best model!



Epoch 8/8:   0%|          | 0/20 [00:00<?, ?it/s][A
Epoch 8/8:   5%|▌         | 1/20 [00:00<00:03,  5.09it/s][A
Epoch 8/8:  10%|█         | 2/20 [00:00<00:03,  5.02it/s][A
Epoch 8/8:  15%|█▌        | 3/20 [00:00<00:03,  4.98it/s][A
Epoch 8/8:  20%|██        | 4/20 [00:00<00:03,  4.98it/s][A
Epoch 8/8:  25%|██▌       | 5/20 [00:01<00:03,  4.91it/s][A
Epoch 8/8:  30%|███       | 6/20 [00:01<00:02,  4.90it/s][A
Epoch 8/8:  35%|███▌      | 7/20 [00:01<00:02,  4.90it/s][A
Epoch 8/8:  40%|████      | 8/20 [00:01<00:02,  4.90it/s][A
Epoch 8/8:  45%|████▌     | 9/20 [00:01<00:02,  4.90it/s][A
Epoch 8/8:  50%|█████     | 10/20 [00:02<00:02,  4.89it/s][A
Epoch 8/8:  55%|█████▌    | 11/20 [00:02<00:01,  4.90it/s][A
Epoch 8/8:  60%|██████    | 12/20 [00:02<00:01,  4.90it/s][A
Epoch 8/8:  65%|██████▌   | 13/20 [00:02<00:01,  4.89it/s][A
Epoch 8/8:  70%|███████   | 14/20 [00:02<00:01,  4.90it/s][A
Epoch 8/8:  75%|███████▌  | 15/20 [00:03<00:01,  4.89it/s][A
Epoch 8/8:  80%|████████ 

Average training loss: 0.0850
Validation loss: 0.4075
Validation accuracy: 0.8873
EarlyStopping counter: 1 out of 3


Setup model evaluation

In [35]:
from sklearn.metrics import classification_report, confusion_matrix

def evaluate_model(model, data_loader, device):
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).cpu().numpy()

            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    print(classification_report(true_labels, predictions))
    print("Confusion Matrix:")
    print(confusion_matrix(true_labels, predictions))

Evaluate the model

In [36]:
evaluate_model(model, val_loader, device)

              precision    recall  f1-score   support

           0       0.94      0.85      0.89        40
           1       0.83      0.94      0.88        31

    accuracy                           0.89        71
   macro avg       0.89      0.89      0.89        71
weighted avg       0.89      0.89      0.89        71

Confusion Matrix:
[[34  6]
 [ 2 29]]


Inference Function

In [37]:
def predict(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        pred = torch.argmax(logits, dim=1).item()

    if pred == 1:
        pred = "dementia"
    else:
        pred = "control"

    return pred

Call inference function on test texts

In [38]:
test_gt = []
test_preds = []
for text, label in zip(test_texts, test_labels):  # Iterate directly over items
    pred = predict(text, model, tokenizer, device)
    gt = "dementia" if label == 1 else "control"
    print(f"Prediction: {pred}, Ground truth: {gt}")
    test_gt.append(gt)  # Build ground truth list
    test_preds.append(pred)  # Build predictions list

Prediction: control, Ground truth: control
Prediction: control, Ground truth: control
Prediction: control, Ground truth: control
Prediction: control, Ground truth: dementia
Prediction: dementia, Ground truth: dementia
Prediction: control, Ground truth: control
Prediction: control, Ground truth: control
Prediction: control, Ground truth: control
Prediction: dementia, Ground truth: dementia
Prediction: dementia, Ground truth: dementia
Prediction: dementia, Ground truth: dementia
Prediction: dementia, Ground truth: control
Prediction: control, Ground truth: control
Prediction: dementia, Ground truth: dementia
Prediction: control, Ground truth: dementia
Prediction: control, Ground truth: dementia
Prediction: dementia, Ground truth: dementia
Prediction: dementia, Ground truth: dementia
Prediction: control, Ground truth: control
Prediction: control, Ground truth: control


In [39]:
print(classification_report(test_gt, test_preds))

              precision    recall  f1-score   support

     control       0.75      0.90      0.82        10
    dementia       0.88      0.70      0.78        10

    accuracy                           0.80        20
   macro avg       0.81      0.80      0.80        20
weighted avg       0.81      0.80      0.80        20



Upload to HuggingFace

In [40]:
# from huggingface_hub import notebook_login
# notebook_login()

In [41]:
# # Push to hub
# model.push_to_hub("rmezapi/dementia-bank-seq-classif-xlnet")
# tokenizer.push_to_hub("rmezapi/dementia-bank-seq-classif-xlnet")

CommitInfo(commit_url='https://huggingface.co/rmezapi/dementia-bank-seq-classif-xlnet/commit/b9bf06695cb6da8fe4039437220a85ed0f57266d', commit_message='Upload tokenizer', commit_description='', oid='b9bf06695cb6da8fe4039437220a85ed0f57266d', pr_url=None, pr_revision=None, pr_num=None)

In [42]:
# import requests

# API_URL = "https://api-inference.huggingface.co/models/rmezapi/dementia-bank-seq-classif-xlnet"
# headers = {"Authorization": "Bearer key_here"}

# def query(payload):
# 	response = requests.post(API_URL, headers=headers, json=payload)
# 	return response.json()

# output = query({
# 	"inputs": "the water's running on the floor . boy's taking cookies out of cookie outof the cookie jar . the stool is falling open over . the girl was asking for a cookie . the wife is wiping the dish . I guess not .",
# })

# print(output)