In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Input, Embedding, LSTM, Dense, Dropout, Concatenate, Dot, Multiply, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Download and preprocess the CNN/Daily Mail dataset
dataset_url = 'http://cs.nyu.edu/~kcho/DMQA/'
train_url = os.path.join(dataset_url, 'dm_train_tokenized.npy')
val_url = os.path.join(dataset_url, 'dm_validation_tokenized.npy')
test_url = os.path.join(dataset_url, 'dm_test_tokenized.npy')

train_data = np.load(train_url, allow_pickle=True)
val_data = np.load(val_url, allow_pickle=True)
test_data = np.load(test_url, allow_pickle=True)

# Define the tokenizer
tokenizer = Tokenizer()
tokenizer.fit_on_texts(train_data[:, 0])
vocab_size = len(tokenizer.word_index) + 1

# Define the maximum sequence length
max_sequence_length = 500

# Pad the sequences
train_data_x = pad_sequences(tokenizer.texts_to_sequences(train_data[:, 0]), maxlen=max_sequence_length, padding='post', truncating='post')
train_data_y = pad_sequences(tokenizer.texts_to_sequences(train_data[:, 1]), maxlen=max_sequence_length, padding='post', truncating='post')

val_data_x = pad_sequences(tokenizer.texts_to_sequences(val_data[:, 0]), maxlen=max_sequence_length, padding='post', truncating='post')
val_data_y = pad_sequences(tokenizer.texts_to_sequences(val_data[:, 1]), maxlen=max_sequence_length, padding='post', truncating='post')

test_data_x = pad_sequences(tokenizer.texts_to_sequences(test_data[:, 0]), maxlen=max_sequence_length, padding='post', truncating='post')
test_data_y = pad_sequences(tokenizer.texts_to_sequences(test_data[:, 1]), maxlen=max_sequence_length, padding='post', truncating='post')

# Define the hyperparameters
embed_dim = 300
hidden_dim = 512
num_layers = 2
batch_size = 32
epochs = 10
learning_rate = 0.0002

# Define the discriminator model
def build_discriminator_model(vocab_size, embed_dim, hidden_dim, max_sequence_length):
    input_x = Input(shape=(max_sequence_length,))
    embedding_x = Embedding(input_dim=vocab_size, output_dim=embed_dim, input_length=max_sequence_length)(input_x)
    lstm_x = LSTM(units=hidden_dim, return_sequences=True)(embedding_x)
    dense_x = Dense(units=1, activation='sigmoid')(lstm_x)
    model = Model(inputs=input_x, outputs=dense_x)
    return model

# Define the generator model
def build_generator_model(vocab_size, embed_dim, hidden_dim, num_layers, max_sequence_length):
    input_z = Input(shape=(max_sequence_length,))
    embedding_z = Embedding(input_dim=vocab_size, output_dim=embed_dim, input_length=max_sequence_length)(input_z)
    lstm_z = LSTM(units=hidden_dim, return_sequences=True)(embedding_z)
    for i in range(num_layers-1):
        lstm_z = LSTM(units=hidden_dim, return_sequences=True)(lstm_z)
    output_z = Dense(units=vocab_size, activation='softmax')(lstm_z)
    model = Model(inputs=input_z, outputs=output_z)
    return model

# Define the feature alignment discriminator model
class FADiscriminator(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(FADiscriminator, self).__init__()

        self.hidden_size = hidden_size
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, 1)
        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.dropout(self.relu(self.linear1(x)))
        x = self.dropout(self.relu(self.linear2(x)))
        x = torch.sigmoid(self.linear3(x))
        return x

# Define the feature alignment discriminator optimizer
fa_optimizer = optim.Adam(fa_discriminator.parameters(), lr=lr)

# Define the feature alignment discriminator loss
fa_criterion = nn.BCELoss()

# Train the feature alignment discriminator
for epoch in range(1, n_epochs + 1):
    epoch_loss = 0
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Generate summaries
        outputs, scores = model(inputs)

        # Compute feature alignment loss
        feature_alignment_loss = criterion(scores, targets)

        # Train the feature alignment discriminator
        fa_optimizer.zero_grad()

        # Positive samples
        pos_inputs = scores.detach().clone()
        pos_targets = torch.ones(pos_inputs.shape[0], 1).to(device)
        pos_outputs = fa_discriminator(pos_inputs)

        # Negative samples
        neg_inputs = targets.detach().clone()
        neg_targets = torch.zeros(neg_inputs.shape[0], 1).to(device)
        neg_outputs = fa_discriminator(neg_inputs)

        # Compute feature alignment discriminator loss
        fa_loss = fa_criterion(pos_outputs, pos_targets) + fa_criterion(neg_outputs, neg_targets)

        # Backward pass and optimization
        fa_loss.backward()
        fa_optimizer.step()

        # Compute reconstruction loss
        reconstruction_loss = criterion(outputs, targets[:, 1:])

        # Compute total loss
        loss = feature_alignment_loss + reconstruction_loss - fa_loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch, n_epochs, epoch_loss / len(train_loader)))

# Test the model
model.eval()
with torch.no_grad():
    for i, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Generate summaries
        outputs, _ = model(inputs)

        # Print input and output summaries
        print('Input summary: ', tokenizer.decode(inputs[0].tolist(), skip_special_tokens=True))
        print('Output summary: ', tokenizer.decode(outputs[0].tolist(), skip_special_tokens=True))
        print('Target summary: ', tokenizer.decode(targets[0].tolist(), skip_special_tokens=True))
        print('\n')