In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import RobertaForSequenceClassification, RobertaTokenizer, AdamW,get_linear_schedule_with_warmup
import pandas as pd
import numpy as np
import random
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

train_data = pd.read_csv('train_aug_select.csv')
validation_data = pd.read_csv('validation_aug_select.csv')
test_data = pd.read_csv('test.csv')

class AugmentationDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.sentences = data['sentence'].values
        self.labels_01 = data['aug_select_01'].values
        self.labels_02 = data['aug_select_02'].values
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = str(self.sentences[idx])
        inputs = self.tokenizer(sentence, padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt")
        input_ids = inputs['input_ids'].squeeze()
        attention_mask = inputs['attention_mask'].squeeze()
        label_01 = torch.tensor(self.labels_01[idx], dtype=torch.long)
        label_02 = torch.tensor(self.labels_02[idx], dtype=torch.long)
        return input_ids, attention_mask, label_01, label_02

class AugmentationSelectorModel(nn.Module):
    def __init__(self):
        super(AugmentationSelectorModel, self).__init__()
        self.roberta = RobertaForSequenceClassification.from_pretrained('roberta-large', num_labels=5)
        for name, param in self.roberta.named_parameters():
            if "encoder.layer." in name and int(name.split(".")[3]) < 20:
                param.requires_grad = False  

        for param in self.roberta.roberta.embeddings.parameters():
            param.requires_grad = False

        self.dropout = nn.Dropout(0.2)

        self.classifier_01 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 5)
        )

        self.classifier_02 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 5)
        )
        
    def forward(self, input_ids, attention_mask):
        outputs = self.roberta.roberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        pooled_output = torch.sum(sequence_output,dim=1)
        out_01 = self.classifier_01(pooled_output)
        out_02 = self.classifier_02(pooled_output)
        return out_01, out_02

tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
train_dataset = AugmentationDataset(train_data, tokenizer)
validation_dataset = AugmentationDataset(validation_data, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = AugmentationSelectorModel().to(device)  

for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=4e-5, correct_bias=False, weight_decay=1e-2)

total_steps = len(train_loader) * 30  

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps = int(0.1 * total_steps),
    num_training_steps=total_steps
)

def train_model(model, train_loader, validation_loader, epochs=30):
    best_val_loss = np.inf

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct_01 = 0
        correct_02 = 0
        total = 0

        for input_ids, attention_mask, label_01, label_02 in tqdm(train_loader):
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            label_01, label_02 = label_01.to(device), label_02.to(device)

            optimizer.zero_grad()

            out_01, out_02 = model(input_ids, attention_mask)
            loss_01 = criterion(out_01, label_01)
            loss_02 = criterion(out_02, label_02)

            loss = loss_01 + loss_02  
            loss.backward()  
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 
            optimizer.step()  

            train_loss += loss.item()
            preds_01 = torch.argmax(out_01, dim=1)
            preds_02 = torch.argmax(out_02, dim=1)
            correct_01 += (preds_01 == label_01).sum().item()
            correct_02 += (preds_02 == label_02).sum().item()
            total += label_01.size(0)
        
        scheduler.step()
        
        train_loss /= len(train_loader)
        train_accuracy_01 = correct_01 / total
        train_accuracy_02 = correct_02 / total
        train_accuracy = (train_accuracy_01 + train_accuracy_02) / 2
       
        model.eval()
        val_loss = 0
        val_correct_01 = 0
        val_correct_02 = 0
        val_total = 0

        with torch.no_grad():
            for input_ids, attention_mask, label_01, label_02 in validation_loader:
                input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
                label_01, label_02 = label_01.to(device), label_02.to(device)

                out_01, out_02 = model(input_ids, attention_mask)
                loss_01 = criterion(out_01, label_01)
                loss_02 = criterion(out_02, label_02)
                val_loss += (loss_01 + loss_02).item()

                val_correct_01 += (torch.argmax(out_01, dim=1) == label_01).sum().item()
                val_correct_02 += (torch.argmax(out_02, dim=1) == label_02).sum().item()
                val_total += label_01.size(0)

        val_loss /= len(validation_loader)
        val_accuracy_01 = val_correct_01 / val_total
        val_accuracy_02 = val_correct_02 / val_total
        val_accuracy = (val_accuracy_01 + val_accuracy_02) / 2
        
        print(f'Epoch {epoch + 1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
              f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'Test_aug_select_model.pth')
           
train_model(model, train_loader, validation_loader)

model.load_state_dict(torch.load('Test_aug_select_model.pth'))
model.eval()

test_sentences = test_data['sentence'].values
test_labels = test_data['label'].values

predictions = []

with torch.no_grad():
    for i, sentence in enumerate(test_sentences):
        inputs = tokenizer(sentence, padding='max_length', truncation=True, max_length=128, return_tensors="pt")
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)

        out_01, out_02 = model(input_ids, attention_mask)
        pred_01 = torch.argmax(out_01, dim=1).item()
        pred_02 = torch.argmax(out_02, dim=1).item()

        predictions.append([sentence, test_labels[i], pred_01, pred_02])

output_df = pd.DataFrame(predictions, columns=['sentence', 'label', 'aug_select_01', 'aug_select_02'])
output_df.to_csv('test_aug_pre_M.csv', index=False)
print("Predictions saved to 'test_aug_pre_M.csv'")


Using device: cuda


Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.bias', 'classi

roberta.roberta.embeddings.word_embeddings.weight: False
roberta.roberta.embeddings.position_embeddings.weight: False
roberta.roberta.embeddings.token_type_embeddings.weight: False
roberta.roberta.embeddings.LayerNorm.weight: False
roberta.roberta.embeddings.LayerNorm.bias: False
roberta.roberta.encoder.layer.0.attention.self.query.weight: False
roberta.roberta.encoder.layer.0.attention.self.query.bias: False
roberta.roberta.encoder.layer.0.attention.self.key.weight: False
roberta.roberta.encoder.layer.0.attention.self.key.bias: False
roberta.roberta.encoder.layer.0.attention.self.value.weight: False
roberta.roberta.encoder.layer.0.attention.self.value.bias: False
roberta.roberta.encoder.layer.0.attention.output.dense.weight: False
roberta.roberta.encoder.layer.0.attention.output.dense.bias: False
roberta.roberta.encoder.layer.0.attention.output.LayerNorm.weight: False
roberta.roberta.encoder.layer.0.attention.output.LayerNorm.bias: False
roberta.roberta.encoder.layer.0.intermediate.de

100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.36it/s]


Epoch 1, Train Loss: 94.1971, Train Accuracy: 0.2316, Validation Loss: 88.7898, Validation Accuracy: 0.2044


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:52<00:00,  2.38it/s]


Epoch 2, Train Loss: 87.4786, Train Accuracy: 0.2310, Validation Loss: 73.9821, Validation Accuracy: 0.2071


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:52<00:00,  2.37it/s]


Epoch 3, Train Loss: 67.9198, Train Accuracy: 0.2313, Validation Loss: 39.5352, Validation Accuracy: 0.2157


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:52<00:00,  2.37it/s]


Epoch 4, Train Loss: 44.6766, Train Accuracy: 0.2124, Validation Loss: 28.1097, Validation Accuracy: 0.2175


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.36it/s]


Epoch 5, Train Loss: 25.7151, Train Accuracy: 0.2134, Validation Loss: 30.4442, Validation Accuracy: 0.2248


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 6, Train Loss: 19.9565, Train Accuracy: 0.2151, Validation Loss: 26.0718, Validation Accuracy: 0.2262


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 7, Train Loss: 18.0775, Train Accuracy: 0.2181, Validation Loss: 24.5068, Validation Accuracy: 0.2348


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 8, Train Loss: 16.2598, Train Accuracy: 0.2223, Validation Loss: 21.4523, Validation Accuracy: 0.2389


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 9, Train Loss: 14.4306, Train Accuracy: 0.2206, Validation Loss: 19.4116, Validation Accuracy: 0.2257


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 10, Train Loss: 13.0039, Train Accuracy: 0.2169, Validation Loss: 18.0945, Validation Accuracy: 0.2189


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 11, Train Loss: 11.4000, Train Accuracy: 0.2145, Validation Loss: 14.9037, Validation Accuracy: 0.2325


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 12, Train Loss: 10.2854, Train Accuracy: 0.2157, Validation Loss: 12.3314, Validation Accuracy: 0.2225


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 13, Train Loss: 9.1709, Train Accuracy: 0.2173, Validation Loss: 9.9823, Validation Accuracy: 0.2243


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 14, Train Loss: 7.9129, Train Accuracy: 0.2185, Validation Loss: 8.8484, Validation Accuracy: 0.2216


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.35it/s]


Epoch 15, Train Loss: 6.8746, Train Accuracy: 0.2197, Validation Loss: 7.3304, Validation Accuracy: 0.2171


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.34it/s]


Epoch 16, Train Loss: 6.0404, Train Accuracy: 0.2238, Validation Loss: 6.0026, Validation Accuracy: 0.2121


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.34it/s]


Epoch 17, Train Loss: 5.2054, Train Accuracy: 0.2166, Validation Loss: 4.9086, Validation Accuracy: 0.2125


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 18, Train Loss: 4.5366, Train Accuracy: 0.2257, Validation Loss: 4.3788, Validation Accuracy: 0.2180


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:53<00:00,  2.34it/s]


Epoch 19, Train Loss: 4.0635, Train Accuracy: 0.2235, Validation Loss: 3.8139, Validation Accuracy: 0.2302


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 20, Train Loss: 3.7357, Train Accuracy: 0.2251, Validation Loss: 3.5735, Validation Accuracy: 0.2480


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 21, Train Loss: 3.5081, Train Accuracy: 0.2330, Validation Loss: 3.3694, Validation Accuracy: 0.2330


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 22, Train Loss: 3.3983, Train Accuracy: 0.2302, Validation Loss: 3.3351, Validation Accuracy: 0.2284


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 23, Train Loss: 3.3211, Train Accuracy: 0.2353, Validation Loss: 3.2702, Validation Accuracy: 0.2262


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 24, Train Loss: 3.2522, Train Accuracy: 0.2405, Validation Loss: 3.2450, Validation Accuracy: 0.2257


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.34it/s]


Epoch 25, Train Loss: 3.2302, Train Accuracy: 0.2383, Validation Loss: 3.2196, Validation Accuracy: 0.2275


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 26, Train Loss: 3.2182, Train Accuracy: 0.2412, Validation Loss: 3.2219, Validation Accuracy: 0.2262


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.32it/s]


Epoch 27, Train Loss: 3.2121, Train Accuracy: 0.2377, Validation Loss: 3.1941, Validation Accuracy: 0.2361


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 28, Train Loss: 3.2134, Train Accuracy: 0.2445, Validation Loss: 3.1984, Validation Accuracy: 0.2271


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 29, Train Loss: 3.1940, Train Accuracy: 0.2408, Validation Loss: 3.2171, Validation Accuracy: 0.2325


100%|████████████████████████████████████████████████████████████████████████████████| 267/267 [01:54<00:00,  2.33it/s]


Epoch 30, Train Loss: 3.1870, Train Accuracy: 0.2446, Validation Loss: 3.1744, Validation Accuracy: 0.2280
Predictions saved to 'test_aug_pre_M.csv'
Classifier 01 Accuracy: 0.2158
Classifier 02 Accuracy: 0.2502
Average Accuracy: 0.2330
