In [2]:
import os
import re
import torch
import transformers
from tqdm import tqdm
import json
import numpy as np
import polars as pl
from sklearn.metrics import cohen_kappa_score
from typing import Optional, Literal
from utils import load_asap_dataset, load_toefl_dataset, get_score_range
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

In [10]:
# TASK = "ASAP"
TASK = "TOEFL11"

In [11]:
if TASK == "ASAP":
    df = load_asap_dataset('datasets/ASAP', stratify=True)
elif TASK == "TOEFL11":
    df = load_toefl_dataset('datasets/TOEFL11')

In [12]:
df

essay_id,essay_set,original_score,essay,score
i64,i64,str,str,i32
10229,8,"""low""","""I DO NOT AGREE WITH THIS STATE…",0
10392,1,"""high""","""I am not quite sure about my o…",2
10445,3,"""medium""","""It is often said that young pe…",1
10535,7,"""medium""","""In mordern society, students a…",1
10769,1,"""high""","""I find it productive and rewar…",2
…,…,…,…,…
1175383,2,"""medium""",""" There are heated disscussion…",1
1175412,3,"""high""","""In my personal opinion young p…",2
1175488,8,"""medium""","""I believe that people who take…",1
1175980,8,"""low""","""With the rapid progress of tim…",0


In [47]:
# model_name = "meta-llama/Llama-2-7b-chat-hf"
# model_name = "meta-llama/Llama-3.1-8B-Instruct"
# model_name = "meta-llama/Llama-3.2-3B-Instruct"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
results = pl.read_csv(f"outputs/vanilla/{TASK}_{model_name.split('/')[1]}.csv")
results = results.rename({"score": "pred_str"})
if TASK == "TOEFL11":
    results = results.with_columns(
        pl.when(pl.col("pred_str").str.to_lowercase() == "low")
        .then(0)
        .when(pl.col("pred_str").str.to_lowercase() == "medium")
        .then(1)
        .when(pl.col("pred_str").str.to_lowercase() == "high")
        .then(2)
        .otherwise(-1)
        .alias("pred_score")
    )
results

essay_set,essay_id,response,pred_str,pred_score
i64,i64,str,str,i32
8,10229,"""Evaluation: The essay presents…","""Low""",0
1,10392,"""Evaluation: The essay presents…","""medium""",1
3,10445,"""Evaluation: The essay presents…","""low""",0
7,10535,"""Evaluation: The essay presents…","""high""",2
1,10769,"""Evaluation: The essay presents…","""high""",2
…,…,…,…,…
2,1175383,"""Evaluation: The essay presents…","""medium""",1
3,1175412,"""Evaluation: The essay presents…","""medium""",1
8,1175488,"""Evaluation: The essay presents…","""medium""",1
8,1175980,"""Evaluation: The essay presents…","""low""",0


In [48]:
results.describe()

statistic,essay_set,essay_id,response,pred_str,pred_score
str,f64,f64,str,str,f64
"""count""",1099.0,1099.0,"""1099""","""1099""",1099.0
"""null_count""",0.0,0.0,"""0""","""0""",0.0
"""mean""",4.484986,985185.592357,,,1.050045
"""std""",2.325612,307937.154371,,,0.80314
"""min""",1.0,10229.0,"""Evaluation: In Mokhtar Motamed…","""Low""",0.0
"""25%""",2.0,1030800.0,,,0.0
"""50%""",4.0,1080962.0,,,1.0
"""75%""",7.0,1127200.0,,,2.0
"""max""",8.0,1176184.0,"""Evaluation: The essay provides…","""medium""",2.0


In [49]:
for_evaluation = df.join(results, on=['essay_set', 'essay_id'])
for_evaluation

essay_id,essay_set,original_score,essay,score,response,pred_str,pred_score
i64,i64,str,str,i32,str,str,i32
10229,8,"""low""","""I DO NOT AGREE WITH THIS STATE…",0,"""Evaluation: The essay presents…","""Low""",0
10392,1,"""high""","""I am not quite sure about my o…",2,"""Evaluation: The essay presents…","""medium""",1
10445,3,"""medium""","""It is often said that young pe…",1,"""Evaluation: The essay presents…","""low""",0
10535,7,"""medium""","""In mordern society, students a…",1,"""Evaluation: The essay presents…","""high""",2
10769,1,"""high""","""I find it productive and rewar…",2,"""Evaluation: The essay presents…","""high""",2
…,…,…,…,…,…,…,…
1175383,2,"""medium""",""" There are heated disscussion…",1,"""Evaluation: The essay presents…","""medium""",1
1175412,3,"""high""","""In my personal opinion young p…",2,"""Evaluation: The essay presents…","""medium""",1
1175488,8,"""medium""","""I believe that people who take…",1,"""Evaluation: The essay presents…","""medium""",1
1175980,8,"""low""","""With the rapid progress of tim…",0,"""Evaluation: The essay presents…","""low""",0


In [50]:
for_evaluation = for_evaluation.filter(pl.col("pred_score") != -1)
for_evaluation

essay_id,essay_set,original_score,essay,score,response,pred_str,pred_score
i64,i64,str,str,i32,str,str,i32
10229,8,"""low""","""I DO NOT AGREE WITH THIS STATE…",0,"""Evaluation: The essay presents…","""Low""",0
10392,1,"""high""","""I am not quite sure about my o…",2,"""Evaluation: The essay presents…","""medium""",1
10445,3,"""medium""","""It is often said that young pe…",1,"""Evaluation: The essay presents…","""low""",0
10535,7,"""medium""","""In mordern society, students a…",1,"""Evaluation: The essay presents…","""high""",2
10769,1,"""high""","""I find it productive and rewar…",2,"""Evaluation: The essay presents…","""high""",2
…,…,…,…,…,…,…,…
1175383,2,"""medium""",""" There are heated disscussion…",1,"""Evaluation: The essay presents…","""medium""",1
1175412,3,"""high""","""In my personal opinion young p…",2,"""Evaluation: The essay presents…","""medium""",1
1175488,8,"""medium""","""I believe that people who take…",1,"""Evaluation: The essay presents…","""medium""",1
1175980,8,"""low""","""With the rapid progress of tim…",0,"""Evaluation: The essay presents…","""low""",0


In [51]:
from scipy.stats import spearmanr
# essay_setごとにQWKとスピアマンの順位相関係数を計算
qwk_scores = []
spearman_scores = []
for essay_set in for_evaluation['essay_set'].unique():
    subset = for_evaluation.filter(pl.col('essay_set') == essay_set)
    min_score, max_score = get_score_range(TASK, essay_set)
    qwk = cohen_kappa_score(
        subset['score'].to_numpy(),
        subset['pred_score'].to_numpy(),
        weights='quadratic',
        labels=[i for i in range(min_score, max_score + 1)]
    )
    spearman_corr, _ = spearmanr(subset['score'].to_numpy(), subset['pred_score'].to_numpy())
    qwk_scores.append({
        'essay_set': essay_set,
        'qwk': qwk
    })
    spearman_scores.append({
        'essay_set': essay_set,
        'spearman_corr': spearman_corr
    })

qwk_df = pl.DataFrame(qwk_scores)
spearman_df = pl.DataFrame(spearman_scores)
print("QWK scores by essay set:")
print(qwk_df)
print("スピアマンの順位相関係数 by essay set:")
print(spearman_df)

QWK scores by essay set:
shape: (8, 2)
┌───────────┬──────────┐
│ essay_set ┆ qwk      │
│ ---       ┆ ---      │
│ i64       ┆ f64      │
╞═══════════╪══════════╡
│ 1         ┆ 0.234668 │
│ 2         ┆ 0.127877 │
│ 3         ┆ 0.174229 │
│ 4         ┆ 0.105615 │
│ 5         ┆ 0.050473 │
│ 6         ┆ 0.045932 │
│ 7         ┆ 0.106004 │
│ 8         ┆ 0.22212  │
└───────────┴──────────┘
スピアマンの順位相関係数 by essay set:
shape: (8, 2)
┌───────────┬───────────────┐
│ essay_set ┆ spearman_corr │
│ ---       ┆ ---           │
│ i64       ┆ f64           │
╞═══════════╪═══════════════╡
│ 1         ┆ 0.271581      │
│ 2         ┆ 0.126347      │
│ 3         ┆ 0.185191      │
│ 4         ┆ 0.144957      │
│ 5         ┆ 0.029982      │
│ 6         ┆ 0.042369      │
│ 7         ┆ 0.140897      │
│ 8         ┆ 0.241386      │
└───────────┴───────────────┘


In [52]:
print(f"QWK mean: {qwk_df['qwk'].mean():.3f}")
print(f"スピアマンの順位相関係数 mean: {spearman_df['spearman_corr'].mean():.3f}")

QWK mean: 0.133
スピアマンの順位相関係数 mean: 0.148
