In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import scipy.io

In [None]:
mat = scipy.io.loadmat('mixoutALL_shifted.mat')

In [None]:
np.unique(mat["consts"][0][0][4][0])

In [None]:
class CharacterTrajectoriesDataset(torch.utils.data.Dataset):

	def __init__(self, path = "mixoutALL_shifted.mat"):
		self.path = path
		mat = scipy.io.loadmat(path)
		self.classes = torch.from_numpy(mat["consts"][0][0][4][0]).long()
		trajectories = [torch.from_numpy(x).permute(1, 0) for x in mat["mixout"][0]]
		self.trajectories = torch.nn.utils.rnn.pad_sequence(trajectories, batch_first = True).float()

	def __len__(self):
		return len(self.classes)

	def __getitem__(self, i):
		"""Returns trajectory (T, 3)
		"""
		source = self.trajectories[i]
		target = self.classes[i]
		return source, target

In [None]:
from unicodedata import bidirectional

class CharacterRNN(torch.nn.Module):
    def __init__(self, hidden_size=32, num_layers=2, num_classes = 21, bidirectional = True):
        super(CharacterRNN, self).__init__()
        self.rnn = torch.nn.LSTM(3, hidden_size, num_layers, batch_first = True, bidirectional = bidirectional)
        # self.out = torch.nn.Linear(hidden_size, num_classes)
        bd = 2 if bidirectional else 1
        self.out = torch.nn.Linear(hidden_size*num_layers*bd, num_classes)
    def forward(self, x):
        enc, (hn, cn) = self.rnn(x)
        # out = self.out(hn[-1])
        out = self.out(hn.permute(1,0,2).flatten(1))
        return out

In [None]:
ds = CharacterTrajectoriesDataset()
dl = torch.utils.data.DataLoader(ds, batch_size = 128, shuffle = True)
num_batches = len(dl)
model = CharacterRNN(num_classes = 21)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fcn = torch.nn.CrossEntropyLoss()

In [None]:
src, tar = ds[0]
print(src.shape)
print(tar.shape)

In [None]:
for epoch in range(100): 
    epoch_loss = 0
    num_correct = 0
    for source, target in dl:
        model.zero_grad()
        pred = model(source)
        loss = loss_fcn(pred, target)
        loss.backward()
        optimizer.step()
        batch_loss = loss.item()
        num_correct += (pred.argmax(dim=-1) == target).sum()
        # print(f"\rLoss: {batch_loss}")
        epoch_loss += batch_loss
    epoch_loss /= num_batches
    print(f"\rEpoch Loss: {epoch_loss}")
    print(f"\rEpoch Accuracy: {num_correct / len(ds)}")