In [7]:
!nvidia-smi

Mon Jul 10 14:34:54 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    On   | 00000000:4F:00.0 Off |                    0 |
| 30%   49C    P2   149W / 300W |  43748MiB / 45631MiB |     38%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:52:00.0 Off |                    0 |
| 30%   58C    P2   162W / 300W |  41964MiB / 45631MiB |     46%      Default |
|       

In [None]:
%env CUDA_VISIBLE_DEVICES=
%env TOKENIZERS_PARALLELISM=false

In [1]:
from transformers import XGLMTokenizer, XGLMForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset, concatenate_datasets
import evaluate
import numpy as np
from torch import tensor 
import os
from transformers import DataCollatorForLanguageModeling

In [None]:
file_path = "/home/sumire/discourse_context_mt/data/BSD-master/"
data_files = {"train": f"{file_path}train.json", "validation": f"{file_path}dev.json", "test": f"{file_path}test.json"}
dataset = load_dataset("json", data_files=data_files)
dataset

In [None]:
# define train inputs and targets

inputs = ["Translate English into Japanese: "+sent['en_sentence'] for doc in dataset["train"]["conversation"] for sent in doc]
targets = [sent['ja_sentence'] for doc in dataset["train"]["conversation"] for sent in doc]

print (inputs[:5], targets[:5])

In [None]:
model_checkpoint = "facebook/xglm-7.5B"
configuration = BloomConfig()
tokenizer = XGLMTokenizer.from_pretrained(model_checkpoint)
model =  XGLMForCausalLM.from_pretrained(model_checkpoint)

max_length = 128

def preprocess_function(data): # data should be splitted into train / dev / test internally
    inputs = ["Translate English to Japanese: "+sent['en_sentence'] for doc in data["conversation"] for sent in doc][:50]
    targets = [sent['ja_sentence'] for doc in data["conversation"] for sent in doc][:50]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=128, truncation=True
    )
    return model_inputs

In [None]:
tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
batch = data_collator([tokenized_datasets["train"][i] for i in range(1, 3)])
batch.keys()

In [None]:
model = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B")


In [None]:
metric1 = evaluate.load("sacrebleu")
metric2 =  evaluate.load("comet")

In [None]:
def postprocess_text(preds, labels, input_ids):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    input_ids = [[input_id.strip()] for input_id in input_ids]

    return preds, labels, input_ids


In [None]:
def compute_metrics(output_dir, tgt_lang = "ja", tokenizer, eval_preds):
    preds, labels, input_ids = eval_preds # Check the location of input_ids is appropriate
    
    # Preds
    if isinstance(preds, tuple):
        preds = preds[0]
    
    #sep = tokenizer.sep_token_id
    #preds = [ np.array_split(item, np.where(item == sep)[-1])[-1]  for item in preds ]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    print ("decoded_preds: ", decoded_preds[:5])
    #with open('./results/bsd_en-ja/bleu_ja_pred/inference.json', 'w', encoding='utf8') as json_file:
        #json.dump(decoded_preds, json_file, ensure_ascii=False,)
    
    # Store inference
    with open(output_dir+'/translations.txt','w', encoding='utf8') as wf:
         for translation in decoded_preds:
            wf.write(translation.strip()+'\n') 

    #Labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    #labels= [ np.array_split(item, np.where(item == sep)[-1])[-1]  for item in labels ]
    #print ("checking labels_token:")
    #print (labels[:10][:5])
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    print ("decoded_labels:", decoded_labels[:5])

    
    # Input_ids
    # For comet source info
    input_ids = np.where(input_ids != -100, input_ids, tokenizer.pad_token_id)
    #print ("checking input_ids before split:", input_ids[:10][:5])
    #input_ids = [ np.array_split(item, np.where(item == sep)[-1])[-1]  for item in input_ids ]
    #print ("checking input_ids3 after split:")
    #print (input_ids[:10][:5])
    decoded_input_ids = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    

    decoded_preds, decoded_labels, decoded_input_ids = postprocess_text(decoded_preds, decoded_labels, decoded_input_ids)
    
    # bleu
    if tgt_lang == "ja":
        bleu = metric1.compute(predictions=decoded_preds, references=decoded_labels, tokenize='ja-mecab')
    else: 
        bleu = metric1.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": bleu["score"]}

    # comet
    print ("decoded_input_ids:",  [item for decoded_input_id in decoded_input_ids for item in decoded_input_id][:5], "\ndecoded_preds", decoded_preds[:5], "\ndecoded_labels", [item for decoded_label in decoded_labels for item in decoded_label][:5])
    
    comet = metric2.compute(predictions=decoded_preds, references=[item for decoded_label in decoded_labels for item in decoded_label], sources = [item for decoded_input_id in decoded_input_ids for item in decoded_input_id])
    result["comet"] =  np.mean(comet["scores"])
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    print(result)

    # Store the score
    with open(output_dir+'/test_score.txt','w', encoding='utf8') as wf:
        for key, value in result.items():
            wf.write(f"{key}: {value}\n") #ensure_ascii=False

    return result


In [None]:
training_args = TrainingArguments(
    output_dir="./results/playingaround",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    compute_metrics=partial(eval_bleu.compute_metrics, output_dir, tgt_lang, tokenizer),
    data_collator=data_collator,
)

trainer.train()