In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import glob
from torch.nn.utils.rnn import pad_sequence
import torch_directml
import pandas as pd

# device = torch_directml.device()
device = 'cpu'
print(device)

moves = pd.read_csv("kfm_moves.csv")
all_moves = moves['Value'].to_numpy()

data_dir = 'data'
learning_rate = 1e-3
batch_size = 8
num_epochs = 50
hidden_dim = 64
# frame player all status count
num_classes = len(all_moves)
# game status args count
frame_status_args = 19
frame_window_length = 10

print(f"Total moves: {len(all_moves)}")


class SimpleDataset(Dataset):
    def __init__(self, data_dir, frame_window_length=10):
        # self.files = sorted(glob.glob(os.path.join(data_dir, "*.npy")))
        self.files = [os.path.join(data_dir, '0.npy')]
        self.frame_window_length = frame_window_length
        self.samples = []

        assert len(self.files) > 0, f"No .npy files found in {data_dir}"
        
        print(f"Found {len(self.files)} data files.")
        self._build_samples()
    
    def _build_samples(self):
        for filepath in self.files:
            data = np.load(filepath)
            N = data.shape[0]

            filename = os.path.basename(filepath)
            move = int(os.path.splitext(filename)[0]) - 1
            move = torch.tensor(move, dtype=torch.long)

            for start_idx in range(N - self.frame_window_length):
                end_idx = start_idx + self.frame_window_length
                window = data[start_idx:end_idx]
                window = torch.tensor(window, dtype=torch.float32)

                active_move = (np.abs(torch.tensor(all_moves) - move) < 0.1).float()
                self.samples.append((window, active_move))
        
        print(f"Total samples generated: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(SimpleClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        lstm_out, (hidden, _) = self.lstm(x)
        out = self.fc(hidden[-1])
        out = self.softmax(out)
        return out


def collate_fn(batch):
    datas, labels = zip(*batch)
    return torch.stack(datas), torch.stack(labels)

In [None]:
dataset = SimpleDataset(data_dir, frame_window_length=frame_window_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
model = SimpleClassifier(frame_status_args, hidden_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_datas, batch_labels in dataloader:
        batch_datas = batch_datas.to(device)
        batch_labels = batch_labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_datas)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += batch_labels.size(0)
        correct += predicted.eq(batch_labels.argmax(dim=1)).sum().item()
    
    acc = 100. * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, Acc: {acc:.2f}%")

torch.save(model.state_dict(), "model.pth")
print("Training finished. Model saved to model.pth")