# This section has each the rotor selection, position, text and output built into the code

In [1]:
class Rotor:
    """Represents a single Enigma rotor with its wiring and position"""
    
    def __init__(self, wiring, notch, ring_setting=0, position=0):
        """
        Args:
            wiring: String of 26 letters representing the rotor's internal wiring
            notch: Letter(s) where rotor causes next rotor to step
            ring_setting: Ring setting (0-25, default 0 for 'A')
            position: Initial rotor position (0-25, default 0 for 'A')
        """
        self.wiring = wiring
        self.notch = notch
        self.ring_setting = ring_setting
        self.position = position
    
    def encode_forward(self, char_index):
        """Encode character passing through rotor from right to left"""
        # Adjust for rotor position and ring setting
        shifted = (char_index + self.position - self.ring_setting) % 26
        # Pass through wiring
        encoded = ord(self.wiring[shifted]) - ord('A')
        # Adjust back
        return (encoded - self.position + self.ring_setting) % 26
    
    def encode_backward(self, char_index):
        """Encode character passing through rotor from left to right (after reflector)"""
        # Adjust for rotor position and ring setting
        shifted = (char_index + self.position - self.ring_setting) % 26
        # Find inverse mapping in wiring
        encoded = self.wiring.index(chr(shifted + ord('A')))
        # Adjust back
        return (encoded - self.position + self.ring_setting) % 26
    
    def is_at_notch(self):
        """Check if rotor is at notch position (will cause next rotor to turn)"""
        return chr(self.position + ord('A')) in self.notch
    
    def step(self):
        """Rotate the rotor by one position"""
        self.position = (self.position + 1) % 26


class Reflector:
    """Represents the Enigma reflector"""
    
    def __init__(self, wiring):
        """
        Args:
            wiring: String of 26 letters representing reflector pairs
        """
        self.wiring = wiring
    
    def reflect(self, char_index):
        """Reflect the character back through the rotors"""
        return ord(self.wiring[char_index]) - ord('A')


class EnigmaMachine:
    """Complete Enigma machine with 3 rotors and reflector"""
    
    # Historical rotor wirings (Enigma I)
    ROTOR_I = "EKMFLGDQVZNTOWYHXUSPAIBRCJ"
    ROTOR_II = "AJDKSIRUXBLHWTMCQGZNPYFVOE"
    ROTOR_III = "BDFHJLCPRTXVZNYEIWGAKMUSQO"
    ROTOR_IV = "ESOVPZJAYQUIRHXLNFTGKDCMWB"
    ROTOR_V = "VZBRGITYUPSDNHLXAWMJQOFECK"
    
    # Reflector B (most commonly used)
    REFLECTOR_B = "YRUHQSLDPXNGOKMIEBFZCWVJAT"
    
    # Notch positions for each rotor
    NOTCHES = {
        'I': 'Q',    # Rotor I turns rotor II when moving from Q to R
        'II': 'E',   # Rotor II turns rotor III when moving from E to F
        'III': 'V',  # Rotor III turns nothing (it's the leftmost)
        'IV': 'J',
        'V': 'Z'
    }
    
    def __init__(self, rotor_types=('I', 'II', 'III'), 
                 positions=(0, 0, 0), ring_settings=(0, 0, 0)):
        """
        Initialize Enigma machine
        
        Args:
            rotor_types: Tuple of 3 rotor identifiers (rightmost, middle, leftmost)
            positions: Starting positions for each rotor (0-25 for A-Z)
            ring_settings: Ring settings for each rotor (0-25 for A-Z)
        """
        rotor_wirings = {
            'I': self.ROTOR_I,
            'II': self.ROTOR_II,
            'III': self.ROTOR_III,
            'IV': self.ROTOR_IV,
            'V': self.ROTOR_V
        }
        
        # Create rotors (right, middle, left)
        self.rotors = [
            Rotor(rotor_wirings[rotor_types[0]], self.NOTCHES[rotor_types[0]], 
                  ring_settings[0], positions[0]),
            Rotor(rotor_wirings[rotor_types[1]], self.NOTCHES[rotor_types[1]], 
                  ring_settings[1], positions[1]),
            Rotor(rotor_wirings[rotor_types[2]], self.NOTCHES[rotor_types[2]], 
                  ring_settings[2], positions[2])
        ]
        
        self.reflector = Reflector(self.REFLECTOR_B)
    
    def step_rotors(self):
        """
        Step rotors according to Enigma stepping mechanism
        Implements the double-stepping mechanism
        """
        # Check if middle rotor is at notch (double-stepping)
        if self.rotors[1].is_at_notch():
            self.rotors[1].step()
            self.rotors[2].step()
        # Check if right rotor is at notch
        elif self.rotors[0].is_at_notch():
            self.rotors[1].step()
        
        # Always step the rightmost rotor
        self.rotors[0].step()
    
    def encode_char(self, char):
        """
        Encode a single character through the Enigma machine
        
        Args:
            char: Single uppercase letter A-Z
            
        Returns:
            Encoded uppercase letter A-Z
        """
        if not char.isalpha():
            return char  # Return non-alphabetic characters unchanged
        
        char = char.upper()
        char_index = ord(char) - ord('A')
        
        # Step rotors before encoding (historical behavior)
        self.step_rotors()
        
        # Pass through rotors right to left
        for rotor in self.rotors:
            char_index = rotor.encode_forward(char_index)
        
        # Reflect
        char_index = self.reflector.reflect(char_index)
        
        # Pass back through rotors left to right
        for rotor in reversed(self.rotors):
            char_index = rotor.encode_backward(char_index)
        
        return chr(char_index + ord('A'))
    
    def encode_text(self, text):
        """
        Encode a string of text
        
        Args:
            text: String to encode
            
        Returns:
            Encoded string
        """
        result = []
        for char in text:
            result.append(self.encode_char(char))
        return ''.join(result)
    
    def get_rotor_positions(self):
        """Get current rotor positions as letters"""
        return ''.join(chr(r.position + ord('A')) for r in self.rotors)


# Example usage
if __name__ == "__main__":
    # Create Enigma machine with rotors I, II, III starting at AAA
    enigma = EnigmaMachine(
        rotor_types=('I', 'II', 'III'),
        positions=(0, 0, 0),  # AAA
        ring_settings=(0, 0, 0)  # AAA
    )
    
    print("Enigma Machine Emulator")
    print("=" * 50)
    print(f"Initial rotor positions: {enigma.get_rotor_positions()}")
    print()
    
    # Encode a message
    plaintext = "HELLO WORLD"
    print(f"Plaintext:  {plaintext}")
    
    ciphertext = enigma.encode_text(plaintext)
    print(f"Ciphertext: {ciphertext}")
    print(f"Final rotor positions: {enigma.get_rotor_positions()}")
    print()
    
    # Decode the message (Enigma is reciprocal)
    enigma_decode = EnigmaMachine(
        rotor_types=('I', 'II', 'III'),
        positions=(0, 0, 0),
        ring_settings=(0, 0, 0)
    )
    
    decoded = enigma_decode.encode_text(ciphertext)
    print(f"Decoded:    {decoded}")
    print(f"Final rotor positions: {enigma_decode.get_rotor_positions()}")
    

Enigma Machine Emulator
Initial rotor positions: AAA

Plaintext:  HELLO WORLD
Ciphertext: MFNCZ BBFZM
Final rotor positions: KAA

Decoded:    HELLO WORLD
Final rotor positions: KAA


# This section asks the user for inputs

In [7]:
class Rotor:
    """Represents a single Enigma rotor with its wiring and position"""
    
    def __init__(self, wiring, notch, ring_setting=0, position=0):
        """
        Args:
            wiring: String of 26 letters representing the rotor's internal wiring
            notch: Letter(s) where rotor causes next rotor to step
            ring_setting: Ring setting (0-25, default 0 for 'A')
            position: Initial rotor position (0-25, default 0 for 'A')
        """
        self.wiring = wiring
        self.notch = notch
        self.ring_setting = ring_setting
        self.position = position
    
    def encode_forward(self, char_index):
        """Encode character passing through rotor from right to left"""
        # Adjust for rotor position and ring setting
        shifted = (char_index + self.position - self.ring_setting) % 26
        # Pass through wiring
        encoded = ord(self.wiring[shifted]) - ord('A')
        # Adjust back
        return (encoded - self.position + self.ring_setting) % 26
    
    def encode_backward(self, char_index):
        """Encode character passing through rotor from left to right (after reflector)"""
        # Adjust for rotor position and ring setting
        shifted = (char_index + self.position - self.ring_setting) % 26
        # Find inverse mapping in wiring
        encoded = self.wiring.index(chr(shifted + ord('A')))
        # Adjust back
        return (encoded - self.position + self.ring_setting) % 26
    
    def is_at_notch(self):
        """Check if rotor is at notch position (will cause next rotor to turn)"""
        return chr(self.position + ord('A')) in self.notch
    
    def step(self):
        """Rotate the rotor by one position"""
        self.position = (self.position + 1) % 26


class Reflector:
    """Represents the Enigma reflector"""
    
    def __init__(self, wiring):
        """
        Args:
            wiring: String of 26 letters representing reflector pairs
        """
        self.wiring = wiring
    
    def reflect(self, char_index):
        """Reflect the character back through the rotors"""
        return ord(self.wiring[char_index]) - ord('A')


class EnigmaMachine:
    """Complete Enigma machine with 3 rotors and reflector"""
    
    # Historical rotor wirings (Enigma I)
    ROTOR_I = "EKMFLGDQVZNTOWYHXUSPAIBRCJ"
    ROTOR_II = "AJDKSIRUXBLHWTMCQGZNPYFVOE"
    ROTOR_III = "BDFHJLCPRTXVZNYEIWGAKMUSQO"
    ROTOR_IV = "ESOVPZJAYQUIRHXLNFTGKDCMWB"
    ROTOR_V = "VZBRGITYUPSDNHLXAWMJQOFECK"
    
    # Reflector B (most commonly used)
    REFLECTOR_B = "YRUHQSLDPXNGOKMIEBFZCWVJAT"
    
    # Notch positions for each rotor
    NOTCHES = {
        'I': 'Q',    # Rotor I turns rotor II when moving from Q to R
        'II': 'E',   # Rotor II turns rotor III when moving from E to F
        'III': 'V',  # Rotor III turns nothing (it's the leftmost)
        'IV': 'J',
        'V': 'Z'
    }
    
    def __init__(self, rotor_types=('I', 'II', 'III'), 
                 positions=(0, 0, 0), ring_settings=(0, 0, 0)):
        """
        Initialize Enigma machine
        
        Args:
            rotor_types: Tuple of 3 rotor identifiers (rightmost, middle, leftmost)
            positions: Starting positions for each rotor (0-25 for A-Z)
            ring_settings: Ring settings for each rotor (0-25 for A-Z)
        """
        rotor_wirings = {
            'I': self.ROTOR_I,
            'II': self.ROTOR_II,
            'III': self.ROTOR_III,
            'IV': self.ROTOR_IV,
            'V': self.ROTOR_V
        }
        
        # Create rotors (right, middle, left)
        self.rotors = [
            Rotor(rotor_wirings[rotor_types[0]], self.NOTCHES[rotor_types[0]], 
                  ring_settings[0], positions[0]),
            Rotor(rotor_wirings[rotor_types[1]], self.NOTCHES[rotor_types[1]], 
                  ring_settings[1], positions[1]),
            Rotor(rotor_wirings[rotor_types[2]], self.NOTCHES[rotor_types[2]], 
                  ring_settings[2], positions[2])
        ]
        
        self.reflector = Reflector(self.REFLECTOR_B)
    
    def step_rotors(self):
        """
        Step rotors according to Enigma stepping mechanism
        Implements the double-stepping mechanism
        """
        # Check if middle rotor is at notch (double-stepping)
        if self.rotors[1].is_at_notch():
            self.rotors[1].step()
            self.rotors[2].step()
        # Check if right rotor is at notch
        elif self.rotors[0].is_at_notch():
            self.rotors[1].step()
        
        # Always step the rightmost rotor
        self.rotors[0].step()
    
    def encode_char(self, char):
        """
        Encode a single character through the Enigma machine
        
        Args:
            char: Single uppercase letter A-Z
            
        Returns:
            Encoded uppercase letter A-Z
        """
        if not char.isalpha():
            return char  # Return non-alphabetic characters unchanged
        
        char = char.upper()
        char_index = ord(char) - ord('A')
        
        # Step rotors before encoding (historical behavior)
        self.step_rotors()
        
        # Pass through rotors right to left
        for rotor in self.rotors:
            char_index = rotor.encode_forward(char_index)
        
        # Reflect
        char_index = self.reflector.reflect(char_index)
        
        # Pass back through rotors left to right
        for rotor in reversed(self.rotors):
            char_index = rotor.encode_backward(char_index)
        
        return chr(char_index + ord('A'))
    
    def encode_text(self, text):
        """
        Encode a string of text
        
        Args:
            text: String to encode
            
        Returns:
            Encoded string
        """
        result = []
        for char in text:
            result.append(self.encode_char(char))
        return ''.join(result)
    
    def get_rotor_positions(self):
        """Get current rotor positions as letters"""
        return ''.join(chr(r.position + ord('A')) for r in self.rotors)


# Example usage
if __name__ == "__main__":
    print("=" * 70)
    print("ENIGMA MACHINE EMULATOR")
    print("=" * 70)
    print()
    
    # Get rotor configuration from user
    print("Available rotors: I, II, III, IV, V")
    print("Enter rotor types (right, middle, left) separated by spaces")
    print("Example: I II III")
    rotor_input = input("Rotor types: ").strip().upper().split()
    
    while len(rotor_input) != 3 or not all(r in ['I', 'II', 'III', 'IV', 'V'] for r in rotor_input):
        print("Invalid input. Please enter exactly 3 rotors from: I, II, III, IV, V")
        rotor_input = input("Rotor types: ").strip().upper().split()
    
    rotor_types = tuple(rotor_input)
    
    # Get initial positions
    print("\nEnter initial rotor positions (3 letters A-Z)")
    print("Example: AAA or ABC")
    position_input = input("Initial positions: ").strip().upper()
    
    while len(position_input) != 3 or not all(c.isalpha() for c in position_input):
        print("Invalid input. Please enter exactly 3 letters A-Z")
        position_input = input("Initial positions: ").strip().upper()
    
    positions = tuple(ord(c) - ord('A') for c in position_input)
    
    # Get text to encode
    print("\nEnter text to encode (letters and spaces):")
    plaintext = input("Text: ").strip().upper()
    
    # Create Enigma machine
    enigma = EnigmaMachine(
        rotor_types=rotor_types,
        positions=positions,
        ring_settings=(0, 0, 0)
    )
    
    print("\n" + "=" * 70)
    print("CONFIGURATION")
    print("=" * 70)
    print(f"Rotors (R→M→L):        {rotor_types[0]}, {rotor_types[1]}, {rotor_types[2]}")
    print(f"Initial Positions:     {position_input}")
    print(f"Ring Settings:         AAA")
    print(f"Reflector:             B")
    
    # Encode character by character and track positions
    print("\n" + "=" * 70)
    print("ENCODING PROCESS")
    print("=" * 70)
    print(f"{'Step':<6} {'Input':<8} {'Rotor Pos':<12} {'Output':<8}")
    print("-" * 70)
    
    # Reset enigma for character-by-character encoding
    enigma = EnigmaMachine(rotor_types=rotor_types, positions=positions, ring_settings=(0, 0, 0))
    
    ciphertext = []
    step = 0
    
    for char in plaintext:
        if char.isalpha():
            step += 1
            input_char = char
            rotor_pos_before = enigma.get_rotor_positions()
            output_char = enigma.encode_char(char)
            ciphertext.append(output_char)
            print(f"{step:<6} {input_char:<8} {rotor_pos_before:<12} {output_char:<8}")
        else:
            ciphertext.append(char)
    
    final_positions = enigma.get_rotor_positions()
    ciphertext_str = ''.join(ciphertext)
    
    # Summary
    print("\n" + "=" * 70)
    print("RESULTS")
    print("=" * 70)
    print(f"Plaintext:             {plaintext}")
    print(f"Ciphertext:            {ciphertext_str}")
    print(f"Final Rotor Positions: {final_positions}")
    print(f"Characters Encoded:    {step}")
    
    # Verification
    print("\n" + "=" * 70)
    print("VERIFICATION (Decoding)")
    print("=" * 70)
    enigma_decode = EnigmaMachine(rotor_types=rotor_types, positions=positions, ring_settings=(0, 0, 0))
    decoded = enigma_decode.encode_text(ciphertext_str)
    print(f"Decoded Text:          {decoded}")
    print(f"Match:                 {'✓ SUCCESS' if decoded == plaintext else '✗ FAILED'}")
    print("=" * 70)

ENIGMA MACHINE EMULATOR

Available rotors: I, II, III, IV, V
Enter rotor types (right, middle, left) separated by spaces
Example: I II III


Rotor types:  I II III



Enter initial rotor positions (3 letters A-Z)
Example: AAA or ABC


Initial positions:  AAA



Enter text to encode (letters and spaces):


Text:  There once was a man who lived In a blue house



CONFIGURATION
Rotors (R→M→L):        I, II, III
Initial Positions:     AAA
Ring Settings:         AAA
Reflector:             B

ENCODING PROCESS
Step   Input    Rotor Pos    Output  
----------------------------------------------------------------------
1      T        AAA          Z       
2      H        BAA          P       
3      E        CAA          T       
4      R        DAA          Q       
5      E        EAA          P       
6      O        FAA          Q       
7      N        GAA          F       
8      C        HAA          N       
9      E        IAA          G       
10     W        JAA          N       
11     A        KAA          J       
12     S        LAA          H       
13     A        MAA          G       
14     M        NAA          V       
15     A        OAA          N       
16     N        PAA          Y       
17     W        QAA          R       
18     H        RBA          X       
19     O        SBA          A       
20     L        TBA    

## This is attempt at a model mady by claud to predict the Rotors

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import os

# Force file paths to be relative to the script location (works in both Python script & Jupyter)
try:
    BASE_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
    BASE_DIR = os.getcwd()

INPUT_TEXT_FILE = os.path.join(BASE_DIR, 'output_matrix_no_punctuation.txt')



# Training configuration
MESSAGE_LENGTH = 100  # How many characters to use per training sample
NUM_TRAINING_SAMPLES = 5000  # Number of encrypted samples to generate

# ============================================================================
# ENIGMA SIMULATOR (Simplified - no plugboard)
# ============================================================================
class SimpleEnigma:
    """Simplified Enigma machine simulator"""
    
    # Historical rotor wirings (I-V)
    ROTORS = {
        'I':   'EKMFLGDQVZNTOWYHXUSPAIBRCJ',
        'II':  'AJDKSIRUXBLHWTMCQGZNPYFVOE',
        'III': 'BDFHJLCPRTXVZNYEIWGAKMUSQO',
        'IV':  'ESOVPZJAYQUIRHXLNFTGKDCMWB',
        'V':   'VZBRGITYUPSDNHLXAWMJQOFECK'
    }
    
    # Notch positions (when rotor steps the next one)
    NOTCHES = {
        'I': 'Q', 'II': 'E', 'III': 'V', 'IV': 'J', 'V': 'Z'
    }
    
    REFLECTOR = 'YRUHQSLDPXNGOKMIEBFZCWVJAT'
    
    def __init__(self, rotors, positions):
        """
        rotors: list of 3 rotor names, e.g., ['III', 'II', 'I']
        positions: list of 3 starting positions (0-25)
        """
        self.rotors = [self.ROTORS[r] for r in rotors]
        self.rotor_names = rotors
        self.positions = positions.copy()
        self.initial_positions = positions.copy()
    
    def reset(self):
        """Reset rotor positions to initial state"""
        self.positions = self.initial_positions.copy()
    
    def step_rotors(self):
        """Advance rotors according to Enigma stepping rules"""
        # Check for double-stepping of middle rotor
        if self.rotor_at_notch(1):
            self.positions[1] = (self.positions[1] + 1) % 26
            self.positions[2] = (self.positions[2] + 1) % 26
        elif self.rotor_at_notch(0):
            self.positions[1] = (self.positions[1] + 1) % 26
        
        # Always step the rightmost rotor
        self.positions[0] = (self.positions[0] + 1) % 26
    
    def rotor_at_notch(self, rotor_index):
        """Check if rotor is at its notch position"""
        notch = self.NOTCHES[self.rotor_names[rotor_index]]
        notch_pos = ord(notch) - ord('A')
        return self.positions[rotor_index] == notch_pos
    
    def encrypt_char(self, char):
        """Encrypt a single character"""
        if char not in string.ascii_uppercase:
            return char
        
        # Step rotors before encryption
        self.step_rotors()
        
        # Convert char to number (A=0, B=1, ...)
        pos = ord(char) - ord('A')
        
        # Forward through rotors (right to left)
        for i in range(3):
            pos = (pos + self.positions[i]) % 26
            pos = ord(self.rotors[i][pos]) - ord('A')
            pos = (pos - self.positions[i]) % 26
        
        # Through reflector
        pos = ord(self.REFLECTOR[pos]) - ord('A')
        
        # Backward through rotors (left to right)
        for i in range(2, -1, -1):
            pos = (pos + self.positions[i]) % 26
            pos = self.rotors[i].index(chr(pos + ord('A')))
            pos = (pos - self.positions[i]) % 26
        
        return chr(pos + ord('A'))
    
    def encrypt(self, text):
        """Encrypt a full message"""
        self.reset()
        result = []
        for char in text.upper():
            result.append(self.encrypt_char(char))
        return ''.join(result)

# ============================================================================
# FILE READING AND DATA PREPARATION
# ============================================================================
def read_text_from_file(filename):
    """
    Read text from a file with one letter per line
    Returns a string of uppercase letters only
    """
    if not os.path.exists(filename):
        raise FileNotFoundError(f"Input file '{filename}' not found!")
    
    letters = []
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip().upper()
            # Only keep valid letters
            for char in line:
                if char in string.ascii_uppercase:
                    letters.append(char)
    
    text = ''.join(letters)
    print(f"Read {len(text)} letters from '{filename}'")
    return text

def generate_dataset_from_file(filename, rotor_order, initial_positions, 
                                num_samples=5000, message_length=100):
    """
    Generate training dataset by encrypting text from file
    
    Args:
        filename: Path to input text file
        rotor_order: List of 3 rotor names (e.g., ['III', 'II', 'I'])
        initial_positions: List of 3 starting positions (e.g., [0, 0, 0] for AAA)
        num_samples: Number of training samples to generate
        message_length: Length of each training sample
    """
    # Read plaintext from file
    full_plaintext = read_text_from_file(filename)
    
    if len(full_plaintext) < message_length:
        raise ValueError(f"File contains only {len(full_plaintext)} letters, but need at least {message_length}")
    
    data = []
    print(f"\nGenerating {num_samples} training samples...")
    print(f"Rotor order: {rotor_order} (Left-Middle-Right)")
    print(f"Initial positions: {initial_positions} ({chr(initial_positions[0]+65)}{chr(initial_positions[1]+65)}{chr(initial_positions[2]+65)})")
    
    # Create enigma machine with fixed configuration
    enigma = SimpleEnigma(rotor_order, initial_positions)
    
    for i in range(num_samples):
        if (i + 1) % 1000 == 0:
            print(f"  Generated {i + 1}/{num_samples} samples")
        
        # Extract a chunk of text from the file
        # Use modulo to wrap around if we need more samples than file length allows
        start_idx = (i * message_length) % (len(full_plaintext) - message_length)
        plaintext = full_plaintext[start_idx:start_idx + message_length]
        
        # Encrypt the plaintext
        ciphertext = enigma.encrypt(plaintext)
        
        data.append({
            'ciphertext': ciphertext,
            'plaintext': plaintext,
            'rotors': rotor_order,
            'positions': initial_positions
        })
    
    return data

# ============================================================================
# PYTORCH DATASET
# ============================================================================
class EnigmaDataset(Dataset):
    """PyTorch dataset for Enigma encrypted messages"""
    
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Convert ciphertext to one-hot encoded tensor
        ciphertext_encoded = self.encode_text(item['ciphertext'])
        
        # Convert positions to tensor (3 values, each 0-25)
        positions = torch.tensor(item['positions'], dtype=torch.long)
        
        return ciphertext_encoded, positions
    
    @staticmethod
    def encode_text(text):
        """Convert text to one-hot encoded tensor"""
        # Create tensor of shape (seq_len, 26)
        encoded = torch.zeros(len(text), 26)
        for i, char in enumerate(text):
            if char in string.ascii_uppercase:
                encoded[i, ord(char) - ord('A')] = 1
        return encoded

# ============================================================================
# NEURAL NETWORK MODEL
# ============================================================================
class EnigmaRotorClassifier(nn.Module):
    """
    CNN-based model to predict rotor positions from encrypted text
    
    OPTIMIZATION NOTES:
    - To change the number of convolutional layers, add/remove conv layers
    - To change the number of nodes in conv layers, modify the channel numbers
      (currently: 64 -> 128 -> 256)
    - To change the number of fully connected layers, add/remove fc layers
    - To change the number of nodes in FC layers, modify hidden_dim parameter
      (currently: 256)
    """
    
    def __init__(self, message_length=100, hidden_dim=256):
        """
        Args:
            message_length: Length of input message
            hidden_dim: Number of nodes in fully connected layers
                       *** CHANGE THIS TO ADJUST FC LAYER SIZE ***
        """
        super(EnigmaRotorClassifier, self).__init__()
        
        # ====================================================================
        # CONVOLUTIONAL LAYERS - MODIFY NUMBER OF LAYERS AND NODES HERE
        # ====================================================================
        # Current architecture: 3 conv layers with 64, 128, 256 channels
        # To add more layers: add self.conv4, self.batch_norm4, etc.
        # To change nodes: modify the channel numbers (e.g., 64 -> 128)
        
        self.conv1 = nn.Conv1d(26, 64, kernel_size=3, padding=1)      # *** 64 nodes ***
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)     # *** 128 nodes ***
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)    # *** 256 nodes ***
        
        self.pool = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.3)
        
        # Calculate flattened size after convolutions
        # NOTE: If you change number of conv layers, adjust the divisor
        # Currently: 3 conv layers with pooling = divide by 2^3 = 8
        conv_output_size = message_length // 8 * 256  # 256 is last conv layer size
        
        # ====================================================================
        # FULLY CONNECTED LAYERS - MODIFY NUMBER OF LAYERS AND NODES HERE
        # ====================================================================
        # Current architecture: 2 FC layers with hidden_dim nodes each
        # To add more layers: add self.fc3, self.fc4, etc. and update forward()
        # To change nodes: modify hidden_dim parameter when creating model
        
        self.fc1 = nn.Linear(conv_output_size, hidden_dim)  # *** hidden_dim nodes ***
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)        # *** hidden_dim nodes ***
        
        # Three output heads (one for each rotor position, 26 classes each)
        self.rotor1_head = nn.Linear(hidden_dim, 26)
        self.rotor2_head = nn.Linear(hidden_dim, 26)
        self.rotor3_head = nn.Linear(hidden_dim, 26)
        
        self.relu = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(64)
        self.batch_norm2 = nn.BatchNorm1d(128)
        self.batch_norm3 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        # x shape: (batch_size, seq_len, 26)
        # Conv1d expects (batch_size, channels, seq_len)
        x = x.transpose(1, 2)
        
        # ====================================================================
        # FORWARD PASS THROUGH CONVOLUTIONAL LAYERS
        # If you add/remove conv layers, update this section
        # ====================================================================
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.pool(x)
        
        x = self.relu(self.batch_norm2(self.conv2(x)))
        x = self.pool(x)
        
        x = self.relu(self.batch_norm3(self.conv3(x)))
        x = self.pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        
        # ====================================================================
        # FORWARD PASS THROUGH FULLY CONNECTED LAYERS
        # If you add/remove FC layers, update this section
        # ====================================================================
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        
        # Output predictions for each rotor
        rotor1_out = self.rotor1_head(x)
        rotor2_out = self.rotor2_head(x)
        rotor3_out = self.rotor3_head(x)
        
        return rotor1_out, rotor2_out, rotor3_out

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train_model(model, train_loader, val_loader, num_epochs=20, device='cuda'):
    """Train the Enigma rotor classifier"""
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = [0, 0, 0]
        train_total = 0
        
        for batch_idx, (ciphertext, positions) in enumerate(train_loader):
            ciphertext = ciphertext.to(device)
            positions = positions.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            out1, out2, out3 = model(ciphertext)
            
            # Calculate loss for each rotor
            loss1 = criterion(out1, positions[:, 0])
            loss2 = criterion(out2, positions[:, 1])
            loss3 = criterion(out3, positions[:, 2])
            loss = loss1 + loss2 + loss3
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Calculate accuracy
            _, pred1 = torch.max(out1, 1)
            _, pred2 = torch.max(out2, 1)
            _, pred3 = torch.max(out3, 1)
            
            train_correct[0] += (pred1 == positions[:, 0]).sum().item()
            train_correct[1] += (pred2 == positions[:, 1]).sum().item()
            train_correct[2] += (pred3 == positions[:, 2]).sum().item()
            train_total += positions.size(0)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = [0, 0, 0]
        val_total = 0
        
        with torch.no_grad():
            for ciphertext, positions in val_loader:
                ciphertext = ciphertext.to(device)
                positions = positions.to(device)
                
                out1, out2, out3 = model(ciphertext)
                
                loss1 = criterion(out1, positions[:, 0])
                loss2 = criterion(out2, positions[:, 1])
                loss3 = criterion(out3, positions[:, 2])
                loss = loss1 + loss2 + loss3
                
                val_loss += loss.item()
                
                _, pred1 = torch.max(out1, 1)
                _, pred2 = torch.max(out2, 1)
                _, pred3 = torch.max(out3, 1)
                
                val_correct[0] += (pred1 == positions[:, 0]).sum().item()
                val_correct[1] += (pred2 == positions[:, 1]).sum().item()
                val_correct[2] += (pred3 == positions[:, 2]).sum().item()
                val_total += positions.size(0)
        
        # Calculate accuracies
        train_acc = [(c / train_total) * 100 for c in train_correct]
        val_acc = [(c / val_total) * 100 for c in val_correct]
        avg_val_acc = sum(val_acc) / 3
        
        scheduler.step(val_loss)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Train Acc - R1: {train_acc[0]:.2f}%, R2: {train_acc[1]:.2f}%, R3: {train_acc[2]:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}')
        print(f'Val Acc - R1: {val_acc[0]:.2f}%, R2: {val_acc[1]:.2f}%, R3: {val_acc[2]:.2f}%')
        
        # Save best model
        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            torch.save(model.state_dict(), 'best_enigma_model.pth')
            print(f'✓ Saved best model with avg accuracy: {avg_val_acc:.2f}%')
    
    return model

# ============================================================================
# PREDICTION AND TESTING
# ============================================================================
def predict_positions(model, ciphertext, device='cuda'):
    """Predict rotor positions from encrypted message"""
    model.eval()
    
    # Encode ciphertext
    encoded = EnigmaDataset.encode_text(ciphertext).unsqueeze(0).to(device)
    
    with torch.no_grad():
        out1, out2, out3 = model(encoded)
        
        _, pred1 = torch.max(out1, 1)
        _, pred2 = torch.max(out2, 1)
        _, pred3 = torch.max(out3, 1)
        
        # Get confidence scores
        probs1 = torch.softmax(out1, dim=1)[0]
        probs2 = torch.softmax(out2, dim=1)[0]
        probs3 = torch.softmax(out3, dim=1)[0]
        
        confidence1 = probs1[pred1].item() * 100
        confidence2 = probs2[pred2].item() * 100
        confidence3 = probs3[pred3].item() * 100
    
    predictions = [pred1.item(), pred2.item(), pred3.item()]
    confidences = [confidence1, confidence2, confidence3]
    
    return predictions, confidences

# ============================================================================
# MAIN EXECUTION
# ============================================================================
if __name__ == '__main__':
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}\n')
    
    # Generate dataset from file
    print('=' * 60)
    print('GENERATING DATASET FROM FILE')
    print('=' * 60)
    
    try:
        dataset = generate_dataset_from_file(
            filename=INPUT_TEXT_FILE,
            rotor_order=DEFAULT_ROTOR_ORDER,
            initial_positions=DEFAULT_POSITIONS,
            num_samples=NUM_TRAINING_SAMPLES,
            message_length=MESSAGE_LENGTH
        )
    except FileNotFoundError:
        print(f"\nERROR: Input file '{INPUT_TEXT_FILE}' not found!")
        print("Please create this file with one letter per line.")
        print("Example content:")
        print("A")
        print("B")
        print("C")
        print("...")
        exit(1)
    
    # Split into train/val
    split_idx = int(0.8 * len(dataset))
    train_data = dataset[:split_idx]
    val_data = dataset[split_idx:]
    
    train_dataset = EnigmaDataset(train_data)
    val_dataset = EnigmaDataset(val_data)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    print(f'\nTraining samples: {len(train_data)}')
    print(f'Validation samples: {len(val_data)}')
    
    # Create model
    print('\n' + '=' * 60)
    print('CREATING MODEL')
    print('=' * 60)
    
    # ========================================================================
    # MODEL CREATION - CHANGE hidden_dim TO ADJUST FC LAYER SIZE
    # ========================================================================
    model = EnigmaRotorClassifier(
        message_length=MESSAGE_LENGTH,
        hidden_dim=256  # *** CHANGE THIS TO MODIFY FC LAYER SIZE ***
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters: {total_params:,}')
    
    # Train model
    print('\n' + '=' * 60)
    print('TRAINING MODEL')
    print('=' * 60)
    model = train_model(model, train_loader, val_loader, num_epochs=20, device=device)
    
    # Test predictions
    print('\n' + '=' * 60)
    print('TESTING PREDICTIONS')
    print('=' * 60)
    
    # Create test message using the same configuration
    enigma = SimpleEnigma(DEFAULT_ROTOR_ORDER, DEFAULT_POSITIONS)
    test_plaintext = dataset[0]['plaintext']  # Use first sample as test
    test_ciphertext = enigma.encrypt(test_plaintext)
    
    print(f'\nTest Configuration:')
    print(f'  Rotors: {DEFAULT_ROTOR_ORDER} (Left-Middle-Right)')
    print(f'  True positions: {DEFAULT_POSITIONS} ({chr(DEFAULT_POSITIONS[0]+65)}{chr(DEFAULT_POSITIONS[1]+65)}{chr(DEFAULT_POSITIONS[2]+65)})')
    print(f'  Plaintext: {test_plaintext[:50]}...')
    print(f'  Ciphertext: {test_ciphertext[:50]}...')
    
    predictions, confidences = predict_positions(model, test_ciphertext, device)
    
    print(f'\nPredictions:')
    print(f'  Rotor 1: {predictions[0]} ({chr(predictions[0]+65)}) - Confidence: {confidences[0]:.1f}%')
    print(f'  Rotor 2: {predictions[1]} ({chr(predictions[1]+65)}) - Confidence: {confidences[1]:.1f}%')
    print(f'  Rotor 3: {predictions[2]} ({chr(predictions[2]+65)}) - Confidence: {confidences[2]:.1f}%')
    
    print(f'\nAccuracy:')
    print(f'  Rotor 1: {"✓ CORRECT" if predictions[0] == DEFAULT_POSITIONS[0] else "✗ INCORRECT"}')
    print(f'  Rotor 2: {"✓ CORRECT" if predictions[1] == DEFAULT_POSITIONS[1] else "✗ INCORRECT"}')
    print(f'  Rotor 3: {"✓ CORRECT" if predictions[2] == DEFAULT_POSITIONS[2] else "✗ INCORRECT"}')
    
    print('\n' + '=' * 60)
    print('TRAINING COMPLETE')
    print('=' * 60)
    print('Model saved as: best_enigma_model.pth')

Using device: cuda

GENERATING DATASET FROM FILE
Read 37231 letters from 'C:\Users\acool\1 Capstone\Engima-Capstone-main\output_matrix_no_punctuation.txt'

Generating 5000 training samples...
Rotor order: ['III', 'II', 'I'] (Left-Middle-Right)
Initial positions: [0, 0, 0] (AAA)
  Generated 1000/5000 samples
  Generated 2000/5000 samples
  Generated 3000/5000 samples
  Generated 4000/5000 samples
  Generated 5000/5000 samples

Training samples: 4000
Validation samples: 1000

CREATING MODEL
Total parameters: 1,001,742

TRAINING MODEL

Epoch 1/20
Train Loss: 0.1179
Train Acc - R1: 99.17%, R2: 99.33%, R3: 99.25%
Val Loss: 0.0000
Val Acc - R1: 100.00%, R2: 100.00%, R3: 100.00%
✓ Saved best model with avg accuracy: 100.00%

Epoch 2/20
Train Loss: 0.0000
Train Acc - R1: 100.00%, R2: 100.00%, R3: 100.00%
Val Loss: 0.0000
Val Acc - R1: 100.00%, R2: 100.00%, R3: 100.00%

Epoch 3/20
Train Loss: 0.0000
Train Acc - R1: 100.00%, R2: 100.00%, R3: 100.00%
Val Loss: 0.0000
Val Acc - R1: 100.00%, R2: 10

# Archive

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import string

# ============================================================================
# ENIGMA SIMULATOR (Simplified - no plugboard)
# ============================================================================

class SimpleEnigma:
    """Simplified Enigma machine simulator for demonstration"""
    
    # Historical rotor wirings (I-V)
    ROTORS = {
        'I':   'EKMFLGDQVZNTOWYHXUSPAIBRCJ',
        'II':  'AJDKSIRUXBLHWTMCQGZNPYFVOE',
        'III': 'BDFHJLCPRTXVZNYEIWGAKMUSQO',
        'IV':  'ESOVPZJAYQUIRHXLNFTGKDCMWB',
        'V':   'VZBRGITYUPSDNHLXAWMJQOFECK'
    }
    
    # Notch positions (when rotor steps the next one)
    NOTCHES = {
        'I': 'Q', 'II': 'E', 'III': 'V', 'IV': 'J', 'V': 'Z'
    }
    
    REFLECTOR = 'YRUHQSLDPXNGOKMIEBFZCWVJAT'
    
    def __init__(self, rotors, positions):
        """
        rotors: list of 3 rotor names, e.g., ['I', 'II', 'III']
        positions: list of 3 starting positions (0-25)
        """
        self.rotors = [self.ROTORS[r] for r in rotors]
        self.rotor_names = rotors
        self.positions = positions.copy()
        self.initial_positions = positions.copy()
    
    def reset(self):
        """Reset rotor positions to initial state"""
        self.positions = self.initial_positions.copy()
    
    def step_rotors(self):
        """Advance rotors according to Enigma stepping rules"""
        # Check for double-stepping of middle rotor
        if self.rotor_at_notch(1):
            self.positions[1] = (self.positions[1] + 1) % 26
            self.positions[2] = (self.positions[2] + 1) % 26
        elif self.rotor_at_notch(0):
            self.positions[1] = (self.positions[1] + 1) % 26
        
        # Always step the rightmost rotor
        self.positions[0] = (self.positions[0] + 1) % 26
    
    def rotor_at_notch(self, rotor_index):
        """Check if rotor is at its notch position"""
        notch = self.NOTCHES[self.rotor_names[rotor_index]]
        notch_pos = ord(notch) - ord('A')
        return self.positions[rotor_index] == notch_pos
    
    def encrypt_char(self, char):
        """Encrypt a single character"""
        if char not in string.ascii_uppercase:
            return char
        
        # Step rotors before encryption
        self.step_rotors()
        
        # Convert char to number (A=0, B=1, ...)
        pos = ord(char) - ord('A')
        
        # Forward through rotors (right to left)
        for i in range(3):
            pos = (pos + self.positions[i]) % 26
            pos = ord(self.rotors[i][pos]) - ord('A')
            pos = (pos - self.positions[i]) % 26
        
        # Through reflector
        pos = ord(self.REFLECTOR[pos]) - ord('A')
        
        # Backward through rotors (left to right)
        for i in range(2, -1, -1):
            pos = (pos + self.positions[i]) % 26
            pos = self.rotors[i].index(chr(pos + ord('A')))
            pos = (pos - self.positions[i]) % 26
        
        return chr(pos + ord('A'))
    
    def encrypt(self, text):
        """Encrypt a full message"""
        self.reset()
        result = []
        for char in text.upper():
            result.append(self.encrypt_char(char))
        return ''.join(result)


# ============================================================================
# DATA GENERATION
# ============================================================================

def generate_random_text(length=100):
    """Generate random English-like text"""
    # Use more common letters for realistic distribution
    common_letters = 'ETAOINSHRDLCUMWFGYPBVKJXQZ'
    weights = [12.7, 9.1, 8.2, 7.5, 7.0, 6.7, 6.3, 6.1, 6.0, 5.9,
               4.3, 4.0, 2.8, 2.8, 2.4, 2.4, 2.2, 2.0, 2.0, 1.9,
               1.5, 1.0, 0.15, 0.15, 0.10, 0.07]
    
    return ''.join(random.choices(common_letters, weights=weights, k=length))


def generate_dataset(num_samples=5000, message_length=100):
    """Generate training dataset with encrypted messages and rotor positions"""
    rotor_names = ['I', 'II', 'III', 'IV', 'V']
    data = []
    
    print(f"Generating {num_samples} training samples...")
    
    for i in range(num_samples):
        if (i + 1) % 1000 == 0:
            print(f"  Generated {i + 1}/{num_samples} samples")
        
        # Random rotor configuration
        rotors = random.sample(rotor_names, 3)
        positions = [random.randint(0, 25) for _ in range(3)]
        
        # Generate and encrypt message
        plaintext = generate_random_text(message_length)
        enigma = SimpleEnigma(rotors, positions)
        ciphertext = enigma.encrypt(plaintext)
        
        data.append({
            'ciphertext': ciphertext,
            'plaintext': plaintext,
            'rotors': rotors,
            'positions': positions
        })
    
    return data


# ============================================================================
# PYTORCH DATASET
# ============================================================================

class EnigmaDataset(Dataset):
    """PyTorch dataset for Enigma encrypted messages"""
    
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Convert ciphertext to one-hot encoded tensor
        ciphertext_encoded = self.encode_text(item['ciphertext'])
        
        # Convert positions to tensor (3 values, each 0-25)
        positions = torch.tensor(item['positions'], dtype=torch.long)
        
        return ciphertext_encoded, positions
    
    @staticmethod
    def encode_text(text):
        """Convert text to one-hot encoded tensor"""
        # Create tensor of shape (seq_len, 26)
        encoded = torch.zeros(len(text), 26)
        for i, char in enumerate(text):
            if char in string.ascii_uppercase:
                encoded[i, ord(char) - ord('A')] = 1
        return encoded


# ============================================================================
# NEURAL NETWORK MODEL
# ============================================================================

class EnigmaRotorClassifier(nn.Module):
    """
    CNN-based model to predict rotor positions from encrypted text
    """
    
    def __init__(self, message_length=100, embedding_dim=128, hidden_dim=256):
        super(EnigmaRotorClassifier, self).__init__()
        
        # Convolutional layers to extract patterns from ciphertext
        self.conv1 = nn.Conv1d(26, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        
        self.pool = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.3)
        
        # Calculate flattened size after convolutions
        conv_output_size = message_length // 8 * 256
        
        # Fully connected layers
        self.fc1 = nn.Linear(conv_output_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        
        # Three output heads (one for each rotor position, 26 classes each)
        self.rotor1_head = nn.Linear(hidden_dim, 26)
        self.rotor2_head = nn.Linear(hidden_dim, 26)
        self.rotor3_head = nn.Linear(hidden_dim, 26)
        
        self.relu = nn.ReLU()
        self.batch_norm1 = nn.BatchNorm1d(64)
        self.batch_norm2 = nn.BatchNorm1d(128)
        self.batch_norm3 = nn.BatchNorm1d(256)
    
    def forward(self, x):
        # x shape: (batch_size, seq_len, 26)
        # Conv1d expects (batch_size, channels, seq_len)
        x = x.transpose(1, 2)
        
        # Convolutional feature extraction
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.pool(x)
        
        x = self.relu(self.batch_norm2(self.conv2(x)))
        x = self.pool(x)
        
        x = self.relu(self.batch_norm3(self.conv3(x)))
        x = self.pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        
        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        
        # Output predictions for each rotor
        rotor1_out = self.rotor1_head(x)
        rotor2_out = self.rotor2_head(x)
        rotor3_out = self.rotor3_head(x)
        
        return rotor1_out, rotor2_out, rotor3_out


# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

def train_model(model, train_loader, val_loader, num_epochs=20, device='cuda'):
    """Train the Enigma rotor classifier"""
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = [0, 0, 0]
        train_total = 0
        
        for batch_idx, (ciphertext, positions) in enumerate(train_loader):
            ciphertext = ciphertext.to(device)
            positions = positions.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            out1, out2, out3 = model(ciphertext)
            
            # Calculate loss for each rotor
            loss1 = criterion(out1, positions[:, 0])
            loss2 = criterion(out2, positions[:, 1])
            loss3 = criterion(out3, positions[:, 2])
            loss = loss1 + loss2 + loss3
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # Calculate accuracy
            _, pred1 = torch.max(out1, 1)
            _, pred2 = torch.max(out2, 1)
            _, pred3 = torch.max(out3, 1)
            
            train_correct[0] += (pred1 == positions[:, 0]).sum().item()
            train_correct[1] += (pred2 == positions[:, 1]).sum().item()
            train_correct[2] += (pred3 == positions[:, 2]).sum().item()
            train_total += positions.size(0)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = [0, 0, 0]
        val_total = 0
        
        with torch.no_grad():
            for ciphertext, positions in val_loader:
                ciphertext = ciphertext.to(device)
                positions = positions.to(device)
                
                out1, out2, out3 = model(ciphertext)
                
                loss1 = criterion(out1, positions[:, 0])
                loss2 = criterion(out2, positions[:, 1])
                loss3 = criterion(out3, positions[:, 2])
                loss = loss1 + loss2 + loss3
                
                val_loss += loss.item()
                
                _, pred1 = torch.max(out1, 1)
                _, pred2 = torch.max(out2, 1)
                _, pred3 = torch.max(out3, 1)
                
                val_correct[0] += (pred1 == positions[:, 0]).sum().item()
                val_correct[1] += (pred2 == positions[:, 1]).sum().item()
                val_correct[2] += (pred3 == positions[:, 2]).sum().item()
                val_total += positions.size(0)
        
        # Calculate accuracies
        train_acc = [(c / train_total) * 100 for c in train_correct]
        val_acc = [(c / val_total) * 100 for c in val_correct]
        avg_val_acc = sum(val_acc) / 3
        
        scheduler.step(val_loss)
        
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}')
        print(f'Train Acc - R1: {train_acc[0]:.2f}%, R2: {train_acc[1]:.2f}%, R3: {train_acc[2]:.2f}%')
        print(f'Val Loss: {val_loss/len(val_loader):.4f}')
        print(f'Val Acc - R1: {val_acc[0]:.2f}%, R2: {val_acc[1]:.2f}%, R3: {val_acc[2]:.2f}%')
        
        # Save best model
        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            torch.save(model.state_dict(), 'best_enigma_model.pth')
            print(f'✓ Saved best model with avg accuracy: {avg_val_acc:.2f}%')
    
    return model


# ============================================================================
# PREDICTION AND TESTING
# ============================================================================

def predict_positions(model, ciphertext, device='cuda'):
    """Predict rotor positions from encrypted message"""
    model.eval()
    
    # Encode ciphertext
    encoded = EnigmaDataset.encode_text(ciphertext).unsqueeze(0).to(device)
    
    with torch.no_grad():
        out1, out2, out3 = model(encoded)
        
        _, pred1 = torch.max(out1, 1)
        _, pred2 = torch.max(out2, 1)
        _, pred3 = torch.max(out3, 1)
        
        # Get confidence scores
        probs1 = torch.softmax(out1, dim=1)[0]
        probs2 = torch.softmax(out2, dim=1)[0]
        probs3 = torch.softmax(out3, dim=1)[0]
        
        confidence1 = probs1[pred1].item() * 100
        confidence2 = probs2[pred2].item() * 100
        confidence3 = probs3[pred3].item() * 100
    
    predictions = [pred1.item(), pred2.item(), pred3.item()]
    confidences = [confidence1, confidence2, confidence3]
    
    return predictions, confidences


# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == '__main__':
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}\n')
    
    # Generate dataset
    print('=' * 60)
    print('GENERATING DATASET')
    print('=' * 60)
    dataset = generate_dataset(num_samples=5000, message_length=100)
    
    # Split into train/val
    split_idx = int(0.8 * len(dataset))
    train_data = dataset[:split_idx]
    val_data = dataset[split_idx:]
    
    train_dataset = EnigmaDataset(train_data)
    val_dataset = EnigmaDataset(val_data)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    
    print(f'\nTraining samples: {len(train_data)}')
    print(f'Validation samples: {len(val_data)}')
    
    # Create model
    print('\n' + '=' * 60)
    print('CREATING MODEL')
    print('=' * 60)
    model = EnigmaRotorClassifier(message_length=100)
    
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters: {total_params:,}')
    
    # Train model
    print('\n' + '=' * 60)
    print('TRAINING MODEL')
    print('=' * 60)
    model = train_model(model, train_loader, val_loader, num_epochs=20, device=device)
    
    # Test predictions
    print('\n' + '=' * 60)
    print('TESTING PREDICTIONS')
    print('=' * 60)
    
    # Create test message
    test_rotors = ['I', 'II', 'III']
    test_positions = [5, 12, 20]  # F, M, U
    enigma = SimpleEnigma(test_rotors, test_positions)
    test_plaintext = generate_random_text(100)
    test_ciphertext = enigma.encrypt(test_plaintext)
    
    print(f'\nTest Configuration:')
    print(f'  Rotors: {test_rotors}')
    print(f'  True positions: {test_positions} ({chr(test_positions[0]+65)}, {chr(test_positions[1]+65)}, {chr(test_positions[2]+65)})')
    print(f'  Plaintext: {test_plaintext[:50]}...')
    print(f'  Ciphertext: {test_ciphertext[:50]}...')
    
    predictions, confidences = predict_positions(model, test_ciphertext, device)
    
    print(f'\nPredictions:')
    print(f'  Rotor 1: {predictions[0]} ({chr(predictions[0]+65)}) - Confidence: {confidences[0]:.1f}%')
    print(f'  Rotor 2: {predictions[1]} ({chr(predictions[1]+65)}) - Confidence: {confidences[1]:.1f}%')
    print(f'  Rotor 3: {predictions[2]} ({chr(predictions[2]+65)}) - Confidence: {confidences[2]:.1f}%')
    
    print(f'\nAccuracy:')
    print(f'  Rotor 1: {"✓ CORRECT" if predictions[0] == test_positions[0] else "✗ INCORRECT"}')
    print(f'  Rotor 2: {"✓ CORRECT" if predictions[1] == test_positions[1] else "✗ INCORRECT"}')
    print(f'  Rotor 3: {"✓ CORRECT" if predictions[2] == test_positions[2] else "✗ INCORRECT"}')
    
    print('\n' + '=' * 60)
    print('TRAINING COMPLETE')
    print('=' * 60)
    print('Model saved as: best_enigma_model.pth')

Using device: cuda

GENERATING DATASET
Generating 5000 training samples...
  Generated 1000/5000 samples
  Generated 2000/5000 samples
  Generated 3000/5000 samples
  Generated 4000/5000 samples
  Generated 5000/5000 samples

Training samples: 4000
Validation samples: 1000

CREATING MODEL
Total parameters: 1,001,742

TRAINING MODEL

Epoch 1/20
Train Loss: 9.8075
Train Acc - R1: 4.15%, R2: 4.23%, R3: 4.17%
Val Loss: 9.7748
Val Acc - R1: 3.40%, R2: 3.40%, R3: 4.00%
✓ Saved best model with avg accuracy: 3.60%

Epoch 2/20
Train Loss: 9.7746
Train Acc - R1: 4.35%, R2: 4.17%, R3: 4.38%
Val Loss: 9.7742
Val Acc - R1: 3.40%, R2: 3.40%, R3: 4.00%

Epoch 3/20
Train Loss: 9.7726
Train Acc - R1: 4.28%, R2: 4.30%, R3: 4.08%
Val Loss: 9.7751
Val Acc - R1: 3.40%, R2: 2.90%, R3: 3.70%

Epoch 4/20
Train Loss: 9.7717
Train Acc - R1: 4.15%, R2: 4.42%, R3: 4.47%
Val Loss: 9.7752
Val Acc - R1: 3.40%, R2: 2.90%, R3: 3.70%

Epoch 5/20
Train Loss: 9.7679
Train Acc - R1: 4.35%, R2: 4.15%, R3: 4.00%
Val Loss: 9