In [1]:
# Run once
#%pip install transformers accelerate datasets

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from tqdm import tqdm

class ConfigDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = f"Input: {item['input']}\nOutput: {item['output']}"
        encoding = self.tokenizer(input_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt")
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

In [3]:
# Your provided dataset
data = [
    {
        "input": "Configure the eth28 ethernet device with the static IPv4 address 232.162.200.174/25",
        "output": "---\ninterfaces:\n- name: eth28\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 232.162.200.174\n      prefix-length: 25"
    },
    {
        "input": "Set the eth1 ethernet device with the IPv4 address 192.168.1.1/24",
        "output": "---\ninterfaces:\n- name: eth1\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.1.1\n      prefix-length: 24"
    },
    {
        "input": "Configure the eth28 ethernet device with the IPv4 address 232.162.200.174/25",
        "output": "---\ninterfaces:\n- name: eth28\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 232.162.200.174\n      prefix-length: 25"
    },
    {
        "input": "Set the eth1 ethernet device with the static IPv4 address 192.168.1.1/24",
        "output": "---\ninterfaces:\n- name: eth1\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.1.1\n      prefix-length: 24"
    },
    {
        "input": "Assign the eth0 ethernet device with the IPv4 address 10.0.0.1/8",
        "output": "---\ninterfaces:\n- name: eth0\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.0.0.1\n      prefix-length: 8"
    },
    {
        "input": "Configure the eth2 ethernet device with the IPv4 address 172.16.0.1/16",
        "output": "---\ninterfaces:\n- name: eth2\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 172.16.0.1\n      prefix-length: 16"
    },
    {
        "input": "Set the eth3 ethernet device with the IPv4 address 192.168.100.1/24",
        "output": "---\ninterfaces:\n- name: eth3\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.100.1\n      prefix-length: 24"
    },
    {
        "input": "Assign the eth4 ethernet device with the IPv4 address 10.10.10.1/24",
        "output": "---\ninterfaces:\n- name: eth4\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.10.10.1\n      prefix-length: 24"
    },
    {
        "input": "Configure the eth5 ethernet device with the static IPv4 address 172.31.255.255/12",
        "output": "---\ninterfaces:\n- name: eth5\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 172.31.255.255\n      prefix-length: 12"
    },
    {
        "input": "Set the eth6 ethernet device with the IPv4 address 192.0.2.1/24",
        "output": "---\ninterfaces:\n- name: eth6\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.0.2.1\n      prefix-length: 24"
    },
    {
        "input": "Assign the eth7 ethernet device with the IPv4 address 203.0.113.1/24",
        "output": "---\ninterfaces:\n- name: eth7\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 203.0.113.1\n      prefix-length: 24"
    },
    {
        "input": "Configure the eth8 ethernet device with the IPv4 address 198.51.100.1/24",
        "output": "---\ninterfaces:\n- name: eth8\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 198.51.100.1\n      prefix-length: 24"
    },
    {
        "input": "Set up the eth9 ethernet device with the IPv4 address 10.1.1.1/24",
        "output": "---\ninterfaces:\n- name: eth9\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.1.1.1\n      prefix-length: 24"
    },
    {
        "input": "Configure the eth10 ethernet device with the static IPv4 address 192.168.200.1/29",
        "output": "---\ninterfaces:\n- name: eth10\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.200.1\n      prefix-length: 29"
    },
    {
        "input": "Assign the eth11 ethernet device with the IPv4 address 172.20.0.1/20",
        "output": "---\ninterfaces:\n- name: eth11\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 172.20.0.1\n      prefix-length: 20"
    },
    {
        "input": "Set the eth12 ethernet device with the IPv4 address 10.5.5.1/16",
        "output": "---\ninterfaces:\n- name: eth12\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.5.5.1\n      prefix-length: 16"
    },
    {
        "input": "Configure the eth13 ethernet device with the IPv4 address 192.168.50.1/25",
        "output": "---\ninterfaces:\n- name: eth13\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.50.1\n      prefix-length: 25"
    },
    {
        "input": "Assign the eth14 ethernet device with the IPv4 address 172.31.10.1/24",
        "output": "---\ninterfaces:\n- name: eth14\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 172.31.10.1\n      prefix-length: 24"
    },
    {
        "input": "Set the eth15 ethernet device with the IPv4 address 10.255.255.1/8",
        "output": "---\ninterfaces:\n- name: eth15\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.255.255.1\n      prefix-length: 8"
    },
    {
        "input": "Configure the eth16 ethernet device with the IPv4 address 198.51.100.100/24",
        "output": "---\ninterfaces:\n- name: eth16\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 198.51.100.100\n      prefix-length: 24"
    },
    {
        "input": "Assign the eth17 ethernet device with the IPv4 address 192.168.1.100/24",
        "output": "---\ninterfaces:\n- name: eth17\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 192.168.1.100\n      prefix-length: 24"
    },
    {
        "input": "Set the eth18 ethernet device with the static IPv4 address 10.10.1.1/16",
        "output": "---\ninterfaces:\n- name: eth18\n  type: ethernet\n  state: up\n  ipv4:\n    enabled: true\n    dhcp: false\n    address:\n    - ip: 10.10.1.1\n      prefix-length: 16"
    },
]

In [4]:
# Split the data
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# Initialize tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.config.pad_token_id = model.config.eos_token_id

# Create datasets and dataloaders
train_dataset = ConfigDataset(train_data, tokenizer)
val_dataset = ConfigDataset(val_data, tokenizer)

batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Training settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 145  # Increased number of epochs
num_training_steps = num_epochs * len(train_loader)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=num_training_steps)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Average train loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Average validation loss: {avg_val_loss:.4f}")

Epoch 1/145: 100%|██████████| 3/3 [00:00<00:00,  3.62it/s]


Average train loss: 8.0626
Average validation loss: 9.3394


Epoch 2/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 7.9450
Average validation loss: 9.0176


Epoch 3/145: 100%|██████████| 3/3 [00:00<00:00,  4.75it/s]


Average train loss: 7.6422
Average validation loss: 8.4304


Epoch 4/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 6.9616
Average validation loss: 7.5629


Epoch 5/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 6.2125
Average validation loss: 6.4063


Epoch 6/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 5.3486
Average validation loss: 5.0419


Epoch 7/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 4.3564
Average validation loss: 3.6478


Epoch 8/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 3.2041
Average validation loss: 2.3598


Epoch 9/145: 100%|██████████| 3/3 [00:00<00:00,  4.75it/s]


Average train loss: 2.2121
Average validation loss: 1.3623


Epoch 10/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 1.3649
Average validation loss: 0.7984


Epoch 11/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.9087
Average validation loss: 0.6126


Epoch 12/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.7126
Average validation loss: 0.5970


Epoch 13/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.6888
Average validation loss: 0.6283


Epoch 14/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.7369
Average validation loss: 0.6573


Epoch 15/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.7429
Average validation loss: 0.6642


Epoch 16/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.7448
Average validation loss: 0.6476


Epoch 17/145: 100%|██████████| 3/3 [00:00<00:00,  4.74it/s]


Average train loss: 0.7303
Average validation loss: 0.6132


Epoch 18/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.6731
Average validation loss: 0.5735


Epoch 19/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.6555
Average validation loss: 0.5392


Epoch 20/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.6141
Average validation loss: 0.5128


Epoch 21/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.5753
Average validation loss: 0.4930


Epoch 22/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.5588
Average validation loss: 0.4779


Epoch 23/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.5499
Average validation loss: 0.4655


Epoch 24/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.5305
Average validation loss: 0.4543


Epoch 25/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.5228
Average validation loss: 0.4432


Epoch 26/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.5037
Average validation loss: 0.4320


Epoch 27/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4795
Average validation loss: 0.4208


Epoch 28/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4779
Average validation loss: 0.4093


Epoch 29/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4562
Average validation loss: 0.3974


Epoch 30/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.4644
Average validation loss: 0.3854


Epoch 31/145: 100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


Average train loss: 0.4267
Average validation loss: 0.3731


Epoch 32/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4232
Average validation loss: 0.3605


Epoch 33/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4136
Average validation loss: 0.3475


Epoch 34/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.4069
Average validation loss: 0.3343


Epoch 35/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.3728
Average validation loss: 0.3210


Epoch 36/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.3684
Average validation loss: 0.3081


Epoch 37/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.3655
Average validation loss: 0.2955


Epoch 38/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.3304
Average validation loss: 0.2833


Epoch 39/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.3193
Average validation loss: 0.2717


Epoch 40/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.3060
Average validation loss: 0.2607


Epoch 41/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.3015
Average validation loss: 0.2502


Epoch 42/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.2808
Average validation loss: 0.2403


Epoch 43/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.2710
Average validation loss: 0.2307


Epoch 44/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.2679
Average validation loss: 0.2213


Epoch 45/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.2649
Average validation loss: 0.2119


Epoch 46/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.2477
Average validation loss: 0.2027


Epoch 47/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.2501
Average validation loss: 0.1938


Epoch 48/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.2424
Average validation loss: 0.1853


Epoch 49/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.2475
Average validation loss: 0.1772


Epoch 50/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.2208
Average validation loss: 0.1698


Epoch 51/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.2215
Average validation loss: 0.1629


Epoch 52/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.2003
Average validation loss: 0.1571


Epoch 53/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1881
Average validation loss: 0.1511


Epoch 54/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1825
Average validation loss: 0.1448


Epoch 55/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1802
Average validation loss: 0.1386


Epoch 56/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.1730
Average validation loss: 0.1334


Epoch 57/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1584
Average validation loss: 0.1288


Epoch 58/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.2008
Average validation loss: 0.1243


Epoch 59/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1601
Average validation loss: 0.1200


Epoch 60/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1607
Average validation loss: 0.1165


Epoch 61/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1464
Average validation loss: 0.1135


Epoch 62/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.1556
Average validation loss: 0.1106


Epoch 63/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.1329
Average validation loss: 0.1068


Epoch 64/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.1429
Average validation loss: 0.1022


Epoch 65/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.1321
Average validation loss: 0.0982


Epoch 66/145: 100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


Average train loss: 0.1321
Average validation loss: 0.0949


Epoch 67/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1239
Average validation loss: 0.0920


Epoch 68/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.1286
Average validation loss: 0.0898


Epoch 69/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1159
Average validation loss: 0.0873


Epoch 70/145: 100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


Average train loss: 0.1100
Average validation loss: 0.0848


Epoch 71/145: 100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


Average train loss: 0.1027
Average validation loss: 0.0828


Epoch 72/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1046
Average validation loss: 0.0810


Epoch 73/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.1198
Average validation loss: 0.0791


Epoch 74/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.1022
Average validation loss: 0.0772


Epoch 75/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0981
Average validation loss: 0.0758


Epoch 76/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.1006
Average validation loss: 0.0749


Epoch 77/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0957
Average validation loss: 0.0741


Epoch 78/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0886
Average validation loss: 0.0725


Epoch 79/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0891
Average validation loss: 0.0708


Epoch 80/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.0875
Average validation loss: 0.0696


Epoch 81/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0883
Average validation loss: 0.0686


Epoch 82/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.0860
Average validation loss: 0.0678


Epoch 83/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0789
Average validation loss: 0.0673


Epoch 84/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0814
Average validation loss: 0.0675


Epoch 85/145: 100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


Average train loss: 0.0793
Average validation loss: 0.0668


Epoch 86/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0770
Average validation loss: 0.0659


Epoch 87/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0840
Average validation loss: 0.0654


Epoch 88/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0777
Average validation loss: 0.0651


Epoch 89/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0799
Average validation loss: 0.0649


Epoch 90/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.0674
Average validation loss: 0.0643


Epoch 91/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0734
Average validation loss: 0.0639


Epoch 92/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0729
Average validation loss: 0.0633


Epoch 93/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0687
Average validation loss: 0.0629


Epoch 94/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0746
Average validation loss: 0.0629


Epoch 95/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0648
Average validation loss: 0.0634


Epoch 96/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0693
Average validation loss: 0.0637


Epoch 97/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0611
Average validation loss: 0.0641


Epoch 98/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0655
Average validation loss: 0.0638


Epoch 99/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0637
Average validation loss: 0.0631


Epoch 100/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0785
Average validation loss: 0.0621


Epoch 101/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0647
Average validation loss: 0.0610


Epoch 102/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0699
Average validation loss: 0.0601


Epoch 103/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0580
Average validation loss: 0.0597


Epoch 104/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0624
Average validation loss: 0.0595


Epoch 105/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0656
Average validation loss: 0.0594


Epoch 106/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.0684
Average validation loss: 0.0596


Epoch 107/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0594
Average validation loss: 0.0597


Epoch 108/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0675
Average validation loss: 0.0600


Epoch 109/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.0639
Average validation loss: 0.0603


Epoch 110/145: 100%|██████████| 3/3 [00:00<00:00,  4.70it/s]


Average train loss: 0.0673
Average validation loss: 0.0606


Epoch 111/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0620
Average validation loss: 0.0608


Epoch 112/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0676
Average validation loss: 0.0608


Epoch 113/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.0577
Average validation loss: 0.0606


Epoch 114/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0649
Average validation loss: 0.0604


Epoch 115/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0566
Average validation loss: 0.0603


Epoch 116/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0514
Average validation loss: 0.0601


Epoch 117/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0570
Average validation loss: 0.0599


Epoch 118/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0661
Average validation loss: 0.0596


Epoch 119/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0600
Average validation loss: 0.0595


Epoch 120/145: 100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


Average train loss: 0.0692
Average validation loss: 0.0596


Epoch 121/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0582
Average validation loss: 0.0596


Epoch 122/145: 100%|██████████| 3/3 [00:00<00:00,  4.62it/s]


Average train loss: 0.0562
Average validation loss: 0.0596


Epoch 123/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0625
Average validation loss: 0.0596


Epoch 124/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0655
Average validation loss: 0.0595


Epoch 125/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0617
Average validation loss: 0.0592


Epoch 126/145: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


Average train loss: 0.0626
Average validation loss: 0.0591


Epoch 127/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0599
Average validation loss: 0.0591


Epoch 128/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0568
Average validation loss: 0.0591


Epoch 129/145: 100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


Average train loss: 0.0562
Average validation loss: 0.0591


Epoch 130/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0529
Average validation loss: 0.0593


Epoch 131/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0588
Average validation loss: 0.0595


Epoch 132/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0594
Average validation loss: 0.0597


Epoch 133/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0567
Average validation loss: 0.0598


Epoch 134/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0515
Average validation loss: 0.0598


Epoch 135/145: 100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


Average train loss: 0.0580
Average validation loss: 0.0598


Epoch 136/145: 100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


Average train loss: 0.0565
Average validation loss: 0.0598


Epoch 137/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.0672
Average validation loss: 0.0597


Epoch 138/145: 100%|██████████| 3/3 [00:00<00:00,  4.62it/s]


Average train loss: 0.0574
Average validation loss: 0.0595


Epoch 139/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


Average train loss: 0.0625
Average validation loss: 0.0592


Epoch 140/145: 100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


Average train loss: 0.0576
Average validation loss: 0.0590


Epoch 141/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0552
Average validation loss: 0.0589


Epoch 142/145: 100%|██████████| 3/3 [00:00<00:00,  4.67it/s]


Average train loss: 0.0572
Average validation loss: 0.0587


Epoch 143/145: 100%|██████████| 3/3 [00:00<00:00,  4.61it/s]


Average train loss: 0.0553
Average validation loss: 0.0587


Epoch 144/145: 100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


Average train loss: 0.0641
Average validation loss: 0.0586


Epoch 145/145: 100%|██████████| 3/3 [00:00<00:00,  4.65it/s]

Average train loss: 0.0524
Average validation loss: 0.0586





In [6]:
def generate_yaml(input_text):
    model.eval()
    input_text = f"Input: {input_text}\nOutput:"
    inputs = tokenizer.encode_plus(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    output = model.generate(
        input_ids, 
        attention_mask=attention_mask,
        max_length=200,
        num_return_sequences=1,
        no_repeat_ngram_size=2, 
        top_k=50,
        top_p=0.95,
        temperature=0.01,
        do_sample=True,
        pad_token_id=model.config.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    yaml_output = generated_text.split("Output:")[1].strip()
    return yaml_output

# Test examples
test_inputs = [
    "Configure the eth3 ethernet device with the static IPv4 address 232.162.200.174/25",
    "Set the eth1 ethernet device with the static IPv4 address 192.168.1.1/24",
    "Assign the eth2 ethernet device with the IPv4 address 10.0.0.1/8",
]

for test_input in test_inputs:
    print(f"Input: {test_input}")
    print(f"Generated YAML:\n{generate_yaml(test_input)}")
    print()

Input: Configure the eth3 ethernet device with the static IPv4 address 232.162.200.174/25
Generated YAML:
---
interfaces:
- name: eth7
  type: ether
 state: up
 ipv4: false
 dhcp: true
 address:

Input: Set the eth1 ethernet device with the static IPv4 address 192.168.1.1/24
Generated YAML:
---
interfaces:
- name: eth0
  type: ether
 state: up
 ipv4: false
 dhcp: true
 address: 192:168:1:24

Input: Assign the eth2 ethernet device with the IPv4 address 10.0.0.1/8
Generated YAML:
---
interfaces:
- name: eth1
  type: ether
 state: up
 ipv4: false
 dhcp: true
 address: 10
    enabled: True
 prefix-length: 8

