In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import tqdm
import inspect
import logging

from models.teacher import Teacher
from models.configuration_teacher import TeacherConfig
from data import CoTDataset, CoTDataCollator, extract_answer

from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def save_model(model, tokenizer, model_dir):
    print ('saving', model_dir)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, max_new_tokens):
    teacher.eval()
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        labels = batch['labels_all'].to(device)
        # Remove answer part
        sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
        input_ids = input_ids_all[:, :sep_positions.max()+1]
        batch_size = input_ids.shape[0]
        with ctx:
            outputs = teacher.compute_loss(input_ids=input_ids_all, labels=labels)
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        beam_output = teacher.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
        )
        # Evaluate
        #import pdb; pdb.set_trace()
        for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_all_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl


    



In [None]:
teacher_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    save_model="../train_models/4_by_4_mult/gpt2/teacher",
    max_new_tokens=128,
    base_model='gpt2',
    epochs=2,
    batch_size=8,
    lr=5e-5,
    max_grad_norm=1.0,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**teacher_trainer_args)

# parser = argparse.ArgumentParser()
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=1)
# parser.add_argument('--batch_size', type=int, default=5e-5)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# args = parser.parse_args()

In [None]:
print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Student 
config = TeacherConfig(base_model=args.base_model)
teacher = Teacher(config).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(teacher.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)



In [None]:
teacher.train()

# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")
    teacher.train()
    for batch in tqdm.tqdm(train_dataloader):
        input_ids = batch['input_ids_all'].to(device)
        labels = batch['labels_all'].to(device)
        with ctx:
            outputs = teacher.compute_loss(input_ids=input_ids, labels=labels)
#         break
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
        step += 1
#     break
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    teacher.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

In [None]:
# print(tokenizer(" ,").input_ids[0])

In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import sys
import inspect
import tqdm
import logging
import random

from data import CoTDataset, CoTDataCollator, extract_answer
from models.teacher import Teacher
from models.student import Student
from models.configuration_student import StudentConfig
from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, student, delta, subset, max_new_tokens):
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        batch_size = input_ids_nocot.shape[0]
        with ctx:
            teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=delta, subset=subset)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states)
            loss = outputs.loss
            token_accuracy = outputs.token_accuracy.item()
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        with ctx:
            beam_output = student.generate(
                input_ids=input_ids_nocot,
                teacher_states=teacher_states,
                max_new_tokens=max_new_tokens,
            )

        # Evaluate
        sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
        for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_all_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl


    

In [None]:
mind_reading_student_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/student_initial",
    base_model='gpt2',
    epochs=10,
    batch_size=8,
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    subset='diagonal',
    
)

from types import SimpleNamespace

args = SimpleNamespace(**mind_reading_student_trainer_args)

In [None]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--teacher', type=str, required=True)
# parser.add_argument('--delta', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal')
# args = parser.parse_args()

# print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Student
config = StudentConfig(base_model=args.base_model)
student = Student(config).to(device).to(ptdtype)

# Load Teacher
teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(student.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

teacher.eval()
student.eval() # to turn off dropout

for p in teacher.parameters():
    p.requires_grad = False




In [None]:
# from utils import get_sep_position

# for batch in tqdm.tqdm(val_dataloader):
#         input_ids_all = batch['input_ids_all'].to(device)
#         first_sep_positions = get_sep_position(input_ids_all, teacher.tokenizer.eos_token_id, skip=0)
#         second_sep_positions = get_sep_position(input_ids_all, teacher.tokenizer.eos_token_id, skip=1)
#         print(compute_positions_to_extract_per_layer('diagonal', 'dynamic', first_sep_positions, second_sep_positions)[:,2].view(-1, 1, 1).expand(-1, -1, teacher.hidden_size).shape)
#         input_ids_all = input_ids_all[:, :second_sep_positions.max()+1]
#         print(input_ids_all.shape)
#         input_ids_nocot = batch['input_ids_nocot'].to(device)
#         labels_nocot = batch['labels_nocot'].to(device)
#         with ctx:
#             with torch.no_grad():
#                 teacher_states = teacher.base_model(input_ids=input_ids_all, output_hidden_states=True)
#                 hidden_states = teacher_states.hidden_states[:-1]
#                 for i, hidden_state in enumerate(hidden_states):
#                     print(hidden_state.shape) # torch.Size([16, 59, 768])
#                     z = hidden_state.gather(1, compute_positions_to_extract_per_layer('diagonal', 'dynamic', first_sep_positions, second_sep_positions)[:,i].view(-1, 1, 1).expand(-1, -1, teacher.hidden_size)).squeeze(1)
#                     print(z.shape) # torch.Size([16, 768])
                    
#         break

In [None]:
# teacher_states.hidden_states[:-1][0].shape

In [None]:
# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        with ctx:
            with torch.no_grad():
                teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=args.delta, subset=args.subset)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states)
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
            sys.stdout.flush()
        step += 1
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, student, args.delta, args.subset, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    student.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import inspect
import tqdm
import logging
import random
import torch.nn as nn

from data import CoTDataset, CoTDataCollator
from models.teacher import Teacher
from models.emulator import Emulator
from models.configuration_emulator import EmulatorConfig


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, emulator, delta, subset):
    total_instances = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        #import pdb; pdb.set_trace()
        input_ids_cot = batch['input_ids_cot'].to(device)
        batch_size = input_ids_cot.shape[0]
        with ctx:
            teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=delta, subset=subset)
            outputs = emulator.compute_loss(input_ids=input_ids_cot, teacher_states=teacher_states)
            loss = outputs.loss
        total_loss += outputs.total_loss.item()
        total_instances += batch_size

    loss = total_loss / total_instances
    return loss


    

In [None]:
thought_emulator_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/emulator_initial",
    base_model='gpt2',
    epochs=10,
    batch_size=8,
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    subset='diagonal',
    mixture_size=1,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**thought_emulator_trainer_args)

In [None]:

# parser = argparse.ArgumentParser()
# parser.add_argument('--teacher', type=str, required=True)
# parser.add_argument('--delta', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal')
# parser.add_argument('--mixture_size', type=int, default=1)
# args = parser.parse_args()

# print (args)
dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Emulator
config = EmulatorConfig(base_model=args.base_model, mixture_size=args.mixture_size)
emulator = Emulator(config).to(device).to(ptdtype)

# Load Teacher
teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(emulator.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

teacher.eval()
emulator.eval() # to turn off dropout

for p in teacher.parameters():
    p.requires_grad = False

# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        #import pdb; pdb.set_trace()
        input_ids_cot = batch['input_ids_cot'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        with ctx:
            with torch.no_grad():
                teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=args.delta, subset=args.subset)
            outputs = emulator.compute_loss(input_ids=input_ids_nocot, teacher_states=teacher_states)
        loss = outputs.loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print (f"Step: {step}. Loss: {loss}.")
        step += 1
    loss = evaluate(val_dataloader, tokenizer, ctx, teacher, emulator, args.delta, args.subset)
    print (f'Val. Loss: {loss}.')
    emulator.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import inspect
import tqdm
import logging
import random
from itertools import chain

from data import CoTDataset, CoTDataCollator, extract_answer
from models.student import Student
from models.emulator import Emulator
from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens):
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        #import pdb; pdb.set_trace()
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        batch_size = input_ids_nocot.shape[0]
        with ctx:
            emulated_teacher_states = emulator(input_ids=input_ids_nocot)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states)
            loss = outputs.loss
            token_accuracy = outputs.token_accuracy.item()
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        with ctx:
            beam_output = student.generate(
                input_ids=input_ids_nocot,
                teacher_states=emulated_teacher_states,
                max_new_tokens=max_new_tokens,
            )

        # Evaluate
        sep_positions = get_sep_position(input_ids_nocot, tokenizer.eos_token_id)
        for i, (input_ids_i, beam_output_i) in enumerate(zip(input_ids_nocot, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl




In [None]:
coupled_emulator_student_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/emulator_initial",
    base_model='gpt2',
    epochs=10,
    batch_size=8,
    emulator="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    student="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    softmax_temperature=0.05,
    fix_emulator=False,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**coupled_emulator_student_trainer_args)

In [None]:

# parser = argparse.ArgumentParser()
# parser.add_argument('--emulator', type=str, required=True)
# parser.add_argument('--student', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--softmax_temperature', type=float, default=0.05)
# parser.add_argument('--fix_emulator', dest='fix_emulator', action='store_true')
# parser.set_defaults(fix_emulator=False)
# args = parser.parse_args()

# print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Load Student
student = Student.from_pretrained(args.student).to(device).to(ptdtype)

# Load Emulator
emulator = Emulator.from_pretrained(args.emulator).to(device).to(ptdtype)

# Load data
tokenizer = emulator.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
if args.fix_emulator:
    trainable_params = list(student.parameters())
    for p in emulator.parameters():
        p.requires_grad = False
else:
    trainable_params = list(student.parameters()) + list(emulator.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

emulator.eval() # to turn off dropout
student.eval() # to turn off dropout


# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        #import pdb; pdb.set_trace()
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        with ctx:
            emulated_teacher_states = emulator(input_ids_nocot, requires_backward=not args.fix_emulator)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states)
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
        step += 1
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    student.save_pretrained(os.path.join(args.save_model, 'student', f'checkpoint_{epoch}'))
    emulator.save_pretrained(os.path.join(args.save_model, 'emulator',  f'checkpoint_{epoch}'))