In [None]:
# %% Import Libraries
import pandas as pd
import re
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from gensim.models import FastText
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# %% Load Pre-trained FastText Model
fasttext_model = FastText.load_fasttext_format('path_to_your_fasttext_model.bin')  # Update the path

# %% Load and Preprocess Data
df = pd.read_csv("datasets/train.csv")

# Simple cleaning and tokenization
def preprocess_text(text):
    text = re.sub(r'[^เค-เคน\s]', '', text)  # Retain only Devanagari characters
    return text.split()

df['text'] = df['text'].apply(preprocess_text)
df.dropna(inplace=True)

# Split the data
X_train, X_test, y_train, y_test = train_test_split(df["text"], df["label"], test_size=0.2, random_state=42)

# %% Dataset Class
class DevanagariDataset(Dataset):
    def __init__(self, texts, labels, max_length=100, embedding_model=None):
        self.texts = texts.tolist()  
        self.labels = labels.tolist()  
        self.max_length = max_length
        self.embedding_model = embedding_model
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        text_tensor = torch.zeros(self.max_length, 100)  # Assuming 100-dimensional FastText embeddings
        
        for i, word in enumerate(text):
            if i < self.max_length:
                if word in self.embedding_model.wv:
                    text_tensor[i] = torch.tensor(self.embedding_model.wv[word], dtype=torch.float)
                else:
                    text_tensor[i] = torch.zeros(100)  # Use a zero vector for unknown words
        
        label_tensor = torch.tensor(self.labels[idx], dtype=torch.long)

        return text_tensor, label_tensor

# %% Create Datasets and DataLoaders
train_dataset = DevanagariDataset(X_train, y_train, embedding_model=fasttext_model)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = DevanagariDataset(X_test, y_test, embedding_model=fasttext_model)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# %% Define the Model
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, lstm_output):
        scores = self.Va(torch.tanh(self.Wa(lstm_output) + self.Ua(lstm_output)))
        attention_weights = torch.softmax(scores, dim=1)
        
        context_vector = torch.bmm(attention_weights.permute(0, 2, 1), lstm_output)
        return context_vector.squeeze(1), attention_weights.squeeze(2)

class BidirectionalLSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BidirectionalLSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, 
                            batch_first=True, bidirectional=True)
        self.attention = Attention(hidden_size * 2)  
        self.fc = nn.Linear(hidden_size * 2, num_classes)  

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device) 
        
        out, _ = self.lstm(x) 

        context_vector, attention_weights = self.attention(out)
        out = self.fc(context_vector)
        return out

# %% Set Device and Hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

input_size = 100  # FastText embedding size
hidden_size = 128  
num_layers = 2  
num_classes = len(df['label'].unique())  
num_epochs = 10
learning_rate = 0.001

model = BidirectionalLSTMModel(input_size, hidden_size, num_layers, num_classes)
model.to(device)

# %% Define the Training Function
def train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        correct_train_predictions = 0
        
        for texts, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training", leave=False):
            texts, labels = texts.to(device), labels.to(device)

            outputs = model(texts)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            total_train_loss += loss.item()
            
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            correct_train_predictions += (predicted == labels).sum().item()
        
        train_accuracy = correct_train_predictions / len(train_loader.dataset)
        train_loss = total_train_loss / len(train_loader)
        
        model.eval()
        total_test_loss = 0
        correct_test_predictions = 0
        
        with torch.no_grad():
            for texts, labels in tqdm(test_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Testing", leave=False):
                texts, labels = texts.to(device), labels.to(device)

                outputs = model(texts)
                loss = criterion(outputs, labels)
                total_test_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                correct_test_predictions += (predicted == labels).sum().item()
        
        test_accuracy = correct_test_predictions / len(test_loader.dataset)
        test_loss = total_test_loss / len(test_loader)
        
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

# %% Train the Model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_model(model, train_loader, test_loader, criterion, optimizer, num_epochs)

# %% Save the Model
torch.save(model.state_dict(), "lstm_attention_model_fasttext.pth")
