In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, DataLoader
import string
import random

In [17]:
KEYWORD = "CRYPTO"
ALPHABET = string.ascii_uppercase
ALPHABET_SIZE = len(ALPHABET)
CHAR_TO_IDX = {char: idx for idx, char in enumerate(ALPHABET)}
IDX_TO_CHAR = {idx: char for idx, char in enumerate(ALPHABET)}

In [18]:
def vignere_cipher(plaintext, keyword=KEYWORD):
    keyword = (keyword * (len(plaintext) // len(keyword) + 1))[:len(plaintext)].upper()
    cipher_text = ''
    for p_char, k_char in zip(plaintext.upper(), keyword):
        if p_char in ALPHABET:
            idx = (CHAR_TO_IDX[p_char] + CHAR_TO_IDX[k_char]) % ALPHABET_SIZE
            cipher_text += IDX_TO_CHAR[idx]
        else:
            cipher_text += p_char
    return cipher_text

In [19]:
def generate_data(num_samples=1000, max_length=10):
    data_list = []
    for _ in range(num_samples):
        length = random.randint(1, max_length)
        plaintext = ''.join(random.choices(ALPHABET, k=length))
        ciphertext = vignere_cipher(plaintext, KEYWORD)

        x = []
        for char in plaintext:
            one_hot = [0] * ALPHABET_SIZE
            one_hot[CHAR_TO_IDX[char]] = 1
            x.append(one_hot)
        x = torch.tensor(x, dtype=torch.float)

        edge_index = []
        for i in range(len(plaintext) - 1):
            edge_index.append([i, i + 1])
            edge_index.append([i + 1, i])  
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

        y = torch.tensor([CHAR_TO_IDX[char] for char in ciphertext], dtype=torch.long)

        data = Data(x=x, edge_index=edge_index, y=y, keyword=KEYWORD)
        data_list.append(data)
    return data_list

In [20]:
dataset = generate_data(num_samples=2000, max_length=10)
train_size = int(0.8 * len(dataset))
train_dataset = dataset[:train_size]
test_dataset = dataset[train_size:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)



In [21]:
class CipherGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(CipherGNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.lin(x)
        return x

In [22]:
INPUT_DIM = ALPHABET_SIZE
HIDDEN_DIM = 64
OUTPUT_DIM = ALPHABET_SIZE
EPOCHS = 50
LEARNING_RATE = 0.01

In [23]:
model = CipherGNN(INPUT_DIM, HIDDEN_DIM, OUTPUT_DIM)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [24]:
for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        out = out.view(-1, OUTPUT_DIM)
        y = batch.y.view(-1)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs
    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch}, Loss: {avg_loss:.4f}")

Epoch 1, Loss: 2.9948
Epoch 2, Loss: 2.6062
Epoch 3, Loss: 2.5140
Epoch 4, Loss: 2.4650
Epoch 5, Loss: 2.4174
Epoch 6, Loss: 2.3740
Epoch 7, Loss: 2.3424
Epoch 8, Loss: 2.3189
Epoch 9, Loss: 2.2963
Epoch 10, Loss: 2.2597
Epoch 11, Loss: 2.2355
Epoch 12, Loss: 2.2192
Epoch 13, Loss: 2.1966
Epoch 14, Loss: 2.1787
Epoch 15, Loss: 2.1613
Epoch 16, Loss: 2.1416
Epoch 17, Loss: 2.1273
Epoch 18, Loss: 2.1216
Epoch 19, Loss: 2.0963
Epoch 20, Loss: 2.0895
Epoch 21, Loss: 2.0750
Epoch 22, Loss: 2.0610
Epoch 23, Loss: 2.0534
Epoch 24, Loss: 2.0353
Epoch 25, Loss: 2.0259
Epoch 26, Loss: 2.0076
Epoch 27, Loss: 2.0140
Epoch 28, Loss: 1.9956
Epoch 29, Loss: 1.9942
Epoch 30, Loss: 1.9740
Epoch 31, Loss: 1.9664
Epoch 32, Loss: 1.9550
Epoch 33, Loss: 1.9444
Epoch 34, Loss: 1.9380
Epoch 35, Loss: 1.9404
Epoch 36, Loss: 1.9290
Epoch 37, Loss: 1.9186
Epoch 38, Loss: 1.9036
Epoch 39, Loss: 1.8988
Epoch 40, Loss: 1.8844
Epoch 41, Loss: 1.8830
Epoch 42, Loss: 1.8831
Epoch 43, Loss: 1.8725
Epoch 44, Loss: 1.86

In [25]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        out = model(batch)
        preds = out.argmax(dim=1)
        correct += (preds == batch.y).sum().item()
        total += batch.y.size(0)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")


Test Accuracy: 20.73%


In [26]:
def edit_distance(s1, s2):
    m, n = len(s1), len(s2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0:
                dp[i][j] = j  
            elif j == 0:
                dp[i][j] = i  
            elif s1[i - 1] == s2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]  
            else:
                dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])  

    return dp[m][n]

model.eval()
total_edit_distance = 0
total_samples = 0

with torch.no_grad():
    for batch in test_loader:
        out = model(batch)
        preds = out.argmax(dim=1)

        batch_size = batch.num_graphs
        pred_ciphertexts = []
        actual_ciphertexts = []

        for i in range(batch_size):
            start = batch.ptr[i]  
            end = batch.ptr[i + 1] 
            pred_chars = [IDX_TO_CHAR[preds[j].item()] for j in range(start, end)]
            actual_chars = [IDX_TO_CHAR[batch.y[j].item()] for j in range(start, end)]
            pred_ciphertexts.append("".join(pred_chars))
            actual_ciphertexts.append("".join(actual_chars))

        for pred, actual in zip(pred_ciphertexts, actual_ciphertexts):
            total_edit_distance += edit_distance(pred, actual)
            total_samples += 1

avg_edit_distance = total_edit_distance / total_samples
print(f"Average Edit Distance: {avg_edit_distance:.2f}")


Average Edit Distance: 4.38


In [27]:
def encrypt_with_model(model, plaintext):
    model.eval()
    x = []
    for char in plaintext.upper():
        one_hot = [0] * ALPHABET_SIZE
        if char in CHAR_TO_IDX:
            one_hot[CHAR_TO_IDX[char]] = 1
        x.append(one_hot)
    x = torch.tensor(x, dtype=torch.float).unsqueeze(0)  # Batch size 1

    edge_index = []
    for i in range(len(plaintext) - 1):
        edge_index.append([i, i + 1])
        edge_index.append([i + 1, i])
    if edge_index:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().unsqueeze(0)
    else:
        edge_index = torch.empty((1, 2, 0), dtype=torch.long)

    data = Data(x=x.squeeze(0), edge_index=edge_index.squeeze(0))
    out = model(data)
    preds = out.argmax(dim=1)
    ciphertext = ''.join([IDX_TO_CHAR[p.item()] for p in preds])
    return ciphertext

In [29]:
plaintext = "HEANDKE"
predicted_ciphertext = encrypt_with_model(model,plaintext)
actual_ciphertext = vignere_cipher(plaintext)
print(f"Plaintext: {plaintext}")
print(f"Predicted Ciphertext: {predicted_ciphertext}")
print(f"Actual Ciphertext: {actual_ciphertext}")

Plaintext: HEANDKE
Predicted Ciphertext: JGCBBBS
Actual Ciphertext: JVYCWYG
