In [5]:
# Imports
import csv

import stanza
import os
import nltk
from nltk.util import ngrams
from nltk.corpus import stopwords
from string import punctuation
import re 
from collections import defaultdict, Counter

# Download necessary resources
nltk.download('punkt')
nltk.download('stopwords')
stanza.download('zh') 
nlp = stanza.Pipeline('zh', processors='tokenize')

# Get the set of stopwords and punctuation
stop_words = set(stopwords.words('english')) 
stop_words.update(
    {'cent', 'href=', 'http', 'says', 'told', 'year', 'ago', 'yesterday', 'since', 'last', 'past', 'next',
     'said', 'almost', 'within', 'would', 'nearly', 'years', 'months', 'according', 'compared', 'go', 'also', 
     "n't"})  
punctuation_set = set(punctuation)
punctuation_set.update({"’", "’", '”', "''", "“", "'s", '--', 'b', '/b', '/strong', '–', '—'})

[nltk_data] Downloading package punkt to /Users/vnnsnnt/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/vnnsnnt/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.9.0.json: 392kB [00:00, 51.1MB/s]                    
2024-11-26 15:55:55 INFO: Downloaded file to /Users/vnnsnnt/stanza_resources/resources.json
2024-11-26 15:55:55 INFO: "zh" is an alias for "zh-hans"
2024-11-26 15:55:55 INFO: Downloading default packages for language: zh-hans (Simplified_Chinese) ...
2024-11-26 15:55:56 INFO: File exists: /Users/vnnsnnt/stanza_resources/zh-hans/default.zip
2024-11-26 15:55:58 INFO: Finished downloading models and saved to /Users/vnnsnnt/stanza_resources
2024-11-26 15:55:58 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with downlo

In [2]:
# Data Structures
class ParallelSentence: 
    def __init__(self, en, zh):
        self.en = en
        self.zh = zh
        
class AnchorWord:
    def __init__(self, en, zh):
        self.en = en
        self.zh = zh
        
class ParallelCorpus: 
    def __init__(self):
        self.parallel_sentences = []
        self.multi_grams_to_consider = []
        self.anchor_words = {}
        
    def load_parallel_sentences(self, data_source):
        parallel_sentences = []
        for file in os.listdir(data_source):
            file_path = os.path.join(data_source, file)
            with open(file_path, mode='r', encoding='utf-8') as data_file:
                reader = csv.reader(data_file, delimiter=';')
                for row in reader:
                    if len(row) < 7: continue   # escape bad data
                    english_content = row[5]    # get english sentences
                    chinese_content = row[6]    # get chinese sentences
        
                    # break apart sentence content by @ delimiter
                    english_sentences = english_content.split('@')  
                    chinese_sentences = chinese_content.split('@')
                    
                    for english_sentence, chinese_sentence in zip(english_sentences, chinese_sentences):
                        clean_english_sentence = english_sentence.strip()
                        
                        # Process the Chinese sentence with Stanza
                        doc = nlp(chinese_sentence)  
                        chinese_tokens = [word.text for sentence in doc.sentences for word in sentence.words]
                        clean_chinese_sentence = " ".join(chinese_tokens)
                        
                        parallel_sentences.append(ParallelSentence(clean_english_sentence, clean_chinese_sentence))
                        
        self.parallel_sentences = parallel_sentences
    
    def generate_multi_grams(self):
        bigrams = self.extract_ngram_counts(n=2).most_common()[:5000]
        trigrams = self.extract_ngram_counts(n=3).most_common()[:3000]
        quadgrams = self.extract_ngram_counts(n=4).most_common()[:1000]
        
        multi_grams_to_consider = set()
        # Add multi-word terms from quad_grams_to_consider
        for quad_gram in quadgrams:
            multi_word_term = '_'.join(quad_gram[0])
            multi_grams_to_consider.add(multi_word_term)
        
        # Add multi-word terms from trigrams_to_consider
        for trigram in trigrams:
            multi_word_term = '_'.join(trigram[0])
            multi_grams_to_consider.add(multi_word_term)
        
        # Add multi-word terms from bigrams_to_consider
        for bigram in bigrams:
            multi_word_term = '_'.join(bigram[0])
            multi_grams_to_consider.add(multi_word_term)
        
        self.multi_grams_to_consider = multi_grams_to_consider
        
    @staticmethod
    def refactor_sentence_with_multiword_term(sentence, multi_word_terms):
        # Tokenize the sentence into words
        words = sentence.split(' ')
        modified_sentence = []
        i = 0
        while i < len(words):
            found = False
            
            # Check for quadgrams (4-word sequences)
            for length in range(4, 1, -1):  # Check for quadgram to bigram
                if i + length <= len(words):
                    multi_word_candidate = '_'.join(words[i:i+length]).lower()
                    if multi_word_candidate in multi_word_terms:
                        # If a match is found, replace the words with the multi-word term
                        modified_sentence.append(multi_word_candidate)
                        i += length
                        found = True
                        break
            
            if not found:
                # If no match is found, just add the word as is
                modified_sentence.append(words[i])
                i += 1

        # Return the modified sentence as a string
        return ' '.join(modified_sentence)
    
    def extract_ngram_counts(self, n):
        ngram_counts = Counter()
        for parallel_sentence in self.parallel_sentences:
            tokens = nltk.word_tokenize(parallel_sentence.en)
            # Filter out stopwords, punctuation, and numbers
            filtered_tokens = [token.lower() for token in tokens 
                               if token.lower() not in stop_words 
                               and token not in punctuation_set 
                               and not token.isdigit()] 
    
            # Generate n-grams for the filtered tokens
            ngram_list = ngrams(filtered_tokens, n)
            # Count the frequency of each n-gram
            ngram_counts.update(ngram_list)
        return ngram_counts
    
    def format_parallel_sentences_for_awesome_align(self):
        with open("zhen.src-tgt", "w") as f:
            for parallel_sentence in self.parallel_sentences:
                modified_sentence = self.refactor_sentence_with_multiword_term(parallel_sentence.en, self.multi_grams_to_consider)
                f.write(f"{modified_sentence} ||| {parallel_sentence.zh}\n")
    
    def build_anchor_words_from_awesome_align_output(self, alignments_path):
        anchor_words = []
        with open(alignments_path, 'r', encoding='utf-8') as file:
            for line in file:
                alignment_pairs = line.strip().split(' ')
                for index, pair in enumerate(alignment_pairs):
                    en_entry, zh_entry = pair.split('<sep>')[0], pair.split('<sep>')[1]
                    if en_entry not in self.multi_grams_to_consider: continue
                    # Clean the English entry
                    cleaned_en_entry = re.sub(r'[^a-zA-Z_]', '', en_entry)
                    
                    # Append only if conditions are met
                    if cleaned_en_entry:
                        if anchor_words and anchor_words[len(anchor_words)-1].en == cleaned_en_entry:
                            if zh_entry not in anchor_words[len(anchor_words)-1].zh:
                                anchor_words[len(anchor_words)-1].zh += zh_entry
                        else:
                            anchor_words.append(AnchorWord(cleaned_en_entry, zh_entry))
                            
        unique_anchors = set(AnchorWord(anchor.en, anchor.zh) for anchor in anchor_words)
        
        # Step 1: Count frequencies of `zh` entries for each `en`
        anchor_freq = defaultdict(Counter)
    
        for anchor in unique_anchors:
            anchor_freq[anchor.en][anchor.zh] += 1
        
        # Step 2: Select the most frequent `zh` entry for each `en`
        filtered_alignments = []
        for en, zh_counter in anchor_freq.items():
            most_frequent_zh = zh_counter.most_common(1)[0][0]  # Get the most frequent `zh`
            filtered_alignments.append(AnchorWord(en, most_frequent_zh))
        
        # Step 3: Sort alphabetically by `en`
        sorted_filtered_anchors = sorted(filtered_alignments, key=lambda anchor: anchor.en)
        
        # Step 4: Write to file
        with open('possible-anchors.txt', 'w') as file:
            for alignment in sorted_filtered_anchors:
                file.write(f"{alignment.en} {alignment.zh}\n")
    
    def load_sorted_anchors(self, anchor_path):
        anchors = set()
        with open(anchor_path, 'r') as file:
            for line in file: 
                alignment = line.strip().split(' ')
                en = alignment[0].replace('_', ' ')
                zh = alignment[1] 
                anchors.add((en, zh))  # Store as a tuple for paired lookup
        self.anchor_words = anchors

In [7]:
parallel_corpus = ParallelCorpus()  # Initialize Corpus Object

In [8]:
parallel_corpus.load_parallel_sentences(data_source='./FTIE/')  # Load parallel sentences from data source

In [9]:
parallel_corpus.generate_multi_grams()  # Generate Multi grams e.g Asian Financial Crisis -> asian_financial_crisis

In [10]:
parallel_corpus.format_parallel_sentences_for_awesome_align() # Format English Sentence With Multi Grams 
# Prepare a data source for awesome align 

# DATA_FILE=./zhen.src-tgt
# MODEL_NAME_OR_PATH=./model_without_co
# OUTPUT_FILE=./output.txt
# OUTPUT_WORDS=./alignments.txt
# OUTPUT_PROB=./alignments-prob.txt
# 
# CUDA_VISIBLE_DEVICES=0 awesome-align \
#     --output_file=$OUTPUT_FILE \
#     --model_name_or_path=$MODEL_NAME_OR_PATH \
#     --data_file=$DATA_FILE \
#     --extraction 'softmax' \
#     --batch_size 32 \
#     --num_workers 0 \
#     --output_word_file=$OUTPUT_WORDS \
#     --output_prob_file=$OUTPUT_PROB 

In [20]:
parallel_corpus.build_anchor_words_from_awesome_align_output('./alignments.txt')    # Generate Possible Anchor Words

In [11]:
parallel_corpus.load_sorted_anchors('./final_anchors.txt')  # Load Final and Verified Anchors

In [57]:
# # How would the original model translate these anchor words? 
# def translate_anchor_words(src_lang, tgt_lang, output_file):
#     # Set the source and target languages
#     tokenizer.src_lang = src_lang
#     tokenizer.tgt_lang = tgt_lang
#     forced_bos_token_id = tokenizer.lang_code_to_id[tgt_lang]  # Ensure the target language is correct
# 
#     # Translate and save results
#     with open(output_file, "w", encoding="utf-8") as f:
#         for index, pair in enumerate(parallel_corpus.anchor_words):
#             # Select source and target based on direction
#             source_anchor = pair.zh if src_lang == "zh_CN" else pair.en
#             target_anchor = pair.en if src_lang == "zh_CN" else pair.zh
# 
#             # Tokenize the input text
#             inputs = tokenizer(source_anchor, return_tensors="pt")
#             # Generate translation with forced BOS token for the target language
#             translated_tokens = model.generate(**inputs, forced_bos_token_id=forced_bos_token_id)
#             # Decode the translated tokens
#             translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
# 
#             # Save the result in the text file
#             f.write(f"{source_anchor}; {target_anchor}; {translation.lower()}\n")
# 
#             if index % 100 == 0:
#                 print(f"Done translating {index} / {len(parallel_corpus.anchor_words)}")
# 


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBart50Tokenizer'. 
The class this function is called from is 'MBartTokenizer'.


In [35]:
# # Translate English to Chinese
# translate_anchor_words(
#     src_lang="en_XX",
#     tgt_lang="zh_CN",
#     output_file="unmodified_en-zh-translated_anchor_words.txt"
# )

Done translating 0 / 3044
Done translating 100 / 3044
Done translating 200 / 3044
Done translating 300 / 3044
Done translating 400 / 3044
Done translating 500 / 3044
Done translating 600 / 3044
Done translating 700 / 3044
Done translating 800 / 3044
Done translating 900 / 3044
Done translating 1000 / 3044
Done translating 1100 / 3044
Done translating 1200 / 3044
Done translating 1300 / 3044
Done translating 1400 / 3044
Done translating 1500 / 3044
Done translating 1600 / 3044
Done translating 1700 / 3044
Done translating 1800 / 3044
Done translating 1900 / 3044
Done translating 2000 / 3044
Done translating 2100 / 3044
Done translating 2200 / 3044
Done translating 2300 / 3044
Done translating 2400 / 3044
Done translating 2500 / 3044
Done translating 2600 / 3044
Done translating 2700 / 3044
Done translating 2800 / 3044
Done translating 2900 / 3044
Done translating 3000 / 3044


In [None]:
# translate_anchor_words(
#     src_lang="zh_CN",
#     tgt_lang="en_XX",
#     output_file="unmodified_zh-en-translated_anchor_words.txt"
# )

In [43]:
# import Levenshtein
# anchor_count = len(parallel_corpus.anchor_words)
# perfect_match_count = 0
# matching_translations = 0
# with open("unmodified_zh-en-translated_anchor_words.txt", "r", encoding="utf-8") as f:
#     for line in f: 
#         items = line.split(';')
#         zh_anchor = items[0].strip()
#         en_anchor = items[1].strip()
#         translation = items[2].strip()
#         
#         if translation == en_anchor:
#             perfect_match_count += 1
#             matching_translations += 1
# 
# print("Unmodified Accuracy on Chinese Anchor Words (zh->en):", matching_translations / anchor_count)
# print("Perfect Match Count:", perfect_match_count, "out of", anchor_count)

Unmodified Accuracy on Chinese Anchor Words (zh->en): 0.32424441524310116
Perfect Match Count: 987 out of 3044


In [42]:
# import Levenshtein
# anchor_count = len(parallel_corpus.anchor_words)
# perfect_match_count = 0
# matching_translations = 0
# with open("unmodified_en-zh-translated_anchor_words.txt", "r", encoding="utf-8") as f:
#     for line in f: 
#         items = line.split(';')
#         en_anchor = items[0].strip()
#         zh_anchor = items[1].strip()
#         translation = items[2].strip()
#             
#         if translation == zh_anchor:
#             perfect_match_count += 1
#             matching_translations += 1
# 
# print("Unmodified Accuracy on English Anchor Words (en->zh):", matching_translations / anchor_count)
# print("Perfect Match Count:", perfect_match_count, "out of", anchor_count)

Unmodified Accuracy on English Anchor Words (en->zh): 0.32490144546649147
Perfect Match Count: 989 out of 3044


In [50]:
anchor_words_dict = {en: zh for en, zh in parallel_corpus.anchor_words}

def refactor_sentence_with_anchors(en_sentence, chinese_sentence, anchor_words):
    # Tokenize the sentence into words
    words = en_sentence.split(' ')
    modified_sentence = []
    i = 0
    refactored_chinese_sentence = chinese_sentence.replace(' ', '')
    while i < len(words):
        found = False
        
        # Check for multi-word anchor terms in English
        for length in range(4, 1, -1):  # Check from 4 words (quadgram) to 2 words (bigram)
            if i + length <= len(words):
                multi_word_candidate = ' '.join(words[i:i+length]).lower()  # Make sure we match underscore-separated terms
                # Iterate over the anchor words and check the English part of the pair
                for en_term, zh_term in anchor_words:
                    if multi_word_candidate == en_term:
                        modified_sentence.append(f"<{multi_word_candidate.replace(' ', '_')}>")  # Replace with English term
                        i += length  # Skip the words that are part of the multi-word term
                        found = True
                        refactored_chinese_sentence = refactored_chinese_sentence.replace(anchor_words_dict[en_term], '<'+zh_term+'>')
                        
                        break
                if found:
                    break
        
        if not found:
            # If no multi-word term is found, just add the current word
            modified_sentence.append(words[i])
            i += 1
    
    # Return the modified sentence as a string
    return ' '.join(modified_sentence), refactored_chinese_sentence


# Example to refactor both English and Chinese sentences
refactored_parallel_sentences = []
for index, parallel_sentence in enumerate(parallel_corpus.parallel_sentences):
    # Refactor the English sentence with anchor words
    modified_english_sentence, modified_chinese_sentence = refactor_sentence_with_anchors(parallel_sentence.en, parallel_sentence.zh, parallel_corpus.anchor_words)

    # Append the refactored sentence pair to the list
    refactored_parallel_sentences.append(ParallelSentence(modified_english_sentence, modified_chinese_sentence))
    if index % 1000 == 0: 
        print("Done refactoring", index, "out of", len(parallel_corpus.parallel_sentences))

Done refactoring 0 out of 255860
Done refactoring 1000 out of 255860
Done refactoring 2000 out of 255860
Done refactoring 3000 out of 255860
Done refactoring 4000 out of 255860
Done refactoring 5000 out of 255860
Done refactoring 6000 out of 255860
Done refactoring 7000 out of 255860
Done refactoring 8000 out of 255860
Done refactoring 9000 out of 255860
Done refactoring 10000 out of 255860
Done refactoring 11000 out of 255860
Done refactoring 12000 out of 255860
Done refactoring 13000 out of 255860
Done refactoring 14000 out of 255860
Done refactoring 15000 out of 255860
Done refactoring 16000 out of 255860
Done refactoring 17000 out of 255860
Done refactoring 18000 out of 255860
Done refactoring 19000 out of 255860
Done refactoring 20000 out of 255860
Done refactoring 21000 out of 255860
Done refactoring 22000 out of 255860
Done refactoring 23000 out of 255860
Done refactoring 24000 out of 255860
Done refactoring 25000 out of 255860
Done refactoring 26000 out of 255860
Done refactori

In [52]:
with open('refactored_parallel_sentences.txt', 'w', encoding='utf-8') as f:
    for pair in refactored_parallel_sentences:
        f.write(f"{pair.en} ; {pair.zh}\n")
        
print("Refactored sentences saved to 'refactored_parallel_sentences.txt'")

Refactored sentences saved to 'refactored_parallel_sentences.txt'


In [3]:
parallel_corpus = ParallelCorpus()
parallel_corpus.load_sorted_anchors('./final_anchors.txt')

In [5]:
refactored_parallel_sentences = []
with open('refactored_parallel_sentences.txt', 'r', encoding='utf-8') as f:
    for line in f:
        items = line.split(';')
        refactored_parallel_sentences.append(ParallelSentence(items[0].strip(), items[1].strip()))
        
print("Refactored sentences loaded")

Refactored sentences loaded


In [8]:
tokens_to_be_added = []
for (en_anchor, zh_anchor) in parallel_corpus.anchor_words:
    tokens_to_be_added.append('<'+en_anchor.replace(' ', '_')+'>')
    tokens_to_be_added.append('<'+zh_anchor+'>')

In [70]:
tokens_to_be_added

['<business_practices>',
 '<商业行为>',
 '<deficit_countries>',
 '<赤字国家>',
 '<rising_us_interest_rates>',
 '<上升美国利率>',
 '<emerging_markets>',
 '<新兴市场>',
 '<corporate_world>',
 '<企业界>',
 '<services_firm>',
 '<服务公司>',
 '<mr_draghi>',
 '<德拉吉>',
 '<exporting_countries>',
 '<出口国>',
 '<european_parliament>',
 '<欧洲议会>',
 '<foreign_investors>',
 '<外国投资者>',
 '<investment_managers>',
 '<投资经理>',
 '<core_inflation>',
 '<核心通胀>',
 '<founder_chief_executive>',
 '<创始首席执行官>',
 '<eastern_china>',
 '<东部中国>',
 '<economic_sanctions>',
 '<经济制裁>',
 '<retail_sales>',
 '<零售>',
 '<big_business>',
 '<大企业>',
 '<challenge_us>',
 '<挑战美国>',
 '<current_crisis>',
 '<当前危机>',
 '<bilateral_trade>',
 '<双边贸易>',
 '<market_economy>',
 '<市场经济>',
 '<shares_rose>',
 '<股价上涨>',
 '<investment_products>',
 '<投资产品>',
 '<southern_guangdong>',
 '<广东>',
 '<us_sanctions>',
 '<美国制裁>',
 '<sovereign_wealth_fund>',
 '<主权财富基金>',
 '<social_media_platform>',
 '<社交媒体平台>',
 '<us_authorities>',
 '<美国当局>',
 '<emerging_market_central_banks>',
 '<新兴市场央行

In [89]:
from transformers import MBartForConditionalGeneration, MBartTokenizer
from datasets import Dataset
import torch
from torch.utils.data import DataLoader

# Load mBART model and tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer = MBartTokenizer.from_pretrained(model_name)

# Add custom tokens and resize model embeddings
tokenizer.add_tokens(tokens_to_be_added)
model.resize_token_embeddings(len(tokenizer))

# Split sentences for training, validation, and test sets
train_sentences = refactored_parallel_sentences[0:10000]
validation_sentences = refactored_parallel_sentences[10000:13000]
test_sentences = refactored_parallel_sentences[13000:15000]

# Define a custom Dataset class
class ParallelDataset(Dataset):
    def __init__(self, inputs, attention_mask, targets):
        self.inputs = inputs
        self.attention_mask = attention_mask
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # Ensure that we're getting a single example from the lists
        return {
            'input_ids': torch.tensor(self.inputs[idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.attention_mask[idx], dtype=torch.long),
            'labels': torch.tensor(self.targets[idx], dtype=torch.long)
        }



def preprocess_data(parallel_sentences, src_lang="en_XX", tgt_lang="zh_CN"):
    tokenizer.src_lang = src_lang
    tokenizer.tgt_lang = tgt_lang

    inputs = []
    attention_mask = []
    targets = []
    
    for sentence in parallel_sentences:
        input_text = sentence.en  # English sentence
        target_text = sentence.zh # Chinese sentence

        # Tokenize the source (English) and target (Chinese) sentences
        input_tokens = tokenizer(input_text, truncation=True, padding="max_length", max_length=1024, return_tensors="pt")
        target_tokens = tokenizer(target_text, truncation=True, padding="max_length", max_length=1024, return_tensors="pt")
        
        # Convert tensors to lists of integers
        inputs.append(input_tokens['input_ids'].squeeze(0).tolist())  # Remove batch dimension
        attention_mask.append(input_tokens['attention_mask'].squeeze(0).tolist())  # Remove batch dimension
        targets.append(target_tokens['input_ids'].squeeze(0).tolist())  # Remove batch dimension

    # Create a Dataset object for PyTorch DataLoader
    dataset = ParallelDataset(inputs, attention_mask, targets)

    return dataset

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MBart50Tokenizer'. 
The class this function is called from is 'MBartTokenizer'.


In [90]:
# Preprocess and create datasets
train_dataset = preprocess_data(train_sentences)
# validation_dataset = preprocess_data(validation_sentences)
# test_dataset = preprocess_data(test_sentences)



# # Create DataLoader objects
# train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# validation_dataloader = DataLoader(validation_dataset, batch_size=16, shuffle=False)
# test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

AttributeError: 'ParallelDataset' object has no attribute '_info'

In [87]:
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and scheduler
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    epoch_loss = 0
    for batch in train_dataloader:
        # Move input tensors to GPU/CPU
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)  # Don't forget the attention mask
        labels = batch['labels'].to(device)

        # Zero the gradients from the previous step
        optimizer.zero_grad()

        # Forward pass
        outputs = model(input_ids=input_ids, 
                         attention_mask=attention_mask, 
                         labels=labels)  # Make sure to pass attention mask
        loss = outputs.loss

        # Backward pass to compute gradients
        loss.backward()

        # Step optimizer and scheduler
        optimizer.step()
        scheduler.step()

        # Update the loss for this batch
        epoch_loss += loss.item()

    # Print the average loss for this epoch
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss / len(train_dataloader)}")

    # Evaluate on the validation set after each epoch
    model.eval()  # Set model to evaluation mode
    validation_loss = 0
    with torch.no_grad():  # No gradient computation for validation
        for batch in validation_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass on the validation set
            outputs = model(input_ids=input_ids, 
                             attention_mask=attention_mask, 
                             labels=labels)
            loss = outputs.loss

            validation_loss += loss.item()

    print(f"Validation Loss: {validation_loss / len(validation_dataloader)}")

# Save the model and tokenizer after training
model.save_pretrained('./model')
tokenizer.save_pretrained('./tokenizer')

TypeError: list indices must be integers or slices, not list