In [20]:
import tqdm
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

In [4]:
# Check if GPU is available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
# Define the MLP model with one weight matrix and one projection matrix
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.weight_matrix = nn.Linear(2, 2)  # First weight matrix (2x2)
        self.projection_matrix = nn.Linear(2, 2)  # Projection layer (2x2)

    def forward(self, x):
        x = self.weight_matrix(x)  # Apply first weight matrix
        x = torch.relu(x)  # Apply ReLU activation
        x = self.projection_matrix(x)  # Apply projection matrix
        return x

# Initialize the model, loss function, and optimizer
model = SimpleMLP().to(device)
criterion = nn.MSELoss()  # Using Mean Squared Error for regression
optimizer = optim.SGD(model.parameters(), lr=0.01)  # Stochastic Gradient Descent

In [None]:
# Example training data (2D input -> 2D output)
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [-1.0, -2.0], [-3.0, -4.0]], device=device)
targets = torch.tensor([[2.0, 4.0], [6.0, 8.0], [-2.0, -4.0], [-6.0, -8.0]], device=device)

In [14]:
class CustomCSVSequenceDataset(Dataset):
    def __init__(self, file_paths):
        """
        Args:
            file_paths (list of str): List of paths to CSV files to load.
        """
        self.data = []

        # Load data from all CSV files
        for path in file_paths:
            df = pd.read_csv(path, usecols=[1, 2])  # Load only columns 1 and 2
            # Convert to list of tuples [(col1_val, col2_val), ...]
            self.data.extend(df.values.tolist())

    def __len__(self):
        """Return the total number of samples in the dataset."""
        return len(self.data) - 1  # One less because we look at (i, i+1)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index of the sample to retrieve.

        Returns:
            tuple: (input_tensor, label_tensor)
        """
        # Input is row i, Label is row i + 1
        input_data = torch.tensor(self.data[index], dtype=torch.float32)
        label_data = torch.tensor(self.data[index + 1], dtype=torch.float32)

        return input_data, label_data

In [26]:
# Load the training data
file_paths = ["../data/predator_prey_solution_14_11.csv", 
              "../data/predator_prey_solution_17_57.csv",
              "../data/predator_prey_solution_24_15.csv"]  # Replace with your actual file paths
dataset = CustomCSVSequenceDataset(file_paths)

# Create a Training DataLoader to iterate over the dataset
train_loader = DataLoader(dataset, batch_size=128, shuffle=True)


# Load the validation data
file_paths = ["../data/predator_prey_solution_25_72.csv", 
              "../data/predator_prey_solution_40_40.csv",
              "../data/predator_prey_solution_44_52.csv"]  # Replace with your actual file paths
val_dataset = CustomCSVSequenceDataset(file_paths)

# Create a Training DataLoader to iterate over the dataset
val_loader = DataLoader(dataset, batch_size=128, shuffle=True)

# # Iterate through the dataset and print some samples
# for inputs, labels in dataloader:
#     print(f"Inputs: {inputs}\nLabels: {labels}")
#     break  # Print only the first batch for inspection

In [24]:
# # Load training data from data files
# reader = csv.reader()
# header = next(reader)

# inputs = list()
# targets = list()
# for i, row in enumerate(reader):
#     inputs.append(row[1:])
#     if i != 0:
#         targets.append(inputs[-1])

In [28]:
# Training loop
## FIX THIS
num_epochs = 1000  # Number of epochs to train the model
for epoch in tqdm.tqdm(range(num_epochs)):

    
    for batch in train_loader:
        inputs, targets = batch
        # Forward pass
        outputs = model(inputs.to(device))
        loss = criterion(outputs, targets.to(device))

        # Backward pass and optimization
        optimizer.zero_grad()  # Zero out gradients
        loss.backward()  # Backpropagate the error
        optimizer.step()  # Update model parameters

    # Print the loss every 10 epochs
    if (epoch + 1) % 10 == 0:
        tqdm.tqdm.write(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Test the model with new inputs
test_input = torch.tensor([[5.0, 6.0]], device=device)
predicted_output = model(test_input)
print(f"Predicted output for {test_input.cpu().numpy()}: {predicted_output.cpu().detach().numpy()}")

  2%|▏         | 15/1000 [00:00<00:35, 27.99it/s]

Epoch [10/1000], Loss: 207.3623


  2%|▏         | 24/1000 [00:00<00:34, 27.90it/s]

Epoch [20/1000], Loss: 165.2606


  3%|▎         | 33/1000 [00:01<00:34, 27.77it/s]

Epoch [30/1000], Loss: 129.9519


  4%|▍         | 45/1000 [00:01<00:34, 27.94it/s]

Epoch [40/1000], Loss: 202.0682


  5%|▌         | 54/1000 [00:01<00:33, 27.92it/s]

Epoch [50/1000], Loss: 151.6202


  6%|▋         | 63/1000 [00:02<00:33, 27.86it/s]

Epoch [60/1000], Loss: 138.2430


  8%|▊         | 75/1000 [00:02<00:33, 27.91it/s]

Epoch [70/1000], Loss: 196.0514


  8%|▊         | 84/1000 [00:03<00:32, 27.94it/s]

Epoch [80/1000], Loss: 210.2364


  9%|▉         | 93/1000 [00:03<00:32, 27.85it/s]

Epoch [90/1000], Loss: 198.4434


 10%|█         | 105/1000 [00:03<00:32, 27.92it/s]

Epoch [100/1000], Loss: 119.0652


 11%|█▏        | 114/1000 [00:04<00:31, 27.91it/s]

Epoch [110/1000], Loss: 112.5309


 12%|█▏        | 123/1000 [00:04<00:31, 27.81it/s]

Epoch [120/1000], Loss: 172.3802


 14%|█▎        | 135/1000 [00:04<00:31, 27.88it/s]

Epoch [130/1000], Loss: 236.5351


 14%|█▍        | 144/1000 [00:05<00:30, 27.88it/s]

Epoch [140/1000], Loss: 166.1363


 15%|█▌        | 153/1000 [00:05<00:30, 27.84it/s]

Epoch [150/1000], Loss: 224.1477


 16%|█▋        | 165/1000 [00:05<00:29, 27.89it/s]

Epoch [160/1000], Loss: 148.8680


 17%|█▋        | 174/1000 [00:06<00:29, 27.91it/s]

Epoch [170/1000], Loss: 167.9887


 18%|█▊        | 183/1000 [00:06<00:29, 27.76it/s]

Epoch [180/1000], Loss: 262.6633


 20%|█▉        | 195/1000 [00:06<00:28, 28.04it/s]

Epoch [190/1000], Loss: 220.2812


 20%|██        | 204/1000 [00:07<00:28, 27.96it/s]

Epoch [200/1000], Loss: 280.6383


 21%|██▏       | 213/1000 [00:07<00:28, 27.69it/s]

Epoch [210/1000], Loss: 209.8702


 22%|██▎       | 225/1000 [00:08<00:27, 27.80it/s]

Epoch [220/1000], Loss: 245.4557


 23%|██▎       | 234/1000 [00:08<00:27, 27.83it/s]

Epoch [230/1000], Loss: 116.9210


 24%|██▍       | 243/1000 [00:08<00:27, 27.71it/s]

Epoch [240/1000], Loss: 198.0028


 26%|██▌       | 255/1000 [00:09<00:26, 27.73it/s]

Epoch [250/1000], Loss: 146.3141


 26%|██▋       | 264/1000 [00:09<00:26, 27.89it/s]

Epoch [260/1000], Loss: 218.3464


 27%|██▋       | 273/1000 [00:09<00:26, 27.84it/s]

Epoch [270/1000], Loss: 101.4124


 28%|██▊       | 285/1000 [00:10<00:25, 27.92it/s]

Epoch [280/1000], Loss: 182.4817


 29%|██▉       | 294/1000 [00:10<00:25, 27.91it/s]

Epoch [290/1000], Loss: 116.0206


 30%|███       | 303/1000 [00:10<00:25, 27.86it/s]

Epoch [300/1000], Loss: 84.7051


 32%|███▏      | 315/1000 [00:11<00:24, 27.96it/s]

Epoch [310/1000], Loss: 123.0679


 32%|███▏      | 324/1000 [00:11<00:24, 27.83it/s]

Epoch [320/1000], Loss: 135.8700


 33%|███▎      | 333/1000 [00:11<00:24, 27.75it/s]

Epoch [330/1000], Loss: 96.0726


 34%|███▍      | 345/1000 [00:12<00:23, 27.83it/s]

Epoch [340/1000], Loss: 134.2233


 35%|███▌      | 354/1000 [00:12<00:23, 27.79it/s]

Epoch [350/1000], Loss: 108.6647


 36%|███▋      | 363/1000 [00:13<00:22, 27.81it/s]

Epoch [360/1000], Loss: 176.5513


 38%|███▊      | 375/1000 [00:13<00:22, 27.79it/s]

Epoch [370/1000], Loss: 151.8065


 38%|███▊      | 384/1000 [00:13<00:22, 27.73it/s]

Epoch [380/1000], Loss: 115.8150


 39%|███▉      | 393/1000 [00:14<00:21, 27.77it/s]

Epoch [390/1000], Loss: 176.1308


 40%|████      | 405/1000 [00:14<00:21, 27.90it/s]

Epoch [400/1000], Loss: 239.2999


 41%|████▏     | 414/1000 [00:14<00:20, 27.92it/s]

Epoch [410/1000], Loss: 223.5792


 42%|████▏     | 423/1000 [00:15<00:20, 27.86it/s]

Epoch [420/1000], Loss: 96.0452


 44%|████▎     | 435/1000 [00:15<00:20, 27.99it/s]

Epoch [430/1000], Loss: 165.8462


 44%|████▍     | 444/1000 [00:15<00:19, 27.94it/s]

Epoch [440/1000], Loss: 190.8784


 45%|████▌     | 453/1000 [00:16<00:19, 27.90it/s]

Epoch [450/1000], Loss: 248.9500


 46%|████▋     | 465/1000 [00:16<00:19, 27.73it/s]

Epoch [460/1000], Loss: 187.2151


 47%|████▋     | 474/1000 [00:16<00:18, 28.06it/s]

Epoch [470/1000], Loss: 179.4674


 48%|████▊     | 483/1000 [00:17<00:18, 28.16it/s]

Epoch [480/1000], Loss: 273.7003


 50%|████▉     | 495/1000 [00:17<00:17, 28.09it/s]

Epoch [490/1000], Loss: 152.7874


 50%|█████     | 504/1000 [00:18<00:17, 27.90it/s]

Epoch [500/1000], Loss: 211.2566


 51%|█████▏    | 513/1000 [00:18<00:17, 27.91it/s]

Epoch [510/1000], Loss: 155.6696


 52%|█████▎    | 525/1000 [00:18<00:17, 27.60it/s]

Epoch [520/1000], Loss: 269.5630


 53%|█████▎    | 534/1000 [00:19<00:16, 27.74it/s]

Epoch [530/1000], Loss: 274.2568


 54%|█████▍    | 543/1000 [00:19<00:16, 27.73it/s]

Epoch [540/1000], Loss: 177.7873


 56%|█████▌    | 555/1000 [00:19<00:15, 27.89it/s]

Epoch [550/1000], Loss: 127.5771


 56%|█████▋    | 564/1000 [00:20<00:15, 27.80it/s]

Epoch [560/1000], Loss: 98.5704


 57%|█████▋    | 573/1000 [00:20<00:15, 27.74it/s]

Epoch [570/1000], Loss: 163.6689


 58%|█████▊    | 585/1000 [00:20<00:14, 27.99it/s]

Epoch [580/1000], Loss: 245.3110


 59%|█████▉    | 594/1000 [00:21<00:14, 27.81it/s]

Epoch [590/1000], Loss: 157.7603


 60%|██████    | 603/1000 [00:21<00:14, 27.72it/s]

Epoch [600/1000], Loss: 222.7516


 62%|██████▏   | 615/1000 [00:22<00:13, 27.83it/s]

Epoch [610/1000], Loss: 307.2194


 62%|██████▏   | 624/1000 [00:22<00:13, 27.70it/s]

Epoch [620/1000], Loss: 98.8292


 63%|██████▎   | 633/1000 [00:22<00:13, 27.73it/s]

Epoch [630/1000], Loss: 177.9777


 64%|██████▍   | 642/1000 [00:23<00:12, 27.85it/s]

Epoch [640/1000], Loss: 118.6276





KeyboardInterrupt: 