In [2]:
!pip install -Uq transformers rich[jupyter] sentencepiece gdown flask-ngrok

In [56]:
MAX_SOURCE_TEXT_LENGTH = 512
MAX_TARGET_TEXT_LENGTH = 17
NEWS_PER_STORY_PUBLIC = 5
NEWS_PER_STORY_OTHER = 1
BATCH_SIZE = 8
TRAIN_EPOCHS = 3
NUM_BEAMS = 2

OUTPUT_DIR = 'output_dir'

In [57]:
import os
from pathlib import Path
import requests
import numpy as np
import pandas as pd
import torch
import pickle

from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from sklearn.pipeline import Pipeline
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [74]:
!gdown https://drive.google.com/uc?id=1ekYuFbcJnpwMT5CcIxgaOc6hprBen9ij
!gdown https://drive.google.com/uc?id=1a4WdZyQ5zdJ_S7oGdJUMLR-zqW8FwkKx
!tar xvf model.tar.gz

Downloading...
From: https://drive.google.com/uc?id=1ekYuFbcJnpwMT5CcIxgaOc6hprBen9ij
To: /content/model.tar.gz
905MB [00:03, 260MB/s]
Downloading...
From: https://drive.google.com/uc?id=1a4WdZyQ5zdJ_S7oGdJUMLR-zqW8FwkKx
To: /content/tfidf.pickle
5.35MB [00:00, 250MB/s]
--2021-08-21 23:49:34--  https://raw.githubusercontent.com/sevskii111/one-hot-gen/main/frontend/index.html
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1095527 (1.0M) [text/plain]
Saving to: ‘index.html’


2021-08-21 23:49:34 (58.1 MB/s) - ‘index.html’ saved [1095527/1095527]

output_dir/model_files/
output_dir/model_files/config.json
output_dir/model_files/spiece.model
output_dir/model_files/pytorch_model.bin
output_dir/model_files/tokenizer_config.json
output_dir/model_files/special_toke

In [59]:
PATH = Path().absolute() / 'output_dir' / 'model_files'
model = T5ForConditionalGeneration.from_pretrained(PATH)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
tokenizer = T5Tokenizer.from_pretrained(PATH)

def generate(texts, **kwargs):
    #inputs = tokenizer(text, return_tensors='pt')
    results = []
    for i in range(0, len(texts), BATCH_SIZE):
      texts_batch = texts[i:i + BATCH_SIZE]

      source = tokenizer.batch_encode_plus(
            texts_batch,
            max_length=MAX_SOURCE_TEXT_LENGTH,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
      ids = source["input_ids"].squeeze().to(device, dtype = torch.long)
      mask = source["attention_mask"].squeeze().to(device, dtype = torch.long)
      generated_ids = model.generate(
          input_ids = ids,
          attention_mask = mask, 
          max_length=MAX_TARGET_TEXT_LENGTH, 
          num_beams=NUM_BEAMS,
          repetition_penalty=1.0, 
          length_penalty=1.0, 
          early_stopping=True
          )
      preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
      results += preds
      
    return results

In [60]:
!wget https://github.com/sevskii111/one-hot-gen/blob/main/datasets/dataset_valid.csv?raw=true -O dataset_valid.csv -q

In [61]:
def predict_easy(df):
  df = df.copy()
  max_words = 5
  for w in range(1, max_words + 1):
      df[f"{w}_words"] = df['X'].apply(lambda h: ' '.join(h.split(' ')[:w]))

  titles = df['y'].unique()

  result = {
      'title': list(),
      'result': list(),
  }

  c = 1.6
  coeffs = [0.1 * c, 0.15 * c, 0.2 * c, 0.15 * c, 0.1 * c]

  for title in titles:
      samples = df[df['y'] == title]
      sum_samples = len(samples)
      for w in reversed(range(1, max_words + 1)):
          u = len(samples[f'{w}_words'].unique())
          coeff = u / sum_samples / np.log(sum_samples)
          res = list(samples[f'{w}_words'].value_counts().items())[0][0]
          if coeff <= coeffs[w - 1] and res.split(' ')[-1].lower() not in ['на', 'в']:
              result['title'].append(title)
              if res[-1] == ':':
                  res = res[:-1]
              
              result['result'].append(res)
              break
  return pd.DataFrame(result)

In [63]:
def predict_t5(df):
  preds = generate(valid_df["X"].values)
  with open('./tfidf.pickle', 'rb') as handle:
    tfidf = pickle.load(handle)
  predictions = pd.DataFrame()
  predictions["Actual Text"] = df['y']
  predictions["Generated Text"] = preds
  gts = predictions["Actual Text"].unique()

  feature_names = np.array(tfidf.get_feature_names())

  def get_top_tf_idf_words(response, top_n=2):
      sorted_nzs = np.argsort(response.data)[:-(top_n+1):-1]
      return feature_names[response.indices[sorted_nzs]]

  dl_results = []

  for gt in gts:
    curr_preds = predictions[predictions['Actual Text'] == gt]["Generated Text"]
    t_text = tfidf.transform(['. '.join(curr_preds.values)])

    top_words = get_top_tf_idf_words(t_text, 2)
    variants = []
    for pred in curr_preds:
      pred_words = pred.lower().split(' ')
      i = len(set(top_words).intersection(set(pred_words)))
      if len(pred_words) > 1:
        i /= len(pred_words)
      variants.append((i, pred, gt))
    res = sorted(variants, reverse=True)[0]
    dl_results.append((res[1:]))

  return pd.DataFrame(dl_results, columns=["result", "title"])

In [64]:
def predict(df):
  easy_preds = predict_easy(df)
  t5_preds = predict_t5(df)

  result = list()

  for story_id in df["y"].unique():
    easy_pred = easy_preds[easy_preds["title"] == story_id]
    if len(easy_pred) > 0:
      result.append((story_id, easy_pred.iloc[0]["result"]))
    else:
      result.append((story_id, t5_preds[t5_preds["title"] == story_id].iloc[0]["result"]))

  return pd.DataFrame(result, columns=["story_id", "story_name"])

In [130]:
!wget https://raw.githubusercontent.com/sevskii111/one-hot-gen/main/frontend/index.html?y -O index.html
!wget https://raw.githubusercontent.com/sevskii111/one-hot-gen/main/frontend/main.js?y -O main.js

--2021-08-22 01:11:51--  https://raw.githubusercontent.com/sevskii111/one-hot-gen/main/frontend/index.html?y
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9309 (9.1K) [text/plain]
Saving to: ‘index.html’


2021-08-22 01:11:51 (75.0 MB/s) - ‘index.html’ saved [9309/9309]

--2021-08-22 01:11:51--  https://raw.githubusercontent.com/sevskii111/one-hot-gen/main/frontend/main.js?y
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1037372 (1013K) [text/plain]
Saving to: ‘main.js’


2021-08-22 01:11:51 (51.0 MB/s) - ‘main.js’ saved [10373

In [None]:
import sys
from flask import Flask, send_file, request
from flask_ngrok import run_with_ngrok


app = Flask(__name__)
run_with_ngrok(app)   
  

@app.route("/")
def home():
    return send_file('index.html')

@app.route("/main.js")
def main():
    return send_file('main.js')

@app.route("/get_preds", methods=["POST"])
def get_preds():
  df = pd.DataFrame(request.json)
  preds = predict(df)
  return preds.to_json()
    
app.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://5a7e-35-204-100-20.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [22/Aug/2021 01:11:57] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:11:57] "[37mGET /main.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:11:59] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
127.0.0.1 - - [22/Aug/2021 01:12:56] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:12:57] "[37mGET /main.js HTTP/1.1[0m" 200 -
  'stop_words.' % sorted(inconsistent))
127.0.0.1 - - [22/Aug/2021 01:13:18] "[37mPOST /get_preds HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:13:54] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:13:55] "[37mGET /main.js HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:13:57] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:13:58] "[37mGET /main.js HTTP/1.1[0m" 200 -
  'stop_words.' % sorted(inconsistent))
127.0.0.1 - - [22/Aug/2021 01:14:23] "[37mPOST /get_preds HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:15:19] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [22/Aug/2021 01:15: