## Validate GPU is available for use

In [None]:
!nvidia-smi

In [None]:
import torch
torch.cuda.is_available()

## Install necessary packages

In [None]:
# ! pip install -U adapter-transformers

In [None]:
# ! pip install datasets

## Load and inspect data

In [None]:
import datasets
import torch
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def get_dataset(dataset):
    ds = datasets.load_dataset('glue', dataset)
    num_classes = ds['train'].features['label'].num_classes
    return ds, num_classes


def create_dataset_from_text_dataset(ds, tokenizer):
    encoding = tokenizer(ds['sentence'], return_tensors='pt', padding=True, truncation=True)
    input_ids = encoding['input_ids']
    attn_masks = encoding['attention_mask']
    labels = torch.tensor(ds['label'])
    return TensorDataset(input_ids, attn_masks, labels)


def get_tensor_datasets(dataset_dict, splits, tokenizer):
    split_datasets = {}
    for s in splits:
        split_datasets[s] = create_dataset_from_text_dataset(dataset_dict[s], tokenizer)
    return split_datasets


def get_data_loaders(split_datasets):
    train_loader = DataLoader(split_datasets['train'], batch_size=256, shuffle=True, num_workers=0)
    val_loader = DataLoader(split_datasets['validation'], batch_size=256, shuffle=False, num_workers=0)
    return train_loader, val_loader

In [None]:
# Load dataset
dataset = 'sst2'
print(f'Loading {dataset} dataset...')
dataset_dict, num_classes = get_dataset(dataset)

In [None]:
dataset_dict

In [None]:
# import datasets
# sst2 = datasets.load_dataset('glue', 'sst2')
# sst2

In [None]:
# for i in range(5):
#     print(sst2['train'][i])

In [None]:
# sst2['train'].features['label'].num_classes

## Load Tokenizer

In [None]:
import torch
import transformers
from transformers import AdapterType
from transformers import BertTokenizerFast, BertForSequenceClassification


def get_tokenizer(model_name):
    if model_name == 'bert-base-uncased':
        tokenizer = BertTokenizerFast.from_pretrained(model_name)
    else:
        raise NotImplementedError

    return tokenizer


def get_transformer(model_name, num_labels, adapter, dataset):
    if model_name == 'bert-base-uncased':
        model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        if adapter:
            model.add_adapter(dataset, AdapterType.text_task)
            model.train_adapter(dataset)
    else:
        raise NotImplementedError

    return model


def get_criterion(num_labels):
    if num_labels == 2:
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise NotImplementedError

    return criterion

In [None]:
# Load tokenizer
model_name = 'bert-base-uncased'
print(f'Loading tokenizer for {model_name}...')
tokenizer = get_tokenizer(model_name)

In [None]:
# import transformers
# from transformers import BertTokenizerFast, BertForSequenceClassification

# model_name = 'bert-base-uncased'
# tokenizer = BertTokenizerFast.from_pretrained(model_name)

## Create data loader for various splits

In [None]:
# Create data loader for each split
splits = list(dataset_dict.keys())
print(f'Creating data loader for {splits} splits...')
split_datasets = get_tensor_datasets(dataset_dict, splits, tokenizer)
train_loader, val_loader = get_data_loaders(split_datasets)

In [None]:
for i_batch, sample_batched in enumerate(train_loader):
    print(i_batch, sample_batched[2].size())
    
    if i_batch == 0:
        print(sample_batched[2])
        sb = sample_batched[2].to('cuda')
        print(sb)
        break

In [None]:
# from torch.utils.data import DataLoader, TensorDataset

# def create_dataset_from_text_dataset(ds):
#     encoding = tokenizer(ds['sentence'], return_tensors='pt', padding=True, truncation=True)
#     input_ids = encoding['input_ids']
#     attn_masks = encoding['attention_mask']
#     labels = torch.tensor(ds['label'])
    
#     return TensorDataset(input_ids, attn_masks, labels)

# splits = ['train',  'validation', 'test']
# split_datasets = {}

# for s in splits:
#     split_datasets[s] = create_dataset_from_text_dataset(sst2[s])

# split_datasets

In [None]:
# Validate data loader
# sample_loader = DataLoader(split_datasets['train'], batch_size=3, shuffle=True)
# for i in sample_loader:
#     input_ids, attn_masks, labels = i
#     decoded = tokenizer.batch_decode(input_ids)
#     for d in decoded:
#         print(d)
#     break

## Create model

In [None]:
# Load model
adapter = False
print(f'Loading {model_name} with adapters={adapter}...')
model = get_transformer(model_name,
                        num_labels=num_classes,
                        adapter=adapter,
                        dataset=dataset)
criterion = get_criterion(num_labels=num_classes)

In [None]:
for name, child in model.named_children():
    print(name)

In [None]:
# def create_model(add_adapters=False):
#     model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
#     if add_adapters:
#         from transformers import AdapterType
#         model.add_adapter("sst2", AdapterType.text_task)
#         model.train_adapter("sst2")
#     return model

In [None]:
# train_loader = DataLoader(split_datasets['train'], batch_size=16, shuffle=True)
# val_loader = DataLoader(split_datasets['validation'], batch_size=128, shuffle=False)

In [None]:
# from transformers import logging
# logging.set_verbosity_warning()

# model = create_model(add_adapters=True)

## Get Learning Scheme

In [None]:
import torch


def get_learning_scheme(learning_scheme, model):
    if learning_scheme == 'differential':
        optimizer_grouped_parameters = differential_learning_scheme(model)
        optimizer = torch.optim.SGD(optimizer_grouped_parameters)
    else:
        raise NotImplementedError

    return optimizer


def differential_learning_scheme(model, learning_rate=0.1, divisor=2.6):
    param_prefixes = {}
    for n, p in model.named_parameters():
        base = n.partition('.weight')[0].partition('.bias')[0]
        if base not in param_prefixes:
            param_prefixes[base] = 0

    param_prefix_divisors = list(reversed([divisor * i for i in range(1, len(param_prefixes))])) + [1]
    param_learning_rates = [learning_rate / ld for ld in param_prefix_divisors]

    param_prefix_lr_lookup = dict(zip(param_prefixes.keys(), param_learning_rates))

    optimizer_grouped_parameters = [
        {'params': p, 'lr': param_prefix_lr_lookup[n.partition('.weight')[0].partition('.bias')[0]]}
        for n, p in model.named_parameters()
    ]

    return optimizer_grouped_parameters

In [None]:
# Get learning scheme
learning_scheme = 'differential'
print(f'Configuring {learning_scheme} learning scheme...')
optimizer = get_learning_scheme(learning_scheme, model)

In [None]:
# prefixes = {}
# for n, p in model.named_parameters():
#     base = n.partition('.weight')[0].partition('.bias')[0]
#     if base not in prefixes:
#         prefixes[base] = 0

In [None]:
# BASE_LR = 0.1
# BASE_DIVISOR = 2.6

# prefix_divisors = list(reversed([BASE_DIVISOR * i for i in range(1, len(prefixes))])) + [1]
# layer_learning_rates = [BASE_LR / ld for ld in prefix_divisors]

# prefix_lr_lookup = dict(zip(prefixes.keys(), layer_learning_rates))

In [None]:
# optimizer_grouped_parameters = [
#     {'params': p, 'lr': prefix_lr_lookup[n.partition('.weight')[0].partition('.bias')[0]]}
#     for n, p in model.named_parameters()
# ]

In [None]:
# optimizer = torch.optim.SGD(optimizer_grouped_parameters)

## Train

In [None]:
import time
import torch
import torch.nn.functional as F

In [None]:
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print('Using device:', device)
model = model.to(device)
n_epochs = 10
optimizer = optimizer
scheduler = None
criterion = criterion.to(device)

N_MINI_BATCH_CHECK = 5

def measure_performance(loader):
    running_loss = 0.0
    correct_count = 0.0
    total_count = 0.0
    for data in loader:
        input_ids = data[0].to(device)
        attn_masks = data[1].to(device)
        labels = data[2].to(device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attn_masks)[0]
            loss = criterion(outputs, labels)
            probas = F.softmax(outputs, dim=1)
            preds = torch.argmax(probas, axis=1)

            # Track stats
            running_loss += loss
            correct_count += torch.sum(preds == labels)
            total_count += len(labels)

    running_loss /= len(loader)
    acc = correct_count / total_count

    return running_loss, acc

In [None]:
if val_loader:
    print('Initial evaluating on validation dataset')
    val_loss, val_acc = measure_performance(val_loader)
    epoch_summary = f'[Epoch 0] | Val acc: {val_acc:.4f} Val loss: {val_loss:.4f}\n\n'
    print(epoch_summary)

In [None]:
for epoch in range(n_epochs):
    print(f'--- Epoch: {epoch} ---')
    epoch_start_time = time.time()
    batch_start_time = time.time()
    running_loss = 0.0

    for i, data in enumerate(train_loader):
        input_ids = data[0].to(device)
        attn_masks = data[1].to(device)
        labels = data[2].to(device)

        optimizer.zero_grad()

        # Evaluation/optimization step
        outputs = model(input_ids=input_ids, attention_mask=attn_masks)[0]
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if scheduler:
            scheduler.step()

        # Print statistics periodically
        running_loss += loss.item()
        if i % N_MINI_BATCH_CHECK == N_MINI_BATCH_CHECK - 1:
            batch_end_time = time.time()
            total_batch_time = batch_end_time - batch_start_time

            print(
                f'[E{epoch + 1:d} B{i + 1:d}] ',
                f'Loss: {running_loss / N_MINI_BATCH_CHECK:.5f} ',
                f'Time: {total_batch_time:.2f} ',
                f'LR: {scheduler.get_last_lr()}' if scheduler else '')

            # Reset statistics
            batch_start_time = time.time()
            running_loss = 0.0

    epoch_end_time = time.time()
    total_epoch_time = epoch_end_time - epoch_start_time
    epoch_summary = '[Epoch {}] {} seconds'.format((epoch + 1), total_epoch_time)

    if val_loader:
        val_loss, val_acc = measure_performance(val_loader)
        epoch_summary += f' | Val acc: {val_acc:.4f} | Val loss: {val_loss:.4f}'

    print(epoch_summary)

print('Finished training')

In [None]:
# import time
# import torch.nn.functional as F
# import copy

# class Trainer:
#     def __init__(self, model, n_epochs, optimizer):
#         self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#         self.model = model.to(self.device)
#         self.n_epochs = n_epochs
#         self.optimizer = optimizer
#         self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

# #         no_decay = ['bias', 'LayerNorm.weight']
# #         optimizer_grouped_parameters = [
# #             {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
# #             {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# #         ]
# #         self.optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=1e-3, momentum=0.9)
#         # self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer)
#         # self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=10000, T_mult=2)        

#     def measure_performance(self, loader):
#         running_loss = 0.0
#         correct_count = 0
#         total_count = 0
#         for data in loader:
#             input_ids = data[0].to(self.device)
#             attn_masks = data[1].to(self.device)
#             labels = data[2].to(self.device)
#             with torch.no_grad():
#                 outputs = self.model(input_ids=input_ids, attention_mask=attn_masks)[0]
#                 loss = self.criterion(outputs, labels)
#                 probas = F.softmax(outputs, dim=1)
#                 preds = torch.argmax(probas, axis=1)
                
#                 # Track stats
#                 running_loss += loss
#                 correct_count += torch.sum(preds == labels) 
#                 total_count += len(labels) 
        
#         running_loss /= len(loader)
#         acc = correct_count / total_count

#         return running_loss, acc

#     def train_loop(self, train_loader, val_loader=None):
#         print('Starting training loop\n\n')
#         N_MINI_BATCH_CHECK = 200

#         if val_loader:
#             print('Initial evaluating on validation dataset')
#             val_loss, val_acc = self.measure_performance(val_loader)
#             epoch_summary = f'[Epoch 0] | Val acc: {val_acc:.4f} Val loss: {val_loss:.4f}\n\n'
#             print(epoch_summary)

#         for epoch in range(self.n_epochs):
#             print(f'--- Epoch: {epoch} ---')
#             epoch_start_time = time.time()
#             batch_start_time = time.time()
#             running_loss = 0.0

#             for i, data in enumerate(train_loader):
#                 input_ids = data[0].to(self.device)
#                 attn_masks = data[1].to(self.device)
#                 labels = data[2].to(self.device)

#                 self.optimizer.zero_grad()

#                 # Evaluation/optimization step
#                 outputs = self.model(input_ids=input_ids, attention_mask=attn_masks)[0]
#                 loss = self.criterion(outputs, labels)
#                 loss.backward()
#                 self.optimizer.step()
#                 if self.scheduler:
#                     self.scheduler.step()
                
#                 # Print statistics periodically
#                 running_loss += loss.item()
#                 if i % N_MINI_BATCH_CHECK == N_MINI_BATCH_CHECK - 1:
#                     batch_end_time = time.time()
#                     total_batch_time = batch_end_time - batch_start_time

#                     print(
#                         f'[E{epoch + 1:d} B{i + 1:d}] ',
#                         f'Loss: {running_loss / N_MINI_BATCH_CHECK:.5f} ',
#                         f'Time: {total_batch_time:.2f} ',
#                         f'LR: {self.scheduler.get_last_lr()}' if self.scheduler else '')

#                     # Reset statistics
#                     batch_start_time = time.time()
#                     running_loss = 0.0

#             epoch_end_time = time.time()
#             total_epoch_time = epoch_end_time - epoch_start_time
#             epoch_summary = '[Epoch {}] {} seconds'.format((epoch + 1), total_epoch_time)
            
#             if val_loader:
#                 val_loss, val_acc = self.measure_performance(val_loader)
#                 epoch_summary += f' | Val acc: {val_acc:.4f} | Val loss: {loss:.4f}'

#             print(epoch_summary)

#         print('Finished training')

#     def lr_test(self, train_loader, lrs=(-9, 2)):
#         """
#         lrs = (min_lr, max_lr, factor_scale)
#         """

#         # Prepare LR-finder loop
#         model = copy.deepcopy(self.model).to(self.device)
#         min_lr, max_lr = lrs
#         lrs = np.logspace(min_lr, max_lr, num=len(train_loader), endpoint=True)
#         losses = []
#         for i, data in enumerate(iter(train_loader)):
#             curr_lr = lrs[i]
#             optimizer = torch.optim.SGD(model.parameters(), lr=curr_lr)

#             input_ids = data[0].to(self.device)
#             attn_masks = data[1].to(self.device)
#             labels = data[2].to(self.device)

#             # Evaluation/optimization step
#             outputs = model(input_ids=input_ids, attention_mask=attn_masks)[0]
#             loss = self.criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()
#             losses.append(loss)

#             if i % 100 == 0:
#                 print(f'Step [{i}, {len(train_loader)}] | LR: {curr_lr:.4e} | Loss: {loss:.4f}')

#         return losses

In [None]:
# trainer = Trainer(model, 10, optimizer)

In [None]:
# trainer.train_loop(train_loader, val_loader)