In [1]:
from google.colab import drive
drive.mount("/content/gdrive")

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
import requests
import torch
import torch.nn.functional as F
import torchtext
import torch.nn as nn
import random
import tarfile

In [3]:
import gc

def free_gpu():
  torch.cuda.empty_cache()
  gc.collect()

In [4]:
from torch import cuda
device = 'cuda'if cuda.is_available() else 'cpu'
!nvidia-smi

Wed Jun 12 03:55:24 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P8               9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
import torch, os
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [6]:
# Constants
# MODELNAME ='iwslt15-en-vi-bilstm.model'
# EPOCH = 110
# BATCHSIZE = 32
# LR = 0.0001
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Download and extract data
def iwslt15(train_test):
    url ='https://github.com/stefan-it/nmt-en-vi/raw/master/data/'
    r = requests.get(url + train_test + '-en-vi.tgz')
    filename = train_test + '-en-vi.tgz'
    with open(filename, 'wb') as f:
        f.write(r.content)
        tarfile.open(filename, "r:gz").extractall("iwslt15")

# Load data
iwslt15('train')
iwslt15('test-2013')

In [7]:
import os

# List all files in the current directory
files = os.listdir()

# Print the list of files
for file in files:
    print(file)

.config
gdrive
iwslt15
tst2013.vi
tst2013.en
train-en-vi.tgz
train.en
train.vi
test-2013-en-vi.tgz
sample_data


In [8]:
import tarfile
import os

# Open the .tgz file
tar = tarfile.open('train-en-vi.tgz', 'r:gz')

# Extract all files into the current directory
tar.extractall()

# Close the .tgz file
tar.close()

In [9]:
import pandas as pd

# Read the 'train.en' file into a DataFrame
df_train_en = pd.read_csv('train.en', delimiter = "\t", header=None)

# Read the 'train.vi' file into a DataFrame
df_train_vi = pd.read_csv('train.vi', delimiter = "\t", header=None)

import pandas as pd
import re

# Rename the columns
df_train_en.columns = ['en']
df_train_vi.columns = ['vi']

# Concatenate the dataframes
df_train = pd.concat([df_train_vi, df_train_en], axis=1)
# Convert 'en' column to string and then create the 'en_token' column
df_train['en'] = df_train['en'].astype(str)
from nltk.tokenize import RegexpTokenizer

# Initialize the tokenizer
tokenizer = RegexpTokenizer(r'\w+')

# Create the 'vi_token' and 'en_token' columns
df_train['vi_token'] = df_train['vi'].apply(lambda x: '_' + '_'.join(tokenizer.tokenize(x)))
df_train['en_token'] = df_train['en'].apply(lambda x: '_' + '_'.join(tokenizer.tokenize(x)))

In [10]:
df_train

Unnamed: 0,vi,en,vi_token,en_token
0,Khoa học đằng sau một tiêu đề về khí hậu,Rachel Pike : The science behind a climate hea...,_Khoa_học_đằng_sau_một_tiêu_đề_về_khí_hậu,_Rachel_Pike_The_science_behind_a_climate_head...
1,"Trong 4 phút , chuyên gia hoá học khí quyển Ra...","In 4 minutes , atmospheric chemist Rachel Pike...",_Trong_4_phút_chuyên_gia_hoá_học_khí_quyển_Rac...,_In_4_minutes_atmospheric_chemist_Rachel_Pike_...
2,Tôi muốn cho các bạn biết về sự to lớn của nhữ...,I &apos;d like to talk to you today about the ...,_Tôi_muốn_cho_các_bạn_biết_về_sự_to_lớn_của_nh...,_I_apos_d_like_to_talk_to_you_today_about_the_...
3,Có những dòng trông như thế này khi bàn về biế...,Headlines that look like this when they have t...,_Có_những_dòng_trông_như_thế_này_khi_bàn_về_bi...,_Headlines_that_look_like_this_when_they_have_...
4,Cả hai đều là một nhánh của cùng một lĩnh vực ...,They are both two branches of the same field o...,_Cả_hai_đều_là_một_nhánh_của_cùng_một_lĩnh_vực...,_They_are_both_two_branches_of_the_same_field_...
...,...,...,...,...
133200,Nó là do con người và có thể ngăn chặn và diệt...,,_Nó_là_do_con_người_và_có_thể_ngăn_chặn_và_diệ...,_nan
133201,Tôi muốn kết luận rằng hành động của hàng ngàn...,,_Tôi_muốn_kết_luận_rằng_hành_động_của_hàng_ngà...,_nan
133202,Rất cảm ơn đã lắng nghe .,,_Rất_cảm_ơn_đã_lắng_nghe,_nan
133203,Paul Pholeros : Làm sao để bớt nghèo khổ ? Hãy...,,_Paul_Pholeros_Làm_sao_để_bớt_nghèo_khổ_Hãy_sử...,_nan


In [11]:
columns_to_drop = ['vi', 'en']
df_train = df_train.drop(columns=columns_to_drop)
df_train = df_train.rename(columns = {
                                        'vi_token': 'vi',
                                        'en_token': 'en'
                                    })

df_train

Unnamed: 0,vi,en
0,_Khoa_học_đằng_sau_một_tiêu_đề_về_khí_hậu,_Rachel_Pike_The_science_behind_a_climate_head...
1,_Trong_4_phút_chuyên_gia_hoá_học_khí_quyển_Rac...,_In_4_minutes_atmospheric_chemist_Rachel_Pike_...
2,_Tôi_muốn_cho_các_bạn_biết_về_sự_to_lớn_của_nh...,_I_apos_d_like_to_talk_to_you_today_about_the_...
3,_Có_những_dòng_trông_như_thế_này_khi_bàn_về_bi...,_Headlines_that_look_like_this_when_they_have_...
4,_Cả_hai_đều_là_một_nhánh_của_cùng_một_lĩnh_vực...,_They_are_both_two_branches_of_the_same_field_...
...,...,...
133200,_Nó_là_do_con_người_và_có_thể_ngăn_chặn_và_diệ...,_nan
133201,_Tôi_muốn_kết_luận_rằng_hành_động_của_hàng_ngà...,_nan
133202,_Rất_cảm_ơn_đã_lắng_nghe,_nan
133203,_Paul_Pholeros_Làm_sao_để_bớt_nghèo_khổ_Hãy_sử...,_nan


In [12]:
import tarfile
import os

# Open the .tgz file
tar = tarfile.open('test-2013-en-vi.tgz', 'r:gz')

# Extract all files into the current directory
tar.extractall()

# Close the .tgz file
tar.close()


In [13]:
import pandas as pd
df_test_en = pd.read_csv('tst2013.en', delimiter = "\t", header=None)

df_test_vi = pd.read_csv('tst2013.vi', delimiter = "\t", header=None)

import pandas as pd
import re

# Rename the columns
df_test_en.columns = ['en']
df_test_vi.columns = ['vi']

# Concatenate the dataframes
df_test = pd.concat([df_test_vi, df_test_en], axis=1)
# Convert 'en' column to string and then create the 'en_token' column
df_test['en'] = df_test['en'].astype(str)
from nltk.tokenize import RegexpTokenizer

# Initialize the tokenizer
tokenizer = RegexpTokenizer(r'\w+')

# Create the 'vi_token' and 'en_token' columns
df_test['vi_token'] = df_test['vi'].apply(lambda x: '_' + '_'.join(tokenizer.tokenize(x)))
df_test['en_token'] = df_test['en'].apply(lambda x: '_' + '_'.join(tokenizer.tokenize(x)))

In [14]:
df_test

Unnamed: 0,vi,en,vi_token,en_token
0,"Khi tôi còn nhỏ , Tôi nghĩ rằng BắcTriều Tiên ...","When I was little , I thought my country was t...",_Khi_tôi_còn_nhỏ_Tôi_nghĩ_rằng_BắcTriều_Tiên_l...,_When_I_was_little_I_thought_my_country_was_th...
1,Tôi đã rất tự hào về đất nước tôi .,And I was very proud .,_Tôi_đã_rất_tự_hào_về_đất_nước_tôi,_And_I_was_very_proud
2,"Ở trường , chúng tôi dành rất nhiều thời gian ...","In school , we spent a lot of time studying th...",_Ở_trường_chúng_tôi_dành_rất_nhiều_thời_gian_đ...,_In_school_we_spent_a_lot_of_time_studying_the...
3,Mặc dù tôi đã từng tự hỏi không biết thế giới ...,Although I often wondered about the outside wo...,_Mặc_dù_tôi_đã_từng_tự_hỏi_không_biết_thế_giới...,_Although_I_often_wondered_about_the_outside_w...
4,"Khi tôi lên 7 , tôi chứng kiến cảnh người ta x...","When I was seven years old , I saw my first pu...",_Khi_tôi_lên_7_tôi_chứng_kiến_cảnh_người_ta_xử...,_When_I_was_seven_years_old_I_saw_my_first_pub...
...,...,...,...,...
1263,"Tôi thực sự tin , nếu ta coi người khác như nh...","I truly believe , if we can see one another as...",_Tôi_thực_sự_tin_nếu_ta_coi_người_khác_như_nhữ...,_I_truly_believe_if_we_can_see_one_another_as_...
1264,Những tấm hình không phải là về bản thân vấnđề...,These images are not of issues . They are of p...,_Những_tấm_hình_không_phải_là_về_bản_thân_vấnđ...,_These_images_are_not_of_issues_They_are_of_pe...
1265,Không có ngày nào mà tôi không nghĩ về những n...,There is not a day that goes by that I don &ap...,_Không_có_ngày_nào_mà_tôi_không_nghĩ_về_những_...,_There_is_not_a_day_that_goes_by_that_I_don_ap...
1266,Tôi hi vọng những tấm hình sẽ đánh thức một ng...,I hope that these images awaken a force in tho...,_Tôi_hi_vọng_những_tấm_hình_sẽ_đánh_thức_một_n...,_I_hope_that_these_images_awaken_a_force_in_th...


In [15]:
columns_to_drop = ['vi', 'en']
df_test = df_test.drop(columns=columns_to_drop)
df_test = df_test.rename(columns = {
                                        'vi_token': 'vi',
                                        'en_token': 'en'
                                    })

df_test

Unnamed: 0,vi,en
0,_Khi_tôi_còn_nhỏ_Tôi_nghĩ_rằng_BắcTriều_Tiên_l...,_When_I_was_little_I_thought_my_country_was_th...
1,_Tôi_đã_rất_tự_hào_về_đất_nước_tôi,_And_I_was_very_proud
2,_Ở_trường_chúng_tôi_dành_rất_nhiều_thời_gian_đ...,_In_school_we_spent_a_lot_of_time_studying_the...
3,_Mặc_dù_tôi_đã_từng_tự_hỏi_không_biết_thế_giới...,_Although_I_often_wondered_about_the_outside_w...
4,_Khi_tôi_lên_7_tôi_chứng_kiến_cảnh_người_ta_xử...,_When_I_was_seven_years_old_I_saw_my_first_pub...
...,...,...
1263,_Tôi_thực_sự_tin_nếu_ta_coi_người_khác_như_nhữ...,_I_truly_believe_if_we_can_see_one_another_as_...
1264,_Những_tấm_hình_không_phải_là_về_bản_thân_vấnđ...,_These_images_are_not_of_issues_They_are_of_pe...
1265,_Không_có_ngày_nào_mà_tôi_không_nghĩ_về_những_...,_There_is_not_a_day_that_goes_by_that_I_don_ap...
1266,_Tôi_hi_vọng_những_tấm_hình_sẽ_đánh_thức_một_n...,_I_hope_that_these_images_awaken_a_force_in_th...


In [16]:
def remove_punc(text):
    return text.replace("_", " ")

In [17]:
df_train["vi"] = df_train["vi"].apply(remove_punc)
df_train["en"] = df_train["en"].apply(remove_punc)
df_train

Unnamed: 0,vi,en
0,Khoa học đằng sau một tiêu đề về khí hậu,Rachel Pike The science behind a climate head...
1,Trong 4 phút chuyên gia hoá học khí quyển Rac...,In 4 minutes atmospheric chemist Rachel Pike ...
2,Tôi muốn cho các bạn biết về sự to lớn của nh...,I apos d like to talk to you today about the ...
3,Có những dòng trông như thế này khi bàn về bi...,Headlines that look like this when they have ...
4,Cả hai đều là một nhánh của cùng một lĩnh vực...,They are both two branches of the same field ...
...,...,...
133200,Nó là do con người và có thể ngăn chặn và diệ...,
133201,Tôi muốn kết luận rằng hành động của hàng ngà...,
133202,Rất cảm ơn đã lắng nghe,
133203,Paul Pholeros Làm sao để bớt nghèo khổ Hãy sử...,


In [21]:
df_test["vi"] = df_test["vi"].apply(remove_punc)
df_test["en"] = df_test["en"].apply(remove_punc)
df_test

Unnamed: 0,vi,en
0,Khi tôi còn nhỏ Tôi nghĩ rằng BắcTriều Tiên l...,When I was little I thought my country was th...
1,Tôi đã rất tự hào về đất nước tôi,And I was very proud
2,Ở trường chúng tôi dành rất nhiều thời gian đ...,In school we spent a lot of time studying the...
3,Mặc dù tôi đã từng tự hỏi không biết thế giới...,Although I often wondered about the outside w...
4,Khi tôi lên 7 tôi chứng kiến cảnh người ta xử...,When I was seven years old I saw my first pub...
...,...,...
1263,Tôi thực sự tin nếu ta coi người khác như nhữ...,I truly believe if we can see one another as ...
1264,Những tấm hình không phải là về bản thân vấnđ...,These images are not of issues They are of pe...
1265,Không có ngày nào mà tôi không nghĩ về những ...,There is not a day that goes by that I don ap...
1266,Tôi hi vọng những tấm hình sẽ đánh thức một n...,I hope that these images awaken a force in th...


In [24]:
df_train['vi']   = df_train['vi'].astype(str)
df_train['en']  = df_train['en'].astype(str)

df_test['vi']   = df_test['vi'].astype(str)
df_test['en']  = df_test['en'].astype(str)

In [25]:
! pip install datasets transformers sacrebleu



In [26]:
from datasets import load_metric

metric = load_metric("sacrebleu")
metric

  metric = load_metric("sacrebleu")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Metric(name: "sacrebleu", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id='references')}, usage: """
Produces BLEU scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions (`list` of `str`): list of translations to score. Each translation should be tokenized into a list of tokens.
    references (`list` of `list` of `str`): A list of lists of references. The contents of the first sub-list are the references for the first prediction, the contents of the second sub-list are for the second prediction, etc. Note that there must be the same number of references for each prediction (i.e. all sub-lists must be of the same length).
    smooth_method (`str`): The smoothing method to use, defaults to `'exp'`. Possible values are:
        - `'none'`: no smoothing
        - `'floor'`: increment zero counts
        - `'add-k'`: increment num/deno

In [27]:
language_codes = {
    'af': 'Afrikaans',
    'af-ZA': 'Afrikaans (South Africa)',
    'ar': 'Arabic',
    'ar-AE': 'Arabic (U.A.E.)',
    'ar-BH': 'Arabic (Bahrain)',
    'ar-DZ': 'Arabic (Algeria)',
    'ar-EG': 'Arabic (Egypt)',
    'ar-IQ': 'Arabic (Iraq)',
    'ar-JO': 'Arabic (Jordan)',
    'ar-KW': 'Arabic (Kuwait)',
    'ar-LB': 'Arabic (Lebanon)',
    'ar-LY': 'Arabic (Libya)',
    'ar-MA': 'Arabic (Morocco)',
    'ar-OM': 'Arabic (Oman)',
    'ar-QA': 'Arabic (Qatar)',
    'ar-SA': 'Arabic (Saudi Arabia)',
    'ar-SY': 'Arabic (Syria)',
    'ar-TN': 'Arabic (Tunisia)',
    'ar-YE': 'Arabic (Yemen)',
    'az': 'Azeri (Latin)',
    'az-AZ': 'Azeri (Latin) (Azerbaijan)',
    'az-AZ': 'Azeri (Cyrillic) (Azerbaijan)',
    'be': 'Belarusian',
    'be-BY': 'Belarusian (Belarus)',
    'bg': 'Bulgarian',
    'bg-BG': 'Bulgarian (Bulgaria)',
    'bs-BA': 'Bosnian (Bosnia and Herzegovina)',
    'ca': 'Catalan',
    'ca-ES': 'Catalan (Spain)',
    'cs': 'Czech',
    'cs-CZ': 'Czech (Czech Republic)',
    'cy': 'Welsh',
    'cy-GB': 'Welsh (United Kingdom)',
    'da': 'Danish',
    'da-DK': 'Danish (Denmark)',
    'de': 'German',
    'de-AT': 'German (Austria)',
    'de-CH': 'German (Switzerland)',
    'de-DE': 'German (Germany)',
    'de-LI': 'German (Liechtenstein)',
    'de-LU': 'German (Luxembourg)',
    'dv': 'Divehi',
    'dv-MV': 'Divehi (Maldives)',
    'el': 'Greek',
    'el-GR': 'Greek (Greece)',
    'en': 'English',
    'en-AU': 'English (Australia)',
    'en-BZ': 'English (Belize)',
    'en-CA': 'English (Canada)',
    'en-CB': 'English (Caribbean)',
    'en-GB': 'English (United Kingdom)',
    'en-IE': 'English (Ireland)',
    'en-JM': 'English (Jamaica)',
    'en-NZ': 'English (New Zealand)',
    'en-PH': 'English (Republic of the Philippines)',
    'en-TT': 'English (Trinidad and Tobago)',
    'en-US': 'English (United States)',
    'en-ZA': 'English (South Africa)',
    'en-ZW': 'English (Zimbabwe)',
    'eo': 'Esperanto',
    'es': 'Spanish',
    'es-AR': 'Spanish (Argentina)',
    'es-BO': 'Spanish (Bolivia)',
    'es-CL': 'Spanish (Chile)',
    'es-CO': 'Spanish (Colombia)',
    'es-CR': 'Spanish (Costa Rica)',
    'es-DO': 'Spanish (Dominican Republic)',
    'es-EC': 'Spanish (Ecuador)',
    'es-ES': 'Spanish (Castilian)',
    'es-ES': 'Spanish (Spain)',
    'es-GT': 'Spanish (Guatemala)',
    'es-HN': 'Spanish (Honduras)',
    'es-MX': 'Spanish (Mexico)',
    'es-NI': 'Spanish (Nicaragua)',
    'es-PA': 'Spanish (Panama)',
    'es-PE': 'Spanish (Peru)',
    'es-PR': 'Spanish (Puerto Rico)',
    'es-PY': 'Spanish (Paraguay)',
    'es-SV': 'Spanish (El Salvador)',
    'es-UY': 'Spanish (Uruguay)',
    'es-VE': 'Spanish (Venezuela)',
    'et': 'Estonian',
    'et-EE': 'Estonian (Estonia)',
    'eu': 'Basque',
    'eu-ES': 'Basque (Spain)',
    'fa': 'Farsi',
    'fa-IR': 'Farsi (Iran)',
    'fi': 'Finnish',
    'fi-FI': 'Finnish (Finland)',
    'fo': 'Faroese',
    'fo-FO': 'Faroese (Faroe Islands)',
    'fr': 'French',
    'fr-BE': 'French (Belgium)',
    'fr-CA': 'French (Canada)',
    'fr-CH': 'French (Switzerland)',
    'fr-FR': 'French (France)',
    'fr-LU': 'French (Luxembourg)',
    'fr-MC': 'French (Principality of Monaco)',
    'gl': 'Galician',
    'gl-ES': 'Galician (Spain)',
    'gu': 'Gujarati',
    'gu-IN': 'Gujarati (India)',
    'he': 'Hebrew',
    'he-IL': 'Hebrew (Israel)',
    'hi': 'Hindi',
    'hi-IN': 'Hindi (India)',
    'hr': 'Croatian',
    'hr-BA': 'Croatian (Bosnia and Herzegovina)',
    'hr-HR': 'Croatian (Croatia)',
    'hu': 'Hungarian',
    'hu-HU': 'Hungarian (Hungary)',
    'hy': 'Armenian',
    'hy-AM': 'Armenian (Armenia)',
    'id': 'Indonesian',
    'id-ID': 'Indonesian (Indonesia)',
    'is': 'Icelandic',
    'is-IS': 'Icelandic (Iceland)',
    'it': 'Italian',
    'it-CH': 'Italian (Switzerland)',
    'it-IT': 'Italian (Italy)',
    'ja': 'Japanese',
    'ja-JP': 'Japanese (Japan)',
    'ka': 'Georgian',
    'ka-GE': 'Georgian (Georgia)',
    'kk': 'Kazakh',
    'kk-KZ': 'Kazakh (Kazakhstan)',
    'kn': 'Kannada',
    'kn-IN': 'Kannada (India)',
    'ko': 'Korean',
    'ko-KR': 'Korean (Korea)',
    'kok': 'Konkani',
    'kok-IN': 'Konkani (India)',
    'ky': 'Kyrgyz',
    'ky-KG': 'Kyrgyz (Kyrgyzstan)',
    'lt': 'Lithuanian',
    'lt-LT': 'Lithuanian (Lithuania)',
    'lv': 'Latvian',
    'lv-LV': 'Latvian (Latvia)',
    'mi': 'Maori',
    'mi-NZ': 'Maori (New Zealand)',
    'mk': 'FYRO Macedonian',
    'mk-MK': 'FYRO Macedonian (Former Yugoslav Republic of Macedonia)',
    'mn': 'Mongolian',
    'mn-MN': 'Mongolian (Mongolia)',
    'mr': 'Marathi',
    'mr-IN': 'Marathi (India)',
    'ms': 'Malay',
    'ms-BN': 'Malay (Brunei Darussalam)',
    'ms-MY': 'Malay (Malaysia)',
    'mt': 'Maltese',
    'mt-MT': 'Maltese (Malta)',
    'nb': 'Norwegian (Bokmål)',
    'nb-NO': 'Norwegian (Bokmål) (Norway)',
    'nl': 'Dutch',
    'nl-BE': 'Dutch (Belgium)',
    'nl-NL': 'Dutch (Netherlands)',
    'nn-NO': 'Norwegian (Nynorsk) (Norway)',
    'ns': 'Northern Sotho',
    'ns-ZA': 'Northern Sotho (South Africa)',
    'pa': 'Punjabi',
    'pa-IN': 'Punjabi (India)',
    'pl': 'Polish',
    'pl-PL': 'Polish (Poland)',
    'ps': 'Pashto',
    'ps-AR': 'Pashto (Afghanistan)',
    'pt': 'Portuguese',
    'pt-BR': 'Portuguese (Brazil)',
    'pt-PT': 'Portuguese (Portugal)',
    'qu': 'Quechua',
    'qu-BO': 'Quechua (Bolivia)',
    'qu-EC': 'Quechua (Ecuador)',
    'qu-PE': 'Quechua (Peru)',
    'ro': 'Romanian',
    'ro-RO': 'Romanian (Romania)',
    'ru': 'Russian',
    'ru-RU': 'Russian (Russia)',
    'sa': 'Sanskrit',
    'sa-IN': 'Sanskrit (India)',
    'se': 'Sami (Northern)',
    'se-FI': 'Sami (Northern) (Finland)',
    'se-FI': 'Sami (Skolt) (Finland)',
    'se-FI': 'Sami (Inari) (Finland)',
    'se-NO': 'Sami (Northern) (Norway)',
    'se-NO': 'Sami (Lule) (Norway)',
    'se-NO': 'Sami (Southern) (Norway)',
    'se-SE': 'Sami (Northern) (Sweden)',
    'se-SE': 'Sami (Lule) (Sweden)',
    'se-SE': 'Sami (Southern) (Sweden)',
    'sk': 'Slovak',
    'sk-SK': 'Slovak (Slovakia)',
    'sl': 'Slovenian',
    'sl-SI': 'Slovenian (Slovenia)',
    'sq': 'Albanian',
    'sq-AL': 'Albanian (Albania)',
    'sr-BA': 'Serbian (Latin) (Bosnia and Herzegovina)',
    'sr-BA': 'Serbian (Cyrillic) (Bosnia and Herzegovina)',
    'sr-SP': 'Serbian (Latin) (Serbia and Montenegro)',
    'sr-SP': 'Serbian (Cyrillic) (Serbia and Montenegro)',
    'sv': 'Swedish',
    'sv-FI': 'Swedish (Finland)',
    'sv-SE': 'Swedish (Sweden)',
    'sw': 'Swahili',
    'sw-KE': 'Swahili (Kenya)',
    'syr': 'Syriac',
    'syr-SY': 'Syriac (Syria)',
    'ta': 'Tamil',
    'ta-IN': 'Tamil (India)',
    'te': 'Telugu',
    'te-IN': 'Telugu (India)',
    'th': 'Thai',
    'th-TH': 'Thai (Thailand)',
    'tl': 'Tagalog',
    'tl-PH': 'Tagalog (Philippines)',
    'tn': 'Tswana',
    'tn-ZA': 'Tswana (South Africa)',
    'tr': 'Turkish',
    'tr-TR': 'Turkish (Turkey)',
    'tt': 'Tatar',
    'tt-RU': 'Tatar (Russia)',
    'ts': 'Tsonga',
    'uk': 'Ukrainian',
    'uk-UA': 'Ukrainian (Ukraine)',
    'ur': 'Urdu',
    'ur-PK': 'Urdu (Islamic Republic of Pakistan)',
    'uz': 'Uzbek (Latin)',
    'uz-UZ': 'Uzbek (Latin) (Uzbekistan)',
    'uz-UZ': 'Uzbek (Cyrillic) (Uzbekistan)',
    'vi': 'Vietnamese',
    'vi-VN': 'Vietnamese (Viet Nam)',
    'xh': 'Xhosa',
    'xh-ZA': 'Xhosa (South Africa)',
    'zh': 'Chinese',
    'zh-CN': 'Chinese (S)',
    'zh-HK': 'Chinese (Hong Kong)',
    'zh-MO': 'Chinese (Macau)',
    'zh-SG': 'Chinese (Singapore)',
    'zh-TW': 'Chinese (T)',
    'zu': 'Zulu',
    'zu-ZA': 'Zulu (South Africa)'
}

In [28]:
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer


src_lang = "vi"
tgt_lang = "en"
max_length = 128

model_name = "facebook/mbart-large-50-many-to-many-mmt"

def get_model(model_name):
  tokenizer = AutoTokenizer.from_pretrained(model_name)
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

  model_type = model.config.model_type


  if "mbart" in model_type:
      langs = "ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN"
      langs = langs.split(",")

      for lang in langs:
          if src_lang in lang:
              tokenizer.src_lang = lang

          if tgt_lang in lang:
              tokenizer.tgt_lang = lang
  elif "marian" in model_type:
    langs = "vie,zho"
    langs = langs.split(",")

    for lang in langs:
        if src_lang in lang:
            tokenizer.source_lang = lang

        if tgt_lang in lang:
            tokenizer.target_lang = lang

    if "mul" in model_name:
        prefix = ">>" + tokenizer.target_lang + "<< "

  if "t5" in model_type:
    prefix = f"translate {language_codes.get(src_lang, src_lang)} to {language_codes.get(tgt_lang, tgt_lang)}: "
  else:
    prefix = ""

  print(f"MODE TYPE\t:\t{model_type}")

  if "mbart" in model_type:
      print(f"SRC_LANG \t:\t{tokenizer.src_lang}")
      print(f"TGT_LANG \t:\t{tokenizer.tgt_lang}")
  elif "marian" in model_type:
      print(f"SRC_LANG \t:\t{tokenizer.source_lang}")
      print(f"TGT_LANG \t:\t{tokenizer.target_lang}")

  if prefix:
      print(f"PREFIX   \t:\t{prefix[:-3]}")
  else:
      print(f"PREFIX   \t:\tNone")

  return model, tokenizer, model_type, prefix

In [29]:
def preprocess_function(tokenizer, prefix, src, tgt):
    inputs = prefix + src
    targets = tgt
    return tokenizer(text = inputs, text_target = targets, max_length=max_length, truncation=True)

In [30]:
!pip install accelerate -U
!pip install sentencepiece



In [31]:

from torch.utils.data import Dataset, DataLoader

class CustomDataset(Dataset):
    def __init__ (self, tokenizer, prefix, X_data, y_data):
        self.tokenizer = tokenizer
        self.X_data = X_data
        self.y_data = y_data
        self.prefix = prefix
    def __len__ (self):
        return len(self.X_data)

    def __getitem__(self, idx):
        inputs = self.prefix + self.X_data[idx]
        targets = self.y_data[idx]

        inputs = self.tokenizer(inputs, max_length=max_length, truncation=True, return_tensors="pt", padding="max_length",)

        labels = self.tokenizer(text_target = targets, max_length=max_length, truncation=True, return_tensors="pt", padding="max_length",)

        return {
            'input_ids' : inputs.input_ids.squeeze(),
            'attention_mask': inputs.attention_mask.squeeze(),
            'labels': labels.input_ids.squeeze(),
        }


In [32]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}

    return result

In [33]:
!pip install xformers
!pip install optimum
!pip install bitsandbytes

Collecting xformers
  Using cached xformers-0.0.26.post1-cp310-cp310-manylinux2014_x86_64.whl (222.7 MB)
Installing collected packages: xformers
Successfully installed xformers-0.0.26.post1
Collecting optimum
  Using cached optimum-1.20.0-py3-none-any.whl (418 kB)
Collecting coloredlogs (from optimum)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m701.6 kB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->optimum)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, optimum
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 optimum-1.20.0
Collecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━

In [34]:
!pip install sentencepiece



In [35]:
free_gpu()

from optimum.bettertransformer import BetterTransformer

model, tokenizer, model_type, prefix = get_model(model_name)
# model = BetterTransformer.transform(model)

from datetime import datetime

now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H-%M-%S")

max_train = -1
if max_train > 0:
    train_dataset = CustomDataset(tokenizer, prefix, df_train[src_lang][:max_train], df_train[tgt_lang][:max_train])
else:
    train_dataset = CustomDataset(tokenizer, prefix, df_train[src_lang], df_train[tgt_lang])

test_dataset  = CustomDataset(tokenizer, prefix, df_test[src_lang], df_test[tgt_lang])

epochs = 1
learning_rate = 1e-4
weight_decay = 0.01
logging_steps = 100
batch_size = 8

model_name = model_name.replace("/", "_")
training_args = Seq2SeqTrainingArguments(
    output_dir = f'/content/gdrive/MyDrive/Machine_Translation/models/{model_type}',
    overwrite_output_dir = True,

    learning_rate = learning_rate,
    num_train_epochs = epochs,
    weight_decay = weight_decay,
    # optim = "adamw_torch",

    # evaluation_strategy= 'epoch',
    save_strategy = 'no',
    # save_steps = logging_steps * 50,
    # eval_steps = logging_steps * 10,

    # load_best_model_at_end = True,
    # auto_find_batch_size = True,
    per_device_train_batch_size = batch_size,


    fp16=True,
    # bf16=True,
    # tf32=True,
    optim="adafactor",
    # optim="adamw_bnb",
    # fp16_full_eval=True,
    # jit_mode_eval=True,
    gradient_accumulation_steps = 64,
    gradient_checkpointing=True,

    logging_strategy="steps",
    logging_steps = logging_steps,

    seed = seed,

    run_name = model_type + f"_{model_name}_{dt_string}"
)

trainer = Seq2SeqTrainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    # eval_dataset = test_dataset,
    tokenizer = tokenizer,
    compute_metrics = compute_metrics,
    data_collator = lambda data : {
        'input_ids' : torch.stack([item['input_ids'] for item in data]),
        'attention_mask' : torch.stack([item['attention_mask'] for item in data]),
        'labels' : torch.stack([item['labels'] for item in data]),
    }
)

trainer.train()

test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

free_gpu()

model.eval()

pred = []
label = []

with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang])

        labels = batch['labels'].to(device)


        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)


        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred += decoded_preds
        label += decoded_labels



result = metric.compute(predictions=pred, references=label)
result = {"bleu": result["score"]}
print(result)


# model.train()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

MODE TYPE	:	mbart
SRC_LANG 	:	vi_VN
TGT_LANG 	:	en_XX
PREFIX   	:	None




Step,Training Loss
100,1.6764
200,0.5844


{'bleu': 2.4478104013169273}


In [36]:
model.eval()

pred = []
label = []

with torch.no_grad():
    for batch in test_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=max_length, forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang])

        labels = batch['labels'].to(device)


        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)


        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        pred += decoded_preds
        label += decoded_labels



result = metric.compute(predictions=pred, references=label)
result = {"bleu": result["score"]}
print(result)

{'bleu': 2.4478104013169273}
