# Fine tuning pretrained BERT models to predict reasons for Cannabis Use in Lupus Patients EHR

based on Nathan Le's code, extending to different fine-tuning strategies

*yiyu wang 2025/02*

In [None]:
import transformers
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
import torch

import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from tqdm import tqdm

%matplotlib inline
%config InlineBackend.figure_format='retina'

sns.set(style='whitegrid', palette='muted', font_scale=1.2)

HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]

sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))

rcParams['figure.figsize'] = 12, 8

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device




In [None]:
path = '/Users/yiyuwang/Projects/CannabisUse/'  

df_train = pd.DataFrame(columns=['sentiment', 'Snippets'])

# Add new data (first batch)
df_new = pd.read_csv(path + 'use_case_active_learning_1.csv')

df_new['sentiment'] = df_new['Sentiment'].astype(int)
df_new = df_new.rename(columns={'text': 'Snippets'})  

df_train = pd.concat([df_train, df_new[['sentiment', 'Snippets']]], ignore_index=True)

# Add new data (second batch)

df_new = pd.read_csv(path + 'use_case_active_learning_2.csv')

df_new['sentiment'] = df_new['Sentiment'].astype(int)
df_new = df_new.rename(columns={'text': 'Snippets'})  

df_train = pd.concat([df_train, df_new[['sentiment', 'Snippets']]], ignore_index=True)

# Add new data (third batch)

df_new = pd.read_csv(path + 'use_case_active_learning_3.csv')

df_new['sentiment'] = df_new['Sentiment'].astype(int)
df_new = df_new.rename(columns={'text': 'Snippets'})  




df_train = pd.concat([df_train, df_new[['sentiment', 'Snippets']]], ignore_index=True)

df_train, df_test = train_test_split(df_train, test_size=0.2, random_state=50)

# convert sentiment from 1- 7 to 0 - 6
df_train['sentiment'] = df_train['sentiment'] - 1
df_test['sentiment'] = df_test['sentiment'] - 1


best_acc = []

In [None]:
df_train.head()

In [None]:
df_test.head()


In [None]:
df_train.info(),df_test.info()


In [None]:
# distribution of the seven class in sentiment
class_label_names = ['pain', 'nausea', 'sleep', 'anxiety/stress/relexation', 'unknown', 'not use','appetite']
sns.barplot(x=df_train.sentiment.value_counts().index, y=df_train.sentiment.value_counts())
plt.xticks(ticks=range(7), labels=class_label_names, rotation=45)
plt.xlabel('sentiment')

In [None]:
df_train.shape,df_test.shape

In [None]:
from transformers import AutoTokenizer, AutoModel
# PRE_TRAINED_MODEL_NAME="roberta-large-mnli"
PRE_TRAINED_MODEL_NAME="emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [None]:
tokenizer.sep_token, tokenizer.sep_token_id


In [None]:
tokenizer.cls_token, tokenizer.cls_token_id


In [None]:
tokenizer.pad_token, tokenizer.pad_token_id


In [None]:
tokenizer.unk_token, tokenizer.unk_token_id


In [None]:
token_lens=[]
for txt in df_train.Snippets:
  tokens=tokenizer.encode(txt,max_length=512)
  token_lens.append(len(tokens))

In [None]:
sns.displot(token_lens)


In [None]:
class CannabisClassData(Dataset):
  def __init__(self, text, label, tokenizer, max_len):
    self.text=text
    self.label=label
    self.tokenizer=tokenizer
    self.max_len=max_len
    
  def __len__(self):
    return len(self.text)
  
  def __getitem__(self,item):
    text= str(self.text[item])
    label=self.label[item]
    encoding=self.tokenizer.encode_plus(
    text,
    max_length=self.max_len,
    add_special_tokens=True,
    pad_to_max_length=True,
    truncation =True,
    return_attention_mask=True,
    return_token_type_ids=False,
    return_tensors='pt'
    )
    return{
        'text':text,
        'input_ids': encoding['input_ids'].flatten(),
        'attention_mask': encoding['attention_mask'].flatten(),
        'label':torch.tensor(label,dtype=torch.long)
    }

class MLMDataset(Dataset):
    def __init__(self, texts, label, tokenizer, max_len):
        self.label=label
        self.encodings = tokenizer(texts, truncation=True, padding='max_length', max_length=max_len)

    def __getitem__(self, idx):
        label=self.label[idx]
        encoding = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return{
        'input_ids': encoding['input_ids'].flatten(),
        'attention_mask': encoding['attention_mask'].flatten(),
        'labels':torch.tensor(label,dtype=torch.long)
    }

    def __len__(self):
        return len(self.encodings.input_ids)
    


In [None]:
def create_data_loader(df, tokenizer, max_len, batch_size, collate_fn=None):
  ds = CannabisClassData(
    text=df.Snippets.to_numpy(),
    label=df.sentiment.to_numpy(),
    tokenizer=tokenizer,
    max_len=max_len
  )
  return DataLoader(
    ds,
    batch_size=batch_size,
    num_workers=0,
    collate_fn=collate_fn
  )
     

In [None]:
class CannabisClassifier(nn.Module):
  def __init__(self,n_classes):
    super(CannabisClassifier,self).__init__()
    self.bert=AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
    self.drop=nn.Dropout(p=0.3)
    self.out=nn.Linear(self.bert.config.hidden_size,n_classes)
    self.softmax=nn.Softmax(dim=1)
  def forward(self,input_ids,attention_mask):
    _,pooled_output=self.bert(
      input_ids=input_ids,
      attention_mask=attention_mask,
      return_dict=False
    )
    output=self.drop(pooled_output)
    output=self.out(output)
    #return self.softmax(output)
    return nn.LogSoftmax(dim=1)(output)

In [None]:
class CannabisClassifierFrozenBackbone(nn.Module):
    def __init__(self, PRE_TRAINED_MODEL_NAME, n_classes):
        super(CannabisClassifierFrozenBackbone, self).__init__()
        self.bert = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        for param in self.bert.parameters():
            param.requires_grad = False
        hidden_size = self.bert.config.hidden_size
        self.drop=nn.Dropout(p=0.1)
        self.out = nn.Linear(hidden_size, n_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=False
        )
        output=self.drop(pooled_output)
        output=self.out(output)
        #return self.softmax(output)
        return nn.LogSoftmax(dim=1)(output)

In [None]:
def train_epoch(
    model,
    data_loader,
    loss_fn,
    optimizer,
    device,
    scheduler,
    n_examples
):
    model=model.train()

    losses=[]
    correct_predictions=0

    for d in data_loader:
      input_ids=d['input_ids'].to(device)
      attention_mask=d['attention_mask'].to(device)
      label=d['label'].to(device)


      outputs=model(
          input_ids=input_ids,
          attention_mask=attention_mask
      )

      preds=torch.max(outputs,dim=1)
      loss=loss_fn(outputs,label)

      #correct_predictions += torch.sum(torch.eq(preds, label))
      correct_predictions +=torch.sum(torch.eq(torch.argmax(outputs,dim=1).cpu(), label.cpu()))
      losses.append(loss.item())

      loss.backward()
      nn.utils.clip_grad_norm_(model.parameters(),max_norm=1.0)
      optimizer.step()
      scheduler.step()
      optimizer.zero_grad()

    return correct_predictions.double() / n_examples, np.mean(losses)

In [None]:
def eval_model(model, data_loader, loss_fn, device, n_examples):
  model = model.eval()

  losses = []
  correct_predictions = 0

  with torch.no_grad():
    for d in data_loader:
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      label = d["label"].to(device)

      outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
      )
      preds = torch.max(outputs, dim=1)

      loss = loss_fn(outputs, label)

      #correct_predictions += torch.sum(torch.eq(preds, label))
      correct_predictions +=torch.sum(torch.eq(torch.argmax(outputs,dim=1).cpu(), label.cpu()))
      losses.append(loss.item())

  return correct_predictions.double() / n_examples, np.mean(losses)

In [None]:
def get_predictions(model, data_loader):
  model = model.eval()
  
  texts = []
  predictions = []
  prediction_probs = []
  real_values = []

  with torch.no_grad():
    for d in data_loader:

      text = d["text"]
      input_ids = d["input_ids"].to(device)
      attention_mask = d["attention_mask"].to(device)
      label = d["label"].to(device)

      outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask
      )
      _, preds = torch.max(outputs, dim=1)

      probs = F.softmax(outputs, dim=1)

      text.extend(text)
      predictions.extend(preds)
      prediction_probs.extend(probs)
      real_values.extend(label)

  predictions = torch.stack(predictions).cpu()
  prediction_probs = torch.stack(prediction_probs).cpu()
  real_values = torch.stack(real_values).cpu()
  return text, predictions, prediction_probs, real_values

In [None]:
MAX_LEN=128
BATCH_SIZE = 32
EPOCHS=10
     

In [None]:
df_train.shape,  df_test.shape


In [None]:
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

In [None]:
data=next(iter(train_data_loader))
data.keys()

In [None]:
print(data['input_ids'].shape)
print(data['attention_mask'].shape)
print(data['label'].shape)

# 1. Regular Fine Tuning

In [None]:

ClinicalBertmodel = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [None]:
model=CannabisClassifier(len(np.unique(df_train.sentiment)))
model=model.to(device)



In [None]:

input_ids=data['input_ids'].to(device)
attention_mask=data['attention_mask'].to(device)
print(input_ids.shape)
print(attention_mask.shape)

In [None]:
model(input_ids,attention_mask)


In [None]:
optimizer=AdamW(model.parameters(),lr=2e-5,correct_bias=False)
total_steps=len(train_data_loader)*EPOCHS
scheduler=get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [None]:
LOSS_WT=True
if LOSS_WT: 
    weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.8, 1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
else:
    loss_fn=nn.CrossEntropyLoss().to(device)



In [None]:
%%time

history = defaultdict(list)
best_accuracy = 0

for epoch in range(EPOCHS):

  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)

  train_acc, train_loss = train_epoch(
    model,
    train_data_loader,    
    loss_fn, 
    optimizer, 
    device, 
    scheduler, 
    len(df_train)
  )

  print(f'Train loss {train_loss} accuracy {train_acc}')


  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)


In [None]:
test_acc, _ = eval_model(
  model,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)

best_acc.append(test_acc.item())

test_acc.item()

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model,
  test_data_loader
)

In [None]:
df_test['sentiment'].value_counts()


In [None]:
np.unique(y_pred, return_counts=True)

In [None]:
np.unique(y_test, return_counts=True)

In [None]:
class_names=['0', '1','2','3','4','5', '6']
class_label_names = ['pain', 'nausea', 'sleep', 'anxiety/stress/relexation', 'unknown', 'not use', 'appetite']
print(classification_report(y_test, y_pred,target_names=class_label_names))


In [None]:
# save model
torch.save(model.state_dict(), path + 'CannabisClassifier_model-ClinicalBERT_dropout-3_batch-32.pth')

# 2. adapt using MLM

In [None]:
import torch
from torch import nn
from transformers import DataCollatorForLanguageModeling

from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer

In [None]:
PRE_TRAINED_MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
# PRE_TRAINED_MODEL_NAME = 'RoBERTa-base'

In [None]:
def domain_adapt_mlm(model, tokenizer, data_collator, df_train, max_len=MAX_LEN, batch_size=BATCH_SIZE):
    # Fine-tune the base model with MLM
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer=AdamW(model.parameters(),lr=2e-5, correct_bias=False)
    ds =MLMDataset(texts=df_train.Snippets.to_list(),
        label=df_train.sentiment.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len)

    train_data_loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

    for epoch in range(EPOCHS):
        model.train()
        for batch in tqdm(train_data_loader):
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            loss.backward()
            optimizer.step()
    
    return model


In [None]:

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
mlm_model = AutoModelForMaskedLM.from_pretrained(PRE_TRAINED_MODEL_NAME)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

mlm_model = domain_adapt_mlm(mlm_model, tokenizer, data_collator, df_train)


In [None]:

# Create a classification model with MLP
class CannabisUseMLPClassifier(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.roberta = base_model
        hidden_size = base_model.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.GELU(),
            nn.Dropout(0.15),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(pooled_output)

# Initialize the classification model
base_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
base_model_dict = base_model.state_dict()
pretrained_dict = mlm_model.bert.state_dict()
base_model_dict.update({k: v for k, v in pretrained_dict.items() if k in base_model_dict})
base_model.load_state_dict(base_model_dict)
classification_model = CannabisUseMLPClassifier(base_model, num_classes=7)
classification_model.to(device)

optimizer=AdamW(classification_model.parameters(),lr=2e-5, correct_bias=False)

total_steps=len(train_data_loader)*EPOCHS
scheduler=get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)
# Train the classification model


history = defaultdict(list)
best_accuracy = 0

train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)


weight_loss=True
if weight_loss:
    weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.8, 1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
else:
    loss_fn=nn.CrossEntropyLoss().to(device)


for epoch in range(EPOCHS):

  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)
  train_acc, train_loss = train_epoch(
    classification_model,
    train_data_loader,    
    loss_fn, 
    optimizer, 
    device, 
    scheduler, 
    len(df_train)
  )

  print(f'Train loss {train_loss} accuracy {train_acc}')


  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)



In [None]:
pretrained_dict

In [None]:
mlm_model

In [None]:
test_acc, _ = eval_model(
  classification_model,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
best_acc.append(test_acc.item())
test_acc.item()

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  classification_model,
  test_data_loader
)

In [None]:
class_names=['0', '1','2','3','4','5', '6']
class_label_names = ['pain', 'nausea', 'sleep', 'anxiety/stress/relexation', 'unknown', 'not use', 'appetite']
print(classification_report(y_test, y_pred,target_names=class_label_names))

# save model
torch.save(classification_model.state_dict(), path + 'CannabisClassifier_model-ClinicalBERT_MLM_dropout-3_batch-32.pth')

# 3. only fine tuning last linear layer

In [None]:
model3=CannabisClassifierFrozenBackbone(PRE_TRAINED_MODEL_NAME, len(np.unique(df_train.sentiment)))
model3=model3.to(device)

In [None]:
%%time

history = defaultdict(list)
best_accuracy = 0

for epoch in range(EPOCHS):

  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)

  train_acc, train_loss = train_epoch(
    model3,
    train_data_loader,    
    loss_fn, 
    optimizer, 
    device, 
    scheduler, 
    len(df_train)
  )

  print(f'Train loss {train_loss} accuracy {train_acc}')


  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)


In [None]:
test_acc, _ = eval_model(
  model3,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
best_acc.append(test_acc.item())
test_acc.item()

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  model3,
  test_data_loader
)

In [None]:
class_names=['0', '1','2','3','4','5','6']
print(classification_report(y_test, y_pred, target_names=['pain', 'nausea', 'sleep', 'anxiety', 'unknown', 'not current use', 'appetite']))

# save
torch.save(model3.state_dict(), f'{path}/CannabisClassifier_model-ClinicalBERT_FrozenBackbone_dropout-3_batch-32.pth')

# 4. fine tune using ensemble methods

In [None]:

# Ensemble Model
class CannabisUseEnsembleModel(nn.Module):
    def __init__(self, models, num_classes, hidden_size=768):
        super().__init__()
        self.models = nn.ModuleList(models)
        
        # Number of models in the ensemble
        self.num_models = len(models)
        
        # Classifier layer
        self.classifier = nn.Linear(hidden_size * self.num_models, num_classes)
        
    def forward(self, input_ids, attention_mask):
        # Get outputs from all models
        model_outputs = [model(input_ids, attention_mask).last_hidden_state[:, 0, :] for model in self.models]
        
        # Concatenate the outputs
        concatenated_outputs = torch.cat(model_outputs, dim=1)
        
        # Pass through the classifier
        logits = self.classifier(concatenated_outputs)
        
        return logits

    def get_logits(self, input_ids, attention_mask):
        # This method returns individual model logits and ensemble logits
        individual_logits = [model(input_ids, attention_mask).logits for model in self.models]
        ensemble_logits = self.forward(input_ids, attention_mask)
        return individual_logits, ensemble_logits


In [None]:
EPOCHS=10
BATCH_SIZE = 32
MAX_LEN=128

train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

In [None]:

from transformers import AutoModelForSequenceClassification
PRE_TRAINED_MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
num_classes = 7
# Assume you have 3 pre-trained models
model1 = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model2 = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model3 = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

# Create the ensemble
ensemble = CannabisUseEnsembleModel([model1, model2, model3], num_classes=num_classes)

# fine tune the ensemble model
ensemble.to(device)

optimizer=AdamW(ensemble.parameters(),lr=2e-5, correct_bias=False)

weight_loss=True
if weight_loss:
    weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.8, 1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
else:
    loss_fn=nn.CrossEntropyLoss().to(device)


from tqdm import tqdm
for epoch in range(EPOCHS):
    ensemble.train()
    for batch in tqdm(train_data_loader):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label' and k != 'text'}
        outputs = ensemble(**inputs)
        preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, batch['label'])
        loss.backward()
        optimizer.step()


In [None]:
# eval

test_acc, _ = eval_model(
  ensemble,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)

best_acc.append(test_acc.item())
print(test_acc.item())

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  ensemble,
  test_data_loader
)

In [None]:
class_names=['0', '1','2','3','4','5','6']
print(classification_report(y_test, y_pred, target_names=['pain', 'nausea', 'sleep', 'anxiety', 'unknown', 'not current use', 'appetite']))

# save model
torch.save(ensemble.state_dict(), path + f'CannabisClassifier_model-ensemble_batch-{BATCH_SIZE}.pth')

# 5. MLM + ensemble

In [None]:
PRE_TRAINED_MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

# Models to use in ensemble
model_names = ["emilyalsentzer/Bio_ClinicalBERT", "roberta-base", "distilbert-base-uncased"]

# Perform domain adaptation on each model
adapted_models = []
for name in model_names:
    model = AutoModelForMaskedLM.from_pretrained(name)
    adapted_model = domain_adapt_mlm(model, tokenizer, data_collator, df_train)
    adapted_models.append(adapted_model)


# Create ensemble of domain-adapted models
num_classes = 7 
ensemblemlm = CannabisUseEnsembleModel([AutoModel.from_pretrained(name) for name in model_names], num_classes)

# update the ensemble with the weights from the adapted models
for i, model_name in enumerate(model_names):
    adapted_model = adapted_models[i]
    base_model_dict = ensemblemlm.models[i].state_dict()
    pretrained_dict = adapted_model.state_dict()
    base_model_dict.update({k: v for k, v in pretrained_dict.items() if k in base_model_dict})
    ensemblemlm.models[i].load_state_dict(base_model_dict)

In [None]:
# update the ensemble with the weights from the adapted models
for i, model_name in enumerate(model_names):
    adapted_model = adapted_models[i]
    base_model_dict = ensemblemlm.models[i].state_dict()
    pretrained_dict = adapted_model.state_dict()
    base_model_dict.update({k: v for k, v in pretrained_dict.items() if k in base_model_dict})
    ensemblemlm.models[i].load_state_dict(base_model_dict)

In [None]:
# train the ensemble model for the classification task

ensemblemlm.to(device)

optimizer=AdamW(ensemblemlm.parameters(),lr=2e-5, correct_bias=False)

weight_loss=True
if weight_loss:
    weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.8, 1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
else:
    loss_fn=nn.CrossEntropyLoss().to(device)


from tqdm import tqdm
for epoch in range(EPOCHS):
    ensemblemlm.train()
    for batch in tqdm(train_data_loader):
        optimizer.zero_grad()
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'label' and k != 'text'}
        outputs = ensemblemlm(**inputs)
        preds = torch.max(outputs, dim=1)
        loss = loss_fn(outputs, batch['label'])
        loss.backward()
        optimizer.step()

In [None]:
test_acc, _ = eval_model(
  ensemble,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
best_acc.append(test_acc.item())
test_acc.item()

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  ensemblemlm,
  test_data_loader
)

In [None]:
class_names=['0', '1','2','3','4','5','6']
print(classification_report(y_test, y_pred, target_names=['pain', 'nausea', 'sleep', 'anxiety', 'unknown', 'not current use', 'appetite']))

# save model
torch.save(ensemblemlm.state_dict(), path + 'CannabisClassifier_model-ensembleMLM_batch-{BATCH_SIZE}.pth')

# 6. fine tune attention layers + classification layers

In [None]:
class CannabisUsePurposeAttentionClassifier(nn.Module):
    def __init__(self, pretrained_model, num_labels):
        super().__init__()
        self.bert = pretrained_model
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        
        # Freeze all parameters except attention
        for name, param in self.bert.named_parameters():
            if 'attention' not in name:
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        return self.classifier(pooled_output)



In [None]:
EPOCHS=10
BATCH_SIZE = 32
MAX_LEN=128

train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

In [None]:

PRE_TRAINED_MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
base_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
modelatt=CannabisUsePurposeAttentionClassifier(base_model, len(np.unique(df_train.sentiment)))

In [None]:
optimizer=AdamW(modelatt.parameters(),lr=2e-5, correct_bias=False)

total_steps=len(train_data_loader)*EPOCHS
scheduler=get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

In [None]:
weight_loss=True
if weight_loss:
    weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.8, 1.0, 1.0]).to(device)
    loss_fn = nn.CrossEntropyLoss(weight=weights).to(device)
else:
    loss_fn=nn.CrossEntropyLoss().to(device)


In [None]:
%%time

history = defaultdict(list)
best_accuracy = 0

for epoch in range(EPOCHS):

  print(f'Epoch {epoch + 1}/{EPOCHS}')
  print('-' * 10)

  train_acc, train_loss = train_epoch(
    modelatt,
    train_data_loader,    
    loss_fn, 
    optimizer, 
    device, 
    scheduler, 
    len(df_train)
  )

  print(f'Train loss {train_loss} accuracy {train_acc}')


  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)


In [None]:
test_acc, _ = eval_model(
  modelatt,
  test_data_loader,
  loss_fn,
  device,
  len(df_test)
)
best_acc.append(test_acc.item())
test_acc.item()

In [None]:
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
  modelatt,
  test_data_loader
)

In [None]:
class_names=['0', '1','2','3','4','5','6']
print(classification_report(y_test, y_pred, target_names=['pain', 'nausea', 'sleep', 'anxiety', 'unknown', 'not current use', 'appetite']))

# save model
torch.save(model.state_dict(), path + f'/CannabisClassifier_model-ClinicalBERT_finetune-attention_batch-{BATCH_SIZE}.pth')

# FINAL STEP: classify the rest of the notes

In [None]:
print("best accuracy = ", np.max(best_acc))


In [None]:
# code to use the model to classify all of the notes
# load model from saved state_dict
num_classes = 7
model=CannabisClassifier(len(np.unique(df_train.sentiment)))
model.load_state_dict(torch.load(path + 'CannabisClassifier_model-ClinicalBERT_dropout-3_batch-32.pth'))
model.to(device)



unlabeled_notes_path = path + "full_list_of_5.csv" 
output_predictions_path = path + "fully_labeled_notes.csv"  

unlabeled_df = pd.read_csv(unlabeled_notes_path)
unlabeled_texts = unlabeled_df['Relevant Snippets'].tolist()  
patient_ids = unlabeled_df['Patient Id'].tolist()  
dates = unlabeled_df['Date'].tolist()  

class UnlabeledDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        encoding = tokenizer.encode_plus(
            text,
            max_length=128,
            add_special_tokens=True,
            padding='max_length',  
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

unlabeled_dataset = UnlabeledDataset(unlabeled_texts, tokenizer, max_len=128)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=False)

model.eval()

predictions = []
with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(unlabeled_loader, desc="Classifying Notes")):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        texts = batch['text']

        # Get model outputs
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)

        for i in range(len(texts)):
            global_index = batch_idx * 32 + i  
            predictions.append({
                'Patient Id': patient_ids[global_index],
                'Date': dates[global_index],
                'Relevant Snippets': texts[i],
                'Predicted Label': predicted_labels[i].item(),
                'Confidence': probabilities[i].max().item()
            })

predictions_df = pd.DataFrame(predictions)
# shift the label to 1-7
predictions_df['Predicted Label'] = predictions_df['Predicted Label'] + 1
predictions_df.to_csv(output_predictions_path, index=False)
print(f"Predictions saved to {output_predictions_path}")

