In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

# Custom functions to create visual charts
from chart_utils import TimeSeriesImageDataset, create_area_chart, create_spiral_chart

# Read UCR Dataset
def read_ucr(filename):
    data = np.loadtxt(filename, delimiter=',')
    Y = data[:, 0]
    X = data[:, 1:]
    return X, Y

# File paths
train_file = 'Adiac/Adiac_TRAIN'
test_file = 'Adiac/Adiac_TEST'

# Load dataset
x_train, y_train = read_ucr(train_file)
x_test, y_test = read_ucr(test_file)

# Normalize labels to be within range [0, num_classes-1]
nb_classes = len(np.unique(y_test))
y_train = (y_train - y_train.min()).astype(int)
y_test = (y_test - y_test.min()).astype(int)

# Normalize features
x_train_mean = x_train.mean()
x_train_std = x_train.std()
x_train = (x_train - x_train_mean) / x_train_std
x_test = (x_test - x_train_mean) / x_train_std

# Create dataset and dataloader
transform = transforms.Compose([transforms.Resize((128, 128)), transforms.ToTensor()])
train_dataset = TimeSeriesImageDataset(x_train, y_train, transform)
test_dataset = TimeSeriesImageDataset(x_test, y_test, transform)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

# Define SimpleCNN
class SimpleCNN(nn.Module):
    def __init__(self, input_channels):
        super(SimpleCNN, self).__init__()
        self.layer_stack = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten()
        )

    def forward(self, x):
        return self.layer_stack(x)

# Define CombinedModel
class CombinedModel(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(CombinedModel, self).__init__()
        self.cnn_area = SimpleCNN(input_shape[0])
        self.cnn_spiral = SimpleCNN(input_shape[0])

        # Calculate the flattened feature size
        cnn_output_size = self._get_cnn_output_size(input_shape)

        # Fully connected layers
        self.fc1 = nn.Sequential(
            nn.Linear(cnn_output_size * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
    # Flattened output size of CNN
    def _get_cnn_output_size(self, input_shape):
        dummy_input = torch.zeros(1, *input_shape)
        output = self.cnn_area(dummy_input)
        return output.view(1, -1).size(1)

    def forward(self, x_area, x_spiral):
        features_area = self.cnn_area(x_area)
        features_spiral = self.cnn_spiral(x_spiral)

        concatenated_features = torch.cat((features_area, features_spiral), dim=1)

        return self.fc1(concatenated_features)

# Example usage of the model
input_shape = (3, 128, 128)  # Example input shape (channels, height, width)
num_classes = nb_classes  # Number of output classes

model = CombinedModel(input_shape, num_classes)
print(model)

# Training loop with accuracy and classified output
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 5
overall_progress_bar = tqdm(range(num_epochs), desc="Training Progress")
for epoch in overall_progress_bar:
    print(f"Epoch: {epoch+1}/{num_epochs}\n-------")
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_predictions = []

    # Add a loop to loop through training batches
    for batch, (images_area, images_spiral, labels) in enumerate(train_dataloader):
        model.train()
        optimizer.zero_grad()

        # Convert labels to LongTensor
        labels = labels.long()

        outputs = model(images_area, images_spiral)
        loss = criterion(outputs, labels)
        running_loss += loss.item()

        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Collect labels and predictions for output
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

    # Divide total train loss by length of train dataloader (average loss per batch per epoch)
    epoch_loss = running_loss / len(train_dataloader)
    epoch_accuracy = 100 * correct / total

    print(f'Training Loss: {epoch_loss:.4f} | Training Accuracy: {epoch_accuracy:.2f}%')
    print(f'Labels: {all_labels}')
    print(f'Predictions: {all_predictions}')

    # Testing phase
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.inference_mode():
        for images_area, images_spiral, labels in test_dataloader:
            labels = labels.long()
            outputs = model(images_area, images_spiral)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_loss /= len(test_dataloader)
    test_accuracy = 100 * correct / total

    print(f'Test Loss: {test_loss:.4f} | Test Accuracy: {test_accuracy:.2f}%')


CombinedModel(
  (cnn_area): SimpleCNN(
    (layer_stack): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Flatten(start_dim=1, end_dim=-1)
    )
  )
  (cnn_spiral): SimpleCNN(
    (layer_stack): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Flatten(start_dim=1, end_dim=-1)
    )
  )
  (fc1): Sequential(
    (0): Linear(in_f

Training Progress:   0%|                                  | 0/5 [00:00<?, ?it/s]

Epoch: 1/5
-------
Training Loss: 3.6359 | Training Accuracy: 2.05%
Labels: [6, 16, 29, 28, 25, 14, 9, 7, 15, 13, 11, 28, 5, 27, 12, 18, 31, 20, 16, 20, 19, 26, 31, 1, 24, 5, 23, 17, 27, 30, 15, 23, 30, 17, 27, 1, 24, 12, 33, 29, 13, 20, 27, 8, 22, 3, 29, 8, 28, 31, 14, 13, 27, 3, 29, 18, 5, 1, 3, 18, 16, 13, 14, 26, 35, 24, 9, 34, 36, 35, 11, 33, 13, 14, 33, 9, 9, 3, 22, 24, 36, 29, 29, 0, 32, 0, 11, 11, 0, 6, 24, 3, 14, 23, 3, 22, 27, 5, 19, 34, 29, 5, 31, 19, 32, 31, 11, 13, 1, 15, 5, 13, 20, 15, 26, 30, 14, 24, 6, 7, 18, 27, 3, 36, 30, 26, 31, 31, 29, 11, 29, 22, 35, 21, 0, 17, 7, 30, 28, 13, 27, 28, 18, 28, 0, 7, 12, 27, 25, 30, 23, 2, 25, 31, 6, 10, 7, 28, 12, 33, 13, 0, 28, 8, 6, 5, 0, 10, 18, 9, 24, 17, 32, 23, 5, 19, 9, 16, 11, 14, 17, 10, 25, 26, 15, 8, 33, 16, 14, 32, 0, 22, 11, 29, 9, 15, 3, 33, 35, 7, 9, 35, 4, 5, 16, 17, 6, 26, 11, 35, 34, 27, 1, 34, 26, 33, 17, 32, 16, 7, 30, 9, 0, 23, 9, 4, 32, 36, 33, 16, 19, 14, 0, 8, 23, 3, 19, 16, 34, 35, 31, 14, 32, 14, 3, 7, 34, 1

Training Progress:  20%|█████▏                    | 1/5 [00:23<01:34, 23.55s/it]

Test Loss: 3.6231 | Test Accuracy: 2.56%
Epoch: 2/5
-------
Training Loss: 3.6082 | Training Accuracy: 2.82%
Labels: [33, 11, 15, 10, 11, 35, 26, 3, 32, 14, 24, 22, 31, 16, 18, 7, 27, 18, 35, 29, 6, 20, 1, 24, 32, 36, 23, 30, 26, 20, 23, 31, 8, 33, 32, 14, 17, 24, 4, 26, 5, 34, 29, 26, 27, 17, 24, 0, 6, 29, 23, 11, 0, 9, 36, 11, 34, 3, 27, 29, 24, 7, 29, 35, 14, 24, 33, 5, 21, 23, 5, 18, 27, 13, 7, 11, 11, 2, 35, 7, 6, 28, 18, 4, 34, 4, 24, 29, 27, 7, 27, 3, 26, 21, 9, 31, 3, 23, 3, 35, 4, 6, 8, 27, 33, 34, 30, 26, 19, 23, 34, 17, 15, 12, 29, 19, 0, 31, 35, 1, 18, 11, 7, 21, 0, 27, 10, 15, 6, 26, 3, 1, 31, 16, 27, 36, 11, 8, 5, 24, 17, 36, 0, 0, 6, 36, 12, 20, 19, 24, 29, 17, 28, 14, 1, 30, 27, 27, 21, 13, 13, 31, 8, 16, 15, 2, 8, 13, 31, 21, 0, 18, 15, 10, 23, 9, 26, 31, 20, 35, 36, 12, 2, 33, 32, 35, 18, 7, 19, 7, 28, 14, 7, 31, 14, 14, 22, 30, 35, 9, 1, 23, 26, 21, 24, 12, 1, 24, 3, 16, 19, 36, 17, 30, 5, 32, 9, 10, 25, 28, 5, 12, 22, 32, 28, 8, 18, 32, 33, 27, 28, 15, 0, 5, 9, 3, 3

Training Progress:  40%|██████████▍               | 2/5 [00:46<01:09, 23.31s/it]

Test Loss: 3.6273 | Test Accuracy: 1.79%
Epoch: 3/5
-------
Training Loss: 3.6035 | Training Accuracy: 3.08%
Labels: [34, 33, 10, 31, 32, 26, 28, 22, 32, 30, 23, 0, 1, 23, 18, 3, 36, 27, 24, 2, 21, 19, 5, 34, 25, 18, 34, 7, 6, 23, 19, 35, 4, 26, 14, 6, 8, 16, 29, 33, 12, 23, 0, 8, 11, 6, 30, 7, 7, 7, 11, 31, 1, 18, 25, 34, 31, 29, 21, 18, 18, 14, 13, 16, 18, 9, 14, 1, 23, 21, 3, 32, 26, 19, 4, 29, 15, 15, 26, 33, 16, 21, 17, 32, 23, 30, 5, 7, 9, 36, 14, 3, 19, 26, 23, 12, 0, 31, 16, 8, 24, 7, 20, 3, 8, 24, 31, 7, 23, 34, 36, 11, 34, 31, 1, 27, 1, 0, 24, 11, 30, 5, 10, 27, 4, 15, 27, 31, 17, 5, 23, 26, 12, 15, 15, 20, 13, 10, 11, 31, 2, 14, 22, 31, 36, 14, 36, 16, 19, 0, 0, 20, 18, 13, 18, 13, 14, 24, 22, 20, 14, 22, 5, 5, 20, 35, 13, 3, 22, 5, 26, 17, 19, 1, 24, 24, 32, 7, 16, 13, 2, 30, 18, 34, 17, 15, 33, 27, 32, 25, 24, 25, 20, 29, 14, 12, 27, 22, 25, 28, 24, 24, 14, 13, 35, 8, 9, 17, 33, 6, 35, 35, 8, 32, 16, 22, 0, 23, 28, 32, 32, 23, 5, 20, 21, 15, 3, 9, 15, 7, 10, 22, 10, 30, 27

Training Progress:  60%|███████████████▌          | 3/5 [01:09<00:46, 23.16s/it]

Test Loss: 3.6339 | Test Accuracy: 1.79%
Epoch: 4/5
-------
Training Loss: 3.5995 | Training Accuracy: 3.33%
Labels: [18, 7, 6, 27, 29, 15, 11, 25, 36, 26, 18, 20, 1, 0, 6, 17, 14, 16, 23, 28, 12, 36, 14, 22, 26, 6, 8, 16, 30, 21, 30, 21, 35, 23, 11, 28, 30, 17, 11, 16, 27, 35, 34, 8, 26, 18, 31, 2, 30, 24, 10, 6, 31, 23, 16, 34, 9, 13, 5, 16, 0, 23, 19, 31, 9, 24, 25, 6, 7, 0, 23, 11, 5, 35, 18, 30, 19, 23, 23, 13, 29, 25, 19, 20, 35, 13, 31, 28, 31, 3, 34, 24, 28, 30, 14, 7, 14, 24, 20, 32, 13, 3, 21, 15, 33, 11, 0, 27, 9, 22, 21, 5, 31, 6, 27, 28, 1, 20, 9, 17, 9, 28, 34, 10, 10, 5, 10, 9, 12, 12, 18, 7, 32, 26, 16, 14, 23, 27, 27, 28, 3, 3, 31, 1, 22, 5, 32, 3, 23, 10, 25, 32, 8, 36, 3, 36, 6, 27, 7, 15, 29, 35, 36, 5, 8, 0, 7, 19, 34, 9, 24, 16, 5, 36, 18, 15, 5, 35, 25, 2, 3, 7, 35, 33, 6, 29, 26, 32, 9, 18, 26, 28, 32, 5, 33, 9, 28, 10, 26, 27, 32, 27, 21, 25, 13, 31, 31, 18, 31, 20, 36, 26, 19, 23, 35, 29, 3, 1, 27, 31, 32, 17, 2, 24, 1, 32, 29, 36, 30, 22, 0, 3, 16, 4, 8, 11, 

Training Progress:  80%|████████████████████▊     | 4/5 [01:34<00:23, 23.67s/it]

Test Loss: 3.6429 | Test Accuracy: 1.79%
Epoch: 5/5
-------
Training Loss: 3.5974 | Training Accuracy: 3.08%
Labels: [34, 34, 0, 7, 31, 15, 19, 33, 12, 15, 9, 28, 33, 25, 20, 27, 5, 30, 7, 8, 32, 28, 27, 30, 5, 3, 16, 10, 34, 7, 7, 6, 29, 8, 11, 30, 17, 0, 13, 25, 13, 16, 14, 18, 0, 30, 10, 4, 14, 14, 0, 13, 23, 6, 23, 21, 27, 31, 19, 26, 0, 16, 21, 16, 1, 31, 10, 34, 4, 31, 27, 31, 20, 32, 14, 13, 20, 11, 10, 5, 26, 26, 22, 36, 10, 23, 27, 23, 7, 28, 7, 29, 24, 12, 26, 10, 23, 11, 25, 11, 18, 36, 5, 24, 18, 22, 8, 23, 14, 18, 33, 11, 11, 32, 24, 15, 0, 10, 2, 1, 32, 9, 34, 22, 1, 35, 29, 35, 29, 31, 24, 3, 24, 14, 31, 17, 11, 1, 2, 24, 4, 25, 18, 32, 7, 5, 28, 32, 32, 11, 5, 24, 18, 21, 8, 34, 23, 9, 33, 11, 17, 25, 9, 13, 17, 22, 21, 29, 32, 17, 29, 3, 15, 13, 25, 11, 35, 8, 14, 0, 6, 27, 35, 12, 19, 18, 17, 2, 7, 27, 18, 30, 24, 24, 7, 7, 13, 9, 27, 23, 24, 15, 6, 6, 11, 3, 23, 31, 29, 35, 3, 35, 5, 21, 19, 31, 5, 34, 29, 26, 22, 0, 26, 19, 1, 1, 16, 10, 14, 3, 26, 34, 19, 29, 27, 2

Training Progress: 100%|██████████████████████████| 5/5 [02:07<00:00, 25.51s/it]

Test Loss: 3.6491 | Test Accuracy: 1.53%



