In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

class CustomDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {"input_values": torch.tensor(self.features[idx], dtype=torch.float),
                "labels": torch.tensor(self.labels[idx], dtype=torch.long)}
        return item
    
X_train = X_tr if isinstance(X_tr, np.ndarray) else X_tr.values
X_val = X_val if isinstance(X_val, np.ndarray) else X_val.values
X_test = X_te if isinstance(X_te, np.ndarray) else X_te.values
y_train = np.array(y_tr)
y_vals = np.array(y_val)
y_test = np.array(y_te)

train_dataset = CustomDataset(X_train, y_train)
val_dataset = CustomDataset(X_val, y_vals)
test_dataset = CustomDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class SimpleTransformer(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, hidden_dim, output_dim, num_layers):
        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(embed_dim, output_dim)

    def forward(self, src):
      src = self.embedding(src)
      output = self.transformer_encoder(src)
      #print(output.shape)
      output = output.permute(0, 1)
      output = self.fc(output)
      return output

    def predict(self, src):
        self.eval()
        with torch.no_grad():
            src = self.embedding(src)
            output = self.transformer_encoder(src)
            output = output.permute(0, 1)
            output = self.fc(output)
            probabilities = torch.softmax(output, dim=1)
        return probabilities


model = SimpleTransformer(input_dim=26, embed_dim=32, num_heads=4, hidden_dim=128, output_dim=2, num_layers=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
accs = []

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        inputs, labels = batch['input_values'].to(device), batch['labels'].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in val_loader:
            inputs, labels = batch['input_values'].to(device), batch['labels'].to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    accs.append(acc)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

    print(f"Validation Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")


test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch['input_values'].to(device), batch['labels'].to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)

        y_true.extend(labels.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())
        #print(y_true, y_pred)

acc = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary')

print(f"Final Test Evaluation - Accuracy: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")