In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/News_Classification

Mounted at /content/drive
/content/drive/MyDrive/News_Classification


In [2]:
!pip install torch
!pip install transformers
!pip install --pre gql[all]
!pip install tornado==4.5.3



In [3]:
import os
import time
import json
import copy
import functools
from tqdm import tqdm

import asyncio
from gql import gql, Client
from gql.transport.aiohttp import AIOHTTPTransport

import torch
from torch import nn 
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from transformers import pipeline
from transformers import AutoTokenizer

In [4]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda:0")

In [5]:
checkpoint_path = os.path.join(os.getcwd(), "checkpoints")
dataset_path = os.path.join(os.getcwd(), "dataset")

# Prepare Data

In [6]:
async def query_data(url, gql_query):
    transport = AIOHTTPTransport(url=url)

    async with Client(
        transport=transport, fetch_schema_from_transport=True,
    ) as session:

        # Execute single query
        query = gql(gql_query)
        result = await session.execute(query)

    return result

In [7]:
query = """
query getArticles {
  Analysis(order_by: {article_id: asc}, where: {type: {_eq: "5"}, Article: {timestamp: {_gte: "2020-04-01", _lte: "2020-06-30"}}}) {
    id
    summary: content(path: "summary[0]")
    Article {
      source: journal
    	time: timestamp
      headline
      url
    }
  }
}
"""

In [8]:
url = ""
query_results = asyncio.run(query_data(url, query))['Analysis']

# Get rid of the parent 'Article
for result in query_results:
    article_info = result['Article']
    result.update(article_info)
    del result['Article']
    
    # Add "tags" for storing later predicted results
    result["tags"] = []

In [9]:
print(f"Number of queried data: {len(query_results)}")
print(f"Example: {json.dumps(query_results[0], indent=4)}")

Number of queried data: 7256
Example: {
    "id": 1342099,
    "summary": "A federal detention center in New York City has no in-house ability to test sick or high-risk inmates for COVID-19.<n>\"MCC New York does not have COVID-19 tests,\" wrote the jail's warden.<n>The warden's letter came in response to an order from a federal judge to test an MCC inmate.",
    "source": "ABC News",
    "time": "2020-04-04T10:12:45",
    "headline": "No COVID-19 tests available for prisoners at center of New York outbreak, court documents show",
    "url": "https://abcnews.go.com/Health/covid-19-tests-prisoners-center-york-outbreak-court/story?id=69969077",
    "tags": []
}


# Prepare Label

In [10]:
# Reference from: https://github.com/yinwenpeng/BenchmarkingZeroShot/blob/master/src/train_yahoo.py

# Top class
top_choice_to_hypothesis = {
    "Containment and Closure Policies": [
        'related with containment and closure policy from governments in the pandemic',
        'talked about a government coronavirus policy for containment and closure',
        'described a government policy about school closing, workspace closing, public event cancellation, restrictions on gatherings, public transport closing, stay at home requirement, restrictions on internel movement, and international travel control in the pandemic.'
    ],
    "Economic Policies": [
        'related with economic policy from governments in the pandemic', 
        'talked about a government coronavirus policy for economy', 
        'described a government policy about income support, debt or contract relief, fiscal measurements, and international support in the pandemic.'
    ],
    "Health System Policies": [
        'related with health system policy from governments in the pandemic', 
        'talked about a government coronavirus policy for health system', 
        'described a government policy about public health compaigns, testing policy, contact tracing, emergency investment in health care, investment in vaccines, facial coverings, and vaccination policy in the pandemic.'
    ],
    "Miscellaneous Policies": [
        'related with miscellaneous policy from governments in the pandemic that do not fit anywhere else',  # combine top & sub
        'talked about a government policy irrelevant to coronavirus', 
        'described a government policy that do not fit anywhere else in the pandemic.'
    ]
}

# Sub class
sub_choice_to_hypothesis = {
    "Containment and Closure Policies": [
        'record closing of schools and universities',                                   
        'record closing of workplaces',                                   
        'record cancellation of public events',                                   
        'record limit on gatherings',                                   
        'record closing of public transport',                                   
        'record order to "shelter-in-place" and otherwise confine to the home',                                   
        'record restriction on internal movement between cities or regions',                                   
        'record restriction on international travel for foreign travellers, not citizens',                                   
        # 'school closing', 
        # 'workspace closing', 
        # 'public event cancellation', 
        # 'restrictions on gatherings', 
        # 'public transport closing', 
        # 'stay at home requirement', 
        # 'restrictions on internel movement', 
        # 'international travel control'
    ],
    "Economic Policies": [
        'record if the government is providing direct cash payments to people who lose their jobs or cannot work',
        'record if the government is freezing financial obligations for households, like stopping loan repayments, preventing services like water from stopping, or banning evictions',
        'announced economic stimulus spending',
        'announced offer of Covid-19 related aid spending to other countries',
        # 'income support', 
        # 'debt or contract relief', 
        # 'fiscal measurements', 
        # 'international support in the pandemic',
    ],
    "Health System Policies": [
        'record presence of public info campaigns',
        'record government policy on who has access to PCR testing instead of antibody test',
        'record government policy on contact tracing after a positive diagnosis',
        'announced short term spending on healthcare system, eg hospitals, masks, etc',
        'announced public spending on Covid-19 vaccine development',
        'record policy on the use of facial coverings outside the home',
        'record policy for vaccine delivery for different groups',
        # 'public health compaigns', 
        # 'testing policy',
        # 'contact tracing',
        # 'emergency investment in health care',
        # 'investment in vaccines' 
        # 'facial coverings',
        # 'vaccination policy in the pandemic',
    ],
}

In [11]:
TOP_HYPOTHESIS_IDX = 0
top_hypothesis = functools.reduce(lambda a, b: a + b, [[hypothesis[TOP_HYPOTHESIS_IDX]] for hypothesis in top_choice_to_hypothesis.values()])
sub_hypothesis = functools.reduce(lambda a, b: a + b, list(sub_choice_to_hypothesis.values()))

In [12]:
candidate_labels = top_hypothesis + sub_hypothesis
print(json.dumps(candidate_labels, indent=4))

[
    "related with containment and closure policy from governments in the pandemic",
    "related with economic policy from governments in the pandemic",
    "related with health system policy from governments in the pandemic",
    "related with miscellaneous policy from governments in the pandemic that do not fit anywhere else",
    "record closing of schools and universities",
    "record closing of workplaces",
    "record cancellation of public events",
    "record limit on gatherings",
    "record closing of public transport",
    "record order to \"shelter-in-place\" and otherwise confine to the home",
    "record restriction on internal movement between cities or regions",
    "record restriction on international travel for foreign travellers, not citizens",
    "record if the government is providing direct cash payments to people who lose their jobs or cannot work",
    "record if the government is freezing financial obligations for households, like stopping loan repayments, p

In [13]:
labels_to_indice = { label: idx for idx, label in enumerate(top_hypothesis) }
top_keys = list(top_choice_to_hypothesis.keys())

for top_idx, top_key in enumerate(top_keys[:3]):
    sub_label = sub_choice_to_hypothesis[top_key]
    labels_to_indice.update(
        {
            label: [top_idx, sub_idx] for sub_idx, label in enumerate(sub_label)
        }
    )

print(json.dumps(labels_to_indice, indent=4))

{
    "related with containment and closure policy from governments in the pandemic": 0,
    "related with economic policy from governments in the pandemic": 1,
    "related with health system policy from governments in the pandemic": 2,
    "related with miscellaneous policy from governments in the pandemic that do not fit anywhere else": 3,
    "record closing of schools and universities": [
        0,
        0
    ],
    "record closing of workplaces": [
        0,
        1
    ],
    "record cancellation of public events": [
        0,
        2
    ],
    "record limit on gatherings": [
        0,
        3
    ],
    "record closing of public transport": [
        0,
        4
    ],
    "record order to \"shelter-in-place\" and otherwise confine to the home": [
        0,
        5
    ],
    "record restriction on internal movement between cities or regions": [
        0,
        6
    ],
    "record restriction on international travel for foreign travellers, not citizens": [

In [14]:
hypothesis_template = "This text {}."

# Prepare Output Format

In [15]:
top_events = [
    "Containment and Closure Policies",
    "Economic Policies",
    "Health System Policies",
    "Misc Policies",
]

sub_events = [
    [
        "School Closing",
        "Workplace Closing",
        "Cancel Public Events",
        "Restrictions on Gatherings",
        "Close Public Transport",
        "Stay at Home Requirements",
        'Restrictions on internel movement', 
        'International travel control'
    ],
    [
        "Income Support",
        "Debt/Contract Relief",
        "Fiscal Measures",
        "International Suport",
    ],
    [
        "Public Info Campaigns",
        "Testing Policy",
        "Contract Tracing",
        "Emergency Investment",
        "Vaccine Investment",
        "Facial Coverings",
        "Vaccination Policy",
    ]
]

output_events = top_events + sub_events

In [16]:
output_template = [
    {
        'event': 'root',
        'children': [
            {
                'event': top_event_name,
                'percentage': 0,
                'children': [
                    {
                        'event': sub_event_name,
                        'percentage': 0
                    }
                    for sub_event_name in sub_events[top_idx]
                ]
            }
            for top_idx, top_event_name in enumerate(top_events) if top_idx < 3
        ],
    }
]
output_template[0]['children'].append(
    {
        'event': "Misc Policies",
        'percentage': 0,
    }
)

In [17]:
print(json.dumps(output_template, indent=4))

[
    {
        "event": "root",
        "children": [
            {
                "event": "Containment and Closure Policies",
                "percentage": 0,
                "children": [
                    {
                        "event": "School Closing",
                        "percentage": 0
                    },
                    {
                        "event": "Workplace Closing",
                        "percentage": 0
                    },
                    {
                        "event": "Cancel Public Events",
                        "percentage": 0
                    },
                    {
                        "event": "Restrictions on Gatherings",
                        "percentage": 0
                    },
                    {
                        "event": "Close Public Transport",
                        "percentage": 0
                    },
                    {
                        "event": "Stay at Home Requirements",
              

In [18]:
def format_output(prediction, labels_to_indice, output_template, top_events, sub_events, temp: float = 0.1):
    result = copy.deepcopy(output_template)
    labels = prediction['labels']
    scores = prediction['scores']

    num_top_event = len(top_events)
    num_sub_event = len(sub_events)

    top_event_scores = torch.zeros(num_top_event)
    sub_event_scores = [torch.zeros(len(sub)) for sub in sub_events]

    for label, score in zip(labels, scores):
        idx = labels_to_indice[label]

        # if top class (top class' indice is integer)
        if isinstance(idx, int):
            top_event_scores[idx] = score
        else:
            top_idx, sub_idx = idx
            sub_event_scores[top_idx][sub_idx] = score

    # Score Normalization
    top_event_scores = (top_event_scores/temp).softmax(dim=-1)

    for top_idx in range(num_top_event):
        top_event_score = top_event_scores[top_idx]

        if top_idx < 3:  # skip Miscellaneous Policies
            sub_event_scores[top_idx] = top_event_score*((sub_event_scores[top_idx]/temp).softmax(dim=-1))

    # Store scores
    for top_idx in range(len(top_events)):
        top_event = result[0]['children'][top_idx]
        top_event['percentage'] = top_event_scores[top_idx].item()

        if top_idx < 3:  # skip Miscellaneous Policies
            for sub_idx in range(len(sub_events[top_idx])):
                top_event['children'][sub_idx]['percentage'] = sub_event_scores[top_idx][sub_idx].item()

    return result

# Load the pretrained Model

In [19]:
model_name = 'checkpoint_9386.bin'
model_path = os.path.join(checkpoint_path, model_name)
classifier = pipeline(
    'zero-shot-classification',
    tokenizer='facebook/bart-large-mnli',
    model=model_path,
    device=device.index,
    framework='pt',
)

# Make Predictions

In [20]:
news_details = []

for result in tqdm(query_results):
    sequences = result['summary']
    
    prediction = classifier(
        sequences, 
        candidate_labels, 
        hypothesis_template, 
        multi_class=True,
    )

    # new_result = copy.deepcopy(result)
    result['tags'] += format_output(prediction, labels_to_indice, output_template, top_events, sub_events)
    news_details.append(result)

100%|██████████| 7256/7256 [43:13<00:00,  2.80it/s]


# Output `query_results` as a json File

In [21]:
output_data = {
    'news_details': news_details
}

In [22]:
output_data['news_source']= {}

# Calculate the number of news from each news source
for result in query_results:
    if result['source'] in output_data['news_source'].keys():
        output_data['news_source'][result['source']] += 1
    else:
        output_data['news_source'][result['source']] = 1

In [23]:
output_path = os.path.join(dataset_path, 'news_info_9386.json')
with open(output_path, 'w') as f:
    json.dump(output_data, f, indent=4)