In [1]:
# %% Import necessary libraries
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# %% Load the dataset
df = pd.read_csv("train.csv")
df.dropna(inplace=True)

# %% Split the dataset
X_train, X_test, y_train, y_test = train_test_split(df["text"], df["label"], test_size=0.2, random_state=42)

# %% Create a custom Dataset class for BERT
class DevanagariDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=100):
        self.texts = texts.tolist()
        self.labels = labels.tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# %% Initialize the tokenizer and datasets
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
train_dataset = DevanagariDataset(X_train, y_train, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = DevanagariDataset(X_test, y_test, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

# %% Define the Transformer Model
class BertClassifier(nn.Module):
    def __init__(self, n_classes):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-multilingual-cased')
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        output = self.dropout(pooled_output)
        return self.fc(output)

# %% Set model parameters
num_classes = len(df['label'].unique())
model = BertClassifier(num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# %% Define training parameters
num_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# %% Training function
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct_predictions = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=False):
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            loss.backward()
            optimizer.step()
            
            _, predicted = torch.max(outputs.data, 1)
            correct_predictions += (predicted == labels).sum().item()
        
        accuracy = correct_predictions / len(train_loader.dataset)
        print(f"Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}")

# %% Train the model
train_model(model, train_loader, criterion, optimizer, num_epochs)




model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(
                                                               

KeyboardInterrupt: 