In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pickle
import os

# Define the neural network architecture
class DecimalToBinaryNet(nn.Module):
    def __init__(self, input_digits, input_classes, hidden_size, output_bits, output_classes):
        super(DecimalToBinaryNet, self).__init__()
        self.input_size = input_digits * input_classes
        self.output_bits = output_bits
        self.output_classes = output_classes
        
        # Flatten input and linear layers
        self.flatten = nn.Flatten()
        self.layer1 = nn.Linear(self.input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_bits * output_classes)
        
        # For reshaping output
        self.output_shape = (-1, output_bits, output_classes)
        
    def forward(self, x):
        # x shape: (batch_size, decimal_digits, 10)
        x = self.flatten(x)  # Flatten to (batch_size, decimal_digits*10)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        # Reshape to (batch_size, binary_bits, 2)
        x = x.view(self.output_shape)
        # Apply softmax along the last dimension (class dim)
        x = nn.functional.softmax(x, dim=2)
        return x

# Custom dataset for decimal to binary conversion
class DecimalToBinaryDataset(Dataset):
    def __init__(self, decimal_range=4096):
        self.decimal_range = decimal_range
        self.max_bits = int(np.log2(decimal_range)) + 1  # Number of bits needed
        self.max_decimal_digits = len(str(decimal_range - 1))  # Max number of decimal digits
        self.data = self.generate_data()
        
    def generate_data(self):
        data = []
        for i in range(self.decimal_range):
            # Convert to one-hot encoded tensors
            decimal_str = str(i).zfill(self.max_decimal_digits)
            binary_str = format(i, f'0{self.max_bits}b')
            
            # Create one-hot encoded tensors for decimal digits
            decimal_tensor = torch.zeros(self.max_decimal_digits, 10)  # 10 possible values (0-9)
            for digit_idx, digit in enumerate(decimal_str):
                decimal_tensor[digit_idx, int(digit)] = 1.0
                
            # Create one-hot encoded tensors for binary digits
            binary_tensor = torch.zeros(self.max_bits, 2)  # 2 possible values (0-1)
            for bit_idx, bit in enumerate(binary_str):
                binary_tensor[bit_idx, int(bit)] = 1.0
                
            data.append({"decimal": decimal_tensor, "binary": binary_tensor})
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    def get_max_bits(self):
        return self.max_bits
    
    def get_max_decimal_digits(self):
        return self.max_decimal_digits

# Data utility functions
def generate_and_split_data(decimal_range=4096, train_ratio=0.8, batch_size=64):
    dataset = DecimalToBinaryDataset(decimal_range=decimal_range)
    
    # Calculate sizes
    train_size = int(train_ratio * len(dataset))
    test_size = len(dataset) - train_size
    
    # Split the dataset
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, dataset.get_max_bits(), dataset.get_max_decimal_digits()

# Training function
def train(model, train_loader, test_loader, criterion, optimizer, num_epochs=10, device="cpu"):
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        for batch in train_loader:
            decimal_data = batch["decimal"].to(device)
            binary_data = batch["binary"].to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(decimal_data)
            loss = criterion(outputs, binary_data)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        # Print statistics
        # print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}")
        
        # Evaluate the model
        if epoch % 50 == 0:
            # print(outputs.shape, binary_data.shape)
            number_accuracy, digit_accuracy = evaluate(model, test_loader, device)
            epoch_loss = running_loss / len(train_loader)
            print(
                f"Epoch {epoch}: Training loss: {epoch_loss:.4f}, "
                f"Number accuracy: {number_accuracy:.2f}%, "
                f"Digit accuracy: {digit_accuracy:.2f}"
            )
    
    print("Training completed!")

# Evaluation function
def evaluate(model, test_loader, device="cpu"):
    model.eval()
    num_correct_numbers = 0
    num_correct_digits = 0
    total_numbers = 0
    total_digits = 0
    
    with torch.no_grad():
        for batch in test_loader:
            decimal_data = batch["decimal"].to(device)
            binary_data = batch["binary"].to(device)
            
            outputs = model(decimal_data)
            # Get the index of the highest probability (argmax) for each position
            _, predicted_indices = torch.max(outputs, dim=2)
            _, target_indices = torch.max(binary_data, dim=2)
            
            # Count fully correct predictions (all bits must be correct)
            num_correct_numbers += (predicted_indices == target_indices).all(dim=1).sum().item()
            total_numbers += decimal_data.size(0)
            
            # Count correct individual digits
            num_correct_digits += (predicted_indices == target_indices).sum().item()
            total_digits += decimal_data.size(0) * predicted_indices.size(1)
    
    # Calculate accuracies
    number_accuracy = 100 * num_correct_numbers / total_numbers
    digit_accuracy = 100 * num_correct_digits / total_digits
    
    return number_accuracy, digit_accuracy




In [10]:
# Custom weighted loss function
# Define the loss function with class weights
def binary_loss(predictions, targets):
    # Get target class indices (0 or 1)
    target_indices = targets.argmax(dim=2)
    
    # Calculate standard cross entropy loss
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    loss = loss_fn(predictions.view(-1, 2), target_indices.view(-1))
    loss = loss.view(targets.shape[0], targets.shape[1])
    loss = loss.mean()
    return loss
    
def main():
    # Set device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    print(f"Using device: {device}")
    
    # Generate and split data
    decimal_range = 4096
    train_loader, test_loader, max_bits, max_decimal_digits = generate_and_split_data(
        decimal_range=decimal_range,
        train_ratio=0.9,
        batch_size=256
    )
    
    # Initialize the model
    input_digits = max_decimal_digits
    input_classes = 10  # 0-9 for decimal digits
    hidden_size = 256  # Can be tuned
    output_bits = max_bits
    output_classes = 2  # 0-1 for binary digits
    
    model = DecimalToBinaryNet(
        input_digits, 
        input_classes, 
        hidden_size, 
        output_bits, 
        output_classes
    )
    
    # Define loss function and optimizer
    # criterion = nn.CrossEntropyLoss()  # Cross-Entropy Loss for classification
    criterion = binary_loss
    # optimizer = optim.SGD(model.parameters(), lr=0.04, momentum=0.0, weight_decay=0.0)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train the model
    train(
        model, 
        train_loader, 
        test_loader, 
        criterion,
        optimizer, 
        num_epochs=2000, 
        device=device
    )
    
    # Test multiple random examples from the test set
    with torch.no_grad():
        # Get a batch from test loader
        batch = next(iter(test_loader))
        decimal_data = batch["decimal"].to(device)
        binary_data = batch["binary"].to(device)
        
        # Get predictions
        predictions = model(decimal_data)
        _, predicted_indices = torch.max(predictions, dim=2)
        _, target_indices = torch.max(binary_data, dim=2)
        
        # Select 5 random samples to display
        num_samples = min(5, decimal_data.size(0))
        sample_indices = torch.randperm(decimal_data.size(0))[:num_samples]
        
        for idx in sample_indices:
            # Convert one-hot decimal back to number
            decimal_digits = decimal_data[idx].argmax(dim=1).cpu().numpy()
            decimal_num = int(''.join([str(d) for d in decimal_digits]).lstrip('0') or '0')
            
            # Get binary prediction and actual binary
            binary_pred = ''.join([str(predicted_indices[idx, i].item()) for i in range(max_bits)])
            binary_actual = ''.join([str(target_indices[idx, i].item()) for i in range(max_bits)])
            
            # Print comparison
            print(f"Decimal: {decimal_num}")
            print(f"Binary prediction: {binary_pred}")
            print(f"Actual binary: {binary_actual}")
            print(f"Correct: {binary_pred == binary_actual}")
            print("-" * 30)

if __name__ == "__main__":
    main()

Using device: cpu
Epoch 0: Training loss: 0.6870, Number accuracy: 0.24%, Digit accuracy: 65.97
Epoch 50: Training loss: 0.4584, Number accuracy: 4.88%, Digit accuracy: 81.95
Epoch 100: Training loss: 0.4234, Number accuracy: 12.68%, Digit accuracy: 86.38
Epoch 150: Training loss: 0.3926, Number accuracy: 39.02%, Digit accuracy: 92.74
Epoch 200: Training loss: 0.3716, Number accuracy: 50.00%, Digit accuracy: 94.45
Epoch 250: Training loss: 0.3610, Number accuracy: 53.66%, Digit accuracy: 95.18
Epoch 300: Training loss: 0.3544, Number accuracy: 56.10%, Digit accuracy: 95.55
Epoch 350: Training loss: 0.3496, Number accuracy: 59.51%, Digit accuracy: 96.06
Epoch 400: Training loss: 0.3465, Number accuracy: 62.68%, Digit accuracy: 96.42
Epoch 450: Training loss: 0.3436, Number accuracy: 65.37%, Digit accuracy: 96.68
Epoch 500: Training loss: 0.3411, Number accuracy: 67.56%, Digit accuracy: 96.90
Epoch 550: Training loss: 0.3399, Number accuracy: 70.00%, Digit accuracy: 97.15
Epoch 600: Trai