In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

!nvidia-smi

In [None]:
# Import necessary libraries
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from sklearn.metrics import classification_report, accuracy_score
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Assuming df is your DataFrame containing the 'tmu_gfm_dataset'
data = pd.read_csv('/content/drive/MyDrive/project/Jfleg4-2-4.csv').dropna()
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

#  Model Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large").to(device)
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")

In [None]:
# Data Acquisition
class GrammarCorrectionDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text = self.data.iloc[idx]['output']
        error_text = self.data.iloc[idx]['source']

        input_encoding = tokenizer.encode_plus(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_encoding = tokenizer.encode(
            error_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': input_encoding['input_ids'].squeeze(0),
            'attention_mask': input_encoding['attention_mask'].squeeze(0),
            'labels': target_encoding.squeeze(0)
        }


In [None]:
# Create training and validation sets
train_dataset = GrammarCorrectionDataset(train_data, tokenizer)
val_dataset = GrammarCorrectionDataset(val_data, tokenizer)

# DataLoader设置
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


In [None]:
# 模型训练
epochs = 10
batch_size = 8
learning_rate = 2e-5

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

for epoch in range(epochs):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item() * input_ids.size(0)
            total_samples += input_ids.size(0)

    avg_loss = total_loss / total_samples
    print(f'Epoch {epoch+1}/{epochs}, Validation Loss: {avg_loss:.4f}')

# # Plot confusion matrix
# cm = confusion_matrix(list(plot_dict.values()), list(plot_dict.keys()))
# sns.heatmap(cm, annot=True)

# # Step 5: Show the result
# print(classification_report(list(plot_dict.values()), list(plot_dict.keys())))
# print("Accuracy:", accuracy_score(list(plot_dict.values()), list(plot_dict.keys())))