In [None]:
%pip install transformers tokenizers datasets

from tqdm.notebook import tqdm
from IPython.display import clear_output

import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import AutoModelForMaskedLM, AutoTokenizer, TrainingArguments, Trainer, AutoModelForSequenceClassification, BertModel,ElectraModel
clear_output()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
from datasets import Dataset, load_dataset

datasets = load_dataset("sst2")

clear_output()

In [None]:
checkpoint = "google/electra-base-discriminator"
# checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
def tokenize_function(example):
    return tokenizer(example["sentence"], truncation=True)

tokenized_train_dataset = datasets['train'].map(tokenize_function, batched=True)
tokenized_valid_dataset = datasets['validation'].map(tokenize_function, batched=True)
tokenized_test_dataset = datasets['test'].map(tokenize_function, batched=True)

clear_output()

In [None]:
tokenized_train_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
tokenized_valid_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])
tokenized_test_dataset.set_format("torch",columns=["input_ids", "attention_mask", "label"])

In [None]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
classifier_batch = 16
num_epochs = 3
t = 100
alpha = 0.2
num_labels = 2

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_train_dataset, shuffle=True, batch_size=classifier_batch, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_valid_dataset, batch_size=classifier_batch, collate_fn=data_collator
)

#SCL


In [None]:
scl_rep_model = ElectraModel.from_pretrained(checkpoint).to(device)
clear_output()

In [None]:
def SCL(reps, labels):
  # t = 0.07
  loss = 0
  # mean_reps = torch.mean(reps[:, 1:-1, :], axis=1)
  mean_reps = reps[:, 0, :]

  for i, rep in enumerate(mean_reps):
    # loss -= torch.log(logit[labels[i]])
    positives = [mean_reps[j].tolist() for j in range(len(labels)) if not j == i and labels[i] == labels[j]]

    if len(positives) != 0:
      positives = torch.tensor(positives).to(device)
      others = torch.cat((mean_reps[0:i], mean_reps[i+1:]))

      others_sum = torch.sum(torch.exp((others @ rep) / t))
      pos_sum = torch.sum(torch.log(torch.exp((positives @ rep) / t) / others_sum))
      loss += -pos_sum/len(positives)

  return loss if loss != 0 else torch.tensor(0.0, dtype=torch.float32, device='cuda:0', requires_grad=True)


In [None]:
from transformers.modeling_outputs import TokenClassifierOutput
from torch import nn

class SCLModel(nn.Module):
  def __init__(self, rep_model, num_labels):
    super(SCLModel,self).__init__()
    self.num_labels = num_labels

    #Load Model with given checkpoint and extract its body
    self.rep_model = rep_model
    self.dropout = torch.nn.Dropout(0.1)
    self.classifier = torch.nn.Linear(768,num_labels) # load and initialize weights
    self.layer = torch.nn.Linear(768,768)

  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    reps = self.rep_model(input_ids=input_ids, attention_mask=attention_mask)
    scl = SCL(reps[0], labels)
    #Add custom layers

    sequence_output = self.dropout(reps[0]) #outputs[0]=last hidden state
    x = self.layer(sequence_output[:,0,:].view(-1,768))
    act = nn.functional.relu(x)
    logits = self.classifier(act)
    # logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    Le = None
    if labels is not None:
      loss_fct = torch.nn.CrossEntropyLoss()
      Le = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

    return TokenClassifierOutput(loss=alpha*Le + (1-alpha) * scl, logits=logits, hidden_states=reps.hidden_states,attentions=reps.attentions)

In [None]:
scl_model = SCLModel(rep_model=scl_rep_model,num_labels=num_labels).to(device)

In [None]:
from transformers import AdamW,get_scheduler

scl_optimizer = AdamW(scl_model.parameters(), lr=1e-5)

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=scl_optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

12630




In [None]:
from datasets import load_metric
scl_f1_metric = load_metric("f1", average='micro')


  scl_f1_metric = load_metric("f1", average='micro')


In [None]:
from tqdm.auto import tqdm

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))


for epoch in range(num_epochs):
  scl_model.train()
  for batch in train_dataloader:
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = scl_model(**batch)
      loss = outputs.loss
      loss.backward()

      scl_optimizer.step()
      lr_scheduler.step()
      scl_optimizer.zero_grad()
      progress_bar_train.update(1)

  scl_model.eval()
  for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = scl_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    scl_f1_metric.add_batch(predictions=predictions, references=batch["labels"])
    # acc_metric.add_batch(predictions=predictions, references=batch["labels"])
    progress_bar_eval.update(1)


  print(scl_f1_metric.compute(average='micro'))

  0%|          | 0/12630 [00:00<?, ?it/s]

  0%|          | 0/165 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'f1': 0.9461009174311925}


In [None]:
scl_model.eval()

test_dataloader = DataLoader(
    tokenized_test_dataset, batch_size=classifier_batch, collate_fn=data_collator
)

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = scl_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    scl_f1_metric.add_batch(predictions=predictions, references=batch["labels"])

scl_f1_metric.compute(average='micro')

#LCL

In [None]:
lcl_rep_model = ElectraModel.from_pretrained(checkpoint).to(device)
weight_model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels).to(device)

clear_output()

In [None]:
def LCL(reps, w, labels):
  # t = 0.07
  loss = 0
  w = torch.nn.functional.softmax(w, dim=1)[range(labels.shape[0]), labels]
  # mean_reps = torch.mean(reps[:, 1:-1, :], axis=1)
  mean_reps = reps[:, 0, :]

  for i, rep in enumerate(mean_reps):
    # loss -= torch.log(logit[labels[i]])
    positives = [mean_reps[j].tolist() for j in range(len(labels)) if not j == i and labels[i] == labels[j]]
    p_w = [w[j].tolist() for j in range(len(labels)) if not j == i and labels[i] == labels[j]]

    if len(positives) != 0:
      positives = torch.tensor(positives).to(device)
      p_w = torch.tensor(p_w).to(device)
      others = torch.cat((mean_reps[0:i], mean_reps[i+1:]))
      o_w = torch.cat((w[0:i], w[i+1:]))

      others_sum = torch.sum(o_w * torch.exp((others @ rep) / t))
      pos_sum = torch.sum(p_w * torch.log(torch.exp((positives @ rep) / t) / others_sum))
      loss += -pos_sum/len(positives)

  return loss if loss != 0 else torch.tensor(0.0, dtype=torch.float32, device='cuda:0', requires_grad=True)


In [None]:
from transformers.modeling_outputs import TokenClassifierOutput
from torch import nn


class LCLModel(nn.Module):
  def __init__(self, rep_model, weight_model, num_labels):
    super(LCLModel,self).__init__()
    self.num_labels = num_labels

    #Load Model with given checkpoint and extract its body
    self.rep_model = rep_model
    self.weight_model = weight_model
    self.dropout = torch.nn.Dropout(0.1)
    self.weight_classifier = torch.nn.Linear(num_labels, num_labels)
    self.classifier = torch.nn.Linear(768,num_labels) # load and initialize weights
    self.layer = torch.nn.Linear(768,768)
  def forward(self, input_ids=None, attention_mask=None,labels=None):
    #Extract outputs from the body
    reps = self.rep_model(input_ids=input_ids, attention_mask=attention_mask)
    w = self.weight_model(input_ids=input_ids, attention_mask=attention_mask)

    lcl = LCL(reps[0], w[0], labels)
    #Add custom layers


    sequence_output = self.dropout(w[0]) #outputs[0]=last hidden state
    logits = self.weight_classifier(sequence_output.view(-1, self.num_labels)) # calculate losses
    Lw = None
    if labels is not None:
      loss_fct = torch.nn.CrossEntropyLoss()
      Lw = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))


    sequence_output = self.dropout(reps[0]) #outputs[0]=last hidden state
    x= self.layer(sequence_output[:,0,:].view(-1,768))
    act = nn.functional.relu(x)
    logits = self.classifier(act)
    # logits = self.classifier(sequence_output[:,0,:].view(-1,768)) # calculate losses
    Le = None
    if labels is not None:
      loss_fct = torch.nn.CrossEntropyLoss()
      Le = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

    return TokenClassifierOutput(loss=alpha*(Le + Lw) + (1-alpha) * lcl, logits=logits, hidden_states=reps.hidden_states,attentions=reps.attentions)

In [None]:
lcl_model = LCLModel(rep_model=lcl_rep_model, weight_model=weight_model,num_labels=5).to(device)


In [None]:
from transformers import AdamW,get_scheduler

optimizer = AdamW(lcl_model.parameters(), lr=1e-5)

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
print(num_training_steps)

In [None]:
from datasets import load_metric
f1_metric = load_metric("f1", average='micro')
acc_metric = load_metric("accuracy")


In [None]:
from tqdm.auto import tqdm

progress_bar_train = tqdm(range(num_training_steps))
progress_bar_eval = tqdm(range(num_epochs * len(eval_dataloader)))


for epoch in range(num_epochs):
  lcl_model.train()
  for batch in train_dataloader:
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = lcl_model(**batch)
      loss = outputs.loss
      loss.backward()

      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()
      progress_bar_train.update(1)

  lcl_model.eval()
  for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = lcl_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    f1_metric.add_batch(predictions=predictions, references=batch["labels"])
    # acc_metric.add_batch(predictions=predictions, references=batch["labels"])
    progress_bar_eval.update(1)

  print(f1_metric.compute(average='micro'))


In [None]:
lcl_model.eval()

test_dataloader = DataLoader(
    tokenized_test_dataset, batch_size=classifier_batch, collate_fn=data_collator
)

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = lcl_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    f1_metric.add_batch(predictions=predictions, references=batch["labels"])

f1_metric.compute(average='micro')