In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import KFold
import pandas as pd
from sklearn.metrics import cohen_kappa_score
import time
from datetime import datetime
from tqdm.auto import tqdm
import math


#gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\nUsing device: {device}")


def quadratic_weighted_kappa(y_true, y_pred):
    """Calculate Quadratic Weighted Kappa score."""
    return cohen_kappa_score(y_true, np.round(y_pred), weights='quadratic')


# Score ranges dictionary with all traits
SCORE_RANGES = {
        1: {'sentence_fluency': (1, 6), 'word_choice': (1, 6), 'conventions': (1, 6),'organization': (1, 6),
            'content': (1, 6), 'holistic': (2, 12)},
        2: {'sentence_fluency': (1, 6), 'word_choice': (1, 6), 'conventions': (1, 6),'organization': (1, 6),
            'content': (1, 6), 'holistic': (1, 6)},
        3: {'narrativity': (0, 3), 'language': (0, 3), 'prompt_adherence': (0, 3), 'content': (0, 3),
            'holistic': (0, 3)},
        4: {'narrativity': (0, 3), 'language': (0, 3), 'prompt_adherence': (0, 3), 'content': (0, 3),
            'holistic': (0, 3)},
        5: {'narrativity': (0, 4), 'language': (0, 4), 'prompt_adherence': (0, 4), 'content': (0, 4),
            'holistic': (0, 4)},
        6: {'narrativity': (0, 4), 'language': (0, 4), 'prompt_adherence': (0, 4), 'content': (0, 4),
            'holistic': (0, 4)},
        7: {'conventions': (0, 6), 'organization': (0, 6), 'content': (0, 6),'holistic': (0, 30)},
        8: {'sentence_fluency': (2, 12), 'word_choice': (2, 12), 'conventions': (2, 12),'organization': (2, 12),
            'content': (2, 12), 'holistic': (0, 60)}}

#traits in order, is going to be output order
traits = ['holistic', 'content', 'organization', 'word_choice', 'sentence_fluency',
               'conventions', 'prompt_adherence', 'language', 'narrativity']

def read_data(path):
    """Reads the CSV file and returns a dictionary with parallel lists."""
    data = pd.read_csv(path)
    return {
        'essay_ids': data['essay_id'].values,
        'prompt_ids': data['prompt_id'].values,
        'essay_text': data['essay_text'].values,
        'features': data.iloc[:, 12:].values,
        'holistic': data['holistic'].values,
        'content': data['content'].values,
        'organization': data['organization'].values,
        'word_choice': data['word_choice'].values,
        'sentence_fluency': data['sentence_fluency'].values,
        'conventions': data['conventions'].values,
        'prompt_adherence': data['prompt_adherence'].values,
        'language': data['language'].values,
        'narrativity': data['narrativity'].values
    }
#function to normalize the score based on its range
def normalize_score(score, score_range):
    return (score - score_range[0]) / (score_range[1] - score_range[0])

#return it back
def denormalize_score(norm_score, score_range):
    return norm_score * (score_range[1] - score_range[0]) + score_range[0]

#get binary mask to show which trait/score is valid for a prompt
#example, prompt 1 essay mask is [1,1,1,1,1,1,0,0,0]
def get_trait_mask(prompt_id):
    mask = torch.zeros(len(traits))
    for i, trait in enumerate(traits):
        if trait in SCORE_RANGES[prompt_id]:
            mask[i] = 1.0
    return mask

class NeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_unit, num_layers):
        super(NeuralNetwork, self).__init__()

        layers = []
        #input layer and activation define
        layers.append(nn.Linear(input_size, hidden_unit))
        layers.append(nn.ReLU())

        #hidden layres in list for tuning later
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_unit, hidden_unit))
            layers.append(nn.ReLU())

        #output layer one node for each trait including holistic
        layers.append(nn.Linear(hidden_unit, len(traits)))

        self.model = nn.Sequential(*layers)

        #he initialization for params
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
         return self.model(x)

class MultivariateLossFn(nn.Module):
    """Loss function based on probability of prediction being correct MAKE SURE TO MAKE IT USE ONLY VALID PROMPTS!!"""
    def __init__(self, sigma=1.0):
        super(MultivariateLossFn, self).__init__()
        self.sigma = sigma
        self.const = 1.0 / (math.sqrt(2 * math.pi) * sigma)

    def forward(self, outputs, targets, masks):
        squared_diff = (targets - outputs) ** 2
        exp_term = torch.exp(-squared_diff / (2 * self.sigma ** 2))
        prob_term = self.const * exp_term
        masked_probs = prob_term * masks #masked only those with values of 1 (valid prompts)
        valid_probs = masked_probs + (1 - masks) + 1e-7 #add epsilon to avoid log(0)
        log_prob = torch.log(valid_probs)
        loss = -torch.sum(log_prob * masks, dim=1) / torch.sum(masks, dim=1)
        return loss.mean()

def normalize_all_scores(data_dict, prompt_id):
    """
    normalizing + masking
    -normalizes those present scores to [0,1]
    -creates mask (1 for valid traits, 0 for invalid)
    """
    normalized_scores = torch.zeros(len(traits))
    mask = torch.zeros(len(traits))

    for i, trait in enumerate(traits):
        if trait in SCORE_RANGES[prompt_id]:
            score = data_dict[trait]
            score_range = SCORE_RANGES[prompt_id][trait]
            #do normalizing of the present scores
            normalized_scores[i] = normalize_score(score, score_range)
            #mask it as 1 to show its present, others will remain 0
            mask[i] = 1.0 

    return normalized_scores, mask

def denormalize_all_scores(normalized_scores, prompt_id):
    """
    denormalization
    -converts all normalized scores back to original ranges
    -only do that for traits valid for the prompt
    """
    denormalized_scores = torch.zeros_like(normalized_scores)

    for i, trait in enumerate(traits):
        if trait in SCORE_RANGES[prompt_id]:
            score_range = SCORE_RANGES[prompt_id][trait]
            #denormalize
            denormalized_scores[i] = denormalize_score(normalized_scores[i], score_range)

    return denormalized_scores

#train and evaluate model for given a target prompt id. 
#one prompt will run this after another, each with their id
def train_and_evaluate_prompt(test_prompt_id, data):
    print(f"\n{'-'*60}")
    print(f"model starts training process for target prompt {test_prompt_id}")
    print(f"{'-'*60}\n")

    #define train prompts to be all except test prompt
    train_prompt_range = [i for i in range(1, 9) if i != test_prompt_id]

    #prep data
    filter = data['prompt_ids'] != test_prompt_id
    prompt_ids = data['prompt_ids'][filter]
    features = torch.FloatTensor(data['features'][filter]).to(device)

    #NORMALIZATION + MASKING
    scores_list = []
    masks_list = []
    #for each essay,
    for idx, prompt_id in enumerate(prompt_ids):
        #temporary dict with scores for this essay
        essay_data = {trait: data[trait][filter][idx] for trait in traits}
        #normalize scores and create mask for this essay
        normalized_scores, mask = normalize_all_scores(essay_data, prompt_id)
        scores_list.append(normalized_scores) #add it to scores
        masks_list.append(mask) #add the mask

    normalized_scores = torch.stack(scores_list)
    score_masks = torch.stack(masks_list).to(device)
    normalized_scores = normalized_scores.to(device)

    print(f"Feature shape: {features.shape}")
    print(f"Scores shape: {normalized_scores.shape}")

    #grid search params
    k_fold = KFold(n_splits=7, shuffle=False) #k fold not shuffled, bc we need one prompt per fold cv
    hidden_units = [8, 16, 32]
    num_layers_options = [1, 2, 4, 8]
    learning_rates = [0.001, 0.01, 0.1]
    batch_size = 4

    #try all possible combinations of those params
    total_combinations = len(hidden_units) * len(num_layers_options) * len(learning_rates)
    print(f"Model does grid search with {total_combinations} combinations")
    print(f"- hidden units per layer (D): {hidden_units}")
    print(f"- num of layers (k): {num_layers_options}")
    print(f"- learning rates: {learning_rates}")
    print(f"- batch size (fixed): {batch_size}\n")

    best_avg_qwk = -1
    best_params = None
    combination_count = 0
    start_time = time.time()
    # ------------------------------------------------------- GRID SEARCH -------------------------------------------------------
    #make sure to track best performing parameters
    for hidden_unit in hidden_units:
        for num_layers in num_layers_options:
            for lr in learning_rates:
                combination_count += 1
                print(f"\n--------------- testing combo {combination_count}/{total_combinations}--------------")
                print(f"hyperparams: units per layer D={hidden_unit}, num of layers k={num_layers}, lr={lr}\n")

                fold_qwks_by_trait = {trait: [] for trait in traits}
                fold_start_time = time.time()

                #7-fold cross-validation (heldout still heldout!)
                for fold, (train_prompt_idx, val_prompt_idx) in enumerate(k_fold.split(train_prompt_range)):
                    train_prompts = [train_prompt_range[i] for i in train_prompt_idx]
                    val_prompt = train_prompt_range[val_prompt_idx[0]]

                    train_filter = np.isin(prompt_ids, train_prompts)
                    val_filter = (prompt_ids == val_prompt)

                    X_train, X_val = features[train_filter], features[val_filter]
                    y_train, y_val = normalized_scores[train_filter], normalized_scores[val_filter]
                    mask_train, mask_val = score_masks[train_filter], score_masks[val_filter]

                    model = NeuralNetwork(86, hidden_unit, num_layers).to(device)
                    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=0.1)
                    criterion = MultivariateLossFn()

                    best_val_loss = float('inf')
                    epochs_no_improve = 0
                    best_epoch_qwks = {}

                    # Training loop
                    for epoch in range(15):
                        model.train()
                        epoch_loss = 0
                        num_batches = 0

                        indices = torch.randperm(len(X_train))
                        X_train = X_train[indices]
                        y_train = y_train[indices]
                        mask_train = mask_train[indices]

                        for i in range(0, len(X_train), batch_size):
                            batch_X = X_train[i:i + batch_size]
                            batch_y = y_train[i:i + batch_size]
                            batch_mask = mask_train[i:i + batch_size]

                            optimizer.zero_grad()
                            outputs = model(batch_X)
                            loss = criterion(outputs, batch_y, batch_mask)
                            loss.backward()
                            optimizer.step()

                            epoch_loss += loss.item()
                            num_batches += 1

                        avg_epoch_loss = epoch_loss / num_batches

                        #validation 
                        model.eval()
                        with torch.no_grad():
                            val_pred = model(X_val)
                            epoch_trait_qwks = {}

                            #calc QWK for each trait
                            #DENORMALIZING HERE
                            for i, trait in enumerate(traits):
                                if trait in SCORE_RANGES[val_prompt]:
                                     #DENORMALIZATION
                                    val_pred_trait = denormalize_score(
                                        val_pred[:, i].cpu().numpy(),
                                        SCORE_RANGES[val_prompt][trait]
                                    )
                                    val_true_trait = denormalize_score(
                                        y_val[:, i].cpu().numpy(),
                                        SCORE_RANGES[val_prompt][trait]
                                    )
                                    qwk = quadratic_weighted_kappa(
                                        val_true_trait.round(),
                                        val_pred_trait
                                    )
                                    epoch_trait_qwks[trait] = qwk

                            # Update best QWKs if this is the best epoch
                            if avg_epoch_loss < best_val_loss:
                                best_val_loss = avg_epoch_loss
                                best_epoch_qwks = epoch_trait_qwks.copy()
                                epochs_no_improve = 0
                            else:
                                epochs_no_improve += 1

                        if epochs_no_improve >= 3:
                            break

                    #store best QWKs for this fold
                    for trait, qwk in best_epoch_qwks.items():
                        fold_qwks_by_trait[trait].append(qwk)

                    #print avg QWK for this validation prompt
                    prompt_avg_qwk = np.mean(list(best_epoch_qwks.values()))
                    print(f"validation prompt no. {val_prompt}: Avg QWK of all scores: {prompt_avg_qwk:.4f}")

                #print average QWK across all folds
                avg_trait_qwks = {}
                for trait in traits:
                    if fold_qwks_by_trait[trait]:
                        avg_trait_qwks[trait] = np.mean(fold_qwks_by_trait[trait])

                avg_qwk = np.mean(list(avg_trait_qwks.values()))
                print(f"\avg QWK across all prompts: {avg_qwk:.4f}")

                #tracking best params
                if avg_qwk > best_avg_qwk:
                    best_avg_qwk = avg_qwk
                    best_params = {
                        'hidden_unit': hidden_unit,
                        'num_layers': num_layers,
                        'learning_rate': lr
                    }
                    print("BEST COMBO FOUND =)!")
    # -------------------------------------------------------------------------------------------------------------------------------

    # ------------------------------------------------------- BATCH SIZE OPTIM -------------------------------------------------------
    #make sure to track best performing batch size

    #data for batch size optimization, dividing dataset into 85% trian 15% test. 
    #the validation prompt being use for batch size optimization is always 8 unless taregt prompt 8 then its 7
    validation_prompt_id = 8 if test_prompt_id != 8 else 7
    ranges = [p for p in train_prompt_range if p != validation_prompt_id] #ranges of prompts are all except the test and validation

    print(f"\n{'-'*50}")
    print("starting the batch size optim with best params from grid search")
    print(f"{'-'*50}\n")
    train_filter = (data['prompt_ids'] != test_prompt_id) & (data['prompt_ids'] != validation_prompt_id) #use data that isnt the test prompt or the validation prompt
    val_filter = data['prompt_ids'] == validation_prompt_id #the validation prompt (prompt 8 for all models rn, and for prompt 8 its 7)

    X_train = torch.FloatTensor(data['features'][train_filter]).to(device)
    prompt_ids_train = data['prompt_ids'][train_filter]

    X_val = torch.FloatTensor(data['features'][val_filter]).to(device)
    prompt_ids_val = data['prompt_ids'][val_filter]

    #normalized scores and masks for training data
    train_scores_list = []
    train_masks_list = []
    for idx, prompt_id in enumerate(prompt_ids_train):
        essay_data = {trait: data[trait][train_filter][idx] for trait in traits} #dict
        normalized_scores, mask = normalize_all_scores(essay_data, prompt_id)
        train_scores_list.append(normalized_scores)
        train_masks_list.append(mask)

    y_train = torch.stack(train_scores_list).to(device)
    mask_train = torch.stack(train_masks_list).to(device)

    #normalized scores and masks for validation data
    #NORMALIZATION + MASKING
    val_scores_list = []
    val_masks_list = []
    for idx, prompt_id in enumerate(prompt_ids_val):
        essay_data = {trait: data[trait][val_filter][idx] for trait in traits}
        normalized_scores, mask = normalize_all_scores(essay_data, prompt_id)
        val_scores_list.append(normalized_scores)
        val_masks_list.append(mask)

    y_val = torch.stack(val_scores_list).to(device)
    mask_val = torch.stack(val_masks_list).to(device)

    print("Dataset after separation:")
    print(f"Training: {len(X_train)} essays (prompts {ranges})")
    print(f"Validation: {len(X_val)} essays (prompt {validation_prompt_id})")
    print(f"Held out: Prompt {test_prompt_id} (for final testing)\n")

    batch_sizes = [4, 8, 16, 32]
    print(f"testing batch sizes: {batch_sizes}")

    best_batch_qwk = -1
    best_batch_size = None
    batch_start_time = time.time()

    criterion = MultivariateLossFn()

    for batch_size in batch_sizes:
        print(f"\n------------ batch size: {batch_size} --------------")

        model = NeuralNetwork(
            input_size=86,
            hidden_unit=best_params['hidden_unit'],
            num_layers=best_params['num_layers']
        ).to(device)

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=best_params['learning_rate'],
            betas=(0.9, 0.999),
            weight_decay=0.1
        )

        best_val_loss = float('inf')
        epochs_without_improvement = 0

        for epoch in range(15):
            model.train()
            epoch_loss = 0
            num_batches = 0

            # Shuffle training data
            indices = torch.randperm(len(X_train))
            X_train = X_train[indices]
            y_train = y_train[indices]
            mask_train = mask_train[indices]

            for i in range(0, len(X_train), batch_size):
                batch_X = X_train[i:i + batch_size]
                batch_y = y_train[i:i + batch_size]
                batch_mask = mask_train[i:i + batch_size]

                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y, batch_mask)
                loss.backward()
                optimizer.step()

                epoch_loss += loss.item()
                num_batches += 1

            avg_epoch_loss = epoch_loss / num_batches

            # Validation step
            model.eval()
            with torch.no_grad():
                val_outputs = model(X_val)
                val_loss = criterion(val_outputs, y_val, mask_val)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

            if epochs_without_improvement >= 3:
                print(f"early stopping at epoch {epoch + 1}")
                break

            if (epoch + 1) % 5 == 0:
                print(f"Epoch {epoch + 1}/15 - Train Loss: {avg_epoch_loss:.6f}, Val Loss: {val_loss:.6f}")

        #evaluate QWK on validation set for all traits
        model.eval()
        with torch.no_grad():
            val_pred = model(X_val)
            val_qwks = {}
            for i, trait in enumerate(traits):
                if trait in SCORE_RANGES[validation_prompt_id]:
                    val_pred_trait = denormalize_score(
                        val_pred[:, i].cpu().numpy(),
                        SCORE_RANGES[validation_prompt_id][trait]
                    )
                    val_true_trait = denormalize_score(
                        y_val[:, i].cpu().numpy(),
                        SCORE_RANGES[validation_prompt_id][trait]
                    )
                    val_qwks[trait] = quadratic_weighted_kappa(
                        val_true_trait.round(),
                        val_pred_trait
                    )

            avg_qwk = np.mean(list(val_qwks.values()))
            print(f"avg validation QWK across all traits: {avg_qwk:.4f}")
            print("trait QWKs:")
            for trait, qwk in val_qwks.items():
                print(f"{trait}: {qwk:.4f}")

            #tracking best batch size
            if avg_qwk > best_batch_qwk:
                best_batch_qwk = avg_qwk
                best_batch_size = batch_size
                print("New best batch size found!!!!! :)")

    batch_time = time.time() - batch_start_time
    print(f"\n{'-'*50}")
    print("done with batch size optimization :D")
    print(f"total time taken: {batch_time:.2f} seconds ({batch_time/60:.2f} minutes)")
    print(f"best batch size: {best_batch_size}")
    print(f"best validation average QWK: {best_batch_qwk:.4f}")
    print(f"{'-'*50}\n")
    # -------------------------------------------------------------------------------------------------------------------------------

    # ------------------------------------------------------- FINAL MODEL TRAINING -------------------------------------------------------
    final_params = best_params.copy()
    final_params['batch_size'] = best_batch_size

    #final training data (all data except test prompt)
    final_train_filter = data['prompt_ids'] != test_prompt_id
    X_final_train = torch.FloatTensor(data['features'][final_train_filter]).to(device)
    prompt_ids_final = data['prompt_ids'][final_train_filter]

    #normalized scores and masks for all training data
    final_scores_list = []
    final_masks_list = []
    for idx, prompt_id in enumerate(prompt_ids_final):
        essay_data = {trait: data[trait][final_train_filter][idx] for trait in traits}
        normalized_scores, mask = normalize_all_scores(essay_data, prompt_id)
        final_scores_list.append(normalized_scores)
        final_masks_list.append(mask)

    y_final_train = torch.stack(final_scores_list).to(device)
    mask_final_train = torch.stack(final_masks_list).to(device)

    final_model = NeuralNetwork(
        input_size=86,
        hidden_unit=final_params['hidden_unit'],
        num_layers=final_params['num_layers']
    ).to(device)

    optimizer = torch.optim.AdamW(
        final_model.parameters(),
        lr=final_params['learning_rate'],
        betas=(0.9, 0.999),
        weight_decay=0.1
    )
    criterion = MultivariateLossFn()

    print("starting training the final model")
    for epoch in range(15):
        final_model.train()
        epoch_loss = 0
        num_batches = 0

        #shuffling training data
        indices = torch.randperm(len(X_final_train))
        X_final_train = X_final_train[indices]
        y_final_train = y_final_train[indices]
        mask_final_train = mask_final_train[indices]

        for i in range(0, len(X_final_train), final_params['batch_size']):
            batch_X = X_final_train[i:i + final_params['batch_size']]
            batch_y = y_final_train[i:i + final_params['batch_size']]
            batch_mask = mask_final_train[i:i + final_params['batch_size']]

            optimizer.zero_grad()
            outputs = final_model(batch_X)
            loss = criterion(outputs, batch_y, batch_mask)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        if (epoch + 1) % 5 == 0:
            avg_epoch_loss = epoch_loss / num_batches
            print(f"Epoch {epoch + 1}/15 - Training Loss: {avg_epoch_loss:.6f}")
    # -------------------------------------------------------------------------------------------------------------------------------

    #testing on the target held out prompt
    test_filter = data['prompt_ids'] == test_prompt_id
    X_test = torch.FloatTensor(data['features'][test_filter]).to(device)

    #test scores and masks
    #NORMALIZATION + MASKING
    test_scores_list = []
    test_masks_list = []
    for idx in range(sum(test_filter)):
        essay_data = {trait: data[trait][test_filter][idx] for trait in traits}
        normalized_scores, mask = normalize_all_scores(essay_data, test_prompt_id)
        test_scores_list.append(normalized_scores)
        test_masks_list.append(mask)

    y_test = torch.stack(test_scores_list)
    mask_test = torch.stack(test_masks_list)

    #evaluate on test set
    final_model.eval()
    with torch.no_grad():
        test_pred = final_model(X_test)

        #calculate QWK for each trait
        #DENORMALIZATION
        trait_qwks = {}
        for i, trait in enumerate(traits):
            if trait in SCORE_RANGES[test_prompt_id]:  #evaluate the available traits only
                pred_scores = denormalize_score(
                    test_pred[:, i].cpu().numpy(),
                    SCORE_RANGES[test_prompt_id][trait]
                )
                true_scores = denormalize_score(
                    y_test[:, i].numpy(),
                    SCORE_RANGES[test_prompt_id][trait]
                )
                trait_qwks[trait] = quadratic_weighted_kappa(
                    true_scores.round(),
                    pred_scores
                )

    print(f"\n{'-'*50}")
    print(f"Test set results on target prompt {test_prompt_id}:")
    for trait, qwk in trait_qwks.items():
        print(f"{trait.capitalize()} QWK: {qwk:.4f}")
    print(f"final best params: {final_params}")
    print(f"{'-'*50}\n")

    # Save the model
    final_model = final_model.cpu()
    scripted_model = torch.jit.script(final_model)
    model_path = f"model-B-{test_prompt_id}.pt"
    scripted_model.save(model_path)
    print(f"model was saved as {model_path}")

    return final_params, trait_qwks

#Train and save deployment model using all data
def train_deployment_model(data, best_overall_params):
    print(f"\n{'-'*50}")
    print(f"Training deployment model using best parameters:")
    print(f"Parameters: {best_overall_params}")
    print(f"{'-'*50}")

    features = torch.FloatTensor(data['features']).to(device)
    prompt_ids = data['prompt_ids']

    #normalized scores and masks for all data
    scores_list = []
    masks_list = []
    for idx, prompt_id in enumerate(prompt_ids):
        essay_data = {trait: data[trait][idx] for trait in traits}
        normalized_scores, mask = normalize_all_scores(essay_data, prompt_id)
        scores_list.append(normalized_scores)
        masks_list.append(mask)

    all_scores = torch.stack(scores_list).to(device)
    all_masks = torch.stack(masks_list).to(device)

    #model for deployment
    deploy_model = NeuralNetwork(
        input_size=86,
        hidden_unit=best_overall_params['hidden_unit'],
        num_layers=best_overall_params['num_layers']
    ).to(device)

    optimizer = torch.optim.AdamW(
        deploy_model.parameters(),
        lr=best_overall_params['learning_rate'],
        betas=(0.9, 0.999),
        weight_decay=0.1
    )
    criterion = MultivariateLossFn()
    batch_size = best_overall_params['batch_size']

    #train on all dataset
    print("starting training of the deployment model")
    for epoch in range(15):
        deploy_model.train()
        epoch_loss = 0
        num_batches = 0

        # Shuffle indices for each epoch
        indices = torch.randperm(len(features))
        features = features[indices]
        all_scores = all_scores[indices]
        all_masks = all_masks[indices]

        for i in range(0, len(features), batch_size):
            batch_X = features[i:i + batch_size]
            batch_y = all_scores[i:i + batch_size]
            batch_mask = all_masks[i:i + batch_size]

            optimizer.zero_grad()
            outputs = deploy_model(batch_X)
            loss = criterion(outputs, batch_y, batch_mask)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        if (epoch + 1) % 5 == 0:
            avg_epoch_loss = epoch_loss / num_batches
            print(f"Epoch {epoch + 1}/15 - Training Loss: {avg_epoch_loss:.6f}")

    # Save deployment model
    print("\nSaving deployment model...")
    deploy_model = deploy_model.cpu()
    scripted_model = torch.jit.script(deploy_model)
    model_path = "model-B-deploy.pt"
    scripted_model.save(model_path)
    print(f"Deployment model saved as {model_path}")

def main():
    print(f"\n{'-'*50}")
    print(f"starting full cross-prompt evaluation at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'-'*50}\n")

    data = read_data('/kaggle/input/dataset/dataset.csv')

    results = {}
    best_overall_params = None
    best_overall_avg_qwk = -1

    #running train and evaluate for each prompt
    for target_prompt in range(1, 9):
        params, trait_qwks = train_and_evaluate_prompt(target_prompt, data)
        #TRACKING BEST PARAMS TO LATEAR USE IN DEPLOYMENT
        results[target_prompt] = {
            'parameters': params,
            'trait_qwks': trait_qwks
        }

        #keep best performing parameters based on average QWK across all traits
        avg_qwk = np.mean(list(trait_qwks.values()))
        if avg_qwk > best_overall_avg_qwk:
            best_overall_avg_qwk = avg_qwk
            best_overall_params = params.copy()

    #print final summary
    print("\nResults:")
    print("-" * 50)
    for prompt, result in results.items():
        print(f"\nPrompt {prompt}:")
        for trait, qwk in result['trait_qwks'].items():
            print(f"  {trait.capitalize()} QWK: {qwk:.4f}")
        print(f"  Parameters: {result['parameters']}")
    print("-" * 50)

    #print average QWK for each trait
    all_trait_qwks = {trait: [] for trait in traits}
    for result in results.values():
        for trait, qwk in result['trait_qwks'].items():
            all_trait_qwks[trait].append(qwk)

    print("\navg QWK per trait:")
    for trait in traits:
        if all_trait_qwks[trait]:  #only print if we have scores for this trait
            avg_qwk = np.mean(all_trait_qwks[trait])
            print(f"{trait.capitalize()}: {avg_qwk:.4f}")

    #train the deployment model
    train_deployment_model(data, best_overall_params)


    print("\nFinal saved models:")
    print("-" * 50)
    print("Individual prompt models:")
    for i in range(1, 9):
        print(f"  - model-B-{i}.pt")
    print("Deployment model:")
    print("  - model-B-deploy.pt")

if __name__ == "__main__":
    main()


Using device: cuda

--------------------------------------------------
Starting full cross-prompt evaluation at 2024-12-06 12:41:50
--------------------------------------------------

Reading the dataset


------------------------------------------------------------
model starts training process for target prompt 1
------------------------------------------------------------

Feature shape: torch.Size([11193, 86])
Scores shape: torch.Size([11193, 9])
Model does grid search with 36 combinations
- hidden units per layer (D): [8, 16, 32]
- num of layers (k): [1, 2, 4, 8]
- learning rates: [0.001, 0.01, 0.1]
- batch size (fixed): 4


--------------- testing combo 1/36--------------
hyperparams: units per layer D=8, num of layers k=1, lr=0.001

validation prompt no. 2: Avg QWK of all scores: 0.4904
validation prompt no. 3: Avg QWK of all scores: 0.5687
validation prompt no. 4: Avg QWK of all scores: 0.5412
validation prompt no. 5: Avg QWK of all scores: 0.5943
validation prompt no. 6: Avg 