In [None]:
pip install wilds

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import torch

!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git


In [None]:
from torchtext.vocab import GloVe
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import pandas as pd
import numpy as np
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
import re
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader
import torchvision.transforms as transforms
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.utils import split_into_groups
from torch.autograd import grad
from wilds.common.data_loaders import get_eval_loader


class ToxicClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, embeddings_vectors, hidden_dim = 32, output_dim = 1):
        super(ToxicClassifier, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(embeddings_vectors, freeze=True)
        self.rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        _, embedded = self.rnn(embedded)
        return self.output(embedded[-1])

def tokenize(text, max_length = 100):
    tokenizer = get_tokenizer('basic_english')
    text = text.lower()
    text = re.sub(r"([.!?,'*])", r"", text)
    text = re.sub(r"([-])", r" ", text)
    tokens = tokenizer(text)
    if len(tokens) < max_length:
      tokens.extend(['<PAD>']*(max_length - len(tokens)))
    tokens = tokens[:max_length]
    tokens = [glove.stoi.get(token, len(glove.stoi) - 1) for token in tokens]
    tokens = np.array(tokens, dtype=np.int64)
    return tokens

def compute_irm_penalty(losses, dummy):
  g1 = grad(losses[0::2].mean(), dummy, create_graph=True)[0]
  g2 = grad(losses[1::2].mean(), dummy, create_graph=True)[0]
  return (g1 * g2).sum()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
token_dim = 50
length = 100

glove = GloVe(name='6B', dim=token_dim)
padding_vector = torch.zeros(token_dim)
padding_token = '<PAD>'
glove.itos.append(padding_token)  
glove.stoi[padding_token] = len(glove.itos) - 1 
glove.vectors = torch.cat((glove.vectors, padding_vector.unsqueeze(0)), dim=0) 

dataset = get_dataset(dataset="civilcomments", download=True)
train_data = dataset.get_subset(
    "train")
train_loader = get_train_loader("standard", train_data, batch_size=128)
test_data = dataset.get_subset(
    "val")
test_loader = get_train_loader("standard", train_data, batch_size=128)

identities = CombinatorialGrouper(dataset, [
            'male',
            'female',
            'LGBTQ',
            'christian',
            'muslim',
            'other_religions',
            'black',
            'white'
        ])
train_loader = get_train_loader(
    "group", train_data, grouper=identities, n_groups_per_batch=4, batch_size=128
)


model = ToxicClassifier(len(glove), token_dim, glove.vectors.to(device))
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr = 1e-3)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience = 10)

negative_samples = (train_data.y_array == 0).sum()
positive_samples = (train_data.y_array == 1).sum()
pos_weight = negative_samples / positive_samples

num_epochs = 20
for epoch in range(0, num_epochs):
  train_loss = 0.0
  train_correct = 0
  train_total = 0
  model.train()
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
  penalty_multiplier = epoch ** 1.6
  for cur_train in tqdm(train_loader):
      optimizer.zero_grad()
      error = 0
      penalty = 0
      input, label, metadata = cur_train
      idx, groups_idx, _ = split_into_groups(identities.metadata_to_group(metadata))
      input = tuple(map(tokenize, input))
      input = torch.Tensor(input).long()
      for i in groups_idx:
        cur_in = torch.index_select(input, 0, i).to(device)
        cur_label = torch.index_select(label, 0 , i).to(device)
        output = model(cur_in)
        predicted_labels = (torch.sigmoid(output) >= 0.5).float()
        loss_erm = F.binary_cross_entropy_with_logits((output * dummy_w).reshape(-1), cur_label.float(), reduction='none', pos_weight= pos_weight)
        penalty += compute_irm_penalty(loss_erm, dummy_w)
        error += loss_erm.mean()
        train_correct += (predicted_labels == cur_label).sum().item()
        train_total += len(cur_label)
      (error + penalty_multiplier * penalty).backward()
      optimizer.step()
  scheduler.step(train_loss)
  train_acc = train_correct / train_total
  print(f"Epoch {epoch+1}/{num_epochs},  Train Acc: {train_acc:.4f}")  

  test_loss = 0.0
  test_correct = 0
  test_total = 0
  pbar = tqdm(test_loader)
  model.eval()
  with torch.no_grad():
      for i, data in enumerate(pbar, 0):
          input, label, groupings = cur_train
          input = tuple(map(tokenize, input))
          input = torch.Tensor(input).long().to(device)
          label = label.to(device)
          output = model(input).reshape(-1)
          loss = criterion(output, label.float())
          predicted_labels = (torch.sigmoid(output) >= 0.5).float()
          test_total += len(label)
          test_correct += (predicted_labels == label).sum().item()
          test_loss += loss
          pbar.set_postfix(MSE=loss.item())
      
  test_loss /= len(test_loader.dataset)
  test_acc = test_correct / test_total
  print(f"Epoch {epoch+1}/{num_epochs}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")


test_data = dataset.get_subset(
    "test",
)
# Prepare the data loader
test_loader = get_eval_loader("standard", test_data, batch_size=32)
trues = []
preds = []
metadatas = []
for input, true, metadata in test_loader:
    with torch.no_grad():
      input = tuple(map(tokenize, input))
      input = torch.Tensor(input).long().to(device)
      output = model(input)
      output = (torch.sigmoid(output) >= 0.5).long().reshape(-1)
      trues.append(true.to('cpu'))
      preds.append(output.to('cpu'))
      metadatas.append(metadata.to('cpu'))
all_preds = torch.cat(preds, dim = 0)
all_trues = torch.cat(trues, dim = 0)
all_metas = torch.cat(metadatas, dim = 0)
print(dataset.eval(all_preds, all_trues, all_metas))

100%|██████████| 2101/2101 [00:57<00:00, 36.38it/s]


Epoch 1/100,  Train Acc: 14.6537


100%|██████████| 2102/2102 [00:33<00:00, 63.61it/s, MSE=0.412]


Epoch 1/100, Test Loss: 0.0032, Test Acc: 0.8203


100%|██████████| 2101/2101 [00:57<00:00, 36.47it/s]


Epoch 2/100,  Train Acc: 20.2337


100%|██████████| 2102/2102 [00:33<00:00, 62.67it/s, MSE=0.784]


Epoch 2/100, Test Loss: 0.0061, Test Acc: 0.5312


100%|██████████| 2101/2101 [00:57<00:00, 36.33it/s]


Epoch 3/100,  Train Acc: 20.5740


100%|██████████| 2102/2102 [00:35<00:00, 59.69it/s, MSE=0.375]


Epoch 3/100, Test Loss: 0.0029, Test Acc: 0.7578


100%|██████████| 2101/2101 [00:57<00:00, 36.46it/s]


Epoch 4/100,  Train Acc: 20.3428


100%|██████████| 2102/2102 [00:36<00:00, 57.25it/s, MSE=0.555]


Epoch 4/100, Test Loss: 0.0043, Test Acc: 0.7422


100%|██████████| 2101/2101 [00:57<00:00, 36.49it/s]


Epoch 5/100,  Train Acc: 20.3249


100%|██████████| 2102/2102 [00:34<00:00, 60.46it/s, MSE=0.674]


Epoch 5/100, Test Loss: 0.0053, Test Acc: 0.7109


100%|██████████| 2101/2101 [00:57<00:00, 36.44it/s]


Epoch 6/100,  Train Acc: 20.6306


100%|██████████| 2102/2102 [00:37<00:00, 56.13it/s, MSE=0.465]


Epoch 6/100, Test Loss: 0.0036, Test Acc: 0.8203


100%|██████████| 2101/2101 [00:57<00:00, 36.44it/s]


Epoch 7/100,  Train Acc: 20.3096


100%|██████████| 2102/2102 [00:33<00:00, 62.26it/s, MSE=0.436]


Epoch 7/100, Test Loss: 0.0034, Test Acc: 0.8359


100%|██████████| 2101/2101 [00:57<00:00, 36.34it/s]


Epoch 8/100,  Train Acc: 20.5490


100%|██████████| 2102/2102 [00:35<00:00, 58.97it/s, MSE=0.126]


Epoch 8/100, Test Loss: 0.0010, Test Acc: 0.9688


100%|██████████| 2101/2101 [00:57<00:00, 36.48it/s]


Epoch 9/100,  Train Acc: 20.4111


100%|██████████| 2102/2102 [00:37<00:00, 56.62it/s, MSE=0.465]


Epoch 9/100, Test Loss: 0.0036, Test Acc: 0.7812
({'acc_avg': 0.5523015260696411, 'acc_y:0_male:1': 0.498511403799057, 'count_y:0_male:1': 12092.0, 'acc_y:1_male:1': 0.7852928042411804, 'count_y:1_male:1': 2203.0, 'acc_y:0_female:1': 0.5364976525306702, 'count_y:0_female:1': 14179.0, 'acc_y:1_female:1': 0.7484581470489502, 'count_y:1_female:1': 2270.0, 'acc_y:0_LGBTQ:1': 0.4261682331562042, 'count_y:0_LGBTQ:1': 3210.0, 'acc_y:1_LGBTQ:1': 0.8034539222717285, 'count_y:1_LGBTQ:1': 1216.0, 'acc_y:0_christian:1': 0.6769688725471497, 'count_y:0_christian:1': 12101.0, 'acc_y:1_christian:1': 0.6817460060119629, 'count_y:1_christian:1': 1260.0, 'acc_y:0_muslim:1': 0.4461251199245453, 'count_y:0_muslim:1': 5355.0, 'acc_y:1_muslim:1': 0.8008604645729065, 'count_y:1_muslim:1': 1627.0, 'acc_y:0_other_religions:1': 0.5664429664611816, 'count_y:0_other_religions:1': 2980.0, 'acc_y:1_other_religions:1': 0.7384615540504456, 'count_y:1_other_religions:1': 520.0, 'acc_y:0_black:1': 0.4023987948894501, 'c