In [None]:
# Run once
#%pip install transformers accelerate datasets

In [None]:
import torch
import json
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from tqdm import tqdm

class ConfigDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        input_text = f"Input: {item['input']}\nOutput: {item['output']}"
        encoding = self.tokenizer(input_text, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt")
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze()
        }

In [None]:
# Read data from JSON file
with open('training_data.json', 'r') as f:
    data = json.load(f)

In [None]:
# Split the data
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

# Initialize tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.config.pad_token_id = model.config.eos_token_id

# Create datasets and dataloaders
train_dataset = ConfigDataset(train_data, tokenizer)
val_dataset = ConfigDataset(val_data, tokenizer)

batch_size = 9
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Training settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)

optimizer = AdamW(model.parameters(), lr=3e-5)
num_epochs = 100  # Increased number of epochs
num_training_steps = num_epochs * len(train_loader)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=num_training_steps)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Average train loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss

            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Average validation loss: {avg_val_loss:.4f}")

In [None]:
def generate_yaml(input_text):
    model.eval()
    input_text = f"Input: {input_text}\nOutput:"
    inputs = tokenizer.encode_plus(input_text, return_tensors='pt', padding=True, truncation=True, max_length=512)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    output = model.generate(
        input_ids, 
        attention_mask=attention_mask,
        max_length=200,
        num_return_sequences=1,
        no_repeat_ngram_size=2, 
        top_k=10,
        top_p=0.95,
        temperature=0.5,
        do_sample=True,
        pad_token_id=model.config.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    yaml_output = generated_text.split("Output:")[1].strip()
    return yaml_output

# Test examples
test_inputs = [
    "Configure the eth63 bond device with the static IPv4 address 232.162.200.174/25",
    "Set the eth81 ethernet device with the static IPv4 address 192.168.1.1/24",
    "Assign the eth52 ethernet device with the IPv4 address 10.0.0.1/8",
]

for test_input in test_inputs:
    print(f"Input: {test_input}")
    print(f"Generated YAML:\n{generate_yaml(test_input)}")
    print()