In [1]:
!pip install transformers
!pip install transformers[torch]
!pip install accelerate -U
!pip install wandb -qU



In [2]:
from transformers import BloomForSequenceClassification, BloomTokenizerFast, Trainer, TrainingArguments
import torch.nn.functional as F
import torch
from google.colab import auth
import pandas as pd
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import accelerate
from transformers import Trainer
from torch.utils.data import DataLoader

In [3]:
auth.authenticate_user()

In [4]:
#load the teacher model
!mkdir -p /content/finetuned_model
!gsutil -m cp -r "gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/*" "/content/finetuned_model/"

Copying gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/config.json...
/ [0/5 files][    0.0 B/  4.0 GiB]   0% Done                                    Copying gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/pytorch_model.bin...
/ [0/5 files][    0.0 B/  4.0 GiB]   0% Done                                    ==> NOTE: You are downloading one or more large file(s), which would
run significantly faster if you enabled sliced object downloads. This
feature is enabled by default but requires that compiled crcmod be
installed (see "gsutil help crcmod").

Copying gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/special_tokens_map.json...
/ [0/5 files][    0.0 B/  4.0 GiB]   0% Done                                    Copying gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/tokenizer.json...
Copying gs://finetunned_760m_bloom/bloom_icd_10/finetuned_model/tokenizer_config.json...
| [5/5 files][  4.0 GiB/  4.0 GiB] 100% Done  37.3 MiB/s ETA 00:00:00           
Opera

In [5]:
# Load teacher model and tokenizer
teacher_model_path = "/content/finetuned_model"
tokenizer = BloomTokenizerFast.from_pretrained(teacher_model_path, repo_type="local")
teacher_model = BloomForSequenceClassification.from_pretrained(teacher_model_path).to('cuda')

In [6]:
!gcloud config set project 'steel-climber-398320'
!gsutil cp gs://training_dataset_bloom/merged_symp_icd/symptoms_icd10.json ./training_data.jsonl
training_data = pd.read_json('training_data.jsonl', lines=True)

Updated property [core/project].
Copying gs://training_dataset_bloom/merged_symp_icd/symptoms_icd10.json...
/ [1 files][625.7 KiB/625.7 KiB]                                                
Operation completed over 1 objects/625.7 KiB.                                    


In [7]:
def tokenizeInputs(symptoms, codes):
    tokenized_inputs = tokenizer(symptoms, max_length=512, truncation=True, padding=True)
    return tokenized_inputs, codes

In [8]:
tokenized_symptoms, coded = tokenizeInputs(training_data['symptoms'].tolist(), training_data['icd10_code'].tolist())

In [9]:
# prepare the data
# Convert codes to unique integer labels
label_encoder = LabelEncoder()
training_data['label'] = label_encoder.fit_transform(training_data['icd10_code'])

# Function to tokenize data and labels
def tokenize_data_and_labels(symptoms, labels):
    tokenized_inputs = tokenizer(symptoms, max_length=512, truncation=True, padding=True, return_tensors="pt")
    tokenized_inputs['labels'] = torch.tensor(labels)
    return tokenized_inputs

tokenized_data = tokenize_data_and_labels(training_data['symptoms'].tolist(), training_data['label'].tolist())

In [10]:
class CustomDataset(Dataset):
    def __init__(self, tokenized_data):
        self.input_ids = tokenized_data["input_ids"]
        self.attention_mask = tokenized_data["attention_mask"]
        self.soft_labels = tokenized_data.get("soft_labels", torch.zeros_like(self.input_ids))
        self.labels = tokenized_data.get("labels", torch.zeros_like(self.input_ids))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_mask[idx],
            "soft_labels": self.soft_labels[idx],
            "labels": self.labels[idx]
        }


In [11]:
# Split data into training and validation sets
train_prompts, val_prompts, train_labels, val_labels = train_test_split(
    training_data['symptoms'].tolist(),
    training_data['label'].tolist(),
    test_size=0.1
)

# Tokenize training and validation data
tokenized_train = tokenize_data_and_labels(train_prompts, train_labels)
tokenized_val = tokenize_data_and_labels(val_prompts, val_labels)

train_dataset = CustomDataset(tokenized_train)
val_dataset = CustomDataset(tokenized_val)

for entry in train_dataset:
    if "soft_labels" not in entry:
        print("Entry missing soft_labels:", entry)

for entry in val_dataset:
    if "soft_labels" not in entry:
        print("Entry missing soft_labels:", entry)

In [12]:
# generate soft labels using the teacher
#this requires a huge chunk of ram
#def get_soft_labels(tokenized_inputs):
#    with torch.no_grad():
#        teacher_outputs = teacher_model(**{k: v.to('cuda') for k, v in tokenized_inputs.items()})
#        soft_labels = F.softmax(teacher_outputs.logits, dim=-1)
#    return soft_labels

#soft_labels = get_soft_labels(tokenized_train)

In [13]:
# generate soft labels using the teacher
#this code will use batches to avoid fulling all the GPU ram
def get_soft_labels(tokenized_inputs, batch_size=32):
    all_soft_labels = []

    # Create a DataLoader to handle batching
    dataset = CustomDataset(tokenized_inputs)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    for batch in dataloader:
        # Exclude the "soft_labels" key for forwarding
        model_inputs = {k: v.to('cuda') for k, v in batch.items() if k != "soft_labels"}
        with torch.no_grad():
            teacher_outputs = teacher_model(**model_inputs)
            soft_labels = F.softmax(teacher_outputs.logits, dim=-1)
        all_soft_labels.append(soft_labels)


    # Concatenate all soft labels into one tensor
    all_soft_labels = torch.cat(all_soft_labels, dim=0)

    return all_soft_labels

soft_labels = get_soft_labels(tokenized_train)



In [14]:
# Adding soft_labels to tokenized_train
tokenized_train['soft_labels'] = soft_labels


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
# Create and initialize a smaller student model
student_model = BloomForSequenceClassification.from_pretrained("bigscience/bloom-560m", num_labels=teacher_model.config.num_labels)
student_model.to(device)

# Define a custom loss function that takes the soft labels into account
def distillation_loss(predictions, soft_labels, temperature=1.0):
    if soft_labels is None:
        raise ValueError("Soft labels are missing. Ensure all batches have soft labels.")
    return -torch.mean(torch.sum(soft_labels * F.log_softmax(predictions / temperature, dim=-1), dim=-1))


Some weights of BloomForSequenceClassification were not initialized from the model checkpoint at bigscience/bloom-560m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
# Log in to W&B account
import wandb
wandb.init(project="bloom_destillation", entity="ac215-the-prescribers")

[34m[1mwandb[0m: Currently logged in as: [33mamorales[0m ([33mac215-the-prescribers[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
def data_collator_fn(data):
    input_ids = torch.stack([item["input_ids"] for item in data])
    attention_mask = torch.stack([item["attention_mask"] for item in data])

    if "soft_labels" in data[0]: # Check if the first item has "soft_labels"
        soft_labels = torch.stack([item["soft_labels"] for item in data])
    else:
        soft_labels = None

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "soft_labels": soft_labels,
    }



In [19]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        logits = model(inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits
        loss = distillation_loss(logits, inputs["soft_labels"])
        return (loss, logits) if return_outputs else loss

# Define training arguments and trainer
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=8,
    num_train_epochs=5,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    report_to="wandb",
)

trainer = CustomTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator_fn
)

In [20]:
for entry in tokenized_train:
    if "soft_labels" not in entry:
        print("Entry missing soft_labels:", entry)

Entry missing soft_labels: input_ids
Entry missing soft_labels: attention_mask
Entry missing soft_labels: labels


In [21]:
# Train the student model using the soft labels
trainer.train()

ValueError: ignored

In [None]:
# Optionally, save the student model
#student_model.save_pretrained("/path/to/save/student/model")