In [10]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import pandas as pd
import re
from collections import Counter


In [12]:
df = pd.read_csv("C:/Users/Vijay/Apy/A_LSTM-BERT/data/CELA.csv")

labels = [
    "Grammar", "Lexical", "Global Organization",
    "Local Organization", "Supporting Ideas", "Holistic"
]

X_train, X_val, y_train, y_val = train_test_split(
    df["Essays"].tolist(),
    df[labels].values,
    test_size=0.2,
    random_state=42
)


In [13]:
def tokenize(text):
    text = text.lower()
    text = re.sub(r"[^a-z\s]", "", text)
    return text.split()


In [14]:
def build_vocab(texts, min_freq=2):
    counter = Counter()
    for text in texts:
        counter.update(tokenize(text))

    vocab = {"<PAD>": 0, "<UNK>": 1}
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = len(vocab)

    return vocab


In [15]:
vocab = build_vocab(X_train)


In [16]:
class EssayLSTMDataset(Dataset):
    def __init__(self, texts, labels, vocab):
        self.texts = texts
        self.labels = torch.tensor(labels, dtype=torch.float)
        self.vocab = vocab

    def encode(self, text):
        return torch.tensor(
            [self.vocab.get(w, self.vocab["<UNK>"]) for w in tokenize(text)],
            dtype=torch.long
        )

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return {
            "input_ids": self.encode(self.texts[idx]),
            "labels": self.labels[idx]
        }


In [17]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    input_ids = [b["input_ids"] for b in batch]
    labels = torch.stack([b["labels"] for b in batch])

    input_ids = pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=0
    )

    return {
        "input_ids": input_ids,
        "labels": labels
    }


In [18]:
train_ds = EssayLSTMDataset(X_train, y_train, vocab)
val_ds   = EssayLSTMDataset(X_val, y_val, vocab)

train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_ds,
    batch_size=4,
    collate_fn=collate_fn
)


In [19]:
class LSTMMultiOutputRegressor(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_outputs=6):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim,
            hidden_dim,
            batch_first=True,
            bidirectional=True
        )
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_outputs)

    def forward(self, x):
        x = self.embedding(x)
        _, (h, _) = self.lstm(x)

        h = torch.cat((h[-2], h[-1]), dim=1)
        h = self.dropout(h)

        return self.fc(h)


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LSTMMultiOutputRegressor(len(vocab)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()


In [22]:
EPOCHS = 10
best_rmse = float("inf")

for epoch in range(EPOCHS):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()

        outputs = model(batch["input_ids"].to(device))
        loss = loss_fn(outputs, batch["labels"].to(device))

        loss.backward()
        optimizer.step()

    model.eval()
    preds, gold = [], []
    with torch.no_grad():
        for batch in val_loader:
            out = model(batch["input_ids"].to(device))
            preds.append(out.cpu())
            gold.append(batch["labels"])

    preds = torch.cat(preds).numpy()
    gold = torch.cat(gold).numpy()

    rmse = mean_squared_error(gold, preds, squared=False)
    print(f"Epoch {epoch+1} | RMSE: {rmse:.4f}")

    if rmse < best_rmse:
        best_rmse = rmse
        torch.save(
            {
                "model_state": model.state_dict(),
                "vocab": vocab
            },
            "C:/Users/Vijay/Apy/A_LSTM-BERT/models/lstm_cela_regressor.pt"
        )


Epoch 1 | RMSE: 0.9965
Epoch 2 | RMSE: 1.0034
Epoch 3 | RMSE: 0.9930
Epoch 4 | RMSE: 1.0959
Epoch 5 | RMSE: 0.9667
Epoch 6 | RMSE: 0.9952
Epoch 7 | RMSE: 0.9674
Epoch 8 | RMSE: 0.9819
Epoch 9 | RMSE: 1.1148
Epoch 10 | RMSE: 1.0672
