In [65]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
import numpy as np
import time

#use GPU on mac if enabled
if torch.backends.mps.is_available():
    print(f"GPU is available.")
else:
    print("No GPU available. Training will run on CPU.")

GPU is available.


In [66]:
labels = [
    "drug reaction",
    "allergy",
    "chicken pox",
    "diabetes",
    "psoriasis",
    "hypertension",
    "cervical spondylosis",
    "bronchial asthma",
    "varicose veins",
    "malaria",
    "dengue",
    "arthritis",
    "impetigo",
    "fungal infection",
    "common cold",
    "gastroesophageal reflux disease",
    "urinary tract infection",
    "typhoid",
    "pneumonia",
    "peptic ulcer disease",
    "jaundice",
    "migraine"
]

In [67]:
# Custom torch dataset
class CustomDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(labels.index(label), dtype=torch.long)
        }

In [68]:
def freeze_bert_layers(model):
    """
    Freezes all bert layers apart from the output layer
    """
    
    for param in model.bert.embeddings.parameters():
        param.requires_grad = False

    for layer in model.bert.encoder.layer[:-1]:
        for param in layer.parameters():
            param.requires_grad = False

    #show information about trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

In [76]:

def train_model(model, train_loader, val_loader, device, num_epochs=3):
    #initialise optimiser only with parameters that require gradients
    optimizer = AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=1e-5
    )

    total_steps = len(train_loader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        total_train_loss = 0
        right_predictions = 0
        for batch in train_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
        
            model.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            #calculate accuracy
            preds = torch.argmax(outputs.logits, dim=1)
            right_predictions += torch.sum(preds == labels).item()

            loss = outputs.loss
            total_train_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        # validation
        model.eval()
        total_val_loss = 0
        predictions = []
        true_labels = []

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                loss = outputs.loss
                total_val_loss += loss.item()

                preds = torch.argmax(outputs.logits, dim=1)
                predictions.extend(preds.cpu().numpy())
                true_labels.extend(labels.cpu().numpy())

        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        accuracy = np.mean(np.array(predictions) == np.array(true_labels))
        end_time = time.time()

        print(f'Epoch {epoch + 1}:')
        print(f'Average training loss: {avg_train_loss:.4f}')
        print(f'Average validation loss: {avg_val_loss:.4f}')
        print(f'Right predictions: {right_predictions} out of {len(train_loader) * 32}')
        print(f'Validation Accuracy: {accuracy:.4f}')
        print(f'Time taken for epoch: {end_time - start_time:.2f} seconds')
        print('-' * 60)

In [50]:
import tempfile, onnx

def convert_pytorch_to_onnx_with_tokenizer(model, tokenizer, max_length=128, onnx_file_path=None):
    """
    Converts a PyTorch model to ONNX format, using tokenizer output as input.

    Args:
    model (torch.nn.Module): The PyTorch model to be converted.
    tokenizer: The tokenizer used to preprocess the input.
    onnx_file_path (str): The file path where the ONNX model will be saved.
    max_length (int): Maximum sequence length for the tokenizer.

    Returns:
    None
    """
    model.eval()

    # Prepare dummy input using the tokenizer
    dummy_input = "This is a sample input text for ONNX conversion."
    inputs = tokenizer(
        dummy_input,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )

    # # Get the input names
    input_names = list(inputs.keys())
    input_names = ["input_ids", "attention_mask"]
    print(f"Input names: {input_names}")

    # # Create dummy inputs for ONNX export
    # dummy_inputs = tuple(encoded_input[name] for name in input_names)
    if onnx_file_path is None:
      onnx_file_path = tempfile.mktemp(suffix=".onnx")
    dynamic_axes = {name: {0: "batch_size"} for name in input_names}
    dynamic_axes.update({f"logits": {0: "batch_size"}})
    print(f"dynamic_axes: {dynamic_axes}")
    # Export the model
    torch.onnx.export(
        model,  # model being run
        tuple(inputs[k] for k in input_names),  # model inputs
        onnx_file_path,  # where to save the model
        export_params=True,  # store the trained parameter weights inside the model file
        opset_version=20,  # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names=input_names,  # the model's input names
        output_names=["logits"],  # the model's output names
        dynamic_axes=dynamic_axes,
    )  # variable length axes

    print(f"Model exported to {onnx_file_path}")

    # Verify the exported model
    onnx_model = onnx.load(onnx_file_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model is valid.")
    return onnx_file_path, input_names


In [41]:
# load tiny bert model and tokenizer
model_name = "prajjwal1/bert-tiny"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(labels)
)
freeze_bert_layers(model)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total parameters: 4,388,758
Trainable parameters: 217,622
Percentage of trainable parameters: 4.96%


In [77]:
from datasets import load_dataset


def main(model, tokenizer):
    # load symptom to diagnosis dataset
    dataset = load_dataset("gretelai/symptom_to_diagnosis")

    # Prepare train and validation datasets
    train_texts = dataset['train']['input_text']
    train_labels = dataset['train']['output_text']
    val_texts = dataset['test']['input_text']
    val_labels = dataset['test']['output_text']

    # Create datasets
    train_dataset = CustomDataset(train_texts, train_labels, tokenizer)
    val_dataset = CustomDataset(val_texts, val_labels, tokenizer)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=0
    )
    val_loader = DataLoader(
        val_dataset, batch_size=32, num_workers=0
    )

    # Set device and move data to GPU if available
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)

    # Train the model
    train_model(model, train_loader, val_loader, device, num_epochs=50)

    # Make model tensors contiguous and move to CPU before saving
    model = model.cpu()

    # Save the fine-tuned model as an ONNX file
    onnx_file_path, input_names = convert_pytorch_to_onnx_with_tokenizer(
        model, tokenizer, max_length=128, onnx_file_path="./saved-models/diagnosis-classifier.onnx"
    )
    print(f"ONNX file path: {onnx_file_path}")
    print(f"Input names: {input_names}")

    # Test the model on a few examples
    model.eval()
    test_texts = [
        "My head hurts and I have a high temperature.",
        "I have a red rash and shortness of breath"
    ]

    with torch.no_grad():
        inputs = tokenizer(
            test_texts,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        ).to(device)
        model.to(device)
        outputs = model(**inputs)
        predictions = torch.nn.functional.softmax(outputs.logits, dim=1)

        for text, pred in zip(test_texts, predictions):
            print(pred.shape)
            print(f"Text: {text}")
            for i in range(0, len(pred)):
                print(f'{labels[i]}: {pred[i]}')
            print("======")

main(model, tokenizer)



Epoch 1:
Average training loss: 1.6183
Average validation loss: 1.5438
Right predictions: 592 out of 864
Validation Accuracy: 0.6557
Time taken for epoch: 1.64 seconds
------------------------------------------------------------
Epoch 2:
Average training loss: 1.5952
Average validation loss: 1.5378
Right predictions: 592 out of 864
Validation Accuracy: 0.6651
Time taken for epoch: 1.65 seconds
------------------------------------------------------------
Epoch 3:
Average training loss: 1.6063
Average validation loss: 1.5303
Right predictions: 594 out of 864
Validation Accuracy: 0.6745
Time taken for epoch: 1.68 seconds
------------------------------------------------------------
Epoch 4:
Average training loss: 1.5951
Average validation loss: 1.5249
Right predictions: 600 out of 864
Validation Accuracy: 0.6651
Time taken for epoch: 1.72 seconds
------------------------------------------------------------
Epoch 5:
Average training loss: 1.5971
Average validation loss: 1.5211
Right predict

In [60]:
#upload trained model to the nillion blockchain
import aivm_client as aic

MODEL_NAME = "DIAGNOSIS_CLASSIFIER" 
aic.upload_bert_tiny_model("./saved-models/diagnosis-classifier.onnx", MODEL_NAME) 

In [64]:
#perform secure inference
inp_text = "I have a sore throat and the back of my head hurts"
tokens = aic.tokenize(inp_text,)
encrypted_tokens = aic.BertTinyCryptensor(*tokens)
result = aic.get_prediction(encrypted_tokens, MODEL_NAME)
probs = torch.nn.functional.softmax(result[0])
probs

  probs = torch.nn.functional.softmax(result[0])


tensor([0.0389, 0.0536, 0.0508, 0.0425, 0.0582, 0.0366, 0.0443, 0.0407, 0.0429,
        0.0461, 0.0432, 0.0422, 0.0450, 0.0468, 0.0459, 0.0599, 0.0433, 0.0451,
        0.0467, 0.0389, 0.0497, 0.0387])