# Data preprocessing

In [2]:
import json

# Mount Google Drive to access data
from google.colab import drive
drive.mount('/content/drive')
file_path = '/content/drive/MyDrive/Colab Notebooks/data_cged/training_data.jsonl'

# Load the data
def load_data(file_path):
    sentences_data = []

    with open(file_path, 'r', encoding='utf-8') as f:
        data = [json.loads(line.strip()) for line in f]

    for item in data:
        # Remove non-pair entries (only grmmatical sentence)
        if item["correct"] != "" and item["error"] != []:
            grammatical = item["correct"]
            ungrammatical = item["text"]
            error_type = [dict["type"] for dict in item["error"]]
            sentences_data.append((grammatical, 1, error_type))
            sentences_data.append((ungrammatical, 0, error_type))

    return sentences_data

# Load the data
sentences_data = load_data(file_path)

# Inspect
print(f"Loaded {len(sentences_data)} sentences.")
print(sentences_data[:10])

Mounted at /content/drive
Loaded 82980 sentences.
[('因为庆祝会的日子是我母亲的生日。', 1, ['R']), ('是因为庆祝会的日子是我母亲的生日。', 0, ['R']), ('那下次见吧。', 1, ['R']), ('那下次见面吧。', 0, ['R']), ('我跟我朋友打算去法国玩儿。', 1, ['R']), ('我跟我朋唷友打算去法国玩儿。', 0, ['R']), ('所以我不能去。', 1, ['M']), ('所以我不能。', 0, ['M']), ('所以我写这一张卡送给你。', 1, ['R']), ('所以我写这一张卡送给你的。', 0, ['R'])]


# Computing log probabilities

In [3]:
import torch

# Check whether a GPU is available
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print(f"Using {device} device")

Using cuda device


In [4]:
def get_sentence_logprobs(sentence, tokenizer, model):
    # Tokenize the sentence
    tokenized_input = tokenizer(sentence, return_tensors="pt").to(device)
    input_ids = tokenized_input.input_ids

    # Disable gradient calculation because the model is not used for training
    with torch.no_grad():
        outputs = model(input_ids)

    # logits shape: (batch_size, seq_len, vocab_size)
    logits = outputs.logits

    # Calculate log probabilities for each token (except first start token)
    log_probs = []
    for i in range(1, input_ids.shape[1]):
        # Get logits for the previous position predicting the current token
        current_logits = logits[0, i-1, :]
        current_token_id = input_ids[0, i]

        # Calculate log probability
        # Apply log_softmax to convert logits to log probabilities
        # Then select the log probability for the actual current token ID
        log_prob = torch.log_softmax(current_logits, dim=-1)[current_token_id].item()
        log_probs.append(log_prob)

    # Total log probability (joint probability of all tokens)
    total_log_prob = sum(log_probs)

    return total_log_prob

# Models

## XGLM-564M

In [5]:
!pip install --upgrade transformers



In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "facebook/xglm-564M"

tokenizer_xglm = AutoTokenizer.from_pretrained(model_name)
model_xglm = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model_xglm.eval()

In [8]:
import tqdm

# Append the original data along with the calculated log probability
sentences_data_with_probs = []
for sentence_data in tqdm.tqdm(sentences_data, desc="Processing"):
    sentence, label, error_type = sentence_data
    total_log_prob = get_sentence_logprobs(sentence, tokenizer_xglm, model_xglm)
    sentences_data_with_probs.append((sentence, label, error_type, total_log_prob))

# Inspect
for data in sentences_data_with_probs[:2]:
    print(f"\nSentence: {data[0]}")
    print(f"Label: {data[1]}, Error Type: {data[2]}")
    print(f"Total Log Probability: {data[3]:.2f}")

Processing: 100%|██████████| 82980/82980 [37:28<00:00, 36.91it/s]


Sentence: 因为庆祝会的日子是我母亲的生日。
Label: 1, Error Type: ['R']
Total Log Probability: -50.02

Sentence: 是因为庆祝会的日子是我母亲的生日。
Label: 0, Error Type: ['R']
Total Log Probability: -55.52





# Evaluation

In [9]:
def evaluate_model(sentences_data_with_probs):
    correct_by_error_type = {}
    total_by_error_type = {}
    overall_correct_pairs = 0
    total_pairs = 0

    # Assuming correct and ungrammatical sentences are paired consecutively
    for i in range(0, len(sentences_data_with_probs), 2):
        total_pairs += 1
        prob_correct = sentences_data_with_probs[i][3]
        prob_ungrammatical = sentences_data_with_probs[i+1][3]

        # Calculate overall correct count
        if prob_correct > prob_ungrammatical:
            overall_correct_pairs += 1


        # Handle cases with multiple error types by sorting and converting to a tuple
        error_types_list =  sentences_data_with_probs[i+1][2]
        error_type_key = tuple(sorted(error_types_list)) if isinstance(error_types_list, list) else error_types_list

        # Initialize dictionaries of error type combinations
        if error_type_key not in total_by_error_type:
            total_by_error_type[error_type_key] = 0
            correct_by_error_type[error_type_key] = 0

        total_by_error_type[error_type_key] += 1

        # Calculate correct count by error type
        if prob_correct > prob_ungrammatical:
            correct_by_error_type[error_type_key] += 1


    # Calculate overall accuracy
    overall_accuracy = (overall_correct_pairs / total_pairs) * 100 if total_pairs > 0 else 0


    # Calculate accuracy for each error type
    accuracy_by_error_type = []

    # ONLY consider combinations that consist of a single error type
    for error_type_key in sorted(total_by_error_type.keys()):
        if len(error_type_key) == 1:
            total = total_by_error_type[error_type_key]
            correct = correct_by_error_type[error_type_key]
            type_accuracy = (correct / total) * 100 if total > 0 else 0

            display_key = list(error_type_key) if isinstance(error_type_key, tuple) else error_type_key
            accuracy_by_error_type.append({'Error Type': display_key, 'Accuracy (%)': type_accuracy, 'Correct': correct, 'Total': total})

    return overall_accuracy, overall_correct_pairs, total_pairs, accuracy_by_error_type

In [10]:
overall_accuracy, overall_correct_pairs, total_pairs, accuracy_by_error_type = evaluate_model(sentences_data_with_probs)

print(f"Overall Accuracy (XGLM-564M): {overall_accuracy:.2f}% ({overall_correct_pairs}/{total_pairs})")

print("\nAccuracy by Error Type:")
# Sort by accuracy low to high before printing
sorted_accuracy_by_error_type = sorted(accuracy_by_error_type, key=lambda x: x['Accuracy (%)'])

for item in sorted_accuracy_by_error_type:
    print(f"  Error Type: {item['Error Type']}: {item['Accuracy (%)']:.2f}% ({item['Correct']}/{item['Total']})")

Overall Accuracy (XGLM-564M): 81.07% (33636/41490)

Accuracy by Error Type:
  Error Type: ['M']: 54.05% (4241/7847)
  Error Type: ['W']: 81.25% (2097/2581)
  Error Type: ['S']: 83.25% (5184/6227)
  Error Type: ['R']: 96.11% (5613/5840)


## BLOOM-560M

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "bigscience/bloom-560m"

tokenizer_bloom = AutoTokenizer.from_pretrained(model_name)
model_bloom = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model_bloom.eval()

In [12]:
# Append the original data along with the calculated log probability
sentences_data_with_probs = []
for sentence_data in tqdm.tqdm(sentences_data, desc="Processing"):
    sentence, label, error_type = sentence_data
    total_log_prob = get_sentence_logprobs(sentence, tokenizer_bloom, model_bloom)
    sentences_data_with_probs.append((sentence, label, error_type, total_log_prob))

# Inspect
for data in sentences_data_with_probs[:2]:
    print(f"\nSentence: {data[0]}")
    print(f"Label: {data[1]}, Error Type: {data[2]}")
    print(f"Total Log Probability: {data[3]:.2f}")

Processing: 100%|██████████| 82980/82980 [35:39<00:00, 38.78it/s]


Sentence: 因为庆祝会的日子是我母亲的生日。
Label: 1, Error Type: ['R']
Total Log Probability: -44.71

Sentence: 是因为庆祝会的日子是我母亲的生日。
Label: 0, Error Type: ['R']
Total Log Probability: -44.45





In [13]:
overall_accuracy, overall_correct_pairs, total_pairs, accuracy_by_error_type = evaluate_model(sentences_data_with_probs)

print(f"Overall Accuracy (BLOOM-560M): {overall_accuracy:.2f}% ({overall_correct_pairs}/{total_pairs})")

print("\nAccuracy by Error Type:")
# Sort by accuracy low to high before printing
sorted_accuracy_by_error_type = sorted(accuracy_by_error_type, key=lambda x: x['Accuracy (%)'])

for item in sorted_accuracy_by_error_type:
    print(f"  Error Type: {item['Error Type']}: {item['Accuracy (%)']:.2f}% ({item['Correct']}/{item['Total']})")

Overall Accuracy (BLOOM-560M): 82.37% (34177/41490)

Accuracy by Error Type:
  Error Type: ['M']: 55.86% (4383/7847)
  Error Type: ['W']: 80.90% (2088/2581)
  Error Type: ['S']: 84.37% (5254/6227)
  Error Type: ['R']: 95.22% (5561/5840)


## Qwen3-0.6B

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "Qwen/Qwen3-0.6B"

tokenizer_qwen = AutoTokenizer.from_pretrained(model_name)
model_qwen = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model_qwen.eval()

In [15]:
# Append the original data along with the calculated log probability
sentences_data_with_probs = []
for sentence_data in tqdm.tqdm(sentences_data, desc="Processing"):
    sentence, label, error_type = sentence_data
    total_log_prob = get_sentence_logprobs(sentence, tokenizer_qwen, model_qwen)
    sentences_data_with_probs.append((sentence, label, error_type, total_log_prob))

# Inspect
for data in sentences_data_with_probs[:2]:
    print(f"\nSentence: {data[0]}")
    print(f"Label: {data[1]}, Error Type: {data[2]}")
    print(f"Total Log Probability: {data[3]:.2f}")

Processing: 100%|██████████| 82980/82980 [1:00:46<00:00, 22.76it/s]


Sentence: 因为庆祝会的日子是我母亲的生日。
Label: 1, Error Type: ['R']
Total Log Probability: -43.79

Sentence: 是因为庆祝会的日子是我母亲的生日。
Label: 0, Error Type: ['R']
Total Log Probability: -44.67





In [16]:
overall_accuracy, overall_correct_pairs, total_pairs, accuracy_by_error_type = evaluate_model(sentences_data_with_probs)

print(f"Overall Accuracy (BLOOM-560M): {overall_accuracy:.2f}% ({overall_correct_pairs}/{total_pairs})")

print("\nAccuracy by Error Type:")
# Sort by accuracy low to high before printing
sorted_accuracy_by_error_type = sorted(accuracy_by_error_type, key=lambda x: x['Accuracy (%)'])

for item in sorted_accuracy_by_error_type:
    print(f"  Error Type: {item['Error Type']}: {item['Accuracy (%)']:.2f}% ({item['Correct']}/{item['Total']})")

Overall Accuracy (BLOOM-560M): 79.07% (32807/41490)

Accuracy by Error Type:
  Error Type: ['M']: 51.89% (4072/7847)
  Error Type: ['W']: 72.03% (1859/2581)
  Error Type: ['S']: 81.87% (5098/6227)
  Error Type: ['R']: 93.41% (5455/5840)
