In [8]:
import os
import argparse
import torch
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [7]:
def choose_from_top_k_top_n(probs, k=50, p=0.8):
    ind = np.argpartition(probs, -k)[-k:]
    top_prob = probs[ind]
    top_prob = {i: top_prob[idx] for idx,i in enumerate(ind)}
    sorted_top_prob = {k: v for k, v in sorted(top_prob.items(), key=lambda item: item[1], reverse=True)}

    t=0
    f=[]
    pr = []
    for k,v in sorted_top_prob.items():
        t+=v
        f.append(k)
        pr.append(v)
        if t>=p:
            break
    top_prob = pr / np.sum(pr)
    token_id = np.random.choice(f, 1, p = top_prob)

    return int(token_id)

def generate(tokenizer, model, sentences, label):
    generated_text = []
    with torch.no_grad():
        for idx in range(sentences):
            finished = False
            cur_ids = torch.tensor(tokenizer.encode(label)).unsqueeze(0).to('cpu')
            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]

                softmax_logits = torch.softmax(logits[0,-1], dim=0)

                if i < 5:
                    n = 10
                else:
                    n = 5

                next_token_id = choose_from_top_k_top_n(softmax_logits.to('cpu').numpy()) #top-k-top-n sampling
                device = torch.device("cpu")
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1)

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    finished = True
                    break

            if finished:	          
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)
                #print (output_text)
                generated_text.append(output_text)
            else:
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)
                #print (output_text)
                generated_text.append(output_text)
    return generated_text

def load_models(model_name):
    """
    Summary:
        Loading the trained model
    """
    print ('Loading Trained GPT-2 Model')
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    model_path = model_name
    model.load_state_dict(torch.load(model_path))
    return tokenizer, model

In [4]:
MODEL_NAME = r'C:\Users\gupta\OneDrive\Desktop\ADT\Projects\Code\hurrican_GPT2.pt'

In [5]:
TOKENIZER, MODEL = load_models(MODEL_NAME)

Loading Trained GPT-2 Model


In [9]:
df = pd.read_csv("hurricane.csv")

In [10]:
class_distribution_df = df.groupby(['label_text']).count().drop(columns=df.columns.tolist()[:len(df.columns.tolist())-2]).reset_index()
class_distribution_df.rename(columns={class_distribution_df.columns[1]:'count'},inplace=True)

In [17]:
class_distribution_df

Unnamed: 0,label_text,count
0,affected_individuals,328
1,infrastructure_and_utility_damage,907
2,injured_or_dead_people,159
3,missing_or_found_people,15
4,rescue_volunteering_or_donation_effort,2625
5,vehicle_damage,50


In [16]:
class_distribution_df[class_distribution_df['label_text']=='affected_individuals']['count']

0    328
Name: count, dtype: int64

In [13]:
class_distribution_df['count'].max()

2625

In [6]:

for dataset_path in ["hurricane.csv","flood.csv","earthquake.csv","wildfire.csv"]:
    df = pd.read_csv(dataset_path)
    class_distribution_df = df.groupby(['label_text']).count().drop(columns=df.columns.tolist()[:len(df.columns.tolist())-2]).reset_index()
    class_distribution_df.rename(columns={class_distribution_df.columns[1]:'count'},inplace=True)
    MAX_NO_OF_SENTENCES = class_distribution_df['count'].max()
    generated_df = pd.DataFrame()
    for LABEL in ['affected_individuals', 'infrastructure_and_utility_damage', 'injured_or_dead_people', 'missing_or_found_people',
                 'rescue_volunteering_or_donation_effort', 'vehicle_damage']:
        SENTENCES = MAX_NO_OF_SENTENCES - class_distribution_df[class_distribution_df['label_text']==LABEL]['count']
        generated_text = generate(TOKENIZER, MODEL, SENTENCES, LABEL)
        generated_text = [text.strip("<|endoftext|>") for text in generated_text]
        labels = [LABEL] * len(generated_text)
        tmp_df = pd.DataFrame(list(zip(generated_text,labels)),columns=['',''])


missing or found people: People wake up to the news of a missing # Harvey # Jax family... more  <|endoftext|>
missing or found people: RT @JL4YTE : Florida man has a 4 feet high 3 foot high rubber raft floated down 1.5 miles to raise money for Harvey victims  <|endoftext|>
missing or found people: RT @WJBF : 5 people found trapped in their van after Harvey  <|endoftext|>
missing or found people: How much debris are we talking about? @JuanCruz8  <|endoftext|>
missing or found people: Florida family faces deportation after Hurricane Irma  <|endoftext|>


['missing or found people: People wake up to the news of a missing # Harvey # Jax family... more  <|endoftext|>',
 'missing or found people: RT @JL4YTE : Florida man has a 4 feet high 3 foot high rubber raft floated down 1.5 miles to raise money for Harvey victims  <|endoftext|>',
 'missing or found people: RT @WJBF : 5 people found trapped in their van after Harvey  <|endoftext|>',
 'missing or found people: How much debris are we talking about? @JuanCruz8  <|endoftext|>',
 'missing or found people: Florida family faces deportation after Hurricane Irma  <|endoftext|>']

In [20]:
generated_text = [text.strip("<|endoftext|>") for text in generated_text]
label = ['x'] * len(generated_text)
label

['x', 'x', 'x', 'x', 'x']

In [23]:
pd.DataFrame(list(zip(generated_text,label)))

Unnamed: 0,0,1
0,missing or found people: People wake up to the...,x
1,missing or found people: RT @JL4YTE : Florida ...,x
2,missing or found people: RT @WJBF : 5 people f...,x
3,missing or found people: How much debris are w...,x
4,missing or found people: Florida family faces ...,x


In [22]:
generate(TOKENIZER, MODEL, SENTENCES, LABEL)

missing or found people: Trump signs travel ban... but does nothing to assist stranded children on the Florida coast  <|endoftext|>
missing or found people: Families of missing family members hold a candlelight vigil in Miami  <|endoftext|>
missing or found people: Photos of the devastation by Irma from Miami - Dade, Florida, on Oct. 27  <|endoftext|>
missing or found people: One More Day: Harvey Still Determining Impact on Families  <|endoftext|>
missing or found people: RT @michael_mullins : In search of help after Hurricane Maria leaves islands reeling  <|endoftext|>
