In [1]:
import comet_ml
import numpy as np
import torch

In [None]:
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer

student_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(student_id)

dataset = load_dataset("imdb")

def pre_process(examples):
    return tokenizer(examples["text"], truncation = True, max_length = 512)

tokenized_data = dataset.map(pre_process, batched = True)

#test_valid = tokenized_data['test'].train_test_split(test_size=0.5)
#tokenized_data = DatasetDict({
#    'train': tokenized_data['train'],
#    'test': test_valid['train'],
#    'validation': test_valid['test']
#})


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#comet_ml.init(project_name="distilbert_dotprod")

In [None]:
from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class distillTrainer(Trainer):
    def __init__(self, *args, teacher_model = None, temperature = None, alpha_ce = None, alpha_cos = None, **kwargs):
        super().__init__(*args,**kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        
    def compute_loss(self, model, inputs, return_outputs = False):
        student_outputs = model(**inputs, output_attentions = True)
        student_context = student_outputs.contexts  
        print("2" , student_context[0].shape)
        with torch.no_grad():
            teacher_outputs = self.teacher(**inputs)
        
        
        
        return (loss, student_outputs) if return_outputs else loss
         
        

In [None]:
labels = tokenized_data['train'].features['label'].names
num_labels = len(labels)
label2id, id2label = {}, {}

for idx, lbl in enumerate(labels):
    label2id[lbl] = idx
    id2label[idx] = lbl
print(label2id)
print(id2label)

In [None]:
from transformers import DistilBertForSequenceClassification, AutoModelForSequenceClassification, DistilBertConfig, DataCollatorWithPadding
from iDistilbert_weights_init import iDistilBertForMaskedLM

#Load Models
teacher_id = "textattack/bert-base-uncased-imdb"
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels = num_labels,
    id2label = id2label,
    label2id = label2id,
    output_hidden_states=True,
)

student_config = DistilBertConfig(
    output_hidden_states = True,
    distance_metric = "manhattan_distance",
    activation_function = "relu",
    signed_inhibitor =  True,
    alpha = 0,
    center = True,
    num_labels = num_labels,
    output_contexts = True,
    )

student_model = iDistilBertForMaskedLM(student_config)

initialized_weights = torch.load('/mnt/tony/MSc2024/distilbert_init/models/iDistilbert_init.pth')
student_model.load_state_dict(initialized_weights, strict=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
teacher_model.to(device)
student_model.to(device)

In [None]:
import evaluate
import numpy as np

#experiment = comet_ml.get_global_experiment()

accuracy = evaluate.load("accuracy")

def preprocess_logits_for_metrics(logits, labels):
    """
    Preprocess the logits to ensure they are in the correct format for metric computation.
    This function will be called during the evaluation process.
    """
    if isinstance(logits, tuple):  
        logits = logits[0]  # get logit tensors

    pred_ids = torch.argmax(logits, dim=-1)
    
    return pred_ids, labels
    
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred

    return accuracy.compute(predictions=predictions[0], references=labels)



In [None]:

%env COMET_MODE=ONLINE
%env COMET_LOG_ASSETS=TRUE

EPOCHS = 2
BATCH_SIZE = 4
LEARNING_RATE = 5e-5

training_args = TrainingArguments(
    output_dir = './results',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    learning_rate = LEARNING_RATE,
    logging_dir = './logs',
    load_best_model_at_end= True,
    metric_for_best_model="accuracy",
    eval_strategy="steps",
    eval_steps = 500,
    save_strategy="steps",
    save_total_limit=2,
    seed = 42,
    #report_to=['comet_ml', 'tensorboard'],
    report_to=['tensorboard'],
)

In [None]:
from torchinfo import summary

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = distillTrainer(
    teacher_model=teacher_model,
    model=student_model,                         
    args=training_args,                  
    train_dataset=tokenized_data['train'],         
    eval_dataset=tokenized_data['test'],
    compute_metrics = compute_metrics,
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    temperature = 5,
    alpha_ce = 0.3,
    alpha_cos = 0.2,
    tokenizer = tokenizer,
    data_collator = data_collator,
)

student_model.freeze_weights_except_q_k()

# Sample input tensor with LongTensor type
input_ids = torch.randint(0, 30522, (4, 512)).long().to(device)  # Assuming vocab size 30522

# Attention mask (optional, but typically used)
attention_mask = torch.ones((4, 512)).long().to(device)

# Generate summary, note that input size should match what the model expects
summary(student_model, input_data={'input_ids': input_ids, 'attention_mask': attention_mask})



In [18]:
trainer.train()



2 torch.Size([16, 512, 768])


NameError: name 'loss' is not defined

In [14]:
import os

folder = 'models/'
os.makedirs(folder, exist_ok=True)
torch.save(student_model.state_dict(), os.path.join(folder, 'weight_opt_iDistilbert.pth'))

In [23]:
try:
    output = student_model(input_ids=input_ids, attention_mask=attention_mask)
    print("Forward pass successful")
except Exception as e:
    print(f"Error during forward pass: {e}")

print(output.contexts[5].shape)

Forward pass successful
torch.Size([4, 512, 768])
