In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

ModuleNotFoundError: No module named 'torch'

### Hyperparameters

In [5]:
# the number of objects in each world
WORLD_SIZE = 5

# the dimension of the feature vector of each object
# as produced by the encoder
OBJECT_FEATURE_DIMENSION = 6

# the dimension of the speaker's generated vector representation
REPRESENTATION_DIMENSION = 12

### Model Components

In [9]:
class ObjectEncoder(nn.Module):
    """
    Encodes a single (3x3) Object into a feature vector
    Input     : [[001][100][010]] or Purple-Cirle-No-Outline. See metadata.json
    InputSize : (3,3)
    OutputSize: (OBJECT_FEATURE_DIMENSION)
    """
    def __init__(self, output_dimension=OBJECT_FEATURE_DIMENSION):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=2)  # → (batch, 4, 2, 2)
            nn.Flatten()                    # → (batch, 16)
            nn.Linear(16, output_dimension)
        )
        
    def forward(self, x):
        return self.encoder(x)

IndentationError: expected an indented block after class definition on line 1 (754396681.py, line 2)

In [10]:
class Speaker(nn.Module):
    """
    Transforms the encoded objects in world W + boolean inclusion mask for target subset X 
    into a representative vector
    
    InputSize : (WORLD_SIZE * OBJECT_FEATURE_DIMENSION) + WORLD_SIZE)
    OutputSize: (REPRESENTATION_DIMENSION)
    """
    def __init__(self, input_dimension=(WORLD_SIZE * OBJECT_FEATURE_DIMENSION) + WORLD_SIZE, 
                 output_dimension=REPRESENTATION_DIMENSION):
        super().__init__()
        self.speaker_net = nn.Sequential(
            nn.Linear(input_dimension, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.RrLU(),
            nn.Linear(32, output_dimension)
        )
    
    def forward(self, x):
        return self.speaker_net(x)

SyntaxError: incomplete input (2709159329.py, line 6)

In [11]:
class Listener(nn.Module):
    """
    Takes the resultant vector from the speaker, along with W_i (an element of the world W)
    and predicts whether this element belongs to X, the target subset.
    
    InputSize : (REPRESENTATION_DIMENSION + OBJECT_FEATURE_DIMENSION)
    OutputSize: (1) -> logit for binary classification
    """
    def __init__(self, input_dimension=REPRESENTATION_DIMENSION + OBJECT_FEATURE_DIMENSION, 
                 output_dimension=1):
        super().__init__():
            self.listener_net = nn.Sequential(
                nn.Linear(input_dimension, 32),
                nn.ReLU(),
                nn.Linear(32, 8),
                nn.ReLU(),
                nn.Linear(8, output_dimension)
            )
            
    def forward(self, x):
        return self.listener_net(x)
            

NameError: name 'nn' is not defined

In [12]:
def SpeakerListenerSystem(nn.Module):
    """
    The end-to-end system.
    Encoder -> Speaker -> Listener
    """
    def __init__(self, world_size, feature_dimension, representation_dimension):
        super().__init__()
        
        self.world_size               = world_size
        self.feature_dimension        = feature_dimension
        self.representation_dimension = representation_dimension
        
        speaker_input_size  = (self.world_size * self.feature_dimension) + self.world_size
        listener_input_size = self.representation_dimension + self.feature_dimension
        
        self.encoder  = ObjectEncoder(output_dimension=self.feature_dimension)
        self.speaker  = Speaker(input_dimension=speaker_input_size, output_dimemsion=representation_dimension)
        self.listener = Listener(input_dimension=listener_input_size, output_dimension=1)
        
        
    def forward(self, W, X_mask):
        """
        W:      A batch of worlds. 
                Each world is a set of (3x3) objects.
                Tensor of shape (batch_size, world_size, 3, 3).
                
        X_mask: A batch of boolean masks
                Each value indicated whether the object at that index in W
                is included in the target subset X
                Tensor of shape (batch_size, world_size)
        """
        batch_size = W.shape[0]
        
        ### STEP 1: Encode all objects in the world
        # Reshape for batch processing ny the encoder (B, 5, 3, 3) -> (B*5, 1, 3, 3)
        W_flat = W.view(-1, 1, 3, 3)
        # Get features for all objects: (B*5, feature_dim)
        object_features_flat = self.encoder(W_flat)
        # Reshape back to per-batch item:(B, 5, feature_dim)
        object_features = object_features_flat.view(batch_size, self.world_size, self.feature_dimension)
        
        
        ### STEP 2: Assemble the inputs to the speaker model
        # Flatten object features: (B, 5, feature_dim) -> (B, 5*feature_dim)
        V_W =  object_features.view(batch_size, -1)
        # Create the indicator mask
        M_X = X_mask.float()
        # Concatenate features and mask
        speaker_input = torch.cat([V_W, M_X], dim=1)
        
        
        ### STEP 3: Speaker generates neuralese
        # representation has shape (B, representation_dimension)
        representation = self.speaker(speaker_input)
        
        ### STEP 4: Prepare the listener's input
        # The listener needs to pair the speaker's representation with each object feature
        # Expand the representation to match the number of objects
        # (B, rep_dim) -> (B, 1, rep_dim) -> (B, world_size, rep_dim)
        r_expanded = representation.unsqueeze(1).repeat(1, self.world_size, 1)
        # Concatenate with object features: (B, 5, rep_dim) + (B, 5, feature_dim)
        listener_input = torch.cat([r_expanded, object_features], dim=2)
        
        
        ### STEP 5: Shuffle inputs to the listener
        # This is to avoid the speaker simply learning to tell the listener about X_mask
        # without learning anything about the objects in X themselves.
        # Create a random permutation for each item in the batch
        shuffled_indices = [torch.randperm(self.world_size) for _ in range(batch_size)]
        # Apply the shuffle
        shuffled_input   = torch.stack([features[p] for features, p in zip(listener_input, shuffled_indices)])
        shuffled_labels  = torch.stack([labels[p]   for labels,   p in zip(X_mask, shuffled_indices)])
        
        
        ### STEP 6: Listener makes an inclusion prediction for each object
        # Reshape for batch processing by the listener
        # (B, 5, rep_dim + feature_dim) -> (B*5, rep_dim + feature_dim)
        listener_input_flat = shuffled_input.view(-1, self.representation_dimension + self.feature_dimension)
        # Get predictions (logits) -> (B*5, 1)
        predictions_flat = self.listener(listener_input_flat)
        # Reshape back to (B, world_size)
        predictions = predictions_flat.view(batch_size, self.world_size)
        
        return predictions, shuffled_labels

SyntaxError: invalid syntax (1308499078.py, line 1)

### Training 

In [None]:
def generate_batch(batch_size, world_size):
    """Generates a batch of dummy data"""
    
    # Random 3x3 objects for the world
    W_batch = torch.rand(batch_size, world_size, 3, 3)
    
    # For each item, randomly decide what objects are in the target set X
    X_mask_batch = []
    for _ in range(batch_size):
        X_mask = [random.randint(0, 1) for _ in range(world_size)]
        X_mask_batch.append(indices)
        
    return W_batch, X_mask_batch

### Main Function

In [None]:
# Hyperparameters for training
BATCH_SIZE    = 2 # 32
NUM_EPOCHS    = 1 # 2000
LEARNING_RATE = 1e-4

In [None]:
# Instantiate the system
model = SpeakerListenerSystem(
    world_size               = WORLD_SIZE,
    feature_dimension        = OBJECT_FEATURE_DIMENSION,
    representation_dimension = REPRESENTATION_DIMENSION,
)

In [13]:
# Use BCEWithLogitsLoss because our model outputs raw logits (more stable)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(NUM_EPOCHS):
    
    # Generate a new batch of data
    W, X_mask = generate_batch(BATCH_SIZE, WORLD_SIZE)
    
    # Forward pass
    optimizer.zero_grad()
    y_logits = model(W, X_mask)
    
    # Calculate loss - (X_mask serves as the ground truth)
    loss = criterion(y_logits, X_mask)
    
    # Backward pass and optimize
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 100 == 0:
        # Calculate accuracy for monitoring
        preds    = torch.sigmoid(y_logits) > 0.5
        accuracy = (preds.float() == X_mask).float().mean() 
        print(f"Epoch [{epoch}/{NUM_EPOCHS}], Loss: {loss.item():.4f}, Accuracy: {accuracy.item():.4f}")
        
print("Training finished.")

SyntaxError: incomplete input (7352943.py, line 23)

### Example Inference

In [14]:
print("\n--- Running a test example ---")

# Set model to evaluation mode
model.eval() 
with torch.no_grad():
    W_test, X_test_mask = generate(1, WORLD_SIZE)
    
    print("Test World W has 5 objects.")
    print(f"Target set X is selected by mask: {X_test_mask[0]}")
    print(f"Ground truth vector: {X_test_mask.numpy().flatten()}")
    
    y_test_logits = model(W_test, X_test_mask)
    y_test_probs  = torch.sigmoid(y_test_logits)
    
    print(f"Model prediction (probabilities): {y_test_probs.numpy().flatten()}")
    print(f"Final prediction (rounded): {[round(p, 2) for p in y_test_probs.numpy().flatten()]}")


--- Running a test example ---
