# Laminet Prototype: Field Evolution Model (Curated 10k Dataset + Color Visualization)

In [None]:
!pip install torch matplotlib numpy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random
import json
from tqdm import tqdm
from IPython.display import clear_output
import matplotlib.animation as animation


In [None]:
with open('/content/laminet_10k_curated.json', 'r') as f:
    data = json.load(f)

themes = sorted(list(set([item['label'] for item in data])))


In [None]:
class FieldPoint(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.position = nn.Parameter(torch.randn(embed_dim))
        self.velocity = nn.Parameter(torch.zeros(embed_dim))
        self.mass = nn.Parameter(torch.ones(1))
        self.charge = nn.Parameter(torch.randn(1))
        self.decay_rate = nn.Parameter(torch.abs(torch.randn(1)))

class LaminaField(nn.Module):
    def __init__(self, embeddings):
        super().__init__()
        self.points = nn.ModuleList([FieldPoint(embeddings.shape[1]) for _ in range(len(embeddings))])
        self.embed_points(embeddings)

    def embed_points(self, embeddings):
        for point, embed in zip(self.points, embeddings):
            point.position.data = embed

    def evolve(self, dt=0.01, steps=50, record_positions=False):
        trajectory = []
        for _ in range(steps):
            if record_positions:
                current_positions = torch.stack([p.position for p in self.points]).detach().cpu().numpy()
                trajectory.append(current_positions)
            for idx, point in enumerate(self.points):
                net_force = self.compute_net_force(idx)
                point.velocity = point.velocity + net_force * dt
                point.position = point.position + point.velocity * dt
                point.velocity *= (1.0 - point.decay_rate.abs() * dt)
        if record_positions:
            return trajectory
        else:
            return None

    def compute_net_force(self, idx):
        net_force = 0
        point = self.points[idx]
        for j, other in enumerate(self.points):
            if j == idx:
                continue
            direction = other.position - point.position
            distance = direction.norm(p=2) + 1e-6
            force_mag = (point.charge * other.charge) / (distance**2)
            net_force += (force_mag * direction / distance)
        return net_force


In [None]:
class LaminetMicro(nn.Module):
    def __init__(self, embed_dim=64, output_dim=6):
        super().__init__()
        self.embedding = nn.Linear(128, embed_dim)
        self.decoder = nn.Linear(embed_dim, output_dim)

    def forward(self, x, evolution_steps=50, record_positions=False):
        x_embed = self.embedding(x)
        field = LaminaField(x_embed)
        trajectory = field.evolve(steps=evolution_steps, record_positions=record_positions)
        final_pos = torch.stack([p.position for p in field.points], dim=0)
        output = self.decoder(final_pos.mean(dim=0))
        return output, trajectory


In [None]:
def animate_field(trajectory, labels, label_to_color):
    fig, ax = plt.subplots(figsize=(6,6))
    scat = ax.scatter([], [], c=[])

    def init():
        ax.set_xlim(-5,5)
        ax.set_ylim(-5,5)
        return scat,

    def update(frame):
        positions = trajectory[frame]
        if positions.shape[1] > 2:
            positions = positions[:, :2]
        colors = [label_to_color[lbl] for lbl in labels]
        scat.set_offsets(positions)
        scat.set_color(colors)
        return scat,

    ani = animation.FuncAnimation(fig, update, frames=len(trajectory), init_func=init, blit=True)
    plt.show()
    return ani


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LaminetMicro(output_dim=len(themes)).to(device)

batch = random.sample(data, 32)
inputs = torch.randn(32, 128).to(device)
batch_labels = [item['label'] for item in batch]

theme_colors = {
    'Discovery': 'green',
    'Love Story': 'pink',
    'Mystery': 'purple',
    'Quest': 'blue',
    'Revenge': 'red',
    'Science': 'orange'
}

model.eval()
with torch.no_grad():
    output, trajectory = model(inputs, evolution_steps=50, record_positions=True)
animate_field(trajectory, batch_labels, theme_colors)
