# Bert-Train

This file is used to train the base-bert model,
and save the checkpoint to ./bert-results/checkpoint-xxxx

Author: Xiaofeng Zhang

Email: Xiaofzhang1@student.unimelb.edu.au

SID: 1194655

Reference: https://huggingface.co/docs/transformers/tasks/sequence_classification

In [None]:
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np

# Load Dataset

In [None]:
from datasets import load_from_disk
rumour = load_from_disk("./data/rumour")

# Preprocess

## Tokenize

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)
tokenized_rumour = rumour.map(preprocess_function, batched=True)

## Padding

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Train

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

In [None]:
training_args = TrainingArguments(
    output_dir="./bert-results",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=40,
    weight_decay=0.01,
)

trainer = Trainer(
    
    model=model,
    args=training_args,
    train_dataset=tokenized_rumour["train"],
    eval_dataset=tokenized_rumour["dev"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

# Write predict result

In [None]:
model.eval()
with open ("test.predict.csv","w") as f:
    f.write("Id,Predicted\n")

    for i in range(len(tokenized_rumour["test"])):
        
            _input = tokenized_rumour["test"][i]['input_ids']
            _mask = tokenized_rumour["test"][i]['attention_mask']
            model.to("cuda")
            _input = torch.tensor(_input).unsqueeze(0).to("cuda")
            _mask = torch.tensor(_mask).unsqueeze(0).to("cuda")
            #a= tokenizer.decode(_input)

            y=model(input_ids=_input,attention_mask=_mask)
            res = np.argmax(y.logits.detach().cpu().numpy())
            #label = tokenized_rumour["dev"][i]["label"]
            
            f.write("{0},{1}\n".format(i,res))
