In [3]:
!pip install simpletransformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
! pip install jedi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jedi
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m62.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi
Successfully installed jedi-0.18.2


In [4]:
import warnings

import pandas as pd

import os
from datetime import datetime
import logging

import pandas as pd
from sklearn.model_selection import train_test_split
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

#from utils import load_data, clean_unnecessary_spaces


def load_data(
    file_path, input_text_column, target_text_column, label_column, keep_label=1
):
    df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
    df = df.loc[df[label_column] == keep_label]
    df = df.rename(
        columns={input_text_column: "input_text", target_text_column: "target_text"}
    )
    df = df[["input_text", "target_text"]]
    df["prefix"] = "paraphrase"

    return df


def clean_unnecessary_spaces(out_string):
    if not isinstance(out_string, str):
        warnings.warn(f">>> {out_string} <<< is not a string.")
        out_string = str(out_string)
    out_string = (
        out_string.replace(" .", ".")
        .replace(" ?", "?")
        .replace(" !", "!")
        .replace(" ,", ",")
        .replace(" ' ", "'")
        .replace(" n't", "n't")
        .replace(" 'm", "'m")
        .replace(" 's", "'s")
        .replace(" 've", "'ve")
        .replace(" 're", "'re")
    )
    return out_string


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

# Google Data
# train_df = pd.read_csv("data/train.tsv", sep="\t").astype(str)
# eval_df = pd.read_csv("data/dev.tsv", sep="\t").astype(str)

# train_df = train_df.loc[train_df["label"] == "1"]
# eval_df = eval_df.loc[eval_df["label"] == "1"]

# train_df = train_df.rename(
#     columns={"sentence1": "input_text", "sentence2": "target_text"}
# )
# eval_df = eval_df.rename(
#     columns={"sentence1": "input_text", "sentence2": "target_text"}
# )

# train_df = train_df[["input_text", "target_text"]]
# eval_df = eval_df[["input_text", "target_text"]]

# train_df["prefix"] = "paraphrase"
# eval_df["prefix"] = "paraphrase"

# # MSRP Data
# train_df = pd.concat(
#     [
#         train_df,
#         load_data("data/msr_paraphrase_train.txt", "#1 String", "#2 String", "Quality"),
#     ]
# )
# eval_df = pd.concat(
#     [
#         eval_df,
#         load_data("data/msr_paraphrase_test.txt", "#1 String", "#2 String", "Quality"),
#     ]
# )

# # Quora Data

# # The Quora Dataset is not separated into train/test, so we do it manually the first time.
# df = load_data(
#     "data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate"
# )
# q_train, q_test = train_test_split(df)

# q_train.to_csv("data/quora_train.tsv", sep="\t")
# q_test.to_csv("data/quora_test.tsv", sep="\t")

# # The code block above only needs to be run once.
# # After that, the two lines below are sufficient to load the Quora dataset.

# # q_train = pd.read_csv("data/quora_train.tsv", sep="\t")
# # q_test = pd.read_csv("data/quora_test.tsv", sep="\t")

# train_df = pd.concat([train_df, q_train])
# eval_df = pd.concat([eval_df, q_test])


# MSRP Data
train_df = load_data("data/msr_paraphrase_train.txt", "#1 String", "#2 String", "Quality")
eval_df = load_data("data/msr_paraphrase_test.txt", "#1 String", "#2 String", "Quality")

train_df = train_df[["prefix", "input_text", "target_text"]]
eval_df = eval_df[["prefix", "input_text", "target_text"]]

train_df = train_df.dropna()
eval_df = eval_df.dropna()

train_df["input_text"] = train_df["input_text"].apply(clean_unnecessary_spaces)
train_df["target_text"] = train_df["target_text"].apply(clean_unnecessary_spaces)

eval_df["input_text"] = eval_df["input_text"].apply(clean_unnecessary_spaces)
eval_df["target_text"] = eval_df["target_text"].apply(clean_unnecessary_spaces)

print(train_df)

model_args = Seq2SeqArgs()
model_args.do_sample = True
model_args.eval_batch_size = 64
model_args.evaluate_during_training = True
model_args.evaluate_during_training_steps = 2500
model_args.evaluate_during_training_verbose = True
model_args.fp16 = False
model_args.learning_rate = 5e-5
model_args.max_length = 128
model_args.max_seq_length = 128
model_args.num_beams = 0
model_args.num_return_sequences = 3
model_args.num_train_epochs = 2
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.save_eval_checkpoints = False
model_args.save_steps = -1
model_args.top_k = 50
model_args.top_p = 0.95
model_args.train_batch_size = 8
model_args.use_multiprocessing = False
model_args.wandb_project = "Paraphrasing with BART"


model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=model_args,
)

model.train_model(train_df, eval_data=eval_df, output_dir="data/text_gen")

# to_predict = [
#     prefix + ": " + str(input_text)
#     for prefix, input_text in zip(eval_df["prefix"].tolist(), eval_df["input_text"].tolist())
# ]
# truth = eval_df["target_text"].tolist()

# preds = model.predict(to_predict)

# # Saving the predictions if needed
# os.makedirs("predictions", exist_ok=True)

# with open(f"predictions/predictions_{datetime.now()}.txt", "w") as f:
#     for i, text in enumerate(eval_df["input_text"].tolist()):
#         f.write(str(text) + "\n\n")

#         f.write("Truth:\n")
#         f.write(truth[i] + "\n\n")

#         f.write("Prediction:\n")
#         for pred in preds[i]:
#             f.write(str(pred) + "\n")
#         f.write(
#             "________________________________________________________________________________\n"
#         )





  df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
b'Skipping line 102: expected 5 fields, saw 6\nSkipping line 656: expected 5 fields, saw 6\nSkipping line 867: expected 5 fields, saw 6\nSkipping line 880: expected 5 fields, saw 6\nSkipping line 980: expected 5 fields, saw 6\nSkipping line 1439: expected 5 fields, saw 6\nSkipping line 1473: expected 5 fields, saw 6\nSkipping line 1822: expected 5 fields, saw 6\nSkipping line 1952: expected 5 fields, saw 6\nSkipping line 2009: expected 5 fields, saw 6\nSkipping line 2230: expected 5 fields, saw 6\n'


  df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
b'Skipping line 34: expected 5 fields, saw 6\nSkipping line 121: expected 5 fields, saw 6\nSkipping line 211: expected 5 fields, saw 6\nSkipping line 263: expected 5 fields, saw 6\nSkipping line 345: expected 5 fields, saw 6\nSkipping line 696: expected 5 fields, saw 6\nSkipping line 733: expected 5 fields, saw 6\nSkipping line 847: expected 5 fields, saw 6\n'
A 

          prefix                                         input_text  \
0     paraphrase  Amrozi accused his brother, whom he called "th...   
2     paraphrase  They had published an advertisement on the Int...   
4     paraphrase  The stock rose $2.11, or about 11 percent, to ...   
5     paraphrase  Revenue in the first quarter of the year dropp...   
7     paraphrase  The DVD-CCA then appealed to the state Supreme...   
...          ...                                                ...   
2425  paraphrase  Attackers detonated a second roadside bomb lat...   
2428  paraphrase  Southwest said it recently exercised remaining...   
2430  paraphrase  My judgment is 95 percent of that information ...   
2431  paraphrase  Following the ATP's notification by Hewitt's l...   
2433  paraphrase  Dotson, 21, admitted to FBI agents that he sho...   

                                            target_text  
0     Referring to him as only "the witness", Amrozi...  
2     On June 10, the ship's ow

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

  0%|          | 0/1610 [00:00<?, ?it/s]

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Running Epoch 0 of 2:   0%|          | 0/202 [00:00<?, ?it/s]

  0%|          | 0/633 [00:00<?, ?it/s]

Running Epoch 1 of 2:   0%|          | 0/202 [00:00<?, ?it/s]

  0%|          | 0/633 [00:00<?, ?it/s]

Generating outputs:   0%|          | 0/10 [00:00<?, ?it/s]

TypeError: ignored

In [5]:
#model.save_model("data/text_gen")

In [None]:
#model.save("data/text_gen_1")

In [6]:
import logging

from simpletransformers.seq2seq import Seq2SeqModel


logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

model = Seq2SeqModel(
    encoder_decoder_type="bart", encoder_decoder_name="outputs"
)


while True:
    #original = input("Enter text to paraphrase: ")
    original = "The world's two largest automakers said their U.S. sales declined more than predicted last month as a late summer sales frenzy caused more of an industry backlash than expected."
    to_predict = [original]

    preds = model.predict(to_predict)

    print("---------------------------------------------------------")
    print(original)

    print()
    print("Predictions >>>")
    for pred in preds[0]:
        print(pred)

    print("---------------------------------------------------------")
    print()

Generating outputs:   0%|          | 0/1 [00:00<?, ?it/s]

TypeError: ignored