# Preamble

## Drive integration

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## GPU

In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Free GPU memory

In [None]:
import gc
def free_gpu_memory():
  gc.collect()
  torch.cuda.empty_cache()

## Imports

In [None]:
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

# Classifier Main

In [None]:
q1 = pd.read_csv('/content/drive/MyDrive/data/csv/queries.csv')

In [None]:
q1.head(5)

In [None]:
label_mapping = {'quantitative analysis': 0, 'general information': 1, 'miscellaneous':2}
q1['label'] = q1['label'].map(label_mapping)

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=256):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# Load the pre-trained BioBERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('dmis-lab/biobert-v1.1')
model = BertForSequenceClassification.from_pretrained('dmis-lab/biobert-v1.1', num_labels=3)
model.to(device)

In [None]:
# Prepare data
texts = q1['text'].tolist()
labels = q1['label'].tolist()
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels,
    test_size = 0.1, stratify = labels
)

In [None]:
train_dataset = TextDataset(train_texts, train_labels, tokenizer)
val_dataset = TextDataset(val_texts, val_labels, tokenizer)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = 10,
    per_device_train_batch_size = 32,
    per_device_eval_batch_size = 32,
    warmup_steps = 0,
    learning_rate = 2e-5,
    weight_decay = 0.01,
    adam_epsilon = 1e-8,
    logging_dir = './logs',
    logging_steps = 10,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    greater_is_better=False,
    report_to = "none"  # Disable logging to external services (e.g. WandB)
)

In [None]:
# Define Trainer
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    tokenizer = tokenizer
)

In [None]:
torch.cuda.empty_cache()

In [None]:
# Train the model
trainer.train()

In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

In [None]:
model.to(device)

In [None]:
model.eval()
print("Model loaded successfully.")

In [None]:
# Inference function
def classify(model, tokenizer, text, label_dict, max_len = 256):
    encoding = tokenizer(
        text, return_tensors = 'pt',
        max_length = max_len,
        truncation = True, padding = True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask = attention_mask)
    prediction_idx = torch.argmax(outputs.logits, dim = 1).item()

    return label_dict[prediction_idx]

In [None]:
# Inference
label_dict = {0: 'quantitative analysis', 1: 'general information', 2: 'miscellaneous'}

In [None]:
new_texts = [
    "How much is 5 multiplied by 10?",
    "What is the capital of France?",
    "What is the square root of 16?",
    "How many planets are there in the solar system?",
    "What is my risk for cardiovascular disease if my blood pressure goes up to 180?",
    "What causes the buildup of plaque in the arteries?",
    "What are the main causes of atherosclerosis?",
    "Is there a correlation between developing diabetes and the risk of cardiovascular disease?",
    "Will developing diabetes affect my risk of developing cardiovascular disease?",
    "Can I get tickets to the 9:00 showing of Cats?",
    "What will happen to my risk of cardiovascular disease if my blood pressure increases by 50%?"
]

In [None]:
for text in new_texts:
    print(f"{text} --> {classify(model, tokenizer, text, label_dict)}")

## Save model

In [None]:
model.save_pretrained('/content/drive/MyDrive/classifiers/v1')

In [None]:
classifier_v1 = BertForSequenceClassification.from_pretrained('/content/drive/MyDrive/classifiers/v1')
classifier_v1.to(device)
classifier_v1.eval()
print("Classifier loaded successfully.")