In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# --- 1. DEFINE THE MODEL HERE (Instead of importing it) ---
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Input: 2, Hidden: 5, Output: 1
        self.hidden_layer = nn.Linear(2, 5)
        self.output_layer = nn.Linear(5, 1)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.hidden_layer(x)
        x = self.activation(x)
        x = self.output_layer(x)
        x = torch.sigmoid(x)
        return x

# --- 2. THE TRAINING LOOP ---
def train():
    # --- 1. New Data: Quadrant Classification ---
    # We create 8 points. 
    # Target is 1.0 if BOTH x and y are positive, else 0.0
    X_train = torch.tensor([
        [1.0, 1.0],   # Top Right (Target 1)
        [2.0, 3.0],   # Top Right (Target 1)
        [-1.0, -1.0], # Bottom Left (Target 0)
        [-2.0, 1.0],  # Top Left (Target 0)
        [1.0, -2.0],  # Bottom Right (Target 0)
        [0.5, 0.5],   # Top Right (Target 1)
        [-0.5, 0.5],  # Top Left (Target 0)
        [-2.0, -2.0]  # Bottom Left (Target 0)
    ])
    
    y_train = torch.tensor([
        [1.0], [1.0], 
        [0.0], [0.0], [0.0], 
        [1.0], [0.0], [0.0]
    ])
    
    # --- 2. Training Setup ---
    model = SimpleNet() # Re-initialize to random weights
    optimizer = optim.SGD(model.parameters(), lr=0.1) # Higher learning rate for classification
    loss_fn = nn.BCELoss() # We can still use MSE, though BCELoss is better for binary stuff
    
    print("Training Quadrant Detector...")
    
    for epoch in range(1000):
        optimizer.zero_grad()
        y_pred = model(X_train)
        loss = loss_fn(y_pred, y_train)
        loss.backward()
        optimizer.step()
    
    # --- 3. Test it! ---
    # Let's try a point we never trained on: (5, 5) which should be Class 1
    test_point = torch.tensor([5.0, 5.0])
    prediction = model(test_point).item()
    
    print(f"Prediction for (5,5): {prediction:.4f}")
    print("If close to 1.0, it thinks it's Top Right. If close to 0, it's not.")

# Run it
train()

Training Quadrant Detector...
Prediction for (5,5): 1.0000
If close to 1.0, it thinks it's Top Right. If close to 0, it's not.
