In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [9]:
np.random.seed(42)
torch.manual_seed(42)

class GridWorldEnv:
    def __init__(self, size=5, goal=(4, 4)):
        self.size = size
        self.goal = goal
        self.action_space = 4

    def reset(self):
        self.position = (0, 0)
        return np.array(self.position, dtype=np.float32)

    def step(self, action):
        x, y = self.position
        if action == 0 and x > 0:
            x -= 1
        elif action == 1 and x < self.size - 1:
            x += 1
        elif action == 2 and y > 0:
            y -= 1
        elif action == 3 and y < self.size - 1:
            y += 1

        self.position = (x, y)
        reward = 1.0 if self.position == self.goal else -0.1
        done = self.position == self.goal
        return np.array(self.position, dtype=np.float32), reward, done

def generate_data(env, episodes=200):
    data = []
    for _ in range(episodes):
        state = env.reset()
        done = False
        while not done:
            action = np.random.randint(env.action_space)
            next_state, reward, done = env.step(action)
            data.append((state, action, reward, next_state))
            state = next_state
    return data

def split_data(data):
    states, actions, rewards, next_states = zip(*data)
    states = np.stack(states)
    return train_test_split(states, np.array(actions), test_size=0.2, random_state=42)

class VIN(nn.Module):
    def __init__(self, input_channels, state_dim, action_dim):
        super(VIN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 16, kernel_size=3, padding=1)
        self.fc = nn.Linear(state_dim * state_dim * 16, action_dim)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = x.view(x.size(0), -1)
        return self.fc(x)

def preprocess_states(states, state_dim):
    processed = np.zeros((len(states), 1, state_dim, state_dim), dtype=np.float32)
    for i, (x, y) in enumerate(states):
        processed[i, 0, int(x), int(y)] = 1.0
    return torch.tensor(processed, dtype=torch.float32)

def train_vin(model, states, actions, epochs=10, lr=0.001):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    inputs = preprocess_states(states, state_dim=5)
    labels = torch.tensor(actions, dtype=torch.long)

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item():.4f}")

def fine_tune_vin(model, states, actions, fine_tune_epochs=5):
    train_vin(model, states, actions, epochs=fine_tune_epochs, lr=0.0005)

def evaluate_model(model, states, actions):
    model.eval()
    with torch.no_grad():
        inputs = preprocess_states(states, state_dim=5)
        labels = torch.tensor(actions, dtype=torch.long)
        outputs = model(inputs)
        predictions = torch.argmax(outputs, dim=1)
        accuracy = (predictions == labels).float().mean().item()
        print(f"Test Accuracy: {accuracy * 100:.2f}%")

def deploy_policy(model, state):
    model.eval()
    with torch.no_grad():
        input_state = preprocess_states([state], state_dim=5)
        output = model(input_state)
        return torch.argmax(output, dim=1).item()

In [10]:
if __name__ == "__main__":
    env = GridWorldEnv(size=5)
    data = generate_data(env)
    train_states, test_states, train_actions, test_actions = split_data(data)

    vin_model = VIN(input_channels=1, state_dim=5, action_dim=4)
    train_vin(vin_model, train_states, train_actions)

    fine_tune_vin(vin_model, train_states, train_actions)
    evaluate_model(vin_model, test_states, test_actions)

    random_state = (np.random.randint(0, 5), np.random.randint(0, 5))
    action = deploy_policy(vin_model, random_state)
    print(f"Deployed Action: {action}")

Epoch 1/10, Loss: 1.3864
Epoch 2/10, Loss: 1.3861
Epoch 3/10, Loss: 1.3861
Epoch 4/10, Loss: 1.3860
Epoch 5/10, Loss: 1.3859
Epoch 6/10, Loss: 1.3858
Epoch 7/10, Loss: 1.3858
Epoch 8/10, Loss: 1.3857
Epoch 9/10, Loss: 1.3856
Epoch 10/10, Loss: 1.3855
Epoch 1/5, Loss: 1.3855
Epoch 2/5, Loss: 1.3854
Epoch 3/5, Loss: 1.3853
Epoch 4/5, Loss: 1.3853
Epoch 5/5, Loss: 1.3852
Test Accuracy: 23.98%
Deployed Action: 0
