In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, BertForSequenceClassification
from torch import nn, optim
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.preprocessing import LabelEncoder

In [16]:
df = pd.read_csv('../data/data.csv')
df['input'] = df['plain_text'] + " " + df['encrypted_text']  # 平文と暗号文を結合
df['output'] = df['key']  # 出力は鍵

In [17]:
# ラベルエンコーダーの作成
label_encoder = LabelEncoder()
df['key_encoded'] = label_encoder.fit_transform(df['key'])

In [27]:
class EncryptionDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


In [28]:
# モデルの定義
class EncryptionModel(nn.Module):
    def __init__(self, n_classes):
        super(EncryptionModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=False
        )
        output = self.drop(pooled_output)
        return self.out(output)

# トークナイザーとデータセットの準備
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
max_len = 128

train_df, test_df = train_test_split(df, test_size=0.2)
train_texts = (train_df['plain_text'] + "[SEP]" + train_df['encrypted_text']).tolist()
train_labels = train_df['key_encoded'].tolist()
test_texts = (test_df['plain_text'] + "[SEP]" + test_df['encrypted_text']).tolist()
test_labels = test_df['key_encoded'].tolist()

train_dataset = EncryptionDataset(train_texts, train_labels, tokenizer, max_len)
test_dataset = EncryptionDataset(test_texts, test_labels, tokenizer, max_len)
train_loader = DataLoader(train_dataset, batch_size=16)
test_loader = DataLoader(test_dataset, batch_size=16)


In [29]:
# モデルのインスタンス化
model = EncryptionModel(n_classes=len(label_encoder.classes_))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 損失関数とオプティマイザ
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)

# トレーニングループ
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}/{num_epochs} completed.')


Epoch 1/3 completed.
Epoch 2/3 completed.
Epoch 3/3 completed.


In [30]:
# 評価関数
def evaluate(model, data_loader):
    model.eval()
    correct_predictions = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            _, preds = torch.max(outputs, dim=1)
            correct_predictions += torch.sum(preds == labels)

    return correct_predictions.double() / len(data_loader.dataset)

# モデルの評価
test_accuracy = evaluate(model, test_loader)
print(f'Test Accuracy: {test_accuracy.item()}')


Test Accuracy: 1.0


In [32]:
# モデルの状態辞書を保存
torch.save(model.state_dict(), 'encryption_model.pth')

# ラベルエンコーダーも保存（後で予測時に必要）
import joblib
joblib.dump(label_encoder, 'label_encoder.joblib')


['label_encoder.joblib']

In [34]:
# 予測
def predict(text, model, tokenizer, label_encoder, max_len):
    model.eval()
    encoding = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        return_tensors='pt',
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    outputs = model(input_ids, attention_mask)
    _, preds = torch.max(outputs, dim=1)

    return label_encoder.inverse_transform(preds.cpu().numpy())[0]

text = "hello"
encrypted_text = "hfnos"
predicted_key = predict(text + "[SEP]" + encrypted_text, model, tokenizer, label_encoder, max_len)
print(predicted_key)

# モデルの読み込み
model = EncryptionModel(n_classes=len(label_encoder.classes_))
model.load_state_dict(torch.load('encryption_model.pth'))
model.to(device)

# ラベルエンコーダーの読み込み
label_encoder = joblib.load('label_encoder.joblib')

# 予測
predicted_key = predict(text + "[SEP]" + encrypted_text, model, tokenizer, label_encoder, max_len)
print(predicted_key)


bcd0b0b1580b3e23786af6f0dfdc0097
bcd0b0b1580b3e23786af6f0dfdc0097


In [82]:
from Crypto.Cipher import AES
from Crypto.Random import get_random_bytes
import base64

# AES CBCモードで暗号化
plain_text = get_random_bytes(16)  # 16 bytes = 128 bits
key = get_random_bytes(16)
cipher = AES.new(key, AES.MODE_CBC)
encrypted_text = cipher.encrypt(plain_text)
print(plain_text.hex(), encrypted_text.hex())
print("key: ", key.hex())
print("predicted_key: ", predicted_key)
print(predicted_key == base64.b64encode(key).decode("utf-8"))

dbf84bc672834fbccbdd0d3614445824 6bee34d8c9e50acac8620282263ac658
key:  61dc9098e95b8c4511996769c991b572
predicted_key:  6ed42cc64ec6fc29c0d839dad713467b
False
