In [1]:
!git clone https://github.com/chaitjo/personalized-dialog.git
!pip install simplet5

Cloning into 'personalized-dialog'...
remote: Enumerating objects: 1823, done.[K
remote: Total 1823 (delta 0), reused 0 (delta 0), pack-reused 1823[K
Receiving objects: 100% (1823/1823), 1.04 GiB | 23.01 MiB/s, done.
Resolving deltas: 100% (1096/1096), done.
Collecting simplet5
  Downloading simplet5-0.1.4.tar.gz (7.3 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 13.1 MB/s 
Collecting transformers==4.16.2
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 41.0 MB/s 
[?25hCollecting pytorch-lightning==1.5.10
  Downloading pytorch_lightning-1.5.10-py3-none-any.whl (527 kB)
[K     |████████████████████████████████| 527 kB 48.8 MB/s 
[?25hCollecting PyYAML>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K 

In [1]:
# let's get a dataset
import pandas as pd
from sklearn.model_selection import train_test_split

path = "https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv"
df = pd.read_csv(path)
df.head()

Unnamed: 0,headlines,text
0,upGrad learner switches to career in ML & Al w...,"Saurav Kant, an alumnus of upGrad and IIIT-B's..."
1,Delhi techie wins free food from Swiggy for on...,Kunal Shah's credit card bill payment platform...
2,New Zealand end Rohit Sharma-led India's 12-ma...,New Zealand defeated India by 8 wickets in the...
3,Aegon life iTerm insurance plan helps customer...,"With Aegon Life iTerm Insurance plan, customer..."
4,"Have known Hirani for yrs, what if MeToo claim...",Speaking about the sexual harassment allegatio...


In [2]:
# simpleT5 expects dataframe to have 2 columns: "source_text" and "target_text"
df = df.rename(columns={"headlines":"target_text", "text":"source_text"})
df = df[['source_text', 'target_text']]

# T5 model expects a task related prefix: since it is a summarization task, we will add a prefix "summarize: "
df['source_text'] = "summarize: " + df['source_text']
df

Unnamed: 0,source_text,target_text
0,"summarize: Saurav Kant, an alumnus of upGrad a...",upGrad learner switches to career in ML & Al w...
1,summarize: Kunal Shah's credit card bill payme...,Delhi techie wins free food from Swiggy for on...
2,summarize: New Zealand defeated India by 8 wic...,New Zealand end Rohit Sharma-led India's 12-ma...
3,summarize: With Aegon Life iTerm Insurance pla...,Aegon life iTerm insurance plan helps customer...
4,summarize: Speaking about the sexual harassmen...,"Have known Hirani for yrs, what if MeToo claim..."
...,...,...
98396,summarize: A CRPF jawan was on Tuesday axed to...,CRPF jawan axed to death by Maoists in Chhatti...
98397,"summarize: 'Uff Yeh', the first song from the ...",First song from Sonakshi Sinha's 'Noor' titled...
98398,"summarize: According to reports, a new version...",'The Matrix' film to get a reboot: Reports
98399,summarize: A new music video shows rapper Snoo...,Snoop Dogg aims gun at clown dressed as Trump ...


In [3]:
train_df, test_df = train_test_split(df, test_size=0.2)
train_df.shape, test_df.shape

((78720, 2), (19681, 2))

In [4]:
from simplet5 import SimpleT5

model = SimpleT5()
model.from_pretrained(model_type="t5", model_name="t5-base")
model.train(train_df=train_df[:5000],
            eval_df=test_df[:100], 
            source_max_token_len=128, 
            target_max_token_len=50, 
            batch_size=8, max_epochs=3, use_gpu=True)

Global seed set to 42


Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /content/lightning_logs

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
# let's load the trained model for inferencing:
model.load_model("t5","outputs/SimpleT5-epoch-2-train-loss-0.9478", use_gpu=True)

text_to_summarize="""summarize: Rahul Gandhi has replied to Goa CM Manohar Parrikar's letter, 
which accused the Congress President of using his "visit to an ailing man for political gains". 
"He's under immense pressure from the PM after our meeting and needs to demonstrate his loyalty by attacking me," 
Gandhi wrote in his letter. Parrikar had clarified he didn't discuss Rafale deal with Rahul.
"""
model.predict(text_to_summarize)

['Rahul responds to Goa CM accusing him of using visit for political gain']