$$a_{i}, s_{i} = f(\bar{o}_{:i})$$

In [None]:
import sys
sys.path.append('..')
from LATMOS import *
from gen_augment import state_to_embedding_stats, state_to_embedding_map

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def gen_augment_batch_trace_fast(original_state, original_embed, num_states, embedding_dim, random_sample_per_step=5):
    """
    Generate negative examples from a given trace using efficient batch operations,
    sampling embeddings from Gaussian distributions defined by state means and variances.
    Args:
        original_state: (batch_size, num_steps) tensor
        original_embed: (batch_size, num_steps, embedding_dim) tensor
        num_states: int
        embedding_dim: int
        random_sample_per_step: int
    Returns:
        tuple: (acceptance_traces, embedding_traces)
        acceptance_traces: (batch_size * random_sample_per_step * num_steps + batch_size, num_steps) tensor
        embedding_traces: (batch_size * random_sample_per_step * num_steps + batch_size, num_steps, embedding_dim) tensor
    """
    batch_size, num_steps = original_state.shape
    device = original_state.device
    
    # Create indices for each step
    steps = torch.arange(num_steps, device=device)
    # Create mask for acceptance (1s before step, 0s after)
    # Shape: (num_steps, num_steps)
    acceptance_mask = steps.unsqueeze(0) < steps.unsqueeze(1)
    
    # Expand acceptance mask for all samples and batches
    # Shape: (num_steps, random_sample_per_step, batch_size, num_steps)
    acceptance_traces = acceptance_mask.unsqueeze(1).unsqueeze(1).expand(
        num_steps, random_sample_per_step, batch_size, num_steps
    )
    
    # Expand original state for copying
    # Shape: (num_steps, random_sample_per_step, batch_size, num_steps)
    original_expanded = original_state.unsqueeze(0).unsqueeze(1).expand(
        num_steps, random_sample_per_step, batch_size, num_steps
    )
    
    # Generate random states
    # Shape: (num_steps, random_sample_per_step, batch_size, num_steps)
    random_states = torch.randint(0, num_states,
        (num_steps, random_sample_per_step, batch_size, num_steps),
        device=device
    )
    
    # Create mask for selecting between original and random states
    # Shape: (num_steps, random_sample_per_step, batch_size, num_steps)
    selection_mask = acceptance_mask.unsqueeze(1).unsqueeze(1).expand(
        num_steps, random_sample_per_step, batch_size, num_steps
    )
    
    # Combine original and random states using the mask
    # Shape: (num_steps, random_sample_per_step, batch_size, num_steps)
    state_traces = torch.where(
        selection_mask,
        original_expanded,
        random_states
    )
    
    # Reshape state traces
    # Shape: (num_steps * random_sample_per_step * batch_size, num_steps)
    state_traces = state_traces.reshape(-1, num_steps)
    
    # Convert states to embeddings using sampling from Gaussian distributions
    # First convert state_to_embedding_stats dict to tensors of means and variances
    # Shape: (num_states, embedding_dim)
    embedding_means = torch.stack([
        state_to_embedding_stats[i]['mean']
        for i in range(num_states)
    ]).to(device)
    
    # Shape: (num_states, embedding_dim)
    embedding_vars = torch.stack([
        state_to_embedding_stats[i]['std']
        for i in range(num_states)
    ]).to(device)
    
    # Get means and variances for all states
    # Shape: (num_steps * random_sample_per_step * batch_size, num_steps, embedding_dim)
    trace_means = embedding_means[state_traces]
    trace_std = embedding_vars[state_traces]

    # Sample from Gaussian distributions
    # Using reparameterization trick: z = mu + sigma * epsilon
    epsilon = torch.randn_like(trace_means)
    embedding_traces = trace_means + trace_std * epsilon
    
    # Reshape acceptance traces to match embedding_traces
    # Shape: (num_steps * random_sample_per_step * batch_size, num_steps)
    acceptance_traces = acceptance_traces.reshape(-1, num_steps)
    
    # Add original positive examples
    original_acceptance = torch.ones((batch_size, num_steps), device=device)
    acceptance_traces = torch.cat((original_acceptance, acceptance_traces), dim=0)
    embedding_traces = torch.cat((original_embed, embedding_traces), dim=0)
    state_traces = torch.cat((original_state, state_traces), dim=0)
    return acceptance_traces.to(torch.long), embedding_traces, state_traces

In [None]:
def gen_augment_batch_trace_slow(original_state, original_embed, num_states, embedding_dim, random_sample_per_step=5):
    """
    Generate negative examples from a given trace.
    Args:
    original_state: (batch_size, num_steps) tensor
    num_states: int
    Returns:
    tuple: (acceptance_traces, state_traces)
        acceptance_traces: (batch_size * random_sample_per_step * num_steps + 1, num_steps) tensor
        state_traces: (batch_size * random_sample_per_step * num_steps + 1, num_steps) tensor
    """
    batch_size, num_steps = original_state.shape
    device = original_state.device

    # Initialize tensors to collect traces
    acceptance_traces = torch.zeros((num_steps, random_sample_per_step, batch_size, num_steps), device=device)
    state_traces = torch.zeros((num_steps, random_sample_per_step, batch_size, num_steps), device=device, dtype=torch.int)

    for step in range(num_steps):
        random_states = torch.randint(0, num_states, (random_sample_per_step, batch_size, num_steps-step), device=device)
        acceptance_traces[step, :, :, :step] = 1
        state_traces     [step, :, :, :step] = original_state[:, :step]
        state_traces     [step, :, :, step:] = random_states

    # Reshape tensors
    acceptance_traces = acceptance_traces.view(-1, num_steps)
    state_traces = state_traces.view(-1, num_steps)

    # Concatenate with the original positive trace
    original_acceptance = torch.ones((batch_size, num_steps), device=device)
    acceptance_traces = torch.cat((original_acceptance, acceptance_traces), dim=0)

    embedding_traces = torch.zeros((state_traces.shape[0], state_traces.shape[1], embedding_dim))
    for i in range(1, state_traces.shape[0]):
        for j in range(state_traces.shape[1]):
            embedding_traces[i, j] = random.choice(state_to_embedding_map[state_traces[i, j].item()])
    embedding_traces = torch.cat((original_embed, embedding_traces.to(device)), dim=0)

    state_traces = torch.cat((original_state, state_traces), dim=0)
    return acceptance_traces.to(torch.long), embedding_traces, state_traces

In [None]:
def prepare_dataloaders(data, batch_size):
    dataloaders = {}
    for num_steps, (embedding_traces, acceptance_traces, state_traces) in data.items():
        if num_steps < 1:
            continue
        dataset = TensorDataset(embedding_traces.to(device), acceptance_traces.to(torch.long).to(device), state_traces.to(torch.long).to(device))
        dataloaders[num_steps] = DataLoader(dataset, batch_size, True)
    return dataloaders

# Load and prepare data
random_sample_per_step = 0
data_train = torch.load(f'data_augment/segment_train_{random_sample_per_step}.pt')
data_val = torch.load(f'data_augment/segment_val_{random_sample_per_step}.pt')

# embed_size = 512
# embed_size = 1536
embed_size = 2304

num_states = 629
batch_size = 2**12

train_loaders = prepare_dataloaders(data_train, batch_size)
val_loaders = prepare_dataloaders(data_val, batch_size)

---

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import random

def train_model_egoexo4d(model, train_loaders, val_loaders, 
                         num_epochs, learning_rate, patience):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    num_steps_list = list(train_loaders.keys())
    
    # Early stopping variables
    best_loss = float('inf')
    patience_counter = 0

    # Metrics recording
    train_loss_history = []
    train_metrics_history = []
    val_metrics_history = []
    val_epochs = []

    bar = tqdm(range(num_epochs))
    for epoch in bar:
        model.train()
        total_loss = 0

        acceptance_TP, acceptance_TN, acceptance_FP, acceptance_FN = 0, 0, 0, 0
        state_correct = 0
        total_step = 0
        
        random.shuffle(num_steps_list)
        for num_step in num_steps_list:
            loader = train_loaders[num_step]
            
            for batch_embedding, batch_acceptance, batch_state in loader:
                batch_acceptance, batch_embedding, batch_state = gen_augment_batch_trace_fast(batch_state, batch_embedding, num_states, 
                                                                            embed_size, random_sample_per_step=1)
                batch_acceptance = batch_acceptance.to(device)
                batch_embedding = batch_embedding.to(device)
                batch_state = batch_state.to(device)
                for _ in range(4):
                    optimizer.zero_grad()
                    # Forward pass
                    state_output, acceptance_output = model(batch_embedding)
                    
                    # Compute losses
                    acceptance_output_flat = acceptance_output.view(-1, 2)
                    batch_acceptance_flat = batch_acceptance.view(-1)
                    acceptance_loss = criterion(acceptance_output_flat, batch_acceptance_flat)
                    
                    state_output_flat = state_output.view(-1, num_states)
                    batch_state_flat = batch_state.view(-1)
                    state_loss = criterion(state_output_flat, batch_state_flat)
                    loss = acceptance_loss + state_loss
                    
                    # loss = state_loss
                    loss.backward()
                    optimizer.step()

                total_loss += loss.item()

                # Calculate acceptance accuracy
                predicted_acceptance = acceptance_output.argmax(dim=-1).bool()
                batch_acceptance = batch_acceptance.bool() if batch_acceptance.dtype != torch.bool else batch_acceptance
                acceptance_TP += ( predicted_acceptance &  batch_acceptance).sum().item()
                acceptance_TN += (~predicted_acceptance & ~batch_acceptance).sum().item()
                acceptance_FP += ( predicted_acceptance & ~batch_acceptance).sum().item()
                acceptance_FN += (~predicted_acceptance &  batch_acceptance).sum().item()
                # Calculate state accuracy
                predicted_states = state_output.argmax(dim=-1)
                state_correct += (predicted_states == batch_state).sum().item()

                total_step += batch_acceptance.size(0) * batch_acceptance.size(1)

                del batch_acceptance, batch_state, batch_embedding
            
        # Calculate state accuracy 
        state_accuracy = state_correct / total_step if total_step > 0 else 0

        train_loss_history.append(total_loss / total_step)
        train_metrics_history.append({
            'TP': acceptance_TP / total_step,
            'TN': acceptance_TN / total_step,
            'FP': acceptance_FP / total_step,
            'FN': acceptance_FN / total_step,
            'state_accuracy': state_accuracy
        })

        # Early stopping check
        if total_loss < best_loss:
            best_loss = total_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping triggered after {epoch + 1} epochs")
                break

        # Validation
        if (epoch + 1) % 1 == 0:
            val_acceptance_TP, val_acceptance_TN, val_acceptance_FP, val_acceptance_FN, val_state_accuracy = evaluate_model_egoexo4d(model, val_loaders)
            val_epochs.append(epoch + 1)
            val_metrics_history.append({
                'TP': val_acceptance_TP,
                'TN': val_acceptance_TN,
                'FP': val_acceptance_FP,
                'FN': val_acceptance_FN,
                'state_accuracy': val_state_accuracy
            })
            print({
                'TP': val_acceptance_TP,
                'TN': val_acceptance_TN,
                'FP': val_acceptance_FP,
                'FN': val_acceptance_FN,
                'state_accuracy': val_state_accuracy
            })

        bar.set_postfix({"Average Loss": total_loss / total_step, 
                         'train_acceptance_TP': acceptance_TP / total_step,
                         'train_acceptance_TN': acceptance_TN / total_step,
                         'train_acceptance_FP': acceptance_FP / total_step,
                         'train_acceptance_FN': acceptance_FN / total_step,
                         'train_state_accuracy': state_accuracy})

    return train_loss_history, train_metrics_history, val_epochs, val_metrics_history

def evaluate_model_egoexo4d(model, val_loaders):
    model.eval()
    acceptance_TP, acceptance_TN, acceptance_FP, acceptance_FN = 0, 0, 0, 0
    state_correct = 0
    total_steps = 0
    
    with torch.no_grad():
        for _, loader in val_loaders.items():
            for batch_embedding, batch_acceptance, batch_state in loader:
                batch_acceptance, batch_embedding, batch_state = gen_augment_batch_trace_slow(batch_state, batch_embedding, num_states, embed_size, random_sample_per_step=1)
                batch_acceptance = batch_acceptance.to(device)
                batch_embedding = batch_embedding.to(device)
                batch_state = batch_state.to(device)

                state_output, acceptance_output = model(batch_embedding)
                
                # Calculate acceptance accuracy
                predicted_acceptance = acceptance_output.argmax(dim=-1).bool()
                batch_acceptance = batch_acceptance.bool() if batch_acceptance.dtype != torch.bool else batch_acceptance
                acceptance_TP += ( predicted_acceptance &  batch_acceptance).sum().item()
                acceptance_TN += (~predicted_acceptance & ~batch_acceptance).sum().item()
                acceptance_FP += ( predicted_acceptance & ~batch_acceptance).sum().item()
                acceptance_FN += (~predicted_acceptance &  batch_acceptance).sum().item()

                # Calculate state accuracy
                _, predicted_states = torch.max(state_output, -1)
                state_correct += (predicted_states == batch_state).sum().item()
                
                total_steps += batch_acceptance.size(0) * batch_acceptance.size(1)

                del batch_acceptance, batch_state

    state_accuracy = state_correct / total_steps if total_steps > 0 else 0
    
    return (acceptance_TP/total_steps, 
            acceptance_TN/total_steps, 
            acceptance_FP/total_steps, 
            acceptance_FN/total_steps,
            state_accuracy)

---

In [None]:
# Initialize model
model = create_model('attention',
                     input_size  = embed_size, 
                     hidden_size = embed_size, 
                     output_size = num_states, 
                     device=device)
model.get_model_size()

In [None]:
# Train the model
losses = train_model_egoexo4d(model, train_loaders, val_loaders, 
                     num_epochs=80, learning_rate=1e-5,
                     patience=10)