$$y_i = f(x_i)$$

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

In [None]:
data_train = torch.load('data_augment/segment_train_0.pt')
data_val = torch.load('data_augment/segment_val_0.pt')

In [None]:
def extract_data(data_dict):
    embeddings_list = []
    states_list = []
    
    # Iterate through all sequences in the dictionary
    for sequences in data_dict.values():
        embeddings = sequences[0]
        states = sequences[2]

        embeddings_list.append(embeddings.reshape(-1, embeddings.shape[-1]))
        states_list.append(states.reshape(-1))
    
    # Concatenate all embeddings and states
    embeddings = torch.cat(embeddings_list, dim=0)
    states = torch.cat(states_list, dim=0)
    
    return embeddings, states

train_embeddings, train_states = extract_data(data_train)
val_embeddings, val_states = extract_data(data_val)

In [None]:
import torch.nn.functional as F

num_states = 629
batch_size = 2**14

train_dataset = TensorDataset(train_embeddings, train_states)
val_dataset = TensorDataset(val_embeddings, val_states)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_dim):
        super(MLPClassifier, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size, output_dim)
        )
    
    def forward(self, x):
        return self.layers(x)

input_size = train_dataset.tensors[0].shape[-1] 
model = MLPClassifier(input_size=input_size, hidden_size=128, output_dim=num_states).to(device)

In [None]:
def train_model(model, train_loader, val_loader, 
                num_epochs, lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    train_accs = []
    val_accs = []

    bar = tqdm(range(num_epochs))
    for _ in bar:
        # Training
        model.train()
        correct = 0
        total = 0
        
        for embeddings, states in train_loader:
            embeddings, states = embeddings.to(device), states.to(device)
            
            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs, states)
            # l2_lambda = 1e-5
            # l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            # loss = loss + l2_lambda * l2_norm
            loss.backward()
            optimizer.step()
            
            _, predicted = outputs.max(1)
            total += states.size(0)
            correct += predicted.eq(states).sum().item()
        
        train_acc = correct / total
        train_accs.append(train_acc)

        # Validation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for embeddings, states in val_loader:
                embeddings, states = embeddings.to(device), states.to(device)
                outputs = model(embeddings)
                
                _, predicted = outputs.max(1)
                total += states.size(0)
                correct += predicted.eq(states).sum().item()
        
        val_acc = correct / total
        val_accs.append(val_acc)
        
        bar.set_postfix({
            'Train Acc': train_acc,
            'Val Acc': val_acc
        })
    return train_accs, val_accs

train_accs, val_accs = train_model(model, train_loader, val_loader, 600, 1e-3)

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(len(train_accs)), train_accs, label='train')
plt.plot(range(len(val_accs)), val_accs, label='val')
plt.legend()
plt.title('video')
plt.show()