## Setup

In [None]:
!pip -q install transformers

In [None]:
import os
import pickle
import gc
import json
import glob
import time
import itertools
import pandas as pd
import numpy as np
import math
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F

# Visuallize library
import seaborn as sns
import matplotlib.pyplot as plt

# Hugging Face
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {device}')

In [None]:
data_dir = '/kaggle/input/vims-dataset/ViMs'
original_dir = os.path.join(data_dir, 'original')
summary_dir = os.path.join(data_dir, 'summary')

In [None]:
CFG = {
    'show_examples': True,
    'model_arch': 'VietAI/vit5-large-vietnews-summarization'
}

## Helper function

In [None]:
path1 = '/kaggle/input/vims-dataset/ViMs/original/Cluster_001/original/10.txt'
path2 = '/kaggle/input/vims-dataset/ViMs/summary/Cluster_001/1.gold.txt'
def read_txt(path, article_type, sent=False):
    content = []
    write_file = False
    with open(path) as f:
        for line in f:
            if article_type == "original":
                if line.lower().startswith("content"):
                    write_file = True
            else:
                write_file = True
            if write_file: 
                if line.rstrip():
                    content.append(line.rstrip())
    if sent:
        return content[1:]
    return " ".join(content[1:])
if CFG['show_examples']:
    print(read_txt(path2, article_type="summary", sent=False))

## CSV file

In [None]:
def create_csv(data_dir):
    """
    Input: data_dir
    - dir format: data_dir/original/cluster/original/txt
    Output: csv
    """
    df = {'cluster':[], 'path':[]}
    for cluster in os.listdir(data_dir):
        file_type = data_dir[data_dir.rfind("/")+1:]
        if file_type == "original":
            f_path = os.path.join(data_dir, cluster, file_type)
        else:
            f_path = os.path.join(data_dir, cluster)
        for f in glob.glob(f_path + '/*'):
            df['cluster'].append(cluster)
            df['path'].append(f)

    df = pd.DataFrame(df)
    df = df.groupby('cluster')['path'].apply(list).reset_index()
    return df

In [None]:
original_df = create_csv(original_dir)
original_df.columns = ['cluster', 'original_dir']
if CFG['show_examples']:
    print(original_df.head())

In [None]:
summary_df = create_csv(summary_dir)
summary_df.columns = ['cluster', 'summary_dir']
if CFG['show_examples']:
    print(summary_df.head())

In [None]:
df = original_df.merge(summary_df, how='inner', on='cluster')
if CFG['show_examples']:
    print(len(df))
    print(len(df['cluster'].unique()))
    print(df.head())

## Overview

In [None]:
## Number of articles per cluster
if CFG['show_examples']:
    arr = df['original_dir'].apply(len).values
    labels, counts = np.unique(arr, return_counts=True)
    plt.figure(figsize=(7,3))
    plt.bar(labels, counts, align='center')
    plt.title("Number of articles per cluster")
    plt.gca().set_xticks(labels)
    plt.show()

In [None]:
## Number of summaries per cluster
if CFG['show_examples']:
    arr = df['summary_dir'].apply(len).values
    labels, counts = np.unique(arr, return_counts=True)
    plt.figure(figsize=(7,3))
    plt.bar(labels, counts, align='center')
    plt.title("Number of summaries per cluster")
    plt.gca().set_xticks(labels)
    plt.show()

In [None]:
## Number of sentence per article
cluster_sent_n = []
article_sent_n = []
dir_list = df['original_dir'].values.tolist()
for i in range(len(dir_list)):
    cluster_length = 0
    for j in range(len(dir_list[i])):
        path = dir_list[i][j]
        sentence = read_txt(path, article_type="original", sent=True)
        article_length = len(sentence)
        article_sent_n.append(article_length)
        cluster_length += article_length
    cluster_sent_n.append(cluster_length)
    
if CFG['show_examples']:
    plt.figure(figsize=(7,3))
    sns.histplot(article_sent_n)
    plt.title("Number of sentence per article")
    plt.show()

In [None]:
## Number of sentence per cluster
if CFG['show_examples']:
    print("Min number of sentences: ", np.min(cluster_sent_n))
    plt.figure(figsize=(7,3))
    sns.histplot(cluster_sent_n)
    plt.title("Number of sentence per cluster")
    plt.show()

## Tokenizer

In [None]:
## print some tokens
tokenizer = AutoTokenizer.from_pretrained(CFG['model_arch'])
path = '/kaggle/input/vims-dataset/ViMs/summary/Cluster_001/0.gold.txt'
sentence = read_txt(path, article_type="summary")
text =  "vietnews: " + sentence + " </s>"
ids = tokenizer(text)['input_ids']
if CFG['show_examples']:
    print(tokenizer.convert_ids_to_tokens(ids)[:10])

In [None]:
## Number tokens per article
tokenizer = AutoTokenizer.from_pretrained(CFG['model_arch'])
cluster_token_lens = []
article_token_lens = []
dir_list = df['original_dir'].values.tolist()
for i in range(len(dir_list)):
    cluster_length = 0
    for j in range(len(dir_list[i])):
        path = dir_list[i][j]
        sentence = read_txt(path, article_type="original")
        text =  "vietnews: " + sentence + " </s>"
        article_length = len(tokenizer(text)['input_ids'])
        article_token_lens.append(article_length)
        cluster_length += article_length
    cluster_token_lens.append(cluster_length)
    
if CFG['show_examples']:
    plt.figure(figsize=(7,3))
    sns.histplot(article_token_lens)
    plt.title("Number of tokens per article")
    plt.show()

In [None]:
## Number of tokens per cluster
if CFG['show_examples']:
    plt.figure(figsize=(7,3))
    sns.histplot(cluster_token_lens)
    plt.title("Number of tokens per cluster")
    plt.show()

In [None]:
## Number of tokens per summary
tokenizer = AutoTokenizer.from_pretrained(CFG['model_arch'])
article_token_lens = []
dir_list = df['summary_dir'].values.tolist()
for i in range(len(dir_list)):
    for j in range(len(dir_list[i])):
        path = dir_list[i][j]
        sentence = read_txt(path, article_type="summary")
        text =  "vietnews: " + sentence + " </s>"
        article_token_lens.append(len(tokenizer(text)['input_ids']))

if CFG['show_examples']:
    plt.figure(figsize=(7,3))
    sns.histplot(article_token_lens)
    plt.title("Number of tokens per summary article")
    plt.show()