In [None]:
!unzip abstract.zip

In [None]:
!pip install torch torchvision matplotlib sympy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tasks import DMSDataset, OneBackDataset, TwoBackDataset, ThreeBackDataset, CtxDMSDataset, InterDMSDataset
from models.RNN_model import CustomRNN
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

def train_model(model, train_dataloader, val_dataloader, num_epochs=2000, learning_rate=0.001, verbose=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Lists to track loss and accuracy
    losses = []
    accuracies = []
    val_losses = []
    val_accuracies = []

    best_accuracy = 0.0
    epochs_without_improvement = 0

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_action_accuracy = 0.0
        running_no_action_accuracy = 0.0
        running_val_action_accuracy = 0.0
        running_val_no_action_accuracy = 0.0
        total_train_batches = 0
        total_val_batches = 0

        # Training loop
        model.train()
        for i, (inputs, labels, task_index) in enumerate(train_dataloader):
            inputs, labels, task_index = inputs, labels, task_index

            # Extend task_index and concatenate with inputs
            task_index_extended = task_index.unsqueeze(1)
            task_index_repeated = task_index_extended.repeat(1, inputs.shape[1], 1)
            concatenated = torch.cat((inputs, task_index_repeated), dim=-1).float()

            optimizer.zero_grad()

            # Forward pass
            outputs, _ = model(concatenated)
            softmax_outputs = F.softmax(outputs, dim=-1)
            predicted_actions = torch.argmax(softmax_outputs, dim=-1)

            # Calculate separate action and no-action accuracies
            action_mask = (labels == 0) | (labels == 1)
            no_action_mask = (labels == 2)

            action_correct = (predicted_actions[action_mask] == labels[action_mask]).float().sum()
            action_total = action_mask.float().sum()
            action_accuracy = (action_correct / action_total).item() if action_total > 0 else 0.0

            no_action_correct = (predicted_actions[no_action_mask] == labels[no_action_mask]).float().sum()
            no_action_total = no_action_mask.float().sum()
            no_action_accuracy = (no_action_correct / no_action_total).item() if no_action_total > 0 else 0.0

            # Calculate loss
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1).long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_action_accuracy += action_accuracy
            running_no_action_accuracy += no_action_accuracy
            total_train_batches += 1

        # Validation loop
        model.eval()
        for i, (inputs, labels, task_index) in enumerate(val_dataloader):
            inputs, labels, task_index = inputs, labels, task_index

            task_index_extended = task_index.unsqueeze(1)
            task_index_repeated = task_index_extended.repeat(1, inputs.shape[1], 1)
            concatenated = torch.cat((inputs, task_index_repeated), dim=-1).float()

            with torch.no_grad():
                outputs, _ = model(concatenated)
            softmax_outputs = F.softmax(outputs, dim=-1)
            predicted_actions = torch.argmax(softmax_outputs, dim=-1)

            # Calculate separate action and no-action accuracies
            action_mask = (labels == 0) | (labels == 1)
            no_action_mask = (labels == 2)

            action_correct = (predicted_actions[action_mask] == labels[action_mask]).float().sum()
            action_total = action_mask.float().sum()
            action_accuracy = (action_correct / action_total).item() if action_total > 0 else 0.0

            no_action_correct = (predicted_actions[no_action_mask] == labels[no_action_mask]).float().sum()
            no_action_total = no_action_mask.float().sum()
            no_action_accuracy = (no_action_correct / no_action_total).item() if no_action_total > 0 else 0.0

            running_val_action_accuracy += action_accuracy
            running_val_no_action_accuracy += no_action_accuracy
            total_val_batches += 1

        # Calculate epoch statistics
        epoch_loss = running_loss / total_train_batches
        epoch_action_accuracy = running_action_accuracy / total_train_batches
        epoch_no_action_accuracy = running_no_action_accuracy / total_train_batches

        epoch_val_action_accuracy = running_val_action_accuracy / total_val_batches
        epoch_val_no_action_accuracy = running_val_no_action_accuracy / total_val_batches

        losses.append(epoch_loss)
        accuracies.append(epoch_action_accuracy)
        val_accuracies.append(epoch_val_action_accuracy)

        if verbose:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, '
                  f'Action Accuracy: {epoch_action_accuracy * 100:.2f}%, '
                  f'No Action Accuracy: {epoch_no_action_accuracy * 100:.2f}%, '
                  f'Validation Action Accuracy: {epoch_val_action_accuracy * 100:.2f}%, '
                  f'Validation No Action Accuracy: {epoch_val_no_action_accuracy * 100:.2f}%')

        if len(accuracies) > 100 and np.mean(accuracies[-10:]) >= 0.99:
            # Stop training if accuracy saturates
            break

    print(f'Final Epoch [{epoch+1}/{num_epochs}], '
          f'Loss: {epoch_loss:.4f}, '
          f'Action Accuracy: {epoch_action_accuracy * 100:.2f}%, '
          f'No Action Accuracy: {epoch_no_action_accuracy * 100:.2f}%, '
          f'Validation Action Accuracy: {epoch_val_action_accuracy * 100:.2f}%, '
          f'Validation No Action Accuracy: {epoch_val_no_action_accuracy * 100:.2f}%')

    # Plot loss and accuracy
    plt.figure(figsize=(5, 5))

    plt.subplot(2, 2, 1)
    plt.plot(losses, label='Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(accuracies, label='Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training Accuracy')
    plt.legend()

    plt.subplot(2, 2, 3)
    plt.plot(val_losses, label='Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Validation Loss')
    plt.legend()

    plt.subplot(2, 2, 4)
    plt.plot(val_accuracies, label='Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()

    plt.tight_layout()

# Load the dataset and initialize DataLoader
batch_size = 4096

from itertools import product
single_feature_datasets = {
    f"{prefix}_{feature}": (dataset_class, {"feature": feature})
    for prefix, dataset_class in [
        ("dms", DMSDataset),
        ("oneback", OneBackDataset),
        ("twoback", TwoBackDataset),
        ("threeback", ThreeBackDataset),
    ]
    for feature in ["category", "identity", "position"]
}

ctx_dms_datasets = {
    'ctxdms_category_identity_position': (CtxDMSDataset,
      {"features": ["category", "identity", "position"]}
    ),
    'ctxdms_position_category_identity': (CtxDMSDataset,
      {"features": ["position", "category", "identity"]}
    ),
    'ctxdms_position_identity_category': (CtxDMSDataset,
      {"features": ["position", "identity", "category"]}
    ),
    'ctxdms_identity_position_category': (CtxDMSDataset,
      {"features": ["identity", "position", "category"]}
    ),
}

inter_dms_datasets = {
    f"interdms_{pattern}_{feature1}_{feature2}": (
        InterDMSDataset,
        {"pattern": pattern, "features": [feature1, feature2]},
    )
    for pattern in ["AABB", "ABBA", "ABAB"]
    for feature1 in ["category", "identity", "position"]
    for feature2 in ["category", "identity", "position"]
}

dataloaders = {**single_feature_datasets, **ctx_dms_datasets, **inter_dms_datasets}

# Initialize the model with the desired RNN type
input_size = 16 + 43 # Since each sequence element is a scalar, number of ids and task index lenght
hidden_size = 4 # You can adjust the hidden layer size
output_size = 3  # Three possible actions
rnn_types = ['RNN']

for name, (dataset_class, kwargs) in dataloaders.items():
  train_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, add_noise=False, **kwargs)  # Noise during training
  val_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, add_noise=False, **kwargs)  # No noise during validation
  train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
  for rnn_type in rnn_types:
    print(rnn_type, name)
    model = CustomRNN(input_size, hidden_size, output_size, rnn_type)
    train_model(model, train_dataloader, val_dataloader, num_epochs=2, learning_rate=0.001)
    plt.savefig(f'figs/{rnn_type}_{name}.png')
  print('-' * 80)


In [5]:
import seaborn as sns

def train_model(model, train_dataloader, val_dataloader, num_epochs=2000, learning_rate=0.001, verbose=False):
    # Remove the weighted loss, using standard CrossEntropyLoss instead
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    losses = []
    accuracies = []
    val_losses = []
    val_accuracies = []

    task_names = [
        f"{prefix}_{feature}"
        for prefix in ["dms", "oneback", "twoback", "threeback"]
        for feature in ["category", "identity", "position"]
    ]
    task_names += ['ctxdms_category_identity_position',
                  'ctxdms_position_category_identity', 'ctxdms_position_identity_category', 'ctx_identity_position_category']
    task_names += [
        f"interdms_{pattern}_{feature1}_{feature2}"
        for pattern in ["AABB", "ABBA", "ABAB"]
        for feature1 in ["category", "identity", "position"]
        for feature2 in ["category", "identity", "position"]
    ]

    # Initialize confusion matrices for each task
    task_confusion_matrices = {task_name: np.zeros((3, 3), dtype=int) for task_name in task_names}
    task_accuracies = {task_name: [] for task_name in task_names}
    dms_confusion = {key: [[[] for _ in range(32)] for _ in range(32)] for key in task_names}
    dms_embs = {key: [[] for _ in range(32)] for key in task_names}

    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_action_accuracy = 0.0
        running_no_action_accuracy = 0.0
        running_val_loss = 0.0
        running_val_action_accuracy = 0.0
        running_val_no_action_accuracy = 0.0
        total_train_batches = 0
        total_val_batches = 0

        # Training loop
        model.train()
        for i, (inputs, labels, task_index) in enumerate(train_dataloader):
            inputs, labels, task_index = inputs.to('cuda'), labels.to('cuda'), task_index.to('cuda')

            task_index_extended = task_index.unsqueeze(1)
            task_index_repeated = task_index_extended.repeat(1, inputs.shape[1], 1)
            concatenated = torch.cat((inputs, task_index_repeated), dim=-1)
            concatenated = concatenated.float()

            optimizer.zero_grad()
            outputs, _ = model(concatenated)

            softmax_outputs = F.softmax(outputs, dim=-1)
            predicted_actions = torch.argmax(softmax_outputs, dim=-1)

            # Calculate action accuracy (for labels 0 and 1)
            action_mask = (labels == 0) | (labels == 1)
            no_action_mask = (labels == 2)

            action_correct = (predicted_actions[action_mask] == labels[action_mask]).float().sum()
            action_total = action_mask.float().sum()
            action_accuracy = (action_correct / action_total).item() if action_total > 0 else 0.0

            no_action_correct = (predicted_actions[no_action_mask] == labels[no_action_mask]).float().sum()
            no_action_total = no_action_mask.float().sum()
            no_action_accuracy = (no_action_correct / no_action_total).item() if no_action_total > 0 else 0.0

            # Calculate loss
            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1).long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_action_accuracy += action_accuracy
            running_no_action_accuracy += no_action_accuracy
            total_train_batches += 1

        # Validation loop
        model.eval()
        for i, (inputs, labels, task_index) in enumerate(val_dataloader):
            inputs, labels, task_index = inputs.to('cuda'), labels.to('cuda'), task_index.to('cuda')

            task_index_extended = task_index.unsqueeze(1)
            task_index_repeated = task_index_extended.repeat(1, inputs.shape[1], 1)
            concatenated = torch.cat((inputs, task_index_repeated), dim=-1)
            concatenated = concatenated.float()

            with torch.no_grad():
                outputs, _ = model(concatenated)

            softmax_outputs = F.softmax(outputs, dim=-1)
            predicted_actions = torch.argmax(softmax_outputs, dim=-1)

            action_correct = (predicted_actions[action_mask] == labels[action_mask]).float().sum()
            action_total = action_mask.float().sum()
            action_accuracy = (action_correct / action_total).item() if action_total > 0 else 0.0

            no_action_correct = (predicted_actions[no_action_mask] == labels[no_action_mask]).float().sum()
            no_action_total = no_action_mask.float().sum()
            no_action_accuracy = (no_action_correct / no_action_total).item() if no_action_total > 0 else 0.0

            running_val_action_accuracy += action_accuracy
            running_val_no_action_accuracy += no_action_accuracy
            total_val_batches += 1

            loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1).long())
            running_val_loss += loss.item()

            for j in range(labels.shape[0]):
              task_id = torch.argmax(task_index[j]).item()
              task_name_str = task_names[task_id]
              task_correct = (predicted_actions[j] == labels[j]).float().sum().item()
              task_total = labels[j].shape[0]
              task_accuracy = task_correct / task_total if task_total > 0 else 0.0
              task_accuracies[task_name_str].append(task_accuracy)

            print("Action Accuracies for all tasks:")
            for task_name, accuracies in task_accuracies.items():
              avg_accuracy = np.mean(accuracies) if len(accuracies) > 0 else 0.0
              print(f"Task: {task_name}, Accuracy: {avg_accuracy:.4f}")

        # Epoch statistics
        epoch_loss = running_loss / total_train_batches
        epoch_action_accuracy = running_action_accuracy / total_train_batches
        epoch_no_action_accuracy = running_no_action_accuracy / total_train_batches
        epoch_val_loss = running_val_loss / total_val_batches
        epoch_val_action_accuracy = running_val_action_accuracy / total_val_batches
        epoch_val_no_action_accuracy = running_val_no_action_accuracy / total_val_batches

        losses.append(epoch_loss)
        accuracies.append(epoch_action_accuracy)
        val_losses.append(epoch_val_loss)
        val_accuracies.append(epoch_val_action_accuracy)

        if verbose:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, '
                  f'Action Accuracy: {epoch_action_accuracy * 100:.2f}%, '
                  f'No Action Accuracy: {epoch_no_action_accuracy * 100:.2f}%, '
                  f'Validation Loss: {epoch_val_loss:.4f}, '
                  f'Validation Action Accuracy: {epoch_val_action_accuracy * 100:.2f}%, '
                  f'Validation No Action Accuracy: {epoch_val_no_action_accuracy * 100:.2f}%')

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, '
          f'Action Accuracy: {epoch_action_accuracy * 100:.2f}%, '
          f'No Action Accuracy: {epoch_no_action_accuracy * 100:.2f}%, '
          f'Validation Loss: {epoch_val_loss:.4f}, '
          f'Validation Action Accuracy: {epoch_val_action_accuracy * 100:.2f}%, '
          f'Validation No Action Accuracy: {epoch_val_no_action_accuracy * 100:.2f}%')

    avg_task_accuracies = {task: np.mean(acc) for task, acc in task_accuracies.items()}

    # Plot bar chart of per-task accuracies
    tasks = list(avg_task_accuracies.keys())
    accuracies = list(avg_task_accuracies.values())

    plt.figure(figsize=(10, 6))
    plt.barh(tasks, accuracies)
    plt.xlabel('Accuracy')
    plt.ylabel('Task')
    plt.title('Per-task Accuracy')
    plt.tight_layout()
    plt.show()


    # Confusion matrices update for each task
    for i, (inputs, labels, task_index) in enumerate(val_dataloader):
        inputs, labels, task_index = inputs.to('cuda'), labels.to('cuda'), task_index.to('cuda')

        task_index_extended = task_index.unsqueeze(1)
        task_index_repeated = task_index_extended.repeat(1, inputs.shape[1], 1)
        concatenated = torch.cat((inputs, task_index_repeated), dim=-1)
        concatenated = concatenated.float()

        with torch.no_grad():
            outputs, h = model(concatenated)

        softmax_outputs = F.softmax(outputs, dim=-1)
        predicted_actions = torch.argmax(softmax_outputs, dim=-1)

        # Update confusion matrix for each task
        for j in range(labels.shape[0]):
            task_id = torch.argmax(task_index[j]).item()
            task_name_str = task_names[task_id]
            for k in range(labels.shape[1]):
                true_label = labels[j][k].item()
                predicted_label = predicted_actions[j][k].item()
                task_confusion_matrices[task_name_str][int(true_label), int(predicted_label)] += 1

            if task_name_str in [
                'dms_category', 'dms_identity',
                'interdms_AABB_category_category', 'interdms_AABB_category_identity', 'interdms_AABB_category_position',
                'interdms_AABB_identity_category', 'interdms_AABB_identity_identity', 'interdms_AABB_identity_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][1][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][1][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][1][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in [
                'dms_position',
                'interdms_AABB_position_category', 'interdms_AABB_position_identity', 'interdms_AABB_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][1][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][1][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][1][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in ['oneback_category', 'oneback_identity']:
              for l in range(5):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+1][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+1][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+1][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str == 'oneback_position':
              for l in range(5):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+1][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+1][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+1][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str in ['twoback_category', 'twoback_identity']:
              for l in range(4):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+2][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+2][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+2][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str == 'twoback_position':
              for l in range(4):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+2][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+2][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+2][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str in ['threeback_category', 'threeback_identity']:
              for l in range(3):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+3][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+3][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+3][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str == 'threeback_position':
              for l in range(3):
                identity1 = torch.argmax(inputs[j][l][4:12]).item()
                identity2 = torch.argmax(inputs[j][l+3][4:12]).item()
                position1 = torch.argmax(inputs[j][l][12:16]).item()
                position2 = torch.argmax(inputs[j][l+3][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][l+3][1].item())
                dms_embs[task_name_str][id1].append(h[j][l].detach().cpu().numpy())
            if task_name_str in [
                'interdms_AABB_category_category', 'interdms_AABB_identity_category', 'interdms_AABB_position_category',
                'interdms_AABB_category_identity', 'interdms_AABB_identity_identity', 'interdms_AABB_position_identity',
                ]:
                identity1 = torch.argmax(inputs[j][2][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][2][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][2].detach().cpu().numpy())
            if task_name_str in [
                'interdms_AABB_category_position', 'interdms_AABB_identity_position', 'interdms_AABB_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][2][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][2][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][2].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABBA_category_category', 'interdms_ABBA_category_identity', 'interdms_ABBA_category_position',
                'interdms_ABBA_identity_category', 'interdms_ABBA_identity_identity', 'interdms_ABBA_identity_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABBA_position_category', 'interdms_ABBA_position_identity', 'interdms_ABBA_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABBA_category_category', 'interdms_ABBA_identity_category', 'interdms_ABBA_position_category',
                'interdms_ABBA_category_identity', 'interdms_ABBA_identity_identity', 'interdms_ABBA_position_identity',
                ]:
                identity1 = torch.argmax(inputs[j][1][4:12]).item()
                identity2 = torch.argmax(inputs[j][2][4:12]).item()
                position1 = torch.argmax(inputs[j][1][12:16]).item()
                position2 = torch.argmax(inputs[j][2][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][2][1].item())
                dms_embs[task_name_str][id1].append(h[j][1].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABBA_category_position', 'interdms_ABBA_identity_position', 'interdms_ABBA_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][1][4:12]).item()
                identity2 = torch.argmax(inputs[j][2][4:12]).item()
                position1 = torch.argmax(inputs[j][1][12:16]).item()
                position2 = torch.argmax(inputs[j][2][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][2][1].item())
                dms_embs[task_name_str][id1].append(h[j][1].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABAB_category_category', 'interdms_ABAB_category_identity', 'interdms_ABAB_category_position',
                'interdms_ABAB_identity_category', 'interdms_ABAB_identity_identity', 'interdms_ABAB_identity_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][2][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][2][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][2][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABAB_position_category', 'interdms_ABAB_position_identity', 'interdms_ABAB_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][0][4:12]).item()
                identity2 = torch.argmax(inputs[j][2][4:12]).item()
                position1 = torch.argmax(inputs[j][0][12:16]).item()
                position2 = torch.argmax(inputs[j][2][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][2][1].item())
                dms_embs[task_name_str][id1].append(h[j][0].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABAB_category_category', 'interdms_ABAB_identity_category', 'interdms_ABAB_position_category',
                'interdms_ABAB_category_identity', 'interdms_ABAB_identity_identity', 'interdms_ABAB_position_identity',
                ]:
                identity1 = torch.argmax(inputs[j][1][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][1][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = identity1 * 4 + position1
                id2 = identity2 * 4 + position2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][1].detach().cpu().numpy())
            if task_name_str in [
                'interdms_ABAB_category_position', 'interdms_ABAB_identity_position', 'interdms_ABAB_position_position',
                ]:
                identity1 = torch.argmax(inputs[j][1][4:12]).item()
                identity2 = torch.argmax(inputs[j][3][4:12]).item()
                position1 = torch.argmax(inputs[j][1][12:16]).item()
                position2 = torch.argmax(inputs[j][3][12:16]).item()
                id1 = position1 * 8 + identity1
                id2 = position2 * 8 + identity2
                dms_confusion[task_name_str][id1][id2].append(softmax_outputs[j][3][1].item())
                dms_embs[task_name_str][id1].append(h[j][1].detach().cpu().numpy())

    # Plot confusion matrices
    for task_id, cm in task_confusion_matrices.items():
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix for Task {task_id}')
        plt.show()

# Inside the train_model function, after generating the confusion and embedding heatmaps:
    for task in task_names:
        lim = 32
        for i in range(lim):
            for j in range(lim):
                if len(dms_confusion[task][i][j]) > 0:
                    dms_confusion[task][i][j] = np.mean(dms_confusion[task][i][j])
                else:
                    dms_confusion[task][i][j] = 0
        dms_confusion[task] = np.array(dms_confusion[task])

        # Save the probability matrix for the current task in .npz format
        np.savez_compressed(f'probability_matrices/{task}_probability_matrix.npz', matrix=dms_confusion[task])

        plt.figure(figsize=(16, 12))
        sns.heatmap(dms_confusion[task])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Probability Matrix for Task {task}')
        plt.show()

        emb_heatmap = np.zeros((lim, lim))
        for i in range(lim):
            for j in range(lim):
                emb1 = np.mean(dms_embs[task][i], axis=0)
                emb2 = np.mean(dms_embs[task][j], axis=0)
                emb_heatmap[i, j] = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))

        # Save the embedding similarity heatmap for the current task in .npz format
        np.savez_compressed(f'probability_matrices/{task}_embedding_similarity.npz', matrix=emb_heatmap)

        plt.figure(figsize=(16, 12))
        sns.heatmap(emb_heatmap)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Cosine Similarity of {task} embeddings')
        plt.show()

    plt.figure(figsize=(5, 5))

    plt.subplot(2, 2, 1)
    plt.plot(losses, label='Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()

    plt.subplot(2, 2, 2)
    plt.plot(accuracies, label='Action Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Training Action Accuracy')
    plt.legend()

    plt.subplot(2, 2, 3)
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Validation Loss')
    plt.legend()

    plt.subplot(2, 2, 4)
    plt.plot(val_accuracies, label='Validation Action Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Action Accuracy')
    plt.legend()

    plt.tight_layout()


In [None]:
train_datasets = []
val_datasets = []

batch_size = 4096

for name, (dataset_class, kwargs) in dataloaders.items():
    train_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, pad_to=6, add_noise=True, **kwargs)
    train_datasets.append(train_dataset)

    val_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, pad_to=6, add_noise=True, **kwargs)
    val_datasets.append(val_dataset)

# merge datasets
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

hidden_states = [256]

for hidden_size in hidden_states:
    for rnn_type in rnn_types:
        print(rnn_type, hidden_size)
        model = CustomRNN(input_size, hidden_size, output_size, rnn_type).to('cuda')
        train_model(model, train_dataloader, val_dataloader, num_epochs=200, learning_rate=0.001, verbose=True)
        plt.savefig(f'figs/{rnn_type}_{hidden_size}_mixed_0.png')
        print('-' * 80)

In [None]:
train_datasets = []
val_datasets = []

batch_size = 4096

for name, (dataset_class, kwargs) in dataloaders.items():
    train_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, pad_to=6, add_noise=True, **kwargs)
    train_datasets.append(train_dataset)

    val_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0.15, pad_to=6, add_noise=True, **kwargs)
    val_datasets.append(val_dataset)

# merge datasets
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

hidden_states = [256]

for hidden_size in hidden_states:
    for rnn_type in rnn_types:
        print(rnn_type, hidden_size)
        model = CustomRNN(input_size, hidden_size, output_size, rnn_type).to('cuda')
        train_model(model, train_dataloader, val_dataloader, num_epochs=200, learning_rate=0.001, verbose=True)
        plt.savefig(f'figs/{rnn_type}_{hidden_size}_mixed_0.15.png')
        print('-' * 80)

In [None]:
train_datasets = []
val_datasets = []

batch_size = 4096

for name, (dataset_class, kwargs) in dataloaders.items():
    train_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0, pad_to=6, add_noise=True, **kwargs)
    train_datasets.append(train_dataset)

    val_dataset = dataset_class(dataset_size=batch_size, category_size=4, identity_size=2, std=0.35, pad_to=6, add_noise=True, **kwargs)
    val_datasets.append(val_dataset)

# merge datasets
train_dataset = torch.utils.data.ConcatDataset(train_datasets)
val_dataset = torch.utils.data.ConcatDataset(val_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

hidden_states = [256]

for hidden_size in hidden_states:
    for rnn_type in rnn_types:
        print(rnn_type, hidden_size)
        model = CustomRNN(input_size, hidden_size, output_size, rnn_type).to('cuda')
        train_model(model, train_dataloader, val_dataloader, num_epochs=150, learning_rate=0.001, verbose=True)
        plt.savefig(f'figs/{rnn_type}_{hidden_size}_mixed_0.35.png')
        print('-' * 80)