In [None]:
!pip3 install datasets sentencepiece peft -q
!pip install git+https://github.com/huggingface/transformers.git -qq
!pip install torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu -q
!pip install torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT
!cp /kaggle/input/utils-xla/spmd_util.py . # From this repo: https://github.com/HeegyuKim/torch-xla-SPMD

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
)

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model
from spmd_util import partition_module
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.metrics import roc_auc_score

!export USE_TORCH=True
os.environ["PJRT_DEVICE"] = "TPU"
os.environ.pop('TPU_PROCESS_ADDRESSES')
hf_logging.set_verbosity_error()


MAX_INPUT=512
MODEL = "/kaggle/input/llama-3-1-8b-instruct"

In [None]:
df_misconception_mapping = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/misconception_mapping.csv")

In [None]:
external_df = pd.read_csv('/kaggle/input/eedi-external-dataset/train_external.csv')
external_df

In [None]:
import os
from transformers import AutoTokenizer
import pandas as pd

train_df = pd.read_csv("/kaggle/input/eedi-external-dataset/all_train.csv").fillna(-1)
test_ds = pd.read_csv("/kaggle/input/eedi-mining-misconceptions-in-mathematics/test.csv")
tokenizer = AutoTokenizer.from_pretrained("/kaggle/input/llama-3-1-8b-instruct")

PROMPT  = """Question: {Question}
Incorrect Answer: {IncorrectAnswer}
Correct Answer: {CorrectAnswer}
Construct Name: {ConstructName}
Subject Name: {SubjectName}

Your task is to identify the misconception behind Incorrect Answer. Answer concisely and generically. Output misconception only.
"""

def apply_template(row, tokenizer, targetCol):
    messages = [
        {
            "role": "user", 
            "content": PROMPT.format(
                 ConstructName=row["ConstructName"],
                 SubjectName=row["SubjectName"],
                 Question=row["QuestionText"],
                 IncorrectAnswer=row[f"Answer{targetCol}Text"],
                 CorrectAnswer=row[f"Answer{row.CorrectAnswer}Text"])
        },
        {
            "role": "assistant",
            "content": row[f"Misconception{targetCol}Name"]
        }
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text

df = {}
df_label = {}
for idx, row in tqdm(train_df.iterrows()):
    for option in ["A", "B", "C", "D"]:
        try:
            if (row.CorrectAnswer!=option) & (row[f"Misconception{option}Name"]!=-1):
                df[f"{row.QuestionId}_{option}"] = apply_template(row, tokenizer, option)
                df_label[f"{row.QuestionId}_{option}"] = [row[f"Misconception{option}Name"]]
        except Exception as e:
            pass

df_label = pd.DataFrame([df_label]).T.reset_index()
df_label.columns = ["QuestionId_Answer", "MisconceptionName"]

df = pd.DataFrame([df]).T.reset_index()
df.columns = ["QuestionId_Answer", "text"]

In [None]:
lengths = []
for sample in df['text']:
    lengths.append(len(sample.split(' ')))

max(lengths), min(lengths), sum(lengths) / len(lengths)

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def preprocess_function(examples):
    return tokenizer(examples['text'], max_length=512, padding='max_length', truncation=True) # It's slightly excessive

In [None]:
ds = Dataset.from_pandas(df).train_test_split(test_size=0.15)

ds['train'] = ds['train'].map(preprocess_function, batched=False, num_proc=96, remove_columns=['text', 'QuestionId_Answer'])
ds['test'] = ds['test'].map(preprocess_function, batched=False, num_proc=96, remove_columns=['text', 'QuestionId_Answer'])
ds

In [None]:
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16)

In [None]:
FLAGS = {'MAX_INPUT': 512,
         'LOGGING_STEPS': 1,
         'NUM_EPOCHS': 1,
         'BATCH_SIZE': 4,
          'NUM_STEPS': len(ds['train'])} 

In [None]:
from transformers import DataCollatorForLanguageModeling
training_loader = torch.utils.data.DataLoader(ds['train'], batch_size=FLAGS['BATCH_SIZE'], collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=128))
testing_loader = torch.utils.data.DataLoader(ds['test'], batch_size=FLAGS['BATCH_SIZE'], collate_fn=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=128))

device = xm.xla_device()

In [None]:
config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh)

In [None]:
cnt = 0
for param in model.parameters():
    cnt += 1
    param.requires_grad = False
    if cnt > 0:
        param.requires_grad = True

In [None]:
!export XLA_USE_BF16=1

def train_loop(training_loader, epoch, optimizer, scheduler, num_actual_steps):
    model.train()
    print('Epoch {} train begin {} for {} steps'.format(epoch, test_utils.now(), num_actual_steps))
    for step, batch in tqdm(enumerate(training_loader), total=num_actual_steps):
        optimizer.zero_grad()
        input_ids, attention_mask, labels = batch.input_ids.to(device), batch.attention_mask.to(device), batch.labels.to(device)

        xs.mark_sharding(input_ids, mesh, (0, 1))
        xs.mark_sharding(attention_mask, mesh, (0, 1))
        xs.mark_sharding(labels, mesh, (0, 1))

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        xm.mark_step()
        scheduler.step()
    print('Epoch {} train end {}'.format(epoch, test_utils.now()))
        
def eval_loop(testing_loader, epoch):
    model.eval()
    total_loss = 0.0
    total_steps = 0
    with torch.no_grad():
        for step, batch in enumerate(testing_loader):
            input_ids, attention_mask, labels = batch.input_ids.to(device), batch.attention_mask.to(device), batch.labels.to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1))
            xs.mark_sharding(attention_mask, mesh, (0, 1))
            xs.mark_sharding(labels, mesh, (0, 1))
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            total_loss += loss.item()
            total_steps += 1
    average_loss = total_loss / total_steps
    print('Epoch {} test end {}, TEST LOSS={:.2f}'.format(epoch, test_utils.now(), average_loss))
    
def save_model(model, tokenizer, dir_name):
    model = model.cpu()
    model.save_pretrained(dir_name)
    tokenizer.save_pretrained(dir_name)

def train(FLAGS):
    num_actual_steps = FLAGS['NUM_STEPS'] // FLAGS['BATCH_SIZE']
    lr = 1e-5
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters = num_actual_steps)
    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):

        eval_loop(testing_loader, epoch)
        train_loop(training_loader, epoch, optimizer, scheduler, num_actual_steps)

    eval_loop(testing_loader, epoch)
    save_model(model, tokenizer, 'trained_model')

In [None]:
train(FLAGS)