In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import os
import time
import glob
import pathlib
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import pandas as pd

In [13]:
articles_path = './bbc-news-summary-c/News Articles'
summaries_path = './bbc-news-summary-c/Summaries'
categories_list = ['politics', 'sport', 'tech', 'entertainment', 'business']
# categories_list = ['tech']

In [14]:
def read_files_from_folders(articles_path, summaries_path, categories_list=['tech', 'sport'], encoding="ISO-8859-1"):
    articles = []
    summaries = []
    categories = []
    for category in categories_list:
        article_paths = glob.glob(os.path.join(articles_path, category, '*.txt'), recursive=True)
        summary_paths = glob.glob(os.path.join(summaries_path, category, '*.txt'), recursive=True)

        if len(article_paths) != len(summary_paths):
            print('number of files is not equal')
            return
        for i in range(len(article_paths)):
            categories.append(category)
            with open(article_paths[i], mode='r', encoding=encoding) as file:
                articles.append(file.read())

            with open(summary_paths[i], mode='r', encoding=encoding) as file:
                summaries.append(file.read())
    return articles, summaries, categories

In [15]:
articles, summaries, categories = read_files_from_folders(articles_path, summaries_path, categories_list)
df = pd.DataFrame({'articles': articles, 'summaries': summaries, 'categories': categories})

In [16]:
df = df[['articles', 'summaries']]
df = df.dropna()
train_df, test_df = train_test_split(df, test_size=0.1)

In [17]:
df

Unnamed: 0,articles,summaries
0,labour plans maternity pay rise maternity pay ...,she said her party would boost maternity pay i...
1,watchdog probes e mail deletions the informati...,all e mails are subject to the freedom of info...
2,hewitt decries career sexism plans to extend p...,ms hewitt also announced a new drive to help w...
3,labour chooses manchester the labour party wil...,the labour party will hold its N autumn confer...
4,brown ally rejects budget spree chancellor gor...,but mr balls a prospective labour mp said he w...
...,...,...
2220,trial begins of spain top banker the trial of ...,both executives helped mr botin orchestrate sp...
2221,uk economy ends year with spurt the uk economy...,simon rubinsohn chief economist at gerrard sai...
2222,healthsouth ex boss goes on trial the former h...,several former healthsouth employees have alre...
2223,euro firms miss out on optimism more than N of...,possibly as a result the worry about low cost ...


In [18]:
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize

nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\lumin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [19]:
from nltk.tokenize import RegexpTokenizer

sentences = []

for article in df['articles']:
    sentences.extend(sent_tokenize(article))

for summary in df['summaries']:
    sentences.extend(sent_tokenize(summary))

len(sentences)

38834

In [20]:
train_sents, test_sents = train_test_split(sentences, test_size=0.05)
print(len(train_sents), len(test_sents))

36892 1942


In [21]:
print(train_sents[0: 10])

['the cellery worm does not spread via e mail like many other viruses.', 'however companies that decide to offshore their operations are driven not just by cost considerations.', 'nonetheless japanese firms have been stepping up capital investment and the survey found the pace is quickening.', 'an spokesman for ssl which makes the famous durex brand of condom would not to comment on market speculation .', 'the court is expected to reach a verdict in the case in the autumn.', 'he said intercept evidence was only a small part of the case against the men and some of it could not be used because it could put sources lives at risk.', 'where will this end.', 'last year the academy formed a committee to tighten the rules after the campaigns spilled over into personal attacks between studios.', 'the commission latest move comes just a few months after national telecoms regulators across europe launched a joint investigation which could lead to people being charged less for using their mobile p

In [22]:
with open("bbc.train.txt", 'w') as file:
    for sent in train_sents:
        file.write(sent)
        file.write('\n')

In [23]:
with open("bbc.valid.txt", 'w') as file:
    for sent in test_sents:
        file.write(sent)
        file.write('\n')

In [24]:
import os
import json
import torch
import argparse

from model import SentenceVAE
from utils import to_var, idx2word, interpolate

In [29]:
with open('bbc_full/bbc.vocab.json', 'r') as file:
    vocab_full = json.load(file)
print(len(vocab_full['i2w']))

with open('bbc_cleaned/bbc.vocab.json', 'r') as file:
    vocab_cleaned = json.load(file)
print(len(vocab_cleaned['i2w']))

26946
22376


In [82]:
vocab = vocab_cleaned

w2i, i2w = vocab['w2i'], vocab['i2w']

model = SentenceVAE(
    vocab_size=len(w2i),
    sos_idx=w2i['<sos>'],
    eos_idx=w2i['<eos>'],
    pad_idx=w2i['<pad>'],
    unk_idx=w2i['<unk>'],
    max_sequence_length=50,
    embedding_size=300,
    rnn_type='gru',
    hidden_size=256,
    word_dropout=0,
    embedding_dropout=0.5,
    latent_size=32,
    num_layers=2,
    bidirectional=False
)

checkpoint = "bin/2024-Jun-13-23-32-00/E9.pytorch"

if not os.path.exists(checkpoint):
    raise FileNotFoundError(checkpoint)

model.load_state_dict(torch.load(checkpoint))
print("Model loaded from %s" % checkpoint)

if torch.cuda.is_available():
    model = model.cuda()

bidirectional= False
Model loaded from bin/2024-Jun-13-23-32-00/E9.pytorch


In [83]:
model.eval()

samples, z = model.inference(n=10)
print('----------SAMPLES----------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

z1 = torch.randn([16]).numpy()
z2 = torch.randn([16]).numpy()
z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
samples, _ = model.inference(z=z)
print('-------INTERPOLATION-------')
print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

XXXXX
----------SAMPLES----------
the new book is not the only one of the most important game of the world but it is a great opportunity to be the first time in the us presidential election bdo said the uk was not a relatively strong step in the us . <eos>
the n has been chosen by the withdrawals of the n n and n in the n years with the n consecutive nations in the n years the n has been chosen by the end of the season . <eos>
the lib dems say the lib dems would be a parody of the election and the tories have been asked . <eos>
the company said the organizational was not intended to be a huge factor in the us . <eos>
the n is not a fan for the <unk> . <eos>
it is not a good job . <eos>
the company said the company had absorbed in the us and the us bankruptcy in the us and the us and the us subsidiary in the us . <eos>
the government has been agnostic on the issue of the government and the tories who are not allowed to be completed by the law . <eos>
the prime minister said the governme

RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x16 and 32x512)

In [89]:
import os
import io
import json
import torch
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset
from nltk.tokenize import TweetTokenizer

from utils import OrderedCounter


class MyPTB(Dataset):

    def __init__(self, sents, **kwargs):
        super().__init__()
        self.max_sequence_length = kwargs.get('max_sequence_length', 50)
        self.min_occ = kwargs.get('min_occ', 3)

        self.sents = sents
        self.vocab_file = './bbc_cleaned/bbc.vocab.json'

        self._create_data()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idx = str(idx)

        return {
            'input': torch.asarray(self.data[idx]['input']),
            'target': torch.asarray(self.data[idx]['target']),
            'length': self.data[idx]['length']
        }

    @property
    def vocab_size(self):
        return len(self.w2i)

    @property
    def pad_idx(self):
        return self.w2i['<pad>']

    @property
    def sos_idx(self):
        return self.w2i['<sos>']

    @property
    def eos_idx(self):
        return self.w2i['<eos>']

    @property
    def unk_idx(self):
        return self.w2i['<unk>']

    def get_w2i(self):
        return self.w2i

    def get_i2w(self):
        return self.i2w

    def _load_vocab(self):
        with open(self.vocab_file, 'r') as vocab_file:
            vocab = json.load(vocab_file)

        self.w2i, self.i2w = vocab['w2i'], vocab['i2w']

    def _create_data(self):
        self._load_vocab()

        tokenizer = TweetTokenizer(preserve_case=False)

        data = defaultdict(dict)

        for i, line in enumerate(self.sents):
            words = tokenizer.tokenize(line)

            input = ['<sos>'] + words
            input = input[:self.max_sequence_length]

            target = words[:self.max_sequence_length - 1]
            target = target + ['<eos>']

            assert len(input) == len(target), "%i, %i" % (len(input), len(target))
            length = len(input)

            input.extend(['<pad>'] * (self.max_sequence_length - length))
            target.extend(['<pad>'] * (self.max_sequence_length - length))

            input = [self.w2i.get(w, self.w2i['<unk>']) for w in input]
            target = [self.w2i.get(w, self.w2i['<unk>']) for w in target]

            id = len(data)
            data[id]['input'] = input
            data[id]['target'] = target
            data[id]['length'] = length

        self.data = data

        # with io.open(os.path.join(self.data_dir, self.data_file), 'wb') as data_file:
        #     data = json.dumps(data, ensure_ascii=False)
        #     data_file.write(data.encode('utf8', 'replace'))

        # self._load_data(vocab=False)

In [110]:
# Reconstruction

for article in df['articles'][:2]:
    sents = [s for s in sent_tokenize(article)]
    lens = [len(s) for s in sents]
    ptb = MyPTB(sents)
    
    gen_sents = []

    for i in range(len(ptb.data)):
        input_seq = torch.asarray(ptb.data[i]['input'])
        input_seq = input_seq.unsqueeze(0).cuda()
        length = torch.asarray(ptb.data[i]['length'])
        length = length.unsqueeze(0).cuda()

        generations, z = model.generate(input_seq, length)

        print(sents[i].lower())
        print('--------------------------------------')
        s = idx2word(generations, i2w=i2w, pad_idx=w2i['<pad>'])[0]
        print(s[:-6])
        print('======================================')
        gen_sents.append(s)

labour plans maternity pay rise maternity pay for new mothers is to rise by n as part of new proposals announced by the trade and industry secretary patricia hewitt.
--------------------------------------
the government has been agnostic on the issue of the government and the country and the us and the country and the us and the country and the us .
it would mean paid leave would be increased to nine months by n ms hewitt told gmtv sunday programme.
--------------------------------------
the n has been chosen by the withdrawals of the n n in the n years with the n consecutive consecutive consecutive year in the last three months in the us .
other plans include letting maternity pay be given to fathers and extending rights to parents of older children.
--------------------------------------
the lib dems say they are not going to be a unifying force in the uk independence party .
the tories dismissed the maternity pay plan as desperate while the liberal democrats said it was misdirected.

In [104]:
from rouge import Rouge

rouge = Rouge()

results = []

for article in df['articles']:
    sents = [s for s in sent_tokenize(article)]
    lens = [len(s) for s in sents]
    ptb = MyPTB(sents)
    
    gen_sents = []

    for i in range(len(ptb.data)):
        input_seq = torch.asarray(ptb.data[i]['input'])
        input_seq = input_seq.unsqueeze(0).cuda()
        length = torch.asarray(ptb.data[i]['length'])
        length = length.unsqueeze(0).cuda()

        generations, z = model.generate(input_seq, length)

        # print(sents[i].lower())
        # print('--------------------------------------')
        # print(*idx2word(generations, i2w=i2w, pad_idx=w2i['<pad>']))
        s = idx2word(generations, i2w=i2w, pad_idx=w2i['<pad>'])[0]
        # print(s[:-6])
        # print('======================================')
        gen_sents.append(s)
        
    gen_article = ' '.join(gen_sents)
    rouge_scores = rouge.get_scores(gen_article, article, avg=True)
    
    results.append({
        'article': article,
        'generated_article': gen_article,
        'rouge1_precision': rouge_scores['rouge-1']['p'],
        'rouge1_recall': rouge_scores['rouge-1']['r'],
        'rouge1_fmeasure': rouge_scores['rouge-1']['f'],
        'rouge2_precision': rouge_scores['rouge-2']['p'],
        'rouge2_recall': rouge_scores['rouge-2']['r'],
        'rouge2_fmeasure': rouge_scores['rouge-2']['f'],
        'rougeL_precision': rouge_scores['rouge-l']['p'],
        'rougeL_recall': rouge_scores['rouge-l']['r'],
        'rougeL_fmeasure': rouge_scores['rouge-l']['f']
    })
    
results_df = pd.DataFrame(results)

In [106]:
results_df.describe()

Unnamed: 0,rouge1_precision,rouge1_recall,rouge1_fmeasure,rouge2_precision,rouge2_recall,rouge2_fmeasure,rougeL_precision,rougeL_recall,rougeL_fmeasure
count,2225.0,2225.0,2225.0,2225.0,2225.0,2225.0,2225.0,2225.0,2225.0
mean,0.318286,0.141892,0.194871,0.074304,0.035211,0.047368,0.292946,0.130619,0.179377
std,0.058411,0.028173,0.033835,0.03032,0.013451,0.017872,0.055136,0.026648,0.032161
min,0.145455,0.057692,0.085106,0.0,0.0,0.0,0.1375,0.057692,0.085106
25%,0.278481,0.122112,0.171429,0.053191,0.025974,0.035242,0.255556,0.11194,0.156863
50%,0.311594,0.140351,0.193133,0.071429,0.034091,0.046205,0.2875,0.128492,0.177515
75%,0.352941,0.15942,0.216561,0.091633,0.042918,0.057315,0.326923,0.147287,0.199336
max,0.565217,0.263736,0.341709,0.21875,0.136585,0.168168,0.565217,0.241758,0.311558
