In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

In [4]:
def part1():
    global X, y
    data = [
        ([0, 0, 0], 0),
        ([255, 255, 255], 1),
        ([255, 0, 0], 1),
        ([0, 255, 0], 1),
        ([0, 0, 255], 1),
        ([128, 128, 128], 0),
        ([200, 200, 200], 1),
        ([50, 50, 50], 0),
        ([255, 255, 0], 1),
        ([0, 255, 255], 1),
    ]
    X = torch.tensor([x for x, _ in data], dtype=torch.float32) / 255.0
    y = torch.tensor([label for _, label in data], dtype=torch.float32).unsqueeze(1)


def part2():
    global model
    class ColorClassifier(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(3, 8),
                nn.ReLU(),
                nn.Linear(8, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.model(x)

    model = ColorClassifier()
    loss_fn = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # Train
    for epoch in range(2000):
        optimizer.zero_grad()
        outputs = model(X)
        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 500 == 0:
            print(f"Epoch {epoch+1}/2000, Loss: {loss.item():.4f}")

def part3():
    test_colors = torch.tensor([
        [0, 0, 0],
        [255, 255, 255],
        [100, 50, 200],
        [200, 200, 50],
    ], dtype=torch.float32) / 255.0

    predictions = model(test_colors)
    print("\nPredictions:")
    for color, pred in zip(test_colors, predictions):
        label = "light" if pred.item() > 0.5 else "dark"
        print(f"{(color*255).int().tolist()} -> {label} ({pred.item():.2f})")

if __name__ == "__main__":
    print(torch.__version__)
    part1()
    part2()
    part3()

2.8.0+cu126
Epoch 500/2000, Loss: 0.3067
Epoch 1000/2000, Loss: 0.2955
Epoch 1500/2000, Loss: 0.2937
Epoch 2000/2000, Loss: 0.2931

Predictions:
[0, 0, 0] -> dark (0.00)
[255, 255, 255] -> light (0.93)
[100, 50, 200] -> light (0.83)
[200, 200, 50] -> light (0.89)
