In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
import string
import random

In [34]:
KEYWORD="CRYPTO"
ALPHABET = string.ascii_uppercase.replace('J', '') 
ALPHABET_SIZE = len(ALPHABET)

In [35]:
def generate_playfair_matrix(keyword):
    seen = set()
    matrix = []
    for char in keyword.upper():
        if char not in seen and char in ALPHABET:
            seen.add(char)
            matrix.append(char)
    for char in ALPHABET:
        if char not in seen:
            matrix.append(char)
    return matrix

In [36]:
def prepare_pairs(plaintext):
    plaintext = plaintext.upper().replace(" ", "").replace("J", "I")
    pairs = []
    i = 0
    while i < len(plaintext):
        a = plaintext[i]
        if i + 1 < len(plaintext):
            b = plaintext[i + 1]
            if a == b:
                pairs.append((a, 'X'))  
                i += 1
            else:
                pairs.append((a, b))
                i += 2
        else:
            pairs.append((a, 'X'))  
            i += 1
    return pairs

In [37]:
def encrypt_playfair(plaintext, keyword):
    matrix = generate_playfair_matrix(keyword)
    pairs = prepare_pairs(plaintext)
    ciphertext = ''

    position = {char: (i // 5, i % 5) for i, char in enumerate(matrix)}

    for a, b in pairs:
        row_a, col_a = position[a]
        row_b, col_b = position[b]
        
        if row_a == row_b:  
            ciphertext += matrix[row_a * 5 + (col_a + 1) % 5]
            ciphertext += matrix[row_b * 5 + (col_b + 1) % 5]
        elif col_a == col_b:  
            ciphertext += matrix[((row_a + 1) % 5) * 5 + col_a]
            ciphertext += matrix[((row_b + 1) % 5) * 5 + col_b]
        else: 
            ciphertext += matrix[row_a * 5 + col_b]
            ciphertext += matrix[row_b * 5 + col_a]
    
    return ciphertext

In [38]:
def generate_data(num_samples=1000, max_length=10):
    data_list = []
    for _ in range(num_samples):
        length = random.randint(1, max_length)
        plaintext = ''.join(random.choices(ALPHABET, k=length))
        keyword = KEYWORD  
        ciphertext = encrypt_playfair(plaintext, keyword)

        x = []
        for char in plaintext:
            one_hot = [0] * ALPHABET_SIZE
            one_hot[ALPHABET.index(char)] = 1
            x.append(one_hot)
        x = torch.tensor(x, dtype=torch.float)

        edge_index = []
        for i in range(len(plaintext) - 1):
            edge_index.append([i, i + 1])
            edge_index.append([i + 1, i])  
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

        y = []
        pairs = prepare_pairs(plaintext)
        for a, b in pairs:
            encrypted_pair = encrypt_playfair(a + b, keyword)
            y.extend([ALPHABET.index(char) for char in encrypted_pair])

        y = torch.tensor(y, dtype=torch.long)

        if len(y) != len(x):  
            continue  

        data = Data(x=x, edge_index=edge_index, y=y)
        data_list.append(data)
    return data_list


In [39]:
dataset = generate_data(num_samples=2000, max_length=10)
train_size = int(0.8 * len(dataset))
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



In [40]:
class CipherGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(CipherGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return x

In [41]:
# Hyperparameters
INPUT_DIM = ALPHABET_SIZE
HIDDEN_DIM = 64
OUTPUT_DIM = ALPHABET_SIZE
EPOCHS = 50
LEARNING_RATE = 0.01

In [42]:
model = CipherGNN(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [43]:
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        out = out.view(-1, OUTPUT_DIM)
        y = batch.y.view(-1)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

Epoch 1, Loss: 3.0266
Epoch 2, Loss: 2.1871
Epoch 3, Loss: 1.8055
Epoch 4, Loss: 1.6322
Epoch 5, Loss: 1.5315
Epoch 6, Loss: 1.4801
Epoch 7, Loss: 1.4288
Epoch 8, Loss: 1.3769
Epoch 9, Loss: 1.3510
Epoch 10, Loss: 1.3145
Epoch 11, Loss: 1.2886
Epoch 12, Loss: 1.2474
Epoch 13, Loss: 1.2047
Epoch 14, Loss: 1.1817
Epoch 15, Loss: 1.1529
Epoch 16, Loss: 1.1293
Epoch 17, Loss: 1.1272
Epoch 18, Loss: 1.0910
Epoch 19, Loss: 1.0761
Epoch 20, Loss: 1.0394
Epoch 21, Loss: 1.0223
Epoch 22, Loss: 1.0076
Epoch 23, Loss: 0.9862
Epoch 24, Loss: 0.9699
Epoch 25, Loss: 0.9571
Epoch 26, Loss: 0.9483
Epoch 27, Loss: 0.9213
Epoch 28, Loss: 0.9169
Epoch 29, Loss: 0.8849
Epoch 30, Loss: 0.8481
Epoch 31, Loss: 0.8386
Epoch 32, Loss: 0.8321
Epoch 33, Loss: 0.8204
Epoch 34, Loss: 0.8049
Epoch 35, Loss: 0.8063
Epoch 36, Loss: 0.7746
Epoch 37, Loss: 0.7712
Epoch 38, Loss: 0.7607
Epoch 39, Loss: 0.7521
Epoch 40, Loss: 0.7415
Epoch 41, Loss: 0.7277
Epoch 42, Loss: 0.7220
Epoch 43, Loss: 0.6986
Epoch 44, Loss: 0.67

In [44]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        out = model(batch)
        preds = out.argmax(dim=1)
        correct += (preds == batch.y).sum().item()
        total += batch.y.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")


Test Accuracy: 54.90%


In [45]:
def edit_distance(s1, s2):
    m, n = len(s1), len(s2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0:
                dp[i][j] = j  
            elif j == 0:
                dp[i][j] = i 
            elif s1[i - 1] == s2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]  
            else:
                dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) 

    return dp[m][n]

model.eval()
total_edit_distance = 0
total_samples = 0

with torch.no_grad():
    for batch in test_loader:
        out = model(batch)
        preds = out.argmax(dim=1)

        pred_ciphertexts = []
        actual_ciphertexts = []

        batch_size = batch.num_graphs
        for i in range(batch_size):
            start = batch.ptr[i]  
            end = batch.ptr[i + 1]  
            pred_chars = [ALPHABET[preds[j].item()] for j in range(start, end)]
            actual_chars = [ALPHABET[batch.y[j].item()] for j in range(start, end)]
            pred_ciphertexts.append("".join(pred_chars))
            actual_ciphertexts.append("".join(actual_chars))

        for pred, actual in zip(pred_ciphertexts, actual_ciphertexts):
            total_edit_distance += edit_distance(pred, actual)
            total_samples += 1

avg_edit_distance = total_edit_distance / total_samples
print(f"Average Edit Distance: {avg_edit_distance:.2f}")


Average Edit Distance: 2.63


In [46]:
def encrypt_with_model(model, plaintext, keyword=KEYWORD):
    model.eval()
    x = []
    for char in plaintext.upper():
        one_hot = [0] * ALPHABET_SIZE
        one_hot[ALPHABET.index(char)] = 1
        x.append(one_hot)
    x = torch.tensor(x, dtype=torch.float).unsqueeze(0) 

    edge_index = []
    for i in range(len(plaintext) - 1):
        edge_index.append([i, i + 1])
        edge_index.append([i + 1, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().unsqueeze(0)

    data = Data(x=x.squeeze(0), edge_index=edge_index.squeeze(0))
    out = model(data)
    preds = out.argmax(dim=1)
    ciphertext = ''.join([ALPHABET[p.item()] for p in preds])
    return ciphertext

In [49]:
plaintext = "HEANDKE"
predicted_ciphertext = encrypt_with_model(model,plaintext)
actual_ciphertext = encrypt_playfair(plaintext,KEYWORD)
print(f"Plaintext: {plaintext}")
print(f"Predicted Ciphertext: {predicted_ciphertext}")
print(f"Actual Ciphertext: {actual_ciphertext}")

Plaintext: HEANDKE
Predicted Ciphertext: BBBBESK
Actual Ciphertext: KBBMEIDZ
