#### 1. Imports

In [None]:
import optuna
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoModel
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import accelerate # leave here
import numpy as np
from utils import load_class_code_from_directory, save_embeddings_to_csv, process_files
from embeddings import generate_embeddings_for_java_file

In [None]:
# Check if CUDA (GPU) is available and if so, set the device to GPU
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  

device = torch.device(dev)

#### 2. Generate examples for fine-tuning CodeBERT (from POS)

In [None]:
# TODO : Add fine tuning examples here (use POS classes to begin with)

# Labels are 0: Application, 1: Utility, 2: Entity
class_labels = process_files('v_imen', 'pos')

class_code = load_class_code_from_directory('pos')

# Put data in the below format by combining class_code and class_labels based on key if label exists
# examples = [
#     {"text": "<your Java class code here>", "label": 0},
#     {"text": "<another Java class code here>", "label": 1},
#     ...]

examples = []
for key in class_code.keys():
    if key in class_labels.keys():
        examples.append({"text": class_code[key], "label": class_labels[key]})

# Print first 5 examples
print(examples[:5])

#### 3. Fine-tuning and optimization

In [None]:
# Implement a PyTorch Dataset
class CodingDataset(Dataset):
    def __init__(self, examples, tokenizer):
        self.examples = examples
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.examples)
        
    def __getitem__(self, idx):
        example = self.examples[idx]
        encoding = self.tokenizer(example['text'], padding='max_length', truncation=True, max_length=512, return_tensors='pt')
        encoding = {key: torch.flatten(value) for key, value in encoding.items()}  # Flatten tensors
        encoding['labels'] = torch.tensor(example['label'])
        return encoding

In [None]:
# Split your data into train and validation sets
train_examples, val_examples = train_test_split(examples, test_size=0.1)

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base", force_download=False)
model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=3, force_download=False)
model = model.to(device)

best_params = None

In [None]:
# Define objective function for optuna to optimize
def objective(trial):
    # Define hyperparameters for this trial
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)  # Learning rate
    batch_size = trial.suggest_categorical('batch_size', [8, 16, 32, 64])  # Batch size

    # Create data loaders
    train_dataset = CodingDataset(train_examples, tokenizer)
    val_dataset = CodingDataset(val_examples, tokenizer)

    # Specify the training arguments
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=trial.suggest_int('num_train_epochs', 1, 10), # tune hyperparameter here
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        logging_dir='./logs',
    )

    # Training
    trainer = Trainer(
        model=model,
        args=training_args,  
        train_dataset=train_dataset,         
        eval_dataset=val_dataset             
    )

    trainer.train()

    # Evaluate the model on the validation set
    eval_result = trainer.evaluate()
        
    # Optuna seeks for the minimum so return loss as it is
    return eval_result["eval_loss"]

In [None]:
# TODO : Eventually, add argument to .py script to run optuna if desired
if False:
    # Create a study to run the hyperparameter optimization
    study = optuna.create_study(direction="minimize")

    # Run the optimization
    study.optimize(objective, n_trials=10)

    # Print the results
    best_params = study.best_params
    print(f"Best hyperparameters: {best_params}")

In [None]:
def custom_collate_fn(batch):
    keys = batch[0].keys()
    output_batch = {key: torch.stack([item[key] for item in batch]) for key in keys}
    return output_batch

In [None]:
# Trial 4 finished with value: 0.0006419435958378017 and parameters: {'lr': 0.0002536818790618518, 'batch_size': 16, 'num_train_epochs': 8}. 
# Best is trial 4 with value: 0.0006419435958378017.

if best_params is None:
    best_params = {'lr': 0.0002536818790618518, 'batch_size': 16, 'num_train_epochs': 8}

# Fine tuning the model with the best parameters
dataset = CodingDataset(examples, tokenizer)
dataloader = DataLoader(dataset, batch_size=best_params['batch_size'], collate_fn=custom_collate_fn)

# We'll use Adam as our optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=best_params['lr'])

# Put your model into training mode
model.train()

# Training loop
epochs = best_params['num_train_epochs'] # for fine tune, we also use best epochs
for epoch in range(epochs):
    for idx, batch in enumerate(dataloader):
        batch = {key: value.to(device) for key, value in batch.items()} # moving batch data to the device
        optimizer.zero_grad()  # Reset gradients
        outputs = model(**batch)  # Forward pass
        loss = outputs.loss  # Calculate the loss from the outputs
        loss.backward()  # Backpropagation
        optimizer.step()  # Adjust model weights based on gradients

    print("Training completed.")

In [None]:
# Save the model
model.save_pretrained("./codebert_finetuned")
tokenizer.save_pretrained("./codebert_finetuned")

#### 4. Embedding generation

In [None]:
# Load the fine-tuned model and tokenizer
model = AutoModel.from_pretrained("./codebert_finetuned")
tokenizer = AutoTokenizer.from_pretrained("./codebert_finetuned", force_download=False)  # Use the fine-tuned tokenizer

model = model.to(device)
version = 'v_team'

# TODO : Generate embeddings for all systems, save them to csv file and use them in the next step
for system in ['jforum']:
    class_code = load_class_code_from_directory(system)

    class_embeddings = {}
    for class_name, code in class_code.items():
        class_embeddings[class_name] = generate_embeddings_for_java_file(code, model, tokenizer, device)

    save_embeddings_to_csv(version, system, 'ft_codebert', class_embeddings)