In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms
import torch.optim as optim
import cellpylib as cpl
import numpy as np
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


## Data Generation

In [10]:
def generate_1D_sequences_rand(num_samples, rule_number, ca_size, time_steps):
 
    # Initialize dataset array
    input_data = np.zeros((num_samples, time_steps, ca_size))  
    output_data = np.zeros((num_samples, ca_size))
    
    ics = np.zeros((num_samples, ca_size))
    
    for n in range(num_samples):
        
        duplicate = True
        while duplicate:
            
            # create a random initial ca state
            ca = cpl.init_random(ca_size)
            ic = ca[0]
            if not np.array_equal(ics, ic):
                duplicate = False
                ca = cpl.evolve(ca, timesteps=time_steps+1, apply_rule=lambda n, c, t: cpl.nks_rule(n, rule_number), memoize=True)
                
        # convert the ca state from 0s and 1s to -1s and 1s
        # converted_ca = torch.where(torch.tensor(ca) == 0, -1, torch.tensor(ca))
        
        input_data[n] = ca[:time_steps]
        output_data[n] = ca[time_steps:]
        
    input_data = torch.from_numpy(input_data).to(torch.float32)
    output_data = torch.from_numpy(output_data).to(torch.float32)

    return input_data, output_data

# generate the data
ca_size = 70
sample_size = ca_size*320
x_data, y_data = generate_1D_sequences_rand(sample_size, 225, ca_size, 2)
print("x_data", x_data.size())
print(x_data)
print("y_data", y_data.size())
print(y_data)

x_data torch.Size([22400, 2, 70])
tensor([[[0., 1., 0.,  ..., 1., 0., 1.],
         [1., 0., 1.,  ..., 1., 1., 0.]],

        [[0., 1., 0.,  ..., 0., 0., 1.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[1., 1., 1.,  ..., 1., 1., 0.],
         [0., 1., 1.,  ..., 1., 1., 1.]],

        ...,

        [[0., 0., 1.,  ..., 1., 0., 1.],
         [0., 0., 0.,  ..., 1., 1., 0.]],

        [[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 1., 1.,  ..., 1., 0., 0.]]])
y_data torch.Size([22400, 70])
tensor([[0., 1., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [1., 0., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 0.,  ..., 0., 1., 0.],
        [1., 1., 1.,  ..., 1., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 1.]])


In [None]:
def generate_data(rule_number, ca_size, num_configs):
    
     
    # Initialize dataset array
    input_data = np.zeros((1, num_configs, ca_size))  
    output_data = np.zeros((num_configs, ca_size))
    
    ics = np.zeros((num_samples, ca_size))

    for n in range(num_samples):
        
        duplicate = True
        while duplicate:
            
            # create a random initial ca state
            ca = cpl.init_random(ca_size)
            ic = ca[0]
            if not np.array_equal(ics, ic):
                duplicate = False
                ca = cpl.evolve(ca, timesteps=time_steps+1, apply_rule=lambda n, c, t: cpl.nks_rule(n, rule_number), memoize=True)
                


In [None]:
def generate_all_1D_sequences(rule_number, ca_size, time_steps):
 
    # Initialize dataset array
    input_data = np.zeros((2**ca_size, time_steps, ca_size))  
    output_data = np.zeros((2**ca_size, ca_size))

    # create all possible configurations of 0s and 1s
    def generate_all_possible_binary_images(ca_size):
        if ca_size <= 0:
            return [[]]

        def backtrack(start, path, result):
            if len(path) == ca_size:
                result.append([path[:]])
                return

            for i in range(start, 2):
                path.append(i)
                backtrack(0, path, result)
                path.pop()

        result = []
        backtrack(0, [], result)        
        return result

    
    all_binary_images = np.array(generate_all_possible_binary_images(ca_size))
    
    for i in range(len(all_binary_images)):
        
        ic = all_binary_images[i]
        
        # evolve state based on update rule
        ca = cpl.evolve(ic, timesteps=time_steps+1, apply_rule=lambda n, c, t: cpl.nks_rule(n, rule_number), memoize=True)
       
        input_data[i] = ca[:time_steps]
        output_data[i] = ca[time_steps:]
        
    input_data = torch.from_numpy(input_data).to(torch.float32)
    output_data = torch.from_numpy(output_data).to(torch.float32)

    return input_data, output_data

## Split Data Into Training Set and Test Set

In [11]:
split = int(0.85 * len(x_data))
print(split)
x_train = x_data[:split]
y_train = y_data[:split]
x_test = x_data[split:]
y_test = y_data[split:]

19040


In [12]:
batch_size = 90

train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## Build Neural Network

In [16]:
class RecurrentNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RecurrentNet, self).__init__()
        self.hidden_size = hidden_size

        self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(1, x.size(0), self.hidden_size).to(device)
        
        out, _ = self.rnn(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        
        return torch.sigmoid(out)

In [17]:
model = RecurrentNet(input_size=ca_size, hidden_size=ca_size, output_size=ca_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

print(model)

RecurrentNet(
  (rnn): LSTM(70, 70, batch_first=True)
  (fc): Linear(in_features=70, out_features=70, bias=True)
)


## Optimization Loop

In [18]:
n_epochs = 900
# step = 0
for epoch in range(n_epochs):
    
    print("\n-------------------------------")
    
    # train loop
    model.train()
    for batch, (inputs, outputs) in enumerate(train_loader):
        
        inputs, outputs = inputs.to(device), outputs.to(device)
        
        # compute predictions
        predictions = model(inputs)
        
        # calculate loss
        loss = loss_fn(predictions, outputs)
        
        #backpropogation
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # print progress
        # if (batch+1) % 10 == 0:
            # print(f'Epoch [{epoch+1}/{n_epochs}], Step [{batch+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        # writer.add_scalar("Training Loss", loss, global_step=step)
        # step += 1
        
            
    # test loop
    model.eval()
    
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        
        for inputs, outputs in test_loader:
            
            inputs, outputs = inputs.to(device), outputs.to(device)
            
            predictions = model(inputs)
            loss = loss_fn(predictions, outputs)
            test_loss += loss.item()
                
            # correct += (torch.round(predictions) == outputs).sum().item()
            for i in range(len(predictions)):
                rounded_prediction = torch.round(predictions[i])
                if rounded_prediction.equal(outputs[i]):
                    correct += 1
            
    average_test_loss = test_loss / len(test_loader)
    total = len(test_loader.dataset)
    accuracy = correct / total
    
    
    print(f'Epoch [{epoch+1}/{n_epochs}], Test Loss: {average_test_loss:.4f}, Test Accuracy: {accuracy:.2%}')
    print('correct:', correct)
    print('total:', total)
    

    
print('Training and Testing finished.')
    


-------------------------------
Epoch [1/900], Test Loss: 0.2148, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [2/900], Test Loss: 0.1994, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [3/900], Test Loss: 0.1912, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [4/900], Test Loss: 0.1860, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [5/900], Test Loss: 0.1825, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [6/900], Test Loss: 0.1799, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [7/900], Test Loss: 0.1778, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [8/900], Test Loss: 0.1761, Test Accuracy: 0.00%
correct: 0
total: 3360

-------------------------------
Epoch [9/900], Test Loss: 0.1750, Test Accuracy: 0.00%
correct: 0
total: 3360



Epoch [74/900], Test Loss: 0.0343, Test Accuracy: 8.51%
correct: 286
total: 3360

-------------------------------
Epoch [75/900], Test Loss: 0.0330, Test Accuracy: 8.90%
correct: 299
total: 3360

-------------------------------
Epoch [76/900], Test Loss: 0.0317, Test Accuracy: 9.85%
correct: 331
total: 3360

-------------------------------
Epoch [77/900], Test Loss: 0.0308, Test Accuracy: 11.25%
correct: 378
total: 3360

-------------------------------
Epoch [78/900], Test Loss: 0.0295, Test Accuracy: 12.98%
correct: 436
total: 3360

-------------------------------
Epoch [79/900], Test Loss: 0.0286, Test Accuracy: 13.45%
correct: 452
total: 3360

-------------------------------
Epoch [80/900], Test Loss: 0.0273, Test Accuracy: 14.82%
correct: 498
total: 3360

-------------------------------
Epoch [81/900], Test Loss: 0.0264, Test Accuracy: 15.80%
correct: 531
total: 3360

-------------------------------
Epoch [82/900], Test Loss: 0.0254, Test Accuracy: 17.26%
correct: 580
total: 3360



Epoch [145/900], Test Loss: 0.0049, Test Accuracy: 77.11%
correct: 2591
total: 3360

-------------------------------
Epoch [146/900], Test Loss: 0.0048, Test Accuracy: 76.85%
correct: 2582
total: 3360

-------------------------------
Epoch [147/900], Test Loss: 0.0047, Test Accuracy: 77.26%
correct: 2596
total: 3360

-------------------------------
Epoch [148/900], Test Loss: 0.0046, Test Accuracy: 78.18%
correct: 2627
total: 3360

-------------------------------
Epoch [149/900], Test Loss: 0.0046, Test Accuracy: 77.56%
correct: 2606
total: 3360

-------------------------------
Epoch [150/900], Test Loss: 0.0045, Test Accuracy: 78.04%
correct: 2622
total: 3360

-------------------------------
Epoch [151/900], Test Loss: 0.0044, Test Accuracy: 78.96%
correct: 2653
total: 3360

-------------------------------
Epoch [152/900], Test Loss: 0.0042, Test Accuracy: 79.35%
correct: 2666
total: 3360

-------------------------------
Epoch [153/900], Test Loss: 0.0043, Test Accuracy: 79.02%
correc

Epoch [216/900], Test Loss: 0.0026, Test Accuracy: 83.81%
correct: 2816
total: 3360

-------------------------------
Epoch [217/900], Test Loss: 0.0026, Test Accuracy: 83.99%
correct: 2822
total: 3360

-------------------------------
Epoch [218/900], Test Loss: 0.0028, Test Accuracy: 82.89%
correct: 2785
total: 3360

-------------------------------
Epoch [219/900], Test Loss: 0.0026, Test Accuracy: 83.90%
correct: 2819
total: 3360

-------------------------------
Epoch [220/900], Test Loss: 0.0027, Test Accuracy: 83.72%
correct: 2813
total: 3360

-------------------------------
Epoch [221/900], Test Loss: 0.0029, Test Accuracy: 81.93%
correct: 2753
total: 3360

-------------------------------
Epoch [222/900], Test Loss: 0.0026, Test Accuracy: 83.99%
correct: 2822
total: 3360

-------------------------------
Epoch [223/900], Test Loss: 0.0026, Test Accuracy: 84.26%
correct: 2831
total: 3360

-------------------------------
Epoch [224/900], Test Loss: 0.0026, Test Accuracy: 84.32%
correc

Epoch [287/900], Test Loss: 0.0024, Test Accuracy: 84.73%
correct: 2847
total: 3360

-------------------------------
Epoch [288/900], Test Loss: 0.0025, Test Accuracy: 84.55%
correct: 2841
total: 3360

-------------------------------
Epoch [289/900], Test Loss: 0.0024, Test Accuracy: 84.67%
correct: 2845
total: 3360

-------------------------------
Epoch [290/900], Test Loss: 0.0024, Test Accuracy: 84.64%
correct: 2844
total: 3360

-------------------------------
Epoch [291/900], Test Loss: 0.0024, Test Accuracy: 84.73%
correct: 2847
total: 3360

-------------------------------
Epoch [292/900], Test Loss: 0.0024, Test Accuracy: 84.67%
correct: 2845
total: 3360

-------------------------------
Epoch [293/900], Test Loss: 0.0025, Test Accuracy: 84.58%
correct: 2842
total: 3360

-------------------------------
Epoch [294/900], Test Loss: 0.0034, Test Accuracy: 78.66%
correct: 2643
total: 3360

-------------------------------
Epoch [295/900], Test Loss: 0.0026, Test Accuracy: 83.93%
correc

Epoch [358/900], Test Loss: 0.0025, Test Accuracy: 84.14%
correct: 2827
total: 3360

-------------------------------
Epoch [359/900], Test Loss: 0.0024, Test Accuracy: 84.40%
correct: 2836
total: 3360

-------------------------------
Epoch [360/900], Test Loss: 0.0024, Test Accuracy: 84.58%
correct: 2842
total: 3360

-------------------------------
Epoch [361/900], Test Loss: 0.0024, Test Accuracy: 84.64%
correct: 2844
total: 3360

-------------------------------
Epoch [362/900], Test Loss: 0.0024, Test Accuracy: 84.70%
correct: 2846
total: 3360

-------------------------------
Epoch [363/900], Test Loss: 0.0024, Test Accuracy: 84.70%
correct: 2846
total: 3360

-------------------------------
Epoch [364/900], Test Loss: 0.0024, Test Accuracy: 84.73%
correct: 2847
total: 3360

-------------------------------
Epoch [365/900], Test Loss: 0.0024, Test Accuracy: 84.79%
correct: 2849
total: 3360

-------------------------------
Epoch [366/900], Test Loss: 0.0024, Test Accuracy: 84.76%
correc

Epoch [429/900], Test Loss: 0.0020, Test Accuracy: 87.29%
correct: 2933
total: 3360

-------------------------------
Epoch [430/900], Test Loss: 0.0020, Test Accuracy: 87.35%
correct: 2935
total: 3360

-------------------------------
Epoch [431/900], Test Loss: 0.0020, Test Accuracy: 87.35%
correct: 2935
total: 3360

-------------------------------
Epoch [432/900], Test Loss: 0.0020, Test Accuracy: 87.26%
correct: 2932
total: 3360

-------------------------------
Epoch [433/900], Test Loss: 0.0020, Test Accuracy: 87.29%
correct: 2933
total: 3360

-------------------------------
Epoch [434/900], Test Loss: 0.0023, Test Accuracy: 85.12%
correct: 2860
total: 3360

-------------------------------
Epoch [435/900], Test Loss: 0.0020, Test Accuracy: 86.96%
correct: 2922
total: 3360

-------------------------------
Epoch [436/900], Test Loss: 0.0020, Test Accuracy: 87.23%
correct: 2931
total: 3360

-------------------------------
Epoch [437/900], Test Loss: 0.0020, Test Accuracy: 87.38%
correc

Epoch [500/900], Test Loss: 0.0017, Test Accuracy: 88.93%
correct: 2988
total: 3360

-------------------------------
Epoch [501/900], Test Loss: 0.0017, Test Accuracy: 88.96%
correct: 2989
total: 3360

-------------------------------
Epoch [502/900], Test Loss: 0.0017, Test Accuracy: 89.08%
correct: 2993
total: 3360

-------------------------------
Epoch [503/900], Test Loss: 0.0017, Test Accuracy: 89.02%
correct: 2991
total: 3360

-------------------------------
Epoch [504/900], Test Loss: 0.0017, Test Accuracy: 89.11%
correct: 2994
total: 3360

-------------------------------
Epoch [505/900], Test Loss: 0.0017, Test Accuracy: 89.11%
correct: 2994
total: 3360

-------------------------------
Epoch [506/900], Test Loss: 0.0017, Test Accuracy: 89.14%
correct: 2995
total: 3360

-------------------------------
Epoch [507/900], Test Loss: 0.0017, Test Accuracy: 89.23%
correct: 2998
total: 3360

-------------------------------
Epoch [508/900], Test Loss: 0.0017, Test Accuracy: 89.29%
correc

Epoch [571/900], Test Loss: 0.0017, Test Accuracy: 89.08%
correct: 2993
total: 3360

-------------------------------
Epoch [572/900], Test Loss: 0.0017, Test Accuracy: 89.20%
correct: 2997
total: 3360

-------------------------------
Epoch [573/900], Test Loss: 0.0017, Test Accuracy: 89.14%
correct: 2995
total: 3360

-------------------------------
Epoch [574/900], Test Loss: 0.0017, Test Accuracy: 89.11%
correct: 2994
total: 3360

-------------------------------
Epoch [575/900], Test Loss: 0.0017, Test Accuracy: 89.17%
correct: 2996
total: 3360

-------------------------------
Epoch [576/900], Test Loss: 0.0017, Test Accuracy: 89.11%
correct: 2994
total: 3360

-------------------------------
Epoch [577/900], Test Loss: 0.0017, Test Accuracy: 89.20%
correct: 2997
total: 3360

-------------------------------
Epoch [578/900], Test Loss: 0.0017, Test Accuracy: 89.17%
correct: 2996
total: 3360

-------------------------------
Epoch [579/900], Test Loss: 0.0017, Test Accuracy: 89.14%
correc

Epoch [642/900], Test Loss: 0.0017, Test Accuracy: 88.96%
correct: 2989
total: 3360

-------------------------------
Epoch [643/900], Test Loss: 0.0017, Test Accuracy: 88.96%
correct: 2989
total: 3360

-------------------------------
Epoch [644/900], Test Loss: 0.0017, Test Accuracy: 89.05%
correct: 2992
total: 3360

-------------------------------
Epoch [645/900], Test Loss: 0.0017, Test Accuracy: 89.08%
correct: 2993
total: 3360

-------------------------------
Epoch [646/900], Test Loss: 0.0017, Test Accuracy: 89.05%
correct: 2992
total: 3360

-------------------------------
Epoch [647/900], Test Loss: 0.0017, Test Accuracy: 89.05%
correct: 2992
total: 3360

-------------------------------
Epoch [648/900], Test Loss: 0.0016, Test Accuracy: 89.20%
correct: 2997
total: 3360

-------------------------------
Epoch [649/900], Test Loss: 0.0017, Test Accuracy: 89.14%
correct: 2995
total: 3360

-------------------------------
Epoch [650/900], Test Loss: 0.0016, Test Accuracy: 89.17%
correc

Epoch [713/900], Test Loss: 0.0016, Test Accuracy: 89.40%
correct: 3004
total: 3360

-------------------------------
Epoch [714/900], Test Loss: 0.0016, Test Accuracy: 89.43%
correct: 3005
total: 3360

-------------------------------
Epoch [715/900], Test Loss: 0.0016, Test Accuracy: 89.43%
correct: 3005
total: 3360

-------------------------------
Epoch [716/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [717/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [718/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [719/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [720/900], Test Loss: 0.0016, Test Accuracy: 89.52%
correct: 3008
total: 3360

-------------------------------
Epoch [721/900], Test Loss: 0.0016, Test Accuracy: 89.55%
correc

Epoch [784/900], Test Loss: 0.0016, Test Accuracy: 89.38%
correct: 3003
total: 3360

-------------------------------
Epoch [785/900], Test Loss: 0.0016, Test Accuracy: 89.38%
correct: 3003
total: 3360

-------------------------------
Epoch [786/900], Test Loss: 0.0016, Test Accuracy: 89.40%
correct: 3004
total: 3360

-------------------------------
Epoch [787/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [788/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [789/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [790/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [791/900], Test Loss: 0.0038, Test Accuracy: 72.50%
correct: 2436
total: 3360

-------------------------------
Epoch [792/900], Test Loss: 0.0018, Test Accuracy: 88.10%
correc

Epoch [855/900], Test Loss: 0.0016, Test Accuracy: 89.43%
correct: 3005
total: 3360

-------------------------------
Epoch [856/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [857/900], Test Loss: 0.0016, Test Accuracy: 89.46%
correct: 3006
total: 3360

-------------------------------
Epoch [858/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [859/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [860/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [861/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [862/900], Test Loss: 0.0016, Test Accuracy: 89.49%
correct: 3007
total: 3360

-------------------------------
Epoch [863/900], Test Loss: 0.0016, Test Accuracy: 89.43%
correc

In [None]:
ca_size = 12
for t_steps in range(-20, 0):
    t_steps = -t_steps
    
    x_data, y_data = generate_all_1D_sequences(225, ca_size, time_steps=t_steps)
    
    split = int(0.85 * len(x_data))
    
    x_train = x_data[:split]
    y_train = y_data[:split]
    
    x_test = x_data[split:]
    y_test = y_data[split:]
    
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = TensorDataset(x_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    
    model = RecurrentNet(input_size=ca_size, hidden_size=ca_size, output_size=ca_size).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()
    
    n_epochs = 1000
    max_accuracy = 0
    max_correct = 0
    max_total = 0
    for epoch in range(n_epochs):

        # print("\n-------------------------------")

        # train loop
        model.train()
        for batch, (inputs, outputs) in enumerate(train_loader):

            inputs, outputs = inputs.to(device), outputs.to(device)

            # compute predictions
            predictions = model(inputs)

            # calculate loss
            loss = loss_fn(predictions, outputs)

            #backpropogation

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # test loop
        model.eval()

        test_loss, correct = 0, 0

        with torch.no_grad():

            for inputs, outputs in test_loader:

                inputs, outputs = inputs.to(device), outputs.to(device)

                predictions = model(inputs)
                loss = loss_fn(predictions, outputs)
                test_loss += loss.item()

                # correct += (torch.round(predictions) == outputs).sum().item()
                for i in range(len(predictions)):
                    rounded_prediction = torch.round(predictions[i])
                    if rounded_prediction.equal(outputs[i]):
                        correct += 1

        average_test_loss = test_loss / len(test_loader)
        total = len(test_loader.dataset)
        accuracy = correct / total
        if accuracy > max_accuracy:
            max_accuracy = accuracy
            max_correct = correct
            max_total = total
            
        ''''
        print(f'Epoch [{epoch+1}/{n_epochs}], Test Loss: {average_test_loss:.4f}, Test Accuracy: {accuracy:.2%}')
        '''
        
    print('Number of Time Steps:', t_steps)
    print('Max Model Accuracy:', max_accuracy)
    print('Num Correct:', correct)
    print('Total:', total)