In [1]:
import os, sys
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import MBartForConditionalGeneration
from torch import optim

from repo.indobenchmark.toolkit.tokenization_indonlg import IndoNLGTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# # Check that MPS is available
# if not torch.backends.mps.is_available():
#     if not torch.backends.mps.is_built():
#         print("MPS not available because the current PyTorch install was not "
#               "built with MPS enabled.")
#     else:
#         print("MPS not available because the current MacOS version is not 12.3+ "
#               "and/or you do not have an MPS-enabled device on this machine.")

# else:
#     print("all good")
#     device = torch.device('mps')
#     os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "1" # This is tracked as pytorch issue #98222


In [2]:
dataset = load_dataset('maryantocinn/indosum')

dataset['test'].to_pandas().head()

Unnamed: 0,document,id,summary
0,"Jakarta, CNN Indonesia - - Dilansir AFP, seora...",1494135000-wanita-terberat-di-dunia-jalani-fis...,Eman Ahmed Abd El Aty memiliki berat badan men...
1,Menteri Pertahanan Ryamizard Ryacudu menyambut...,1501222980-kemhan-ingin-beli-drone-dari-china-...,Menteri Pertahanan Ryamizard Ryacudu menyambut...
2,"Jakarta, CNN Indonesia - - Meski sudah hampir ...",1475739008-film-mean-girls-akan-dibuat-musikal,Rumah produksi film yang dibintangi Lindsay Lo...
3,"Usai melaksanakan ibadah haji, Eggi Sudjana ak...",1505785500-eggi-sudjana-sumpah-demi-allah-saya...,Eggi Sudjana akhirnya mendatangi kantor Baresk...
4,Banyak cara untuk memberikan pengajaran kepada...,1497394800-kartu-muslim-optimalkan-teknologi-ar,Game permainan Kartu Muslim. Menggunakan basis...


## Load Model

In [3]:
bart_model = MBartForConditionalGeneration.from_pretrained('indobenchmark/indobart-v2')
tokenizer = IndoNLGTokenizer.from_pretrained('indobenchmark/indobart-v2')

model = bart_model
model

MBartForConditionalGeneration(
  (model): MBartModel(
    (shared): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
    (encoder): MBartEncoder(
      (embed_tokens): MBartScaledWordEmbedding(40004, 768, padding_idx=1)
      (embed_positions): MBartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x MBartEncoderLayer(
          (self_attn): MBartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (f

# Prepare Dataset

In [4]:

lr = 1e-4
gamma = 0.9
lower = True
step_size = 1
beam_size = 5
max_norm = 10
early_stop = 5

max_seq_len = 512
grad_accumulate = 1
no_special_token = False
swap_source_target = True
model_type = 'indo-bart'
valid_criterion = 'SacreBLEU'

separator_id = 4
speaker_1_id = 5
speaker_2_id = 6

train_batch_size = 8
valid_batch_size = 8
test_batch_size = 8

source_lang = "[indonesian]"
target_lang = "[indonesian]"


optimizer = optim.Adam(model.parameters(), lr=lr)
src_lid = tokenizer.special_tokens_to_ids[source_lang]
tgt_lid = tokenizer.special_tokens_to_ids[target_lang]

model.config.decoder_start_token_id = tgt_lid

# Make sure cuda is deterministic
# torch.backends.cudnn.deterministic = True

# create directory
model_dir = './saved_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir, exist_ok=True)

# Test model to generate sequences

In [8]:
inputs = ['aku pergi ke toko obat membeli <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt', lang_token = '[indonesian]', decoder_lang_token='[indonesian]')

# bart_input.to(device)
bart_out = model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> aku pergi ke toko obat membeli<mask></s>[indonesian]
<s> aku pergi ke toko obat membeli obat.[indonesian]


In [9]:
inputs = ['kuring ka pasar senen meuli daging <mask>']
bart_input = tokenizer.prepare_input_for_generation(inputs, return_tensors='pt', lang_token = '[sundanese]', decoder_lang_token='[sundanese]')

# bart_input.to(device)
bart_out = bart_model(**bart_input)
print(tokenizer.decode(bart_input['input_ids'][0]))
print(tokenizer.decode(bart_out.logits.topk(1).indices[:,:].squeeze()))

<s> kuring ka pasar senen meuli daging<mask></s>[sundanese]
<s> kuring ka pasar senen meuli daging sapi,[sundanese]
