In [None]:
import numpy as np
import pandas as pd
import re
from sklearn.model_selection import train_test_split
from fastai.imports import *

import torch
from torch.utils.data import DataLoader
from transformers import TrainingArguments,Trainer
from transformers import AutoModelForSequenceClassification,AutoTokenizer
from datasets import Dataset, DatasetDict
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
from tqdm.notebook import tqdm

In [None]:
BASE_PATH = 'COMP5329S1A2Dataset'

In [None]:
def read_csv(path, n_columns=2):
    data = []
    with open(path, 'r') as f:
        for line in f.readlines():
            if not re.match('^\d+\.jpg', line):
                continue
            ImageID = line.split(',')[0]     
            if n_columns ==2:
                Labels = line.split(',')[1] 
                Caption = ','.join(line.split(',')[2:])
                data.append({'ImageID':ImageID, 'Labels':Labels, 'Caption': Caption})
            else:
                Caption = ','.join(line.split(',')[1:])
                data.append({'ImageID':ImageID, 'Labels': '' , 'Caption': Caption})
                
    return pd.DataFrame(data)

In [None]:
df = read_csv(f'{BASE_PATH}/train.csv').iloc[:,[-1, 1]]
df

In [None]:
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=42)

In [None]:
def append_dummies(df):
	labels_df = df['Labels'].str.get_dummies(sep=' ')
	labels_df.columns = ['' + str(col) for col in labels_df.columns]

	return pd.concat([df.iloc[:,:-1], labels_df], axis=1)

In [None]:
train_df = append_dummies(train_df); display(train_df.head(3))
valid_df = append_dummies(valid_df)

In [None]:
labels = [label for label in train_df if label not in ['Caption']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
print(*labels)

In [None]:
lmodel = "bert-base-cased"

In [None]:
train_ds = Dataset.from_pandas(train_df).remove_columns('__index_level_0__')
valid_ds = Dataset.from_pandas(valid_df).remove_columns('__index_level_0__')
#eval_ds = Dataset.from_pandas(eval_df)#.remove_columns('__index_level_0__')

In [None]:
dds = DatasetDict({"train":train_ds, "test": valid_ds})
dds

In [None]:
tokenizer = AutoTokenizer.from_pretrained(lmodel)

def preprocess_data(examples):
  text = examples["Caption"]
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  labels_matrix = np.zeros((len(text), len(labels)))
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()
  
  return encoding

In [None]:
encoded_dataset = dds.map(preprocess_data, batched=True, remove_columns=train_ds.column_names)

In [None]:
example = encoded_dataset['train'][0]
print(example.keys())

In [None]:
tokenizer.decode(example['input_ids'])

In [None]:
encoded_dataset.set_format("torch")

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    lmodel, 
    problem_type="multi_label_classification", 
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
)

In [None]:
batch_size = 128
metric_name = "f1"

In [None]:
args = TrainingArguments(
    f"{lmodel}",
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=15,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

In [None]:
def multi_label_metrics(predictions, labels, threshold=0.5):
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [None]:
#forward pass
outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train'][0]['labels'].unsqueeze(0))
outputs

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
eval_df = read_csv(f'{BASE_PATH}/test.csv', 1).iloc[:,[0,-1]]

In [None]:
eval_df.head(4)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

class EvalDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        text = row.Caption
        encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        return encoding

eval_dataset = EvalDataset(eval_df, tokenizer)


batch_size = 264
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

progress_bar = tqdm(total=len(eval_dataloader), desc="Inference Progress")

all_probs = []
model.to(device)
model.eval()

with torch.no_grad():
    for batch in eval_dataloader:
        inputs = {key: value.squeeze().to(device) for key, value in batch.items()}
        outputs = model(**inputs)
        logits = outputs.logits.squeeze()  
        if len(logits.shape) == 0:
            batch_probs = [logits.item()]
        else:
            batch_probs = torch.sigmoid(logits).tolist()

        all_probs.extend(batch_probs)
        progress_bar.update(1)

        del inputs, outputs, logits

In [None]:
probs_df = pd.DataFrame(all_probs, columns = labels)
probs_df.head(3)

In [None]:
def create_labels_df(df, threshold=0.5):
    df = df.copy()
    labels = []
    for i in range(len(df)):
        label_list = [col for col in df.columns[1:] if df.iloc[i][col] > threshold]
        labels.append(" ".join(label_list))
    df["Labels"] = labels
    return df[["ImageID", "Labels"]]

In [None]:
final_preds = pd.concat([eval_df.iloc[:,:-1], probs_df], axis=1)

In [None]:
final_preds.head(5)