# Predictive model for differential diagnosis

In this notebook, our goal is to develop a model that can take in a patient's symptoms as an input and return a list of the top 3 possible classes (diseases) alongside confidence values for each class expressed as probabilities.


## Library and Data import

**date:** 2021-07-12  
**author:** "Rubanza Silver - Flexible Functions AI Lab"

In [1]:
#|include: false 

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
# Input data files are available in the read-only "../input/" directory
# For example, running the above (by clicking run or pressing Shift+Enter) will list all files under the input directory

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session


/kaggle/input/symptoms-disease-no-id/symptom_disease_no_id_col.csv
/kaggle/input/symptoms-disease-no-id/symptom_no_id.csv


In [2]:
#|include: false 

%pip install seaborn
%pip install fastkaggle
%pip install -Uqq fastbook
%pip install --upgrade pip
%pip install tqdm
#%pip install catboost
#%pip install optuna
#%pip install optuna_distributed
#%pip install openfe
#%pip install xgboost
#%pip install lightgbm
#%pip install h2o
#%pip install polars
#%pip install -q -U autogluon.tabular
#%pip install autogluon
#%pip install wandb
#%pip install sweetviz

Note: you may need to restart the kernel to use updated packages.
Collecting fastkaggle
  Downloading fastkaggle-0.0.8-py3-none-any.whl.metadata (4.3 kB)
Downloading fastkaggle-0.0.8-py3-none-any.whl (11 kB)
Installing collected packages: fastkaggle
Successfully installed fastkaggle-0.0.8
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting pip
  Downloading pip-25.0.1-py3-none-any.whl.metadata (3.7 kB)
Downloading pip-25.0.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m30.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.0
    Uninstalling pip-24.0:
      Successfully uninstalled pip-24.0
Successfully installed pip-25.0.1
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to 

In [3]:
#| code-fold: true
#| output: false
#| code-summary: "Library Import"

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

#import fastbook
#fastbook.setup_book()
#from fastbook import *
from fastai.tabular.all import *
import numpy as np
from numpy import random
from tqdm import tqdm
from ipywidgets import interact
from fastai.imports import *
np.set_printoptions(linewidth=130)
from fastai.text.all import *
from pathlib import Path
import os
import warnings
import gc
import pickle
from joblib import dump, load

# ULMFiT approach

In traditional text transfer learning, We use a pre-trained model called a language model. The model we are also going to use in this example was initially trained on Wikipedia on the task of guessing the next word. We then fine-tuned this model for our disease classification task based on symptoms. We can then use this model for our task of disease classification.

But the Wikipedia English might differ from medical jargon, so to further improve our model. We can employ a technique shown in the [ULMFIT Paper](https://arxiv.org/abs/1801.06146) by Jeremy Howard and Sebastian Ruder. They take the above a step further by fitting the pre-trained model on medical corpus and then using that as a base for our classifier. They noticed that adding this step of training the pretrained model on the task specific corpus gives better result as the model also has better context of the final task.

In [4]:
!ls /kaggle/input/symptoms-disease-no-id

symptom_disease_no_id_col.csv  symptom_no_id.csv


In [5]:
path = Path('/kaggle/input/symptoms-disease-no-id')
path

Path('/kaggle/input/symptoms-disease-no-id')

In [6]:
#symptom_df = pd.read_csv(path_lm/'symptom_synth.csv',index_col=0)
symptom_df = pd.read_csv(path/'symptom_no_id.csv')
sd_df = pd.read_csv(path/'symptom_disease_no_id_col.csv')
symptom_df.head()

Unnamed: 0,text
0,"I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches."
1,"My skin has been peeling, especially on my knees, elbows, and scalp. This peeling is often accompanied by a burning or stinging sensation."
2,"I have been experiencing joint pain in my fingers, wrists, and knees. The pain is often achy and throbbing, and it gets worse when I move my joints."
3,"There is a silver like dusting on my skin, especially on my lower back and scalp. This dusting is made up of small scales that flake off easily when I scratch them."
4,"My nails have small dents or pits in them, and they often feel inflammatory and tender to the touch. Even there are minor rashes on my arms."


In [7]:
symptom_df['text'].nunique(),sd_df['text'].unique()

(1153,
 array(['I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches.',
        'My skin has been peeling, especially on my knees, elbows, and scalp. This peeling is often accompanied by a burning or stinging sensation.',
        'I have been experiencing joint pain in my fingers, wrists, and knees. The pain is often achy and throbbing, and it gets worse when I move my joints.',
        ...,
        "I regularly experience these intense urges and the want to urinate. I frequently feel drowsy and lost. I've also significantly lost my vision.",
        'I have trouble breathing, especially outside. I start to feel hot and start to sweat. I frequently have urinary tract infections and yeast infections.',
        "I constantly sneeze and have a dry cough. My infections don't seem to be healing, and I have palpitations. My throat does ache occasionally, but it usually gets better."],
       dtype=object)

## Finetuning a language model with my medical corpus

Below I define a DataLoader which is an extension of PyTorch's DataLoaders class, albeit with more functionality. This takes in our data, and prepares it as input for our model, passing it in batches etc.

The DataLoaders Object allows us to build data objects we can use for training without specifically changing the raw input data.

The dataloader then acts as input for our models. We also pass in valid_pct=0.2 which samples and uses 20% of our data for validation.

In [8]:
#dls_lm = TextDataLoaders.from_df(symptom_df, path=path, is_lm=True, valid_pct=0.2)
dls_lm = TextDataLoaders.from_df(symptom_df, path=path, is_lm=True,text_col='text', valid_pct=0.2)
#dls_lm = TextDataLoaders.from_folder(path=path_lm, is_lm=True, valid_pct=0.1)

We then use show_batch to have a look at some of our data.Since, we are guessing the next word in a sentence, you will notice that the targets have shifted one word to thr right in the *text_* column.

In [9]:
dls_lm.show_batch(max_n=5)

Unnamed: 0,text,text_
0,"xxbos xxmaj i 've had low temps and blood in my urine for the past xxunk days . xxmaj my pee smells terrible , and my head hurts so much . xxmaj urinary urges xxunk xxunk on xxunk , and i almost never have any control over when they do . xxbos i do n't feel like eating , and swallowing is challenging . xxmaj even after eating little meals , i","xxmaj i 've had low temps and blood in my urine for the past xxunk days . xxmaj my pee smells terrible , and my head hurts so much . xxmaj urinary urges xxunk xxunk on xxunk , and i almost never have any control over when they do . xxbos i do n't feel like eating , and swallowing is challenging . xxmaj even after eating little meals , i frequently"
1,"feel worn out and my senses of taste and smell have xxunk . xxmaj sometimes i have palpitations or a xxunk xxunk . xxbos xxmaj i 've had nausea , vomiting , and exhaustion . xxmaj additionally , xxmaj i 've lost weight and have a temperature . xxmaj my urine is black and my skin has turned yellow . xxmaj i 've also been having stomach pain . xxbos i xxunk","worn out and my senses of taste and smell have xxunk . xxmaj sometimes i have palpitations or a xxunk xxunk . xxbos xxmaj i 've had nausea , vomiting , and exhaustion . xxmaj additionally , xxmaj i 've lost weight and have a temperature . xxmaj my urine is black and my skin has turned yellow . xxmaj i 've also been having stomach pain . xxbos i xxunk lose"
2,". xxmaj the rash is spreading to different parts of my body . xxbos xxmaj i 've been itchy and throwing up . xxmaj in addition , i have lost weight and feel really exhausted . xxmaj my skin has become yellow and i have a severe temperature . i have abdominal ache and black urine . xxbos i often feel the want to urinate and experience these intense desires . i","xxmaj the rash is spreading to different parts of my body . xxbos xxmaj i 've been itchy and throwing up . xxmaj in addition , i have lost weight and feel really exhausted . xxmaj my skin has become yellow and i have a severe temperature . i have abdominal ache and black urine . xxbos i often feel the want to urinate and experience these intense desires . i often"
3,"xxbos i always have foul breath and a sour taste in my mouth , and occasionally , the tingling in my throat makes it difficult to swallow meals . xxbos i regularly experience these intense urges and the want to urinate . i frequently feel xxunk and lost . xxmaj i 've also significantly lost my vision . xxbos xxmaj i 've been losing weight and experiencing severe itching , nausea ,","i always have foul breath and a sour taste in my mouth , and occasionally , the tingling in my throat makes it difficult to swallow meals . xxbos i regularly experience these intense urges and the want to urinate . i frequently feel xxunk and lost . xxmaj i 've also significantly lost my vision . xxbos xxmaj i 've been losing weight and experiencing severe itching , nausea , and"
4,"vomiting have also xxunk . xxmaj i 'm quite worried about my health . xxbos i get frequent urges to urinate at night with little output , and a lot of pain during urination . xxmaj the urine is xxunk and bloody and xxunk foul smelling , and i get nauseous xxbos xxmaj while taking a walk , i suddenly started experiencing headache , chest pain , and dizziness after feeling fine","have also xxunk . xxmaj i 'm quite worried about my health . xxbos i get frequent urges to urinate at night with little output , and a lot of pain during urination . xxmaj the urine is xxunk and bloody and xxunk foul smelling , and i get nauseous xxbos xxmaj while taking a walk , i suddenly started experiencing headache , chest pain , and dizziness after feeling fine all"


From the above, we notice that the texts were processed and split into tokens. It adds some special tokens like xxbos to indicate the beginning of a text and xxmaj to indicate the next word was capitalised.

We then define a fastai [learner](https://docs.fast.ai/learner.html#learner), which is a fastai class that we can use to handle the training loop. It bundles the essential components needed for training together such as the data, model, the dataloaders, loss functions

We use the AWD LSTM architecture. We are also going to use accuracy and perplexity (the Exponential of the loss) as our metrics for this example. Furthermore, we also set a weight decay (wd) of 0.1 and apply mixed precision (.to_fp16()) to the learner, which speeds up training on GPU'S with tensor cores.


In [10]:
learn = language_model_learner(dls_lm, AWD_LSTM, metrics=[accuracy, Perplexity()], path=path, wd=0.1).to_fp16()

  wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)


#### Phased Finetuning

A pre-trained model is one that has already been trained on a large dataset and has learnt general patterns and features in a dataset, which can then be used to fine-tune to a specific task. 

By default, the body of the model is frozen, meaning we won’t be updating the parameters of the body during training. For this case, only the head (first few layers) of the model will train.

In [11]:
#| error: false
learn.fit_one_cycle(1, 1e-2)

  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,4.300655,3.545176,0.348958,34.645775,00:02


As shown below, we can use the *learn.save* to save the state of our model to a file in learn.path/models/ named “filename.pth”. You can use learn.load('filename') to load the content of this file.

In [12]:
#| code-fold: show

# Create a directory to save the model
os.makedirs('/kaggle/working/models', exist_ok=True)

# Set the model directory for the learner
learn.model_dir = '/kaggle/working/models'

# Now save the model
learn.save('1epoch')

Path('/kaggle/working/models/1epoch.pth')

In [13]:
#| error: false
learn = learn.load('1epoch')

  state = torch.load(file, map_location=device)


After training the head of the model, we unfreeze the rest of the body and finetune it alongside the head, except for our final layer, which converts activations into probabilities of picking each token in our vocabulary.

In [14]:
#| error: false
learn.unfreeze()
learn.fit_one_cycle(5, 1e-3)

  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,3.569104,2.964077,0.396701,19.376814,00:02
1,3.248957,2.605404,0.430556,13.536695,00:02
2,3.015349,2.399263,0.468388,11.015057,00:02
3,2.842906,2.315446,0.480758,10.129438,00:02
4,2.722763,2.301781,0.483652,9.991961,00:02


The model not including the final layers is called an encoder. We use fastai's *save_encoder* to save it as shown below.

In [15]:
#| code-fold: true
#| output: false
#| code-summary: "Save the model"
# Now save the model
learn.save_encoder('finetuned')

Now, that our model has been trained to guess or generate the next word in a sentence, we can use it to create or generate new user inputs that start with the below user input text.

In [16]:
#| output: false
#| error: false
TEXT = "I have running nose, stomach and joint pains"
N_WORDS = 40
N_SENTENCES = 2
preds = [learn.predict(TEXT, N_WORDS, temperature=0.75) 
         for _ in range(N_SENTENCES)]

  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


In [17]:
print("\n".join(preds))

i have running nose , stomach and joint pains . i have a high fever and have a yellow fever . My fever is high , and my skin has also been really itchy . i have been feeling really weak , and my neck hurts .
i have running nose , stomach and joint pains . My knees and neck have been inflamed and my arms have been coughing up . My muscles have also been really irritated . I feel having a lot of trouble getting tired and tired . i


## Training a text classifier

We now gather and pass in data to train our text classifier.

In [18]:
#symptom_df = pd.read_csv(path_lm/'symptom_synth.csv',index_col=0)
#sd_df = pd.read_csv(path_lm/'symptom_disease_no_id_col.csv')
sd_df.head()

Unnamed: 0,label,text
0,Psoriasis,"I have been experiencing a skin rash on my arms, legs, and torso for the past few weeks. It is red, itchy, and covered in dry, scaly patches."
1,Psoriasis,"My skin has been peeling, especially on my knees, elbows, and scalp. This peeling is often accompanied by a burning or stinging sensation."
2,Psoriasis,"I have been experiencing joint pain in my fingers, wrists, and knees. The pain is often achy and throbbing, and it gets worse when I move my joints."
3,Psoriasis,"There is a silver like dusting on my skin, especially on my lower back and scalp. This dusting is made up of small scales that flake off easily when I scratch them."
4,Psoriasis,"My nails have small dents or pits in them, and they often feel inflammatory and tender to the touch. Even there are minor rashes on my arms."


In [19]:
# Check for NaN values in the label column
print(sd_df['label'].isna().sum())

# If there are NaNs, you can drop those rows
#df = df.dropna(subset=['label'])

0


In [20]:
#| output: false
#| error: false
#dls_clas = TextDataLoaders.from_df(sd_df, path=path,valid='test', text_vocab=dls_lm.vocab)
dls_clas = TextDataLoaders.from_df(sd_df, path=path,valid='test',text_col='text',label_col='label', text_vocab=dls_lm.vocab)

Passing in *text_vocab=dls_lm.vocab* passes in our previously defined vocabulary to our classifier. 

> To quote the fastai documentation, we have to use the exact same vocabulary as when we were fine-tuning our language model, or the weights learned won’t make any sense.

When you train a language model, it learns to associate specific patterns of numbers (weights) with specific tokens (words or subwords) in your vocabulary. 

Each token is assigned a unique index in the vocabulary, and the model's internal representations (the weights in the embedding layers and beyond) are organised according to these indices.

Think of it like a dictionary where each word has a specific page number. The model learns that information about "good" is on page 382, information about "movie" is on page 1593, and so on. These "page numbers" (indices) must remain consistent for the weights to make sense.

If you were to use a different vocabulary when creating your classifier:
.The token "good" might now be on page 746 instead of 382
.The weights the model learned during language model training were specifically tied to the old index (382)

Now when the classifier sees "good" and looks up page 746, it finds weights that were meant for some completely different word

>This mismatch would render the carefully fine-tuned language model weights essentially random from the perspective of the classifier.

In [21]:
#| error: false
learn = text_classifier_learner(dls_clas, AWD_LSTM, drop_mult=0.5, metrics=accuracy)

  wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)


We then define our text classifier as shown above. Before training it, we load in the previous encoder.

In [22]:
from pathlib import Path
learn.path = Path('/kaggle/working')

In [23]:
#| error: false
learn = learn.load_encoder('finetuned')

  wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)


#### Discriminative Learning Rates & Gradual Unfreezing

**Discriminative learning** rates means using different learning rates for different layers of the model. 

For example, earlier layers (closer to the input) might get smaller learning rates, while the later layers (closer to the output) get larger learning rates.

**Gradual unfreezing** is a technique where layers of the model are unfrozen (made trainable) incrementally during fine-tuning. 
Instead of unfreezing all layers at once, you start by unfreezing only the topmost layers (closest to the output) and train them first.

Unlike computer vision applications where we unfreeze the model at once, gradual unfreezing has been shown to improve performance for NLP models.




In [24]:
len(dls_lm.vocab)

944

In [25]:
#| error: false
learn.fit_one_cycle(1, 2e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.281317,2.35146,0.545833,00:01


In [26]:
#| error: false
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-2/(2.6**4),1e-2))

epoch,train_loss,valid_loss,accuracy,time
0,1.45203,1.559455,0.729167,00:01


In [27]:
learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-3/(2.6**4),1e-3))

epoch,train_loss,valid_loss,accuracy,time
0,1.064698,1.102136,0.775,00:01
1,0.977717,0.814356,0.820833,00:01
2,0.875823,0.701162,0.841667,00:01
3,0.808089,0.656792,0.841667,00:01
4,0.745942,0.653968,0.8375,00:01


In [28]:
def get_top_3_predictions(text, learn):
    # Get prediction and probabilities
    _, _, probs = learn.predict(text)
    
    # Get the disease labels vocabulary (second list in vocab)
    disease_vocab = learn.dls.vocab[1]  # Access the disease labels
    
    # Get number of classes
    n_classes = len(disease_vocab)
    
    # Get indices of top 3 (or fewer) probabilities
    n_preds = min(3, n_classes)
    top_k_indices = probs.argsort(descending=True)[:n_preds]
    
    # Get the actual labels and their probabilities
    predictions = []
    for idx in top_k_indices:
        label = disease_vocab[int(idx)]
        probability = float(probs[idx])
        predictions.append((label, probability))
    
    return predictions

# Function to format and display the predictions nicely
def display_predictions(predictions):
    for i, (disease, prob) in enumerate(predictions, 1):
        print(f"{i}. {disease}: {prob:.3f}")

In [29]:
test_text = "I've been experiencing a severe headache for the last few days.It's worse in the mornings and associated with nausea and vomiting. I feel a bit lightheaded, and my vision is blurry at times."
predictions = get_top_3_predictions(test_text, learn)
print("\nTop 3 Predictions:")
display_predictions(predictions)


Top 3 Predictions:
1. Malaria: 0.170
2. Typhoid: 0.124
3. Hypertension: 0.101


In [30]:
test_text = "I am having a running stomach, fever, general body weakness and have been getting bitten by mosquitoes often. This has been happening for about 2 days"
predictions = get_top_3_predictions(test_text, learn)
print("\nTop 3 Predictions:")
display_predictions(predictions)


Top 3 Predictions:
1. Typhoid: 0.597
2. Hypertension: 0.103
3. Bronchial Asthma: 0.077


In [31]:
#| code-fold: true
#| code-summary: "Click to see full code in one cell"
#| error: falsepath = Path('/kaggle/input/symptoms-disease-no-id')
path = Path('/kaggle/input/symptoms-disease-no-id')
#symptom_df = pd.read_csv(path_lm/'symptom_synth.csv',index_col=0)
symptom_df = pd.read_csv(path/'symptom_no_id.csv')
sd_df = pd.read_csv(path/'symptom_disease_no_id_col.csv')
#dls_lm = TextDataLoaders.from_df(symptom_df, path=path, is_lm=True, valid_pct=0.2)
dls_lm = TextDataLoaders.from_df(symptom_df, path=path, is_lm=True,text_col='text', valid_pct=0.2)
learn = language_model_learner(dls_lm, AWD_LSTM, metrics=[accuracy, Perplexity()], path=path, wd=0.1).to_fp16()
learn.fit_one_cycle(1, 1e-2)

# Create a directory to save the model
os.makedirs('/kaggle/working/models', exist_ok=True)
# Set the model directory for the learner
learn.model_dir = '/kaggle/working/models'
# Now save the model
learn.save('1epoch')
learn = learn.load('1epoch')
learn.unfreeze()
learn.fit_one_cycle(5, 1e-3)
learn.save_encoder('finetuned')
dls_clas = TextDataLoaders.from_df(sd_df, path=path,valid='test',text_col='text',label_col='label', text_vocab=dls_lm.vocab)
learn = text_classifier_learner(dls_clas, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
from pathlib import Path
learn.path = Path('/kaggle/working')
learn = learn.load_encoder('finetuned')
learn.fit_one_cycle(1, 2e-2)
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-2/(2.6**4),1e-2))
learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-3/(2.6**4),1e-3))

  wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,4.284082,3.573452,0.341725,35.639408,00:01


  state = torch.load(file, map_location=device)
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()
  self.autocast,self.learn.scaler,self.scales = autocast(dtype=dtype),GradScaler(**self.kwargs),L()


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,3.637806,2.997802,0.366392,20.041433,00:02
1,3.303594,2.636628,0.410229,13.966033,00:02
2,3.065337,2.459942,0.449002,11.704134,00:02
3,2.883457,2.38313,0.467954,10.838774,00:02
4,2.774912,2.363149,0.471065,10.624359,00:02


  wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)
  wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)


epoch,train_loss,valid_loss,accuracy,time
0,2.236272,2.426904,0.4125,00:01


epoch,train_loss,valid_loss,accuracy,time
0,1.454038,1.574972,0.758333,00:01


epoch,train_loss,valid_loss,accuracy,time
0,1.057274,1.063131,0.833333,00:01
1,0.950661,0.735227,0.875,00:01
2,0.853053,0.617303,0.879167,00:01
3,0.781742,0.587054,0.879167,00:01
4,0.730298,0.579409,0.870833,00:01


In [32]:
def get_top_3_predictions(texts, learn):
    """
    Get top 3 predictions for a single text or list of texts
    
    Args:
        texts: Either a single string or a list of strings
        learn: A trained fastai learner for text classification
        
    Returns:
        For a single text: List of (label, probability) tuples
        For multiple texts: List of lists of (label, probability) tuples
    """
    # Handle both single text and list of texts
    is_single = isinstance(texts, str)
    if is_single:
        texts = [texts]
    
    disease_vocab = learn.dls.vocab[1]
    n_classes = len(disease_vocab)
    
    # Try to use DataLoader for batch prediction if model supports it
    try:
        # This is more efficient but might not work with all models
        preds = learn.get_preds(dl=learn.dls.test_dl(texts))
        probs_list = preds[0]  # Tensor of shape [batch_size, n_classes]
        
        all_predictions = []
        for probs in probs_list:
            n_preds = min(3, n_classes)
            top_k_indices = probs.argsort(descending=True)[:n_preds]
            
            predictions = []
            for idx in top_k_indices:
                label = disease_vocab[int(idx)]
                probability = float(probs[idx])
                predictions.append((label, probability))
            
            all_predictions.append(predictions)
    
    except Exception:
        # Fall back to individual prediction if batch method fails
        all_predictions = []
        for text in texts:
            _, _, probs = learn.predict(text)
            
            n_preds = min(3, n_classes)
            top_k_indices = probs.argsort(descending=True)[:n_preds]
            
            predictions = []
            for idx in top_k_indices:
                label = disease_vocab[int(idx)]
                probability = float(probs[idx])
                predictions.append((label, probability))
            
            all_predictions.append(predictions)
    
    return all_predictions[0] if is_single else all_predictions


def display_predictions(predictions, texts=None):
    """
    Display formatted predictions
    
    Args:
        predictions: Either a list of (label, prob) tuples or a list of such lists
        texts: Optional list of input texts to display with predictions
    """
    # If predictions is a list of (label, prob) tuples (single text case)
    if isinstance(predictions[0], tuple):
        for i, (disease, prob) in enumerate(predictions, 1):
            print(f"{i}. {disease}: {prob:.3f}")
    # If predictions is a list of lists (batch case)
    else:
        for i, preds in enumerate(predictions):
            if texts:
                print(f"\nText: {texts[i][:50]}...")
            else:
                print(f"\nSample #{i+1}:")
            for j, (disease, prob) in enumerate(preds, 1):
                print(f"  {j}. {disease}: {prob:.3f}")


In [33]:
# Assuming 'learn' is your trained FastAI model

# Example 1: Single input
single_text = "Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue."
single_result = get_top_3_predictions(single_text, learn)

print("SINGLE TEXT PREDICTION:")
print(f"Input: {single_text}")
print("Top 3 predictions:")
display_predictions(single_result)


# Example 2: Batch input (small batch)
batch_texts = [
    "Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue.",
    "7-year-old with red, itchy rash on face and arms, started 2 days after camping trip.",
    "Adult male with sudden onset of severe headache, described as 'worst headache of my life'.",
    "Patient reports joint pain in fingers and wrists, worse in the morning, accompanied by stiffness."
]
batch_results = get_top_3_predictions(batch_texts, learn)

print("\nBATCH PREDICTION EXAMPLE:")
display_predictions(batch_results, batch_texts)


# Example 3: Processing a medium-sized dataset
medium_dataset = [
    f"Patient {i}: Symptoms include {symptom}" for i, symptom in enumerate([
        "fever and sore throat",
        "chest pain radiating to left arm",
        "swollen lymph nodes and night sweats",
        "difficulty breathing and wheezing",
        "abdominal pain and vomiting",
        "frequent urination and excessive thirst",
        "joint pain and morning stiffness",
        "persistent headache and blurred vision",
        "unexplained weight loss and fatigue",
        "skin rash and itching"
    ] * 3)  # Repeat symptoms to create 30 samples
]

print("\nPROCESSING MEDIUM DATASET:")
medium_results = get_top_3_predictions(medium_dataset, learn)
# Display first 3 results only for brevity
print("First 3 results from medium dataset:")
display_predictions(medium_results[:3], medium_dataset[:3])


# Example 4: Working with DataFrame data
# This example demonstrates how you might use the function with pandas DataFrame
import pandas as pd

# Create a sample DataFrame
df = pd.DataFrame({
    'patient_id': range(1001, 1006),
    'age': [45, 12, 67, 32, 54],
    'gender': ['M', 'F', 'M', 'F', 'M'],
    'symptoms': [
        "Persistent dry cough and fever for 3 days",
        "Skin rash with small fluid-filled blisters, mild fever",
        "Shortness of breath, chest tightness, wheezing when exercising",
        "Severe migraine, sensitivity to light, nausea",
        "Pain and swelling in the right knee, difficulty walking"
    ]
})

print("\nPROCESSING DATAFRAME:")
print("Sample DataFrame:")
print(df[['patient_id', 'symptoms']].head())

# Process the symptoms column
df_results = get_top_3_predictions(df['symptoms'].tolist(), learn)

# Add predictions back to the DataFrame
df['top_prediction'] = [pred[0][0] for pred in df_results]  # First prediction label
df['confidence'] = [pred[0][1] for pred in df_results]      # First prediction probability

print("\nDataFrame with predictions:")
print(df[['patient_id', 'symptoms', 'top_prediction', 'confidence']])

SINGLE TEXT PREDICTION:
Input: Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue.
Top 3 predictions:
1. Migraine: 0.155
2. diabetes: 0.118
3. Dengue: 0.108



BATCH PREDICTION EXAMPLE:

Text: Patient presents with persistent cough, fever of 1...
  1. Migraine: 0.155
  2. diabetes: 0.118
  3. Dengue: 0.108

Text: 7-year-old with red, itchy rash on face and arms, ...
  1. Impetigo: 0.597
  2. Psoriasis: 0.218
  3. Fungal infection: 0.050

Text: Adult male with sudden onset of severe headache, d...
  1. Dengue: 0.338
  2. Psoriasis: 0.103
  3. Migraine: 0.089

Text: Patient reports joint pain in fingers and wrists, ...
  1. Psoriasis: 0.732
  2. Dengue: 0.184
  3. Fungal infection: 0.020

PROCESSING MEDIUM DATASET:


First 3 results from medium dataset:

Text: Patient 0: Symptoms include fever and sore throat...
  1. Jaundice: 0.233
  2. gastroesophageal reflux disease: 0.180
  3. allergy: 0.102

Text: Patient 1: Symptoms include chest pain radiating t...
  1. Jaundice: 0.348
  2. gastroesophageal reflux disease: 0.211
  3. peptic ulcer disease: 0.073

Text: Patient 2: Symptoms include swollen lymph nodes an...
  1. Impetigo: 0.360
  2. allergy: 0.156
  3. Pneumonia: 0.100

PROCESSING DATAFRAME:
Sample DataFrame:
   patient_id                                                        symptoms
0        1001                       Persistent dry cough and fever for 3 days
1        1002          Skin rash with small fluid-filled blisters, mild fever
2        1003  Shortness of breath, chest tightness, wheezing when exercising
3        1004                   Severe migraine, sensitivity to light, nausea
4        1005         Pain and swelling in the right knee, difficulty walking



DataFrame with predictions:
   patient_id                                                        symptoms  \
0        1001                       Persistent dry cough and fever for 3 days   
1        1002          Skin rash with small fluid-filled blisters, mild fever   
2        1003  Shortness of breath, chest tightness, wheezing when exercising   
3        1004                   Severe migraine, sensitivity to light, nausea   
4        1005         Pain and swelling in the right knee, difficulty walking   

  top_prediction  confidence  
0       Jaundice    0.187939  
1       Jaundice    0.247912  
2        allergy    0.447899  
3       diabetes    0.310563  
4      Psoriasis    0.346180  


In [34]:
# Now you can run these examples with both functions defined

# Example 1: Single input
single_text = "Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue."
single_result = get_top_3_predictions(single_text, learn)

print("SINGLE TEXT PREDICTION:")
print(f"Input: {single_text}")
print("Top 3 predictions:")
display_predictions(single_result)


# Example 2: Batch input (small batch)
batch_texts = [
    "Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue.",
    "7-year-old with red, itchy rash on face and arms, started 2 days after camping trip.",
    "Adult male with sudden onset of severe headache, described as 'worst headache of my life'.",
    "Patient reports joint pain in fingers and wrists, worse in the morning, accompanied by stiffness."
]
batch_results = get_top_3_predictions(batch_texts, learn)

print("\nBATCH PREDICTION EXAMPLE:")
display_predictions(batch_results, batch_texts)

SINGLE TEXT PREDICTION:
Input: Patient presents with persistent cough, fever of 101°F for 5 days, and fatigue.
Top 3 predictions:
1. Migraine: 0.155
2. diabetes: 0.118
3. Dengue: 0.108



BATCH PREDICTION EXAMPLE:

Text: Patient presents with persistent cough, fever of 1...
  1. Migraine: 0.155
  2. diabetes: 0.118
  3. Dengue: 0.108

Text: 7-year-old with red, itchy rash on face and arms, ...
  1. Impetigo: 0.597
  2. Psoriasis: 0.218
  3. Fungal infection: 0.050

Text: Adult male with sudden onset of severe headache, d...
  1. Dengue: 0.338
  2. Psoriasis: 0.103
  3. Migraine: 0.089

Text: Patient reports joint pain in fingers and wrists, ...
  1. Psoriasis: 0.732
  2. Dengue: 0.184
  3. Fungal infection: 0.020


# References

[Fastai Documentation - Text Transfer Learning](https://docs.fast.ai/tutorial.text.html#the-ulmfit-approach)

The dataset for this competition was gotten from [here](https://www.kaggle.com/datasets/niyarrbarman/symptom2disease)

## Next Steps

Using clinical guidelines as a medical corpus source.

Implementing a newer architecture, e.g., replacing AWD_LSTM with transformers.

Try out a RAG implementation 

Finetune our own medical model

Adding reasoning
