In [1]:
from gensim.models.fasttext import FastText
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
from torchmetrics.classification import MulticlassAccuracy
from collections import Counter

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'



In [2]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        self.hidden_size = hidden_size
        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True, nonlinearity='tanh')
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, embedded_input):
        # Initialize hidden state with zeros
        hidden_state = torch.zeros(1, 1, self.hidden_size).to(device)
        # Pass the input sequence through the RNN layer
        rnn_output, hidden_state = self.rnn(embedded_input, hidden_state)
        # Reshape the output to be of shape (batch_size * sequence_length, hidden_size)
        rnn_output = rnn_output.contiguous().view(-1, self.hidden_size)
        # Pass the RNN output through the fully connected layer to get the predicted tags
        predicted_tags = self.fc(rnn_output)
        
        return predicted_tags

In [3]:
class NERDataset(Dataset):
    def __init__(self, embedded_sentences):
        self.embedded_sentences = embedded_sentences
    def __len__(self):
        return len(self.embedded_sentences)
    def __getitem__(self, idx):
        return self.embedded_sentences[idx]

In [4]:
# dataset: https://github.com/google-research-datasets/uninum/blob/master/numbers/rus.tsv
data = []
labels = []
with open("rus.tsv", "r", encoding="utf-8") as file:
    for line in file:
        tokens = list(line.strip().split("\t"))
        data.append(tokens[0])
        labels.append(tokens[-1])

In [5]:
label_unique = list(set(labels))
# label_unique.remove('')

In [6]:
model = FastText(sentences=data, window=5, min_count=1, workers=4, sg=1)

In [7]:
embedded_input = [model.wv.get_vector(word) for word in data]

In [8]:
targets = torch.Tensor(np.array(embedded_input))
labels = LabelEncoder.fit_transform(targets, labels)

In [9]:
targets = list(zip(targets, labels))

In [10]:
counts = Counter(labels)
class_weights = []
values = dict(counts).values()

for val in values:
    class_weights.append((val/sum(values)))
class_weights = torch.Tensor(class_weights).to(device)

In [11]:
input_size = 100
hidden_size = 8
output_size = len(label_unique)
num_epochs = 1500
bs = 64
lr = 8e-3
wd = 6e-3

In [12]:
nn_model = RNN(input_size, hidden_size, output_size)
nn_model = nn_model.to(device)
dataloader = DataLoader(targets, batch_size=bs, shuffle=True)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(nn_model.parameters(), lr=lr, weight_decay=wd)
acc = MulticlassAccuracy(num_classes=output_size).to(device)

In [13]:
for epoch in range(num_epochs):
    total_loss = 0
    total_acc = 0
    for batch, gt in dataloader:
        batch, gt = batch.to(device), gt.to(device)
        optimizer.zero_grad()
        batch = batch.reshape(1,-1,100)
        outputs = nn_model(batch)
        loss = criterion(outputs.reshape(-1, output_size), gt.reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_acc += acc(outputs, gt).item()
    print(f"Epoch {epoch+1}: Loss {total_loss/len(dataloader):.4f} Accuracy: {total_acc/len(dataloader)}")

Epoch 1: Loss 4.7641 Accuracy: 0.0078125
Epoch 2: Loss 4.7338 Accuracy: 0.010869565419852734
Epoch 3: Loss 4.7342 Accuracy: 0.010869565419852734
Epoch 4: Loss 4.7215 Accuracy: 0.0078125
Epoch 5: Loss 4.7218 Accuracy: 0.0078125
Epoch 6: Loss 4.7179 Accuracy: 0.01845079753547907
Epoch 7: Loss 4.7160 Accuracy: 0.010869565419852734
Epoch 8: Loss 4.7165 Accuracy: 0.0078125
Epoch 9: Loss 4.7037 Accuracy: 0.007692307699471712
Epoch 10: Loss 4.7125 Accuracy: 0.0078125
Epoch 11: Loss 4.7098 Accuracy: 0.0078125
Epoch 12: Loss 4.7112 Accuracy: 0.010638297535479069
Epoch 13: Loss 4.7061 Accuracy: 0.007692307699471712
Epoch 14: Loss 4.7039 Accuracy: 0.0
Epoch 15: Loss 4.7030 Accuracy: 0.010638297535479069
Epoch 16: Loss 4.7032 Accuracy: 0.0078125
Epoch 17: Loss 4.6998 Accuracy: 0.0078125
Epoch 18: Loss 4.6978 Accuracy: 0.0078125
Epoch 19: Loss 4.6964 Accuracy: 0.007692307699471712
Epoch 20: Loss 4.7002 Accuracy: 0.0078125
Epoch 21: Loss 4.6989 Accuracy: 0.007692307699471712
Epoch 22: Loss 4.6975 Ac