In [None]:
# Use Minari to generate episode data for conditional sequence modelling task episodes


In [None]:
"""Minimal version of S4D with extra options and features stripped out, for pedagogical purposes."""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from src.models.nn import DropoutNd

class S4DKernel(nn.Module):
    """Generate convolution kernel from diagonal SSM parameters."""

    def __init__(self, d_model, N=64, dt_min=0.001, dt_max=0.1, lr=None):
        super().__init__()
        # Generate dt
        H = d_model
        log_dt = torch.rand(H) * (
            math.log(dt_max) - math.log(dt_min)
        ) + math.log(dt_min)

        C = torch.randn(H, N // 2, dtype=torch.cfloat)
        self.C = nn.Parameter(torch.view_as_real(C))
        self.register("log_dt", log_dt, lr)

        log_A_real = torch.log(0.5 * torch.ones(H, N//2))
        A_imag = math.pi * repeat(torch.arange(N//2), 'n -> h n', h=H)
        self.register("log_A_real", log_A_real, lr)
        self.register("A_imag", A_imag, lr)

    def forward(self, L):
        """
        returns: (..., c, L) where c is number of channels (default 1)
        """

        # Materialize parameters
        dt = torch.exp(self.log_dt) # (H)
        C = torch.view_as_complex(self.C) # (H N)
        A = -torch.exp(self.log_A_real) + 1j * self.A_imag # (H N)

        # Vandermonde multiplication
        dtA = A * dt.unsqueeze(-1)  # (H N)
        K = dtA.unsqueeze(-1) * torch.arange(L, device=A.device) # (H N L)
        C = C * (torch.exp(dtA)-1.) / A
        K = 2 * torch.einsum('hn, hnl -> hl', C, torch.exp(K)).real

        return K

    def register(self, name, tensor, lr=None):
        """Register a tensor with a configurable learning rate and 0 weight decay"""

        if lr == 0.0:
            self.register_buffer(name, tensor)
        else:
            self.register_parameter(name, nn.Parameter(tensor))

            optim = {"weight_decay": 0.0}
            if lr is not None: optim["lr"] = lr
            setattr(getattr(self, name), "_optim", optim)


class S4D(nn.Module):
    def __init__(self, d_model, d_state=64, dropout=0.0, transposed=True, **kernel_args):
        super().__init__()

        self.h = d_model
        self.n = d_state
        self.d_output = self.h
        self.transposed = transposed

        self.D = nn.Parameter(torch.randn(self.h))

        # SSM Kernel
        self.kernel = S4DKernel(self.h, N=self.n, **kernel_args)

        # Pointwise
        self.activation = nn.GELU()
        # dropout_fn = nn.Dropout2d # NOTE: bugged in PyTorch 1.11
        dropout_fn = DropoutNd
        self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity()

        # position-wise output transform to mix features
        self.output_linear = nn.Sequential(
            nn.Conv1d(self.h, 2*self.h, kernel_size=1),
            nn.GLU(dim=-2),
        )

    def forward(self, u, **kwargs): # absorbs return_output and transformer src mask
        """ Input and output shape (B, H, L) """
        if not self.transposed: u = u.transpose(-1, -2)
        L = u.size(-1)

        # Compute SSM Kernel
        k = self.kernel(L=L) # (H L)

        # Convolution
        k_f = torch.fft.rfft(k, n=2*L) # (H L)
        u_f = torch.fft.rfft(u, n=2*L) # (B H L)
        y = torch.fft.irfft(u_f*k_f, n=2*L)[..., :L] # (B H L)

        # Compute D term in state space equation - essentially a skip connection
        y = y + u * self.D.unsqueeze(-1)

        y = self.dropout(self.activation(y))
        y = self.output_linear(y)
        if not self.transposed: y = y.transpose(-1, -2)
        return y, None # Return a dummy state to satisfy this repo's interface, but this can be modified

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models.s4.s4 import S4Block as S4  # Can use full version instead of minimal S4D standalone below
from models.s4.s4d import S4D
from tqdm.auto import tqdm

In [None]:
class S4Model(nn.Module):

    def __init__(
        self,
        d_input,
        d_output=10,
        d_model=256,
        n_layers=4,
        dropout=0.2,
        prenorm=False,
    ):
        super().__init__()

        self.prenorm = prenorm

        # Linear encoder (d_input = 1 for grayscale and 3 for RGB)
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 layers as residual blocks
        self.s4_layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        for _ in range(n_layers):
            self.s4_layers.append(
                S4D(d_model, dropout=dropout, transposed=True, lr=min(0.001, args.lr))
            )
            self.norms.append(nn.LayerNorm(d_model))
            self.dropouts.append(dropout_fn(dropout))

        # Linear decoder
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        """
        Input x is shape (B, L, d_input)
        """
        x = self.encoder(x)  # (B, L, d_input) -> (B, L, d_model)

        x = x.transpose(-1, -2)  # (B, L, d_model) -> (B, d_model, L)
        for layer, norm, dropout in zip(self.s4_layers, self.norms, self.dropouts):
            # Each iteration of this loop will map (B, d_model, L) -> (B, d_model, L)

            z = x
            if self.prenorm:
                # Prenorm
                z = norm(z.transpose(-1, -2)).transpose(-1, -2)

            # Apply S4 block: we ignore the state input and output
            z, _ = layer(z)

            # Dropout on the output of the S4 block
            z = dropout(z)

            # Residual connection
            x = z + x

            if not self.prenorm:
                # Postnorm
                x = norm(x.transpose(-1, -2)).transpose(-1, -2)

        x = x.transpose(-1, -2)

        # Pooling: average pooling over the sequence length
        x = x.mean(dim=1)

        # Decode the outputs
        x = self.decoder(x)  # (B, d_model) -> (B, d_output)

        return x

# Model
print('==> Building model..')
model = S4Model(
    d_input=d_input,
    d_output=d_output,
    d_model=args.d_model,
    n_layers=args.n_layers,
    dropout=args.dropout,
    prenorm=args.prenorm,
)

model = model.to(device)
if device == 'cuda':
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    model.load_state_dict(checkpoint['model'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

def setup_optimizer(model, lr, weight_decay, epochs):
    """
    S4 requires a specific optimizer setup.

    The S4 layer (A, B, C, dt) parameters typically
    require a smaller learning rate (typically 0.001), with no weight decay.

    The rest of the model can be trained with a higher learning rate (e.g. 0.004, 0.01)
    and weight decay (if desired).
    """

    # All parameters in the model
    all_parameters = list(model.parameters())

    # General parameters don't contain the special _optim key
    params = [p for p in all_parameters if not hasattr(p, "_optim")]

    # Create an optimizer with the general parameters
    optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay)

    # Add parameters with special hyperparameters
    hps = [getattr(p, "_optim") for p in all_parameters if hasattr(p, "_optim")]
    hps = [
        dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
    ]  # Unique dicts
    for hp in hps:
        params = [p for p in all_parameters if getattr(p, "_optim", None) == hp]
        optimizer.add_param_group(
            {"params": params, **hp}
        )

    # Create a lr scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=patience, factor=0.2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    # Print optimizer info
    keys = sorted(set([k for hp in hps for k in hp.keys()]))
    for i, g in enumerate(optimizer.param_groups):
        group_hps = {k: g.get(k, None) for k in keys}
        print(' | '.join([
            f"Optimizer group {i}",
            f"{len(g['params'])} tensors",
        ] + [f"{k} {v}" for k, v in group_hps.items()]))

    return optimizer, scheduler

criterion = nn.CrossEntropyLoss()
optimizer, scheduler = setup_optimizer(
    model, lr=args.lr, weight_decay=args.weight_decay, epochs=args.epochs
)

###############################################################################
# Everything after this point is standard PyTorch training!
###############################################################################

# Training
def train():
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    pbar = tqdm(enumerate(trainloader))
    for batch_idx, (inputs, targets) in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        pbar.set_description(
            'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
            (batch_idx, len(trainloader), train_loss/(batch_idx+1), 100.*correct/total, correct, total)
        )


def eval(epoch, dataloader, checkpoint=False):
    global best_acc
    model.eval()
    eval_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader))
        for batch_idx, (inputs, targets) in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_description(
                'Batch Idx: (%d/%d) | Loss: %.3f | Acc: %.3f%% (%d/%d)' %
                (batch_idx, len(dataloader), eval_loss/(batch_idx+1), 100.*correct/total, correct, total)
            )

    # Save checkpoint.
    if checkpoint:
        acc = 100.*correct/total
        if acc > best_acc:
            state = {
                'model': model.state_dict(),
                'acc': acc,
                'epoch': epoch,
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/ckpt.pth')
            best_acc = acc

        return acc

pbar = tqdm(range(start_epoch, args.epochs))
for epoch in pbar:
    if epoch == 0:
        pbar.set_description('Epoch: %d' % (epoch))
    else:
        pbar.set_description('Epoch: %d | Val acc: %1.3f' % (epoch, val_acc))
    train()
    val_acc = eval(epoch, valloader, checkpoint=True)
    eval(epoch, testloader)
    scheduler.step()
    # print(f"Epoch {epoch} learning rate: {scheduler.get_last_lr()}")


In [None]:

# Dataloaders
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valloader = torch.utils.data.DataLoader(
    valset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

In [None]:

# Min GRU and Min LSTM Implementations

class MinGRU(nn.Module):
    def __init__(self, d_model, d_state

In [None]:
import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(512)

x = torch.randn(2, 1024, 512)

out = min_gru(x)

assert x.shape == out.shapt torch


In [None]:
import torch
from minGRU_pytorch import minGRU

min_gru = minGRU(dim = 512, expansion_factor = 1.5)

x = torch.randn(1, 2048, 512)

# parallel

parallel_out = min_gru(x)[:, -1:]

# sequential

prev_hidden = None
for token in x.unbind(dim = 1):
    sequential_out, prev_hidden = min_gru(token[:, None, :], prev_hidden, return_next_prev_hidden = True)

assert torch.allclose(parallel_out, sequential_out, atol = 1e-4)

In [None]:
import torch
import torch.nn as nn
from collections import deque
import numpy as np
from typing import Tuple, Optional

class SequentialReplayBuffer:
    def __init__(
        self, 
        state_dim: int,
        action_dim: int,
        hidden_dim: int = 512,
        buffer_size: int = 100000,
        sequence_length: int = 50,
        device: str = "cuda"
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.buffer_size = buffer_size
        self.sequence_length = sequence_length
        self.device = device

        # Initialize MinGRU encoder
        self.encoder = MinGRU(
            input_dim=state_dim + action_dim,  # Concatenated state-action pairs
            hidden_dim=hidden_dim
        ).to(device)

        # Storage
        self.states = deque(maxlen=buffer_size)
        self.actions = deque(maxlen=buffer_size)
        self.rewards = deque(maxlen=buffer_size)
        self.next_states = deque(maxlen=buffer_size)
        self.dones = deque(maxlen=buffer_size)
        
        self.episode_boundaries = []  # Track episode start indices
        
    def add(self, state, action, reward, next_state, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.next_states.append(next_state)
        self.dones.append(done)
        
        if done:
            self.episode_boundaries.append(len(self.states) - 1)

    def encode_sequence(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """Encode a sequence of state-action pairs using MinGRU"""
        # Concatenate states and actions
        sequence = torch.cat([states, actions], dim=-1)
        # Encode sequence
        encoded = self.encoder(sequence)
        return encoded

    def sample_sequence(self, batch_size: int) -> Tuple[torch.Tensor, ...]:
        """Sample a batch of sequences"""
        # Randomly select sequence start points
        max_start_idx = len(self.states) - self.sequence_length
        start_indices = np.random.randint(0, max_start_idx, size=batch_size)
        
        # Gather sequences
        state_sequences = []
        action_sequences = []
        reward_sequences = []
        next_state_sequences = []
        done_sequences = []
        
        for start_idx in start_indices:
            # Check if sequence crosses episode boundary
            end_idx = start_idx + self.sequence_length
            valid = all(boundary not in range(start_idx, end_idx) 
                       for boundary in self.episode_boundaries)
            
            if not valid:
                # Resample if sequence crosses episode boundary
                continue
                
            states = torch.tensor([self.states[i] for i in range(start_idx, end_idx)])
            actions = torch.tensor([self.actions[i] for i in range(start_idx, end_idx)])
            rewards = torch.tensor([self.rewards[i] for i in range(start_idx, end_idx)])
            next_states = torch.tensor([self.next_states[i] for i in range(start_idx, end_idx)])
            dones = torch.tensor([self.dones[i] for i in range(start_idx, end_idx)])
            
            state_sequences.append(states)
            action_sequences.append(actions)
            reward_sequences.append(rewards)
            next_state_sequences.append(next_states)
            done_sequences.append(dones)

        # Stack sequences
        states = torch.stack(state_sequences).to(self.device)
        actions = torch.stack(action_sequences).to(self.device)
        rewards = torch.stack(reward_sequences).to(self.device)
        next_states = torch.stack(next_state_sequences).to(self.device)
        dones = torch.stack(done_sequences).to(self.device)

        # Encode sequences
        encoded_states = self.encode_sequence(states, actions)
        
        return (states, actions, rewards, next_states, dones, encoded_states)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class SequentialTrainer:
    def __init__(
        self,
        env,
        agent,
        replay_buffer,
        batch_size: int = 32,
        updates_per_step: int = 1,
        device: str = "cuda"
    ):
        self.env = env
        self.agent = agent
        self.replay_buffer = replay_buffer
        self.batch_size = batch_size
        self.updates_per_step = updates_per_step
        self.device = device

    def train_episode(self, max_steps: int = 1000):
        state = self.env.reset()
        episode_reward = 0
        
        for step in range(max_steps):
            # Select action
            action = self.agent.select_action(state)
            
            # Execute action
            next_state, reward, done, _ = self.env.step(action)
            
            # Store transition
            self.replay_buffer.add(state, action, reward, next_state, done)
            
            # Update agent
            if len(self.replay_buffer.states) > self.batch_size:
                for _ in range(self.updates_per_step):
                    sequences = self.replay_buffer.sample_sequence(self.batch_size)
                    self.agent.update(sequences)
            
            state = next_state
            episode_reward += reward
            
            if done:
                break
                
        return episode_reward

    def train(self, num_episodes: int = 1000):
        rewards = []
        
        for episode in tqdm(range(num_episodes)):
            episode_reward = self.train_episode()
            rewards.append(episode_reward)
            
            # Log progress
            if (episode + 1) % 10 == 0:
                avg_reward = sum(rewards[-10:]) / 10
                print(f"Episode {episode+1}: Average Reward = {avg_reward:.2f}")
                
        return rewards

In [None]:
class MinGRUForRL(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_layers: int = 1,
        dropout: float = 0.1
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.gru = nn.GRU(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0,
            batch_first=True
        )
        
        # Additional layers for RL-specific processing
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
        self.norm = nn.LayerNorm(hidden_dim)
        
    def forward(
        self, 
        x: torch.Tensor, 
        hidden: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Process sequence through GRU
        output, hidden = self.gru(x, hidden)
        
        # Apply self-attention over the sequence
        attn_output, _ = self.attention(output, output, output)
        
        # Residual connection and normalization
        output = self.norm(output + attn_output)
        
        return output, hidden

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict, Optional

class OfflineRLAgent(nn.Module):
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        sequence_encoder: nn.Module,
        hidden_dim: int = 512,
        cql_alpha: float = 1.0,
        discount: float = 0.99,
        tau: float = 0.005,
        device: str = "cuda"
    ):
        super().__init__()
        self.device = device
        self.action_dim = action_dim
        self.cql_alpha = cql_alpha
        self.discount = discount
        self.tau = tau

        # Sequence encoder (MinGRU)
        self.sequence_encoder = sequence_encoder

        # Q-networks
        self.q1 = nn.Sequential(
            nn.Linear(hidden_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.q2 = nn.Sequential(
            nn.Linear(hidden_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Target networks
        self.q1_target = nn.Sequential(*[copy.deepcopy(layer) for layer in self.q1])
        self.q2_target = nn.Sequential(*[copy.deepcopy(layer) for layer in self.q2])

        # Policy network
        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim * 2)  # Mean and log_std
        )

        # Move to device
        self.to(device)

    def encode_sequence(
        self, 
        states: torch.Tensor, 
        actions: torch.Tensor
    ) -> torch.Tensor:
        """Encode state-action sequence using MinGRU"""
        sequence = torch.cat([states, actions], dim=-1)
        encoded, _ = self.sequence_encoder(sequence)
        return encoded[:, -1]  # Return last hidden state

    def get_action(
        self, 
        encoded_state: torch.Tensor, 
        deterministic: bool = False
    ) -> torch.Tensor:
        """Sample action from policy"""
        mean, log_std = self.policy(encoded_state).chunk(2, dim=-1)
        log_std = torch.clamp(log_std, -20, 2)
        std = log_std.exp()

        if deterministic:
            return mean

        normal = torch.distributions.Normal(mean, std)
        action = normal.rsample()
        return torch.tanh(action)

    def compute_q_values(
        self, 
        encoded_states: torch.Tensor, 
        actions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute Q-values from both networks"""
        q_input = torch.cat([encoded_states, actions], dim=-1)
        return self.q1(q_input), self.q2(q_input)

    def update(self, batch: Tuple[torch.Tensor, ...]) -> Dict[str, float]:
        """Update agent using CQL"""
        states, actions, rewards, next_states, dones, encoded_states = batch
        batch_size = states.shape[0]

        # Compute target Q-values
        with torch.no_grad():
            next_encoded = self.encode_sequence(next_states, actions)
            next_actions = self.get_action(next_encoded)
            target_q1, target_q2 = self.compute_target_q_values(next_encoded, next_actions)
            target_q = torch.min(target_q1, target_q2)
            target_q = rewards + (1 - dones) * self.discount * target_q

        # Current Q-values
        current_q1, current_q2 = self.compute_q_values(encoded_states, actions)

        # CQL loss
        random_actions = torch.FloatTensor(batch_size, self.action_dim).uniform_(-1, 1).to(self.device)
        random_q1, random_q2 = self.compute_q_values(encoded_states, random_actions)
        
        cql_loss_q1 = torch.logsumexp(random_q1, dim=0) - current_q1.mean()
        cql_loss_q2 = torch.logsumexp(random_q2, dim=0) - current_q2.mean()
        cql_loss = self.cql_alpha * (cql_loss_q1 + cql_loss_q2)

        # Standard Q-learning loss
        q_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        # Total loss
        total_loss = q_loss + cql_loss

        # Optimize
        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        # Update target networks
        self._soft_update_target()

        return {
            'q_loss': q_loss.item(),
            'cql_loss': cql_loss.item(),
            'total_loss': total_loss.item()
        }

    def _soft_update_target(self):
        """Soft update target networks"""
        for param, target_param in zip(self.q1.parameters(), self.q1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        
        for param, target_param in zip(self.q2.parameters(), self.q2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, path: str):
        """Save agent state"""
        torch.save({
            'sequence_encoder': self.sequence_encoder.state_dict(),
            'q1': self.q1.state_dict(),
            'q2': self.q2.state_dict(),
            'policy': self.policy.state_dict()
        }, path)

    def load(self, path: str):
        """Load agent state"""
        checkpoint = torch.load(path)
        self.sequence_encoder.load_state_dict(checkpoint['sequence_encoder'])
        self.q1.load_state_dict(checkpoint['q1'])
        self.q2.load_state_dict(checkpoint['q2'])
        self.policy.load_state_dict(checkpoint['policy'])

In [None]:
import gym
import torch
# from src.replay.sequential_buffer import SequentialReplayBuffer
# from src.training.sequential_trainer import SequentialTrainer
# from src.models.mingru_rl import MinGRUForRL

def main():
    # Environment setup
    env = gym.make('HalfCheetah-v2')
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    # Create replay buffer
    buffer = SequentialReplayBuffer(
        state_dim=state_dim,
        action_dim=action_dim,
        hidden_dim=512,
        sequence_length=50
    )
    
    # Create agent (you'll need to implement this based on your RL algorithm)
    agent = YourRLAgent(
        state_dim=state_dim,
        action_dim=action_dim,
        sequence_encoder=MinGRUForRL(
            input_dim=state_dim + action_dim,
            hidden_dim=512
        )
    )
    
    # Create trainer
    trainer = SequentialTrainer(
        env=env,
        agent=agent,
        replay_buffer=buffer,
        batch_size=32
    )
    
    # Train
    rewards = trainer.train(num_episodes=1000)
    
    # Save results
    torch.save({
        'rewards': rewards,
        'agent_state': agent.state_dict(),
        'buffer_encoder': buffer.encoder.state_dict()
    }, 'results/sequential_rl_experiment.pt')

if __name__ == "__main__":
    main()