In [None]:
%cd ..

In [None]:
import torch
import pytorch_lightning as pl
from transfromers import AutoTokenizer
import pandas as pd
from omegaconf import OmegaConf
import wandb
from src.Mamba.mamba_datamodule import ToxicDataModule
from src.Mamba.mamba_model import ToxicModel
from pytorch_lightning.loggers import WandbLogger

In [None]:
df_train = pd.read_csv("train.csv")
df_test_labels = pd.read_csv("test_labels.csv")
df_test_comments = pd.read_csv("test.csv")
df_test = df_test_comments.merge(df_test_labels, on="id")
df_test = df_test[df_test["toxic"] != -1].reset_index().drop("index", axis=1)
df_train = pd.concat([df_train,df_test]).reset_index().drop("index", axis=1)

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token


tokens = tokenizer(list(df_train["comment_text"]))
tokens_lengths=[]
for token in tokens["input_ids"]:
    tokens_lengths.append(len(token))
    
tokens_lengths = torch.tensor(tokens_lengths)
df_train = df_train[(tokens_lengths <150).numpy()].reset_index().drop("index", axis=1)
df_train.to_csv("training_data.csv")

In [None]:
df = pd.read_csv("training_data.csv")
labels = torch.Tensor(df[['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']].values)
labels_weights = torch.Tensor([20.0, 18.0, 4.0, 4.0, 1.0, 4.0])
labels = (labels @ labels_weights)
labels = (labels - labels.min())/(labels.max() - labels.min())

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer(list(df["comment_text"]), padding="max_length", max_length=150, truncation=True, return_tensors="pt")
slicer = tokens["attention_mask"].sum(dim=1)
tokens = tokens["input_ids"]

In [None]:
%%writefile toxic_model_config.yaml

data:
    train_dir: /kaggle/working/train.csv
    train_split: 0.9
    batch_size: 32
    shuffle: true
    num_workers: 3
        
model:
    d_model: 1024
    n_layers: 48
    vocab_size: 50280
    rms_norm: true
    fused_add_norm: true
    d_output: 1
    learning_rate: 0.00005
    checkpoint: "checkpoints"
    num_epochs: 25
    pos_weight: 10
    weights_path: "/kaggle/working/mamba-370m/pytorch_model.bin"
    freeze_backbone: false
    dropout: 0.3

In [None]:
cfg = OmegaConf.load("toxic_model_config.yaml")

dataModule = ToxicDataModule(cfg, tokens, slicer,labels)
model = ToxicModel(cfg) 

wandb_logger = WandbLogger(project='toxic_detection', name="mamba_based_model", log_model = "all", )
wandb_logger.log_hyperparams(cfg)

checkpoint = pl.callbacks.ModelCheckpoint(
    dirpath=cfg.model.checkpoint,
    monitor="val_loss",
    filename="mamba_model-{val_loss:.2f}",  
    save_top_k=1, 
) 

trainer = pl.Trainer(max_epochs=cfg.model.num_epochs, callbacks=[checkpoint], logger=wandb_logger)
 
trainer.fit(model=model, datamodule=dataModule)      