In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

# === Config ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
VOCAB = ['<PAD>', '<UNK>'] + [
    'party', 'election', 'government', 'youth', 'change', 'freedom', 'nation',
    'leader', 'policy', 'future', 'supporters', 'progressive', 'reform', 'economy',
    'rights', 'development', 'agenda', 'unity', 'growth', 'platform', 'claims',
    'announced', 'speech', 'today', 'social', 'media', 'president', 'prime',
    'minister', 'parliament', 'country', 'people', 'vote', 'public', 'justice',
    'campaign', 'fake', 'truth', 'biased', 'controversial', 'statement', 'agenda',
    'democracy', 'freedom', 'liberty', 'strong', 'support', 'critics', 'rally'
]
VOCAB_SIZE = len(VOCAB)
WORD2IDX = {w: i for i, w in enumerate(VOCAB)}
IDX2WORD = {i: w for w, i in WORD2IDX.items()}
SEQ_LEN = 50
NOISE_DIM = 100
HIDDEN_DIM = 128

# === Tokenization Utilities ===
def tokenize(text):
    return [WORD2IDX.get(w.lower(), WORD2IDX['<UNK>']) for w in text.split()][:SEQ_LEN]

def detokenize(tokens):
    return ' '.join([IDX2WORD.get(tok, '<UNK>') for tok in tokens])

def pad_sequence(seq):
    return seq + [0] * (SEQ_LEN - len(seq))

# === Generator ===
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn = nn.LSTM(NOISE_DIM, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

    def forward(self, noise):
        out, _ = self.rnn(noise)
        logits = self.fc(out)
        return logits

# === Discriminator with Attention ===
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed = nn.Embedding(VOCAB_SIZE, HIDDEN_DIM)
        self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, batch_first=True)
        self.attn = nn.Linear(HIDDEN_DIM, 1)
        self.fc = nn.Linear(HIDDEN_DIM, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        emb = self.embed(x)
        rnn_out, _ = self.rnn(emb)
        weights = torch.softmax(self.attn(rnn_out), dim=1)
        context = torch.sum(weights * rnn_out, dim=1)
        return self.sigmoid(self.fc(context))

# === Real Fake News Examples ===
REAL_TEXTS = [
    "Party A is a progressive party for the nation and good for youth. They believe in future reforms. Their leader spoke at a public rally today. Supporters are excited about the agenda. Critics call it a biased approach.",
    "The new economic policy supports development. Government claims it benefits people. Prime minister announced it at parliament. Social media reacts with mixed opinions. Experts question the plan’s long-term impact.",
    "Leaders of Party B talk about freedom and justice. Campaign speeches emphasize youth empowerment. Support grows across regions. Opposition says it's just a media stunt. Political analysts debate its authenticity.",
    "Public support for new government reforms is rising. Parliament sessions see heated debates. President calls for unity and progress. Fake news spreads confusion among citizens. Experts urge verification of information.",
    "Recent rally promoted a controversial statement. Social media amplifies biased views. Youth are targeted with emotional slogans. Critics warn of misinformation tactics. The campaign faces ethical scrutiny."
]
REAL_TOKENS = [torch.tensor(pad_sequence(tokenize(text)), dtype=torch.long) for text in REAL_TEXTS]

# === Training ===
def train(generator, discriminator, epochs=3000):
    criterion = nn.BCELoss()
    g_opt = optim.Adam(generator.parameters(), lr=0.001)
    d_opt = optim.Adam(discriminator.parameters(), lr=0.001)

    for epoch in range(epochs):
        # === Discriminator ===
        discriminator.zero_grad()
        real_data = torch.stack(random.choices(REAL_TOKENS, k=16)).to(DEVICE)
        real_labels = torch.ones(real_data.size(0), 1).to(DEVICE)

        noise = torch.randn(real_data.size(0), SEQ_LEN, NOISE_DIM).to(DEVICE)
        fake_logits = generator(noise)
        fake_data = torch.argmax(fake_logits, dim=2).detach()
        fake_labels = torch.zeros(real_data.size(0), 1).to(DEVICE)

        d_loss_real = criterion(discriminator(real_data), real_labels)
        d_loss_fake = criterion(discriminator(fake_data), fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_opt.step()

        # === Generator ===
        generator.zero_grad()
        noise = torch.randn(real_data.size(0), SEQ_LEN, NOISE_DIM).to(DEVICE)
        fake_logits = generator(noise)
        fake_data = torch.argmax(fake_logits, dim=2)
        output = discriminator(fake_data)
        g_loss = criterion(output, torch.ones(real_data.size(0), 1).to(DEVICE))
        g_loss.backward()
        g_opt.step()

        # === Output ===
        if epoch % 500 == 0:
            fake_sample = fake_data[0].cpu().tolist()
            fake_text = detokenize(fake_sample)
            lines = [f"{i+1}. {line.strip().capitalize()}" for i, line in enumerate(fake_text.split('.')[:5])]
            print(f"\n[Epoch {epoch}] D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
            print("Generated Fake Article:")
            print('\n'.join(lines))

# === Run Training ===
gen = Generator().to(DEVICE)
disc = Discriminator().to(DEVICE)

train(gen, disc)



[Epoch 0] D_loss: 1.3845, G_loss: 0.7102
Generated Fake Article:
1. Today development platform public policy democracy announced growth growth nation nation reform reform nation election minister vote growth growth vote growth vote reform reform today country country country country country social democracy platform platform reform policy vote vote nation vote people minister public growth growth country social party agenda announced

[Epoch 500] D_loss: 0.0003, G_loss: 9.1002
Generated Fake Article:
1. <unk> democracy liberty public minister vote election minister social controversial social reform liberty supporters country reform policy vote country minister minister country justice vote vote vote vote agenda development minister youth country nation nation agenda controversial public country democracy country minister minister controversial controversial controversial country vote youth liberty vote

[Epoch 1000] D_loss: 0.0001, G_loss: 10.2003
Generated Fake Article:
1. People st