# This is a basic RNN model Designed to develop evaluation metrics

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from tqdm import tqdm
import iri_dataset_generator as iri
from training_loop import train_model

SEQUENCE_LENGTH = 10
train, test = iri.load_iri_datasets(path="../training_data/IRI-only.parquet",seq_length=SEQUENCE_LENGTH, one_hot=True)

In [None]:
class FNN(nn.Module):
    def __init__(self):
        super(FNN, self).__init__()
        self.rnn = nn.RNN(input_size=SEQUENCE_LENGTH, hidden_size=30, num_layers=10, batch_first=True)
        self.final = nn.Linear(30, 3)
        
    def forward(self, x):
        h0 = torch.zeros(10, x.size(0), 30).to(x.device)
        out, _ = self.rnn(x, h0)

        # Decode the hidden state of the last time step
        out = self.final(out[:, -1, :])
        out = nn.LogSoftmax(dim=1)(out)
        return out

In [None]:
model = FNN()
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

train_model(model, train, test, loss, optimizer, epochs=200, test_every_n=10, batch_size=512)

In [None]:
# test accuracy
from torch.utils.data import DataLoader

def accuracy(output, labels):
    _, predictions = torch.max(output, dim=1)
    correct = (predictions == labels).float()
    accuracy = correct.sum() / len(correct)
    return accuracy

device = torch.device("cuda")

n1 = 0
n2 = 0
n3 = 0

model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    train_data = DataLoader(train, batch_size=256, shuffle=True)
    for _, data in enumerate(train_data):
        inputs, goal = data[0].to(device), data[1].to(device)
        outputs = model(inputs)
        _, target_indices = torch.max(goal, dim=1)
        correct += accuracy(outputs, target_indices)
        total += 1
    print(f'Accuracy of the network on the train data: {100 * correct / total} %')
    test_data = DataLoader(test, batch_size=256, shuffle=True)
    correct = 0
    total = 0
    for _, data in enumerate(test_data):
        inputs, goal = data[0].to(device), data[1].to(device)
        outputs = model(inputs)
        _, target_indices = torch.max(goal, dim=1)
        correct += accuracy(outputs, target_indices)
        total += 1
    print(f'Accuracy of the network on the test data: {100 * correct / total} %')