In [1]:
import pandas as pd

# Load train and test data
train_data = pd.read_csv('train_data.csv')
test_data = pd.read_csv('test_data.csv')

In [2]:
train_data.head()

Unnamed: 0,Date,URL,Title,Source,Country,Label
0,20240815T010000Z,https://borneobulletin.com.bn/explosions-repor...,Explosions reported near two ships off Yemen :...,borneobulletin.com.bn,Brunei,2
1,20240716T194500Z,https://www.hindustantimes.com/india-news/crew...,"Crew , including 13 Indians , still missing af...",hindustantimes.com,India,2
2,20240809T100000Z,https://www.yahoo.com/news/multiple-attacks-ta...,Multiple attacks target merchant ship off Yeme...,yahoo.com,United States,3
3,20240717T041500Z,https://timesofoman.com/article/147862-oil-tan...,Oil tanker with 13 Indians on board sinks off ...,timesofoman.com,Oman,2
4,20240812T201500Z,https://menafn.com/1108546043/Multiple-Attacks...,Multiple Attacks Target Merchant Ship Off Yemen,menafn.com,Qatar,3


In [3]:
test_data.head()

Unnamed: 0,Date,URL,Title,Source,Country
0,20221207T020000Z,https://www.rnz.co.nz/news/national/480280/eng...,Engineer fined over huge fire at Napier Port,rnz.co.nz,
1,20221221T150000Z,https://www.ship-technology.com/news/ictsi-lea...,ICTSI reaches 30 - year lease extension for Ba...,ship-technology.com,United States
2,20221018T084500Z,https://www.malaymail.com/news/money/mediaoutr...,DHL : Ocean freight rate moving towards manage...,malaymail.com,United States
3,20221028T151500Z,https://focustaiwan.tw/society/202210280021,Indonesians stuck on vessel in Kaohsiung set t...,focustaiwan.tw,Taiwan
4,20221018T104500Z,https://bdnews24.com/bangladesh/0ggpvbnije,Body found in container sent from Chattogram t...,bdnews24.com,Bangladesh


In [4]:
# Import necessary libraries for text processing
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import string

# Download necessary resources for NLTK
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

# Initialize stopwords and lemmatizer
stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

# Function to clean text (remove stop words and lemmatize)
def clean_text(text):
    # Tokenize the text
    tokens = nltk.word_tokenize(text)
    
    # Remove punctuation and stop words, then apply lemmatization
    cleaned_tokens = [lemmatizer.lemmatize(word.lower()) for word in tokens 
                      if word.lower() not in stop_words and word not in string.punctuation]
    
    # Join tokens back to a single string
    cleaned_text = ' '.join(cleaned_tokens)
    return cleaned_text

# Apply the cleaning function to 'title' columns in train and test data
train_data['cleaned_title'] = train_data['Title'].apply(clean_text)
test_data['cleaned_title'] = test_data['Title'].apply(clean_text)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Regin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Regin\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Regin\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [5]:
train_data.head()

Unnamed: 0,Date,URL,Title,Source,Country,Label,cleaned_title
0,20240815T010000Z,https://borneobulletin.com.bn/explosions-repor...,Explosions reported near two ships off Yemen :...,borneobulletin.com.bn,Brunei,2,explosion reported near two ship yemen securit...
1,20240716T194500Z,https://www.hindustantimes.com/india-news/crew...,"Crew , including 13 Indians , still missing af...",hindustantimes.com,India,2,crew including 13 indian still missing oil tan...
2,20240809T100000Z,https://www.yahoo.com/news/multiple-attacks-ta...,Multiple attacks target merchant ship off Yeme...,yahoo.com,United States,3,multiple attack target merchant ship yemen uni...
3,20240717T041500Z,https://timesofoman.com/article/147862-oil-tan...,Oil tanker with 13 Indians on board sinks off ...,timesofoman.com,Oman,2,oil tanker 13 indian board sink oman coast
4,20240812T201500Z,https://menafn.com/1108546043/Multiple-Attacks...,Multiple Attacks Target Merchant Ship Off Yemen,menafn.com,Qatar,3,multiple attack target merchant ship yemen


In [6]:
test_data.head()

Unnamed: 0,Date,URL,Title,Source,Country,cleaned_title
0,20221207T020000Z,https://www.rnz.co.nz/news/national/480280/eng...,Engineer fined over huge fire at Napier Port,rnz.co.nz,,engineer fined huge fire napier port
1,20221221T150000Z,https://www.ship-technology.com/news/ictsi-lea...,ICTSI reaches 30 - year lease extension for Ba...,ship-technology.com,United States,ictsi reach 30 year lease extension baltic con...
2,20221018T084500Z,https://www.malaymail.com/news/money/mediaoutr...,DHL : Ocean freight rate moving towards manage...,malaymail.com,United States,dhl ocean freight rate moving towards manageab...
3,20221028T151500Z,https://focustaiwan.tw/society/202210280021,Indonesians stuck on vessel in Kaohsiung set t...,focustaiwan.tw,Taiwan,indonesian stuck vessel kaohsiung set return h...
4,20221018T104500Z,https://bdnews24.com/bangladesh/0ggpvbnije,Body found in container sent from Chattogram t...,bdnews24.com,Bangladesh,body found container sent chattogram malaysia


In [8]:
# Import necessary libraries
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from datasets import Dataset

# Load the training data and updated news titles data
train_df = pd.read_csv('training_dataset.csv')
predict_df = pd.read_csv('updated_news_titles.csv')

# Ensure the columns are correctly labeled in the training data
train_df = train_df[['Title', 'LABEL']].rename(columns={'LABEL': 'labels'})

# Split the training data into train and validation sets
train_df, eval_df = train_test_split(train_df, test_size=0.2, random_state=42)

# Convert DataFrames to Hugging Face Dataset format
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)

# Load the tokenizer and model with num_labels=14 to accommodate labels from 0 to 13
model_name = "distilbert-base-uncased"  # Replace with "mistralai/mistral" if you have access
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=14)

# Tokenize the data
def preprocess_data(examples):
    return tokenizer(examples["Title"], truncation=True, padding="max_length", max_length=128)

# Tokenize both train and validation datasets
train_dataset = train_dataset.map(preprocess_data, batched=True)
eval_dataset = eval_dataset.map(preprocess_data, batched=True)

# Training setup with evaluation strategy set to "epoch"
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",  # Enable evaluation after each epoch
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Define a compute_metrics function for evaluation
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

# Initialize the Trainer with both train and validation datasets
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

# Train the model
trainer.train()

# Evaluate the model on the validation set
eval_predictions = trainer.predict(eval_dataset)
eval_preds = eval_predictions.predictions.argmax(-1)

# Generate and display the classification report and accuracy score for the validation set
print("Validation Accuracy:", accuracy_score(eval_df['labels'], eval_preds))
print("Classification Report (Validation Set):")
print(classification_report(eval_df['labels'], eval_preds))

# Prepare the predict dataset (updated news titles)
predict_df = predict_df[['Title']]
predict_dataset = Dataset.from_pandas(predict_df)
predict_dataset = predict_dataset.map(preprocess_data, batched=True)

# Predict on the updated news titles
predict_results = trainer.predict(predict_dataset)
predicted_labels = predict_results.predictions.argmax(-1)

# Save the predictions to a new CSV file
predict_df['Predicted_Label'] = predicted_labels
predict_df.to_csv('updated_news_titles_with_predictions.csv', index=False)
print("Predictions saved to 'updated_news_titles_with_predictions.csv'")


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/570 [00:00<?, ? examples/s]

Map:   0%|          | 0/143 [00:00<?, ? examples/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.964588,0.475524
2,No log,1.708649,0.545455
3,No log,1.653235,0.538462


Validation Accuracy: 0.5384615384615384
Classification Report (Validation Set):
              precision    recall  f1-score   support

           1       0.00      0.00      0.00         6
           2       0.45      0.82      0.58        17
           3       0.69      0.97      0.80        34
           4       0.00      0.00      0.00        11
           6       0.00      0.00      0.00         9
           7       0.00      0.00      0.00         3
           8       0.00      0.00      0.00         6
           9       0.00      0.00      0.00         3
          10       0.00      0.00      0.00         4
          11       0.00      0.00      0.00        14
          12       0.49      0.83      0.62        36

    accuracy                           0.54       143
   macro avg       0.15      0.24      0.18       143
weighted avg       0.34      0.54      0.42       143



  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Map:   0%|          | 0/713 [00:00<?, ? examples/s]

Predictions saved to 'updated_news_titles_with_predictions.csv'
